Build uploaded using `kernels` (batch 8/10).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h +111 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h +541 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h +591 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h +157 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h +38 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +472 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp +570 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp +341 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h +135 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h +94 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h +1549 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h +385 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h +350 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h +311 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp +146 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +162 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h +168 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h +159 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h +355 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h +250 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h +2075 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +142 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h +514 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h +141 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h +186 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp +782 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h +802 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h +66 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h +531 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h +210 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h +228 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp +916 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h +261 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h +318 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h +234 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h +285 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h +319 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h +616 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp +101 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h +256 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h +341 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h +1718 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp +432 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h +134 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h +42 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h +203 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp +203 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h +215 -0
- build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h +262 -0
.gitattributes
CHANGED
|
@@ -16,3 +16,4 @@ build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=
|
|
| 16 |
build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 17 |
build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 18 |
build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 16 |
build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 17 |
build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 18 |
build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief reorder data from the host side
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/coord.h"
|
| 39 |
+
#include "cutlass/util/host_tensor.h"
|
| 40 |
+
#include "cutlass/tensor_view.h"
|
| 41 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 42 |
+
#include "cutlass/util/reference/host/gemm.h"
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
|
| 46 |
+
/// This is needed for the interleaved integer tensor core kernels. The purpose
|
| 47 |
+
/// is to use skip the shared memory part in the epilogue.
|
| 48 |
+
template <int Interleaved, typename Element, typename Layout>
|
| 49 |
+
void reorder_column(TensorRef<Element, Layout> dest,
|
| 50 |
+
TensorRef<Element, Layout> src,
|
| 51 |
+
cutlass::gemm::GemmCoord problem_size) {
|
| 52 |
+
const int InstructionShapeCol = 8;
|
| 53 |
+
// 4 threads per Quad
|
| 54 |
+
const int ElementsPerThread = InstructionShapeCol / 4;
|
| 55 |
+
// 4 threads per Quad
|
| 56 |
+
const int ReorderedElementsPerThread =
|
| 57 |
+
Interleaved / 4;
|
| 58 |
+
|
| 59 |
+
for (int n = 0; n < problem_size.n(); n++) {
|
| 60 |
+
for (int k = 0; k < problem_size.k(); k++) {
|
| 61 |
+
dest.at({k, (n / Interleaved) * Interleaved +
|
| 62 |
+
((n % ReorderedElementsPerThread) / ElementsPerThread) *
|
| 63 |
+
InstructionShapeCol +
|
| 64 |
+
((n % Interleaved) / ReorderedElementsPerThread) *
|
| 65 |
+
ElementsPerThread +
|
| 66 |
+
(n % ElementsPerThread)}) = src.at({k, n});
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <int ColumnInterleaved, int LayoutInterleaved = ColumnInterleaved, typename Element, typename Layout>
|
| 72 |
+
void reorder_convK(TensorRef<Element, Layout> dest,
|
| 73 |
+
TensorRef<Element, Layout> src,
|
| 74 |
+
cutlass::gemm::GemmCoord problem_size) {
|
| 75 |
+
|
| 76 |
+
TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedDest(dest.data(), dest.stride(0));
|
| 77 |
+
TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedSrc(src.data(), src.stride(0));
|
| 78 |
+
|
| 79 |
+
reorder_column<ColumnInterleaved>(
|
| 80 |
+
mappedDest, mappedSrc, problem_size);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// This is needed for the sparse tensor core kernels. The purpose
|
| 84 |
+
/// is to use ldmatrix to load from shared memory to the register file.
|
| 85 |
+
template <typename Element, typename LayoutDest, typename LayoutSrc>
|
| 86 |
+
void reorder_meta(TensorRef<Element, LayoutDest> dest,
|
| 87 |
+
TensorRef<Element, LayoutSrc> src,
|
| 88 |
+
cutlass::gemm::GemmCoord problem_size) {
|
| 89 |
+
for (int m = 0; m < problem_size.m(); m++) {
|
| 90 |
+
for (int k = 0; k < problem_size.k(); k++) {
|
| 91 |
+
// First reorder the rows.
|
| 92 |
+
int group = (sizeof(Element) == 2) ? 32 : 16;
|
| 93 |
+
int interweave = (sizeof(Element) == 2) ? 4 : 2;
|
| 94 |
+
|
| 95 |
+
int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
|
| 96 |
+
int dest_col = k;
|
| 97 |
+
|
| 98 |
+
// Next swizzle the 2x2 blocks from Z to N.
|
| 99 |
+
if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
|
| 100 |
+
++dest_row;
|
| 101 |
+
--dest_col;
|
| 102 |
+
} else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
|
| 103 |
+
--dest_row;
|
| 104 |
+
++dest_col;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
dest.at({dest_row, dest_col}) = src.at({m, k});
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
/*! \file
|
| 34 |
+
\brief HostTensor contributes management for both host and device memory.
|
| 35 |
+
|
| 36 |
+
HostTensor allocates host and device memory upon construction. Basic element-wise operations on
|
| 37 |
+
host memory synchronize device memory automatically. Explicit copy operations provide abstractions
|
| 38 |
+
for CUDA memcpy operations.
|
| 39 |
+
|
| 40 |
+
Call {host, device}_{data, ref, view}() for accessing host or device memory.
|
| 41 |
+
|
| 42 |
+
See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
|
| 43 |
+
*/
|
| 44 |
+
|
| 45 |
+
#include <vector>
|
| 46 |
+
|
| 47 |
+
#include "cutlass/cutlass.h"
|
| 48 |
+
#include "cutlass/tensor_ref.h"
|
| 49 |
+
#include "cutlass/tensor_view.h"
|
| 50 |
+
#include "cutlass/fast_math.h"
|
| 51 |
+
|
| 52 |
+
#include "device_memory.h"
|
| 53 |
+
|
| 54 |
+
namespace cutlass {
|
| 55 |
+
|
| 56 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
/// Host tensor
|
| 59 |
+
template <
|
| 60 |
+
/// Data type of element stored within tensor (concept: NumericType)
|
| 61 |
+
typename Element_,
|
| 62 |
+
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
| 63 |
+
typename Layout_
|
| 64 |
+
>
|
| 65 |
+
class HostTensor {
|
| 66 |
+
public:
|
| 67 |
+
|
| 68 |
+
/// Data type of individual access
|
| 69 |
+
using Element = Element_;
|
| 70 |
+
|
| 71 |
+
/// Mapping function from logical coordinate to linear memory
|
| 72 |
+
using Layout = Layout_;
|
| 73 |
+
|
| 74 |
+
/// Logical rank of tensor index space
|
| 75 |
+
static int const kRank = Layout::kRank;
|
| 76 |
+
|
| 77 |
+
/// Index type
|
| 78 |
+
using Index = typename Layout::Index;
|
| 79 |
+
|
| 80 |
+
/// Long index used for pointer offsets
|
| 81 |
+
using LongIndex = typename Layout::LongIndex;
|
| 82 |
+
|
| 83 |
+
/// Coordinate in logical tensor space
|
| 84 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 85 |
+
|
| 86 |
+
/// Layout's stride vector
|
| 87 |
+
using Stride = typename Layout::Stride;
|
| 88 |
+
|
| 89 |
+
/// Tensor reference to device memory
|
| 90 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 91 |
+
|
| 92 |
+
/// Tensor reference to constant device memory
|
| 93 |
+
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
| 94 |
+
|
| 95 |
+
/// Tensor reference to device memory
|
| 96 |
+
using TensorView = TensorView<Element, Layout>;
|
| 97 |
+
|
| 98 |
+
/// Tensor reference to constant device memory
|
| 99 |
+
using ConstTensorView = typename TensorView::ConstTensorView;
|
| 100 |
+
|
| 101 |
+
/// Reference to element in tensor
|
| 102 |
+
using Reference = typename TensorRef::Reference;
|
| 103 |
+
|
| 104 |
+
/// Constant reference to element in tensor
|
| 105 |
+
using ConstReference = typename ConstTensorRef::Reference;
|
| 106 |
+
|
| 107 |
+
private:
|
| 108 |
+
using StorageUnit = typename platform::conditional_t<std::is_same_v<Element, bool>, uint8_t, // Avoid the std::vector<bool> specialization
|
| 109 |
+
typename platform::conditional_t<sizeof_bits<Element>::value % 8 == 0, // Handle subbyte types
|
| 110 |
+
Element, uint8_t>>;
|
| 111 |
+
using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
|
| 112 |
+
static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits;
|
| 113 |
+
static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements;
|
| 114 |
+
static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes;
|
| 115 |
+
static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit;
|
| 116 |
+
|
| 117 |
+
//
|
| 118 |
+
// Data members
|
| 119 |
+
//
|
| 120 |
+
|
| 121 |
+
/// Extent of tensor in logical dimensions
|
| 122 |
+
TensorCoord extent_;
|
| 123 |
+
|
| 124 |
+
/// Layout object
|
| 125 |
+
Layout layout_;
|
| 126 |
+
|
| 127 |
+
/// Host-side memory allocation
|
| 128 |
+
std::vector<StorageUnit> host_;
|
| 129 |
+
|
| 130 |
+
/// Device-side memory
|
| 131 |
+
device_memory::allocation<StorageUnit> device_;
|
| 132 |
+
|
| 133 |
+
/// number of containers
|
| 134 |
+
size_t count_to_container_storage_unit_count(size_t count) {
|
| 135 |
+
return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
public:
|
| 139 |
+
//
|
| 140 |
+
// Device and Host Methods
|
| 141 |
+
//
|
| 142 |
+
|
| 143 |
+
/// Default constructor
|
| 144 |
+
HostTensor() {}
|
| 145 |
+
|
| 146 |
+
/// Constructs a tensor given an extent. Assumes a packed layout
|
| 147 |
+
HostTensor(
|
| 148 |
+
TensorCoord const &extent,
|
| 149 |
+
bool device_backed = true
|
| 150 |
+
) {
|
| 151 |
+
|
| 152 |
+
this->reset(extent, Layout::packed(extent), device_backed);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
/// Constructs a tensor given an extent and layout
|
| 156 |
+
HostTensor(
|
| 157 |
+
TensorCoord const &extent,
|
| 158 |
+
Layout const &layout,
|
| 159 |
+
bool device_backed = true
|
| 160 |
+
) {
|
| 161 |
+
|
| 162 |
+
this->reset(extent, layout, device_backed);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
~HostTensor() { }
|
| 166 |
+
|
| 167 |
+
/// Clears the HostTensor allocation to size/capacity = 0
|
| 168 |
+
void reset() {
|
| 169 |
+
extent_ = TensorCoord();
|
| 170 |
+
layout_ = Layout::packed(extent_);
|
| 171 |
+
|
| 172 |
+
host_.clear();
|
| 173 |
+
device_.reset();
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/// Resizes internal memory allocations without affecting layout or extent
|
| 177 |
+
void reserve(
|
| 178 |
+
size_t count, ///< size of tensor in elements
|
| 179 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated
|
| 180 |
+
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 181 |
+
CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")");
|
| 182 |
+
#endif
|
| 183 |
+
|
| 184 |
+
device_.reset();
|
| 185 |
+
host_.clear();
|
| 186 |
+
|
| 187 |
+
size_t count_container = count_to_container_storage_unit_count(count);
|
| 188 |
+
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 189 |
+
CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")");
|
| 190 |
+
#endif
|
| 191 |
+
host_.resize(count_container);
|
| 192 |
+
|
| 193 |
+
// Allocate memory
|
| 194 |
+
StorageUnit* device_memory = nullptr;
|
| 195 |
+
if (device_backed_) {
|
| 196 |
+
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
|
| 197 |
+
CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")");
|
| 198 |
+
#endif
|
| 199 |
+
device_memory = device_memory::allocate<StorageUnit>(count_container);
|
| 200 |
+
}
|
| 201 |
+
device_.reset(device_memory, device_backed_ ? count_container : 0);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 205 |
+
/// extent and layout.
|
| 206 |
+
void reset(
|
| 207 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 208 |
+
Layout const &layout, ///< layout object of tensor
|
| 209 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 210 |
+
|
| 211 |
+
extent_ = extent;
|
| 212 |
+
layout_ = layout;
|
| 213 |
+
|
| 214 |
+
reserve(size_t(layout_.capacity(extent_)), device_backed_);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 218 |
+
/// extent and layout. Assumes a packed tensor configuration.
|
| 219 |
+
void reset(
|
| 220 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 221 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 222 |
+
|
| 223 |
+
reset(extent, Layout::packed(extent), device_backed_);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 227 |
+
/// To force allocation, call reset().
|
| 228 |
+
void resize(
|
| 229 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 230 |
+
Layout const &layout, ///< layout object of tensor
|
| 231 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 232 |
+
|
| 233 |
+
extent_ = extent;
|
| 234 |
+
layout_ = layout;
|
| 235 |
+
|
| 236 |
+
LongIndex new_size = size_t(layout_.capacity(extent_));
|
| 237 |
+
LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_)));
|
| 238 |
+
|
| 239 |
+
if (static_cast<decltype(host_.size())>(new_size_container) > host_.size()) {
|
| 240 |
+
reserve(new_size, device_backed_);
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 245 |
+
/// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
|
| 246 |
+
void resize(
|
| 247 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 248 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 249 |
+
|
| 250 |
+
resize(extent, Layout::packed(extent), device_backed_);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
/// Returns the logical number of elements stored in the host tensor
|
| 254 |
+
size_t size() const {
|
| 255 |
+
return layout_.capacity(extent_);
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
/// Returns the logical capacity in terms of number of elements. May be larger than the size().
|
| 259 |
+
LongIndex capacity() const {
|
| 260 |
+
return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
/// Gets pointer to host data
|
| 264 |
+
Element * host_data() { return reinterpret_cast<Element *>(host_.data()); }
|
| 265 |
+
|
| 266 |
+
/// Gets pointer to host data with a pointer offset
|
| 267 |
+
Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
|
| 268 |
+
|
| 269 |
+
/// Gets a reference to an element in host memory
|
| 270 |
+
Reference host_data(LongIndex idx) {
|
| 271 |
+
return ReferenceFactory<Element>::get(host_data(), idx);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
/// Gets pointer to host data
|
| 275 |
+
Element const * host_data() const { return reinterpret_cast<Element const *>(host_.data()); }
|
| 276 |
+
|
| 277 |
+
/// Gets pointer to host data with a pointer offset
|
| 278 |
+
Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
|
| 279 |
+
|
| 280 |
+
/// Gets a constant reference to an element in host memory
|
| 281 |
+
ConstReference host_data(LongIndex idx) const {
|
| 282 |
+
return ReferenceFactory<Element const>::get(host_data(), idx);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
/// Gets pointer to device data
|
| 286 |
+
Element * device_data() { return reinterpret_cast<Element *>(device_.get()); }
|
| 287 |
+
|
| 288 |
+
/// Gets pointer to device data
|
| 289 |
+
Element const * device_data() const { return reinterpret_cast<Element const *>(device_.get()); }
|
| 290 |
+
|
| 291 |
+
/// Gets pointer to device data with a pointer offset
|
| 292 |
+
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
|
| 293 |
+
|
| 294 |
+
/// Gets pointer to device data with a pointer offset
|
| 295 |
+
Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
|
| 296 |
+
|
| 297 |
+
/// Accesses the tensor reference pointing to data
|
| 298 |
+
TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
|
| 299 |
+
|
| 300 |
+
/// Accesses the tensor reference pointing to data
|
| 301 |
+
ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
|
| 302 |
+
|
| 303 |
+
/// Accesses the tensor reference pointing to data
|
| 304 |
+
TensorRef device_ref(LongIndex ptr_element_offset=0) {
|
| 305 |
+
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
/// Accesses the tensor reference pointing to data
|
| 309 |
+
ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
|
| 310 |
+
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
/// Accesses the tensor reference pointing to data
|
| 314 |
+
TensorView host_view(LongIndex ptr_element_offset=0) {
|
| 315 |
+
return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
/// Accesses the tensor reference pointing to data
|
| 319 |
+
ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
|
| 320 |
+
return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
/// Accesses the tensor reference pointing to data
|
| 324 |
+
TensorView device_view(LongIndex ptr_element_offset=0) {
|
| 325 |
+
return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
/// Accesses the tensor reference pointing to data
|
| 329 |
+
ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
|
| 330 |
+
return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
/// Returns true if device memory is allocated
|
| 334 |
+
bool device_backed() const {
|
| 335 |
+
return (device_.get() == nullptr) ? false : true;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
/// Returns the layout object
|
| 340 |
+
Layout & layout() {
|
| 341 |
+
return layout_;
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
/// Returns the layout object
|
| 345 |
+
Layout layout() const {
|
| 346 |
+
return layout_;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
/// Returns the layout object's stride vector
|
| 350 |
+
Stride stride() const {
|
| 351 |
+
return layout_.stride();
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
/// Returns the layout object's stride vector
|
| 355 |
+
Stride & stride() {
|
| 356 |
+
return layout_.stride();
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
/// Returns the layout object's stride in a given physical dimension
|
| 360 |
+
LongIndex stride(int dim) const {
|
| 361 |
+
return layout_.stride().at(dim);
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
/// Returns the layout object's stride in a given physical dimension
|
| 365 |
+
LongIndex & stride(int dim) {
|
| 366 |
+
return layout_.stride().at(dim);
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
/// Computes the offset of an index from the origin of the tensor
|
| 370 |
+
LongIndex offset(TensorCoord const& coord) const {
|
| 371 |
+
return layout_(coord);
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
/// Returns a reference to the element at the logical Coord in host memory
|
| 375 |
+
Reference at(TensorCoord const& coord) {
|
| 376 |
+
return host_data(offset(coord));
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
/// Returns a const reference to the element at the logical Coord in host memory
|
| 380 |
+
ConstReference at(TensorCoord const& coord) const {
|
| 381 |
+
return host_data(offset(coord));
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
/// Returns the extent of the tensor
|
| 385 |
+
TensorCoord extent() const {
|
| 386 |
+
return extent_;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
/// Returns the extent of the tensor
|
| 390 |
+
TensorCoord & extent() {
|
| 391 |
+
return extent_;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
/// Copies data from device to host
|
| 395 |
+
void sync_host() {
|
| 396 |
+
if (device_backed()) {
|
| 397 |
+
device_memory::copy_to_host(
|
| 398 |
+
host_.data(), device_.get(), device_.size());
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
/// Copies data from host to device
|
| 403 |
+
void sync_device() {
|
| 404 |
+
if (device_backed()) {
|
| 405 |
+
device_memory::copy_to_device(
|
| 406 |
+
device_.get(), host_.data(), host_.size());
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 411 |
+
void copy_in_device_to_host(
|
| 412 |
+
Element const* ptr_device, ///< source device memory
|
| 413 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 414 |
+
|
| 415 |
+
if (count < 0) {
|
| 416 |
+
count = capacity();
|
| 417 |
+
}
|
| 418 |
+
else {
|
| 419 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 420 |
+
}
|
| 421 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 422 |
+
device_memory::copy_to_host(
|
| 423 |
+
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 427 |
+
void copy_in_device_to_device(
|
| 428 |
+
Element const* ptr_device, ///< source device memory
|
| 429 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 430 |
+
|
| 431 |
+
if (count < 0) {
|
| 432 |
+
count = capacity();
|
| 433 |
+
}
|
| 434 |
+
else {
|
| 435 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 436 |
+
}
|
| 437 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 438 |
+
device_memory::copy_device_to_device(
|
| 439 |
+
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 443 |
+
void copy_in_host_to_device(
|
| 444 |
+
Element const* ptr_host, ///< source host memory
|
| 445 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 446 |
+
|
| 447 |
+
if (count < 0) {
|
| 448 |
+
count = capacity();
|
| 449 |
+
}
|
| 450 |
+
else {
|
| 451 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 452 |
+
}
|
| 453 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 454 |
+
device_memory::copy_to_device(
|
| 455 |
+
device_.get(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 459 |
+
void copy_in_host_to_host(
|
| 460 |
+
Element const* ptr_host, ///< source host memory
|
| 461 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 462 |
+
|
| 463 |
+
if (count < 0) {
|
| 464 |
+
count = capacity();
|
| 465 |
+
}
|
| 466 |
+
else {
|
| 467 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 468 |
+
}
|
| 469 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 470 |
+
device_memory::copy_host_to_host(
|
| 471 |
+
host_.data(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 475 |
+
void copy_out_device_to_host(
|
| 476 |
+
Element * ptr_host, ///< source device memory
|
| 477 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 478 |
+
|
| 479 |
+
if (count < 0) {
|
| 480 |
+
count = capacity();
|
| 481 |
+
}
|
| 482 |
+
else {
|
| 483 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 484 |
+
}
|
| 485 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 486 |
+
device_memory::copy_to_host(
|
| 487 |
+
reinterpret_cast<StorageUnit *>(ptr_host), device_.get(), container_count);
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 491 |
+
void copy_out_device_to_device(
|
| 492 |
+
Element * ptr_device, ///< source device memory
|
| 493 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 494 |
+
|
| 495 |
+
if (count < 0) {
|
| 496 |
+
count = capacity();
|
| 497 |
+
}
|
| 498 |
+
else {
|
| 499 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 500 |
+
}
|
| 501 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 502 |
+
device_memory::copy_device_to_device(
|
| 503 |
+
reinterpret_cast<StorageUnit *>(ptr_device), device_.get(), container_count);
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 507 |
+
void copy_out_host_to_device(
|
| 508 |
+
Element * ptr_device, ///< source host memory
|
| 509 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 510 |
+
|
| 511 |
+
if (count < 0) {
|
| 512 |
+
count = capacity();
|
| 513 |
+
}
|
| 514 |
+
else {
|
| 515 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 516 |
+
}
|
| 517 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 518 |
+
device_memory::copy_to_device(
|
| 519 |
+
reinterpret_cast<StorageUnit *>(ptr_device), host_.data(), container_count);
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 523 |
+
void copy_out_host_to_host(
|
| 524 |
+
Element * ptr_host, ///< source host memory
|
| 525 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 526 |
+
|
| 527 |
+
if (count < 0) {
|
| 528 |
+
count = capacity();
|
| 529 |
+
}
|
| 530 |
+
else {
|
| 531 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 532 |
+
}
|
| 533 |
+
size_t container_count = count_to_container_storage_unit_count(count);
|
| 534 |
+
device_memory::copy_host_to_host(
|
| 535 |
+
reinterpret_cast<StorageUnit *>(ptr_host), host_.data(), container_count);
|
| 536 |
+
}
|
| 537 |
+
};
|
| 538 |
+
|
| 539 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 540 |
+
|
| 541 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
/*! \file
|
| 34 |
+
\brief HostTensor contributes management for both host and device memory.
|
| 35 |
+
|
| 36 |
+
HostTensor allocates host and device memory upon construction. Basic element-wise operations on
|
| 37 |
+
host memory synchronize device memory automatically. Explicit copy operations provide abstractions
|
| 38 |
+
for CUDA memcpy operations.
|
| 39 |
+
|
| 40 |
+
Call {host, device}_{data, ref, view}() for accessing host or device memory.
|
| 41 |
+
|
| 42 |
+
See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
|
| 43 |
+
*/
|
| 44 |
+
|
| 45 |
+
#include <vector>
|
| 46 |
+
|
| 47 |
+
#include "cutlass/cutlass.h"
|
| 48 |
+
|
| 49 |
+
#include "cutlass/tensor_ref_planar_complex.h"
|
| 50 |
+
#include "cutlass/tensor_view_planar_complex.h"
|
| 51 |
+
|
| 52 |
+
#include "device_memory.h"
|
| 53 |
+
|
| 54 |
+
namespace cutlass {
|
| 55 |
+
|
| 56 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
/// Host tensor
|
| 59 |
+
template <
|
| 60 |
+
/// Data type of element stored within tensor (concept: NumericType)
|
| 61 |
+
typename Element_,
|
| 62 |
+
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
| 63 |
+
typename Layout_
|
| 64 |
+
>
|
| 65 |
+
class HostTensorPlanarComplex {
|
| 66 |
+
public:
|
| 67 |
+
|
| 68 |
+
/// Data type of individual access
|
| 69 |
+
using Element = Element_;
|
| 70 |
+
|
| 71 |
+
/// Mapping function from logical coordinate to linear memory
|
| 72 |
+
using Layout = Layout_;
|
| 73 |
+
|
| 74 |
+
/// Logical rank of tensor index space
|
| 75 |
+
static int const kRank = Layout::kRank;
|
| 76 |
+
|
| 77 |
+
/// Index type
|
| 78 |
+
using Index = typename Layout::Index;
|
| 79 |
+
|
| 80 |
+
/// Long index used for pointer offsets
|
| 81 |
+
using LongIndex = typename Layout::LongIndex;
|
| 82 |
+
|
| 83 |
+
/// Coordinate in logical tensor space
|
| 84 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 85 |
+
|
| 86 |
+
/// Layout's stride vector
|
| 87 |
+
using Stride = typename Layout::Stride;
|
| 88 |
+
|
| 89 |
+
/// Tensor reference to device memory
|
| 90 |
+
using TensorRef = TensorRefPlanarComplex<Element, Layout>;
|
| 91 |
+
|
| 92 |
+
/// Tensor reference to constant device memory
|
| 93 |
+
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
| 94 |
+
|
| 95 |
+
/// Tensor reference to device memory
|
| 96 |
+
using TensorView = TensorViewPlanarComplex<Element, Layout>;
|
| 97 |
+
|
| 98 |
+
/// Tensor reference to constant device memory
|
| 99 |
+
using ConstTensorView = typename TensorView::ConstTensorView;
|
| 100 |
+
|
| 101 |
+
/// Reference to element in tensor
|
| 102 |
+
using Reference = typename TensorRef::Reference;
|
| 103 |
+
|
| 104 |
+
/// Constant reference to element in tensor
|
| 105 |
+
using ConstReference = typename ConstTensorRef::Reference;
|
| 106 |
+
|
| 107 |
+
private:
|
| 108 |
+
|
| 109 |
+
//
|
| 110 |
+
// Data members
|
| 111 |
+
//
|
| 112 |
+
|
| 113 |
+
/// Extent of tensor in logical dimensions
|
| 114 |
+
TensorCoord extent_;
|
| 115 |
+
|
| 116 |
+
/// Layout object
|
| 117 |
+
Layout layout_;
|
| 118 |
+
|
| 119 |
+
/// Host-side memory allocation
|
| 120 |
+
std::vector<Element> host_;
|
| 121 |
+
|
| 122 |
+
/// Device-side memory
|
| 123 |
+
device_memory::allocation<Element> device_;
|
| 124 |
+
|
| 125 |
+
public:
|
| 126 |
+
//
|
| 127 |
+
// Device and Host Methods
|
| 128 |
+
//
|
| 129 |
+
|
| 130 |
+
/// Default constructor
|
| 131 |
+
HostTensorPlanarComplex() {}
|
| 132 |
+
|
| 133 |
+
/// Constructs a tensor given an extent. Assumes a packed layout
|
| 134 |
+
HostTensorPlanarComplex(
|
| 135 |
+
TensorCoord const &extent,
|
| 136 |
+
bool device_backed = true
|
| 137 |
+
) {
|
| 138 |
+
|
| 139 |
+
this->reset(extent, Layout::packed(extent), device_backed);
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
/// Constructs a tensor given an extent and layout
|
| 143 |
+
HostTensorPlanarComplex(
|
| 144 |
+
TensorCoord const &extent,
|
| 145 |
+
Layout const &layout,
|
| 146 |
+
bool device_backed = true
|
| 147 |
+
) {
|
| 148 |
+
|
| 149 |
+
this->reset(extent, layout, device_backed);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
~HostTensorPlanarComplex() { }
|
| 153 |
+
|
| 154 |
+
/// Clears the HostTensor allocation to size/capacity = 0
|
| 155 |
+
void reset() {
|
| 156 |
+
extent_ = TensorCoord();
|
| 157 |
+
layout_ = Layout::packed(extent_);
|
| 158 |
+
|
| 159 |
+
host_.clear();
|
| 160 |
+
device_.reset();
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/// Resizes internal memory allocations without affecting layout or extent
|
| 164 |
+
void reserve(
|
| 165 |
+
size_t count, ///< size of tensor in elements
|
| 166 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated
|
| 167 |
+
|
| 168 |
+
device_.reset();
|
| 169 |
+
host_.clear();
|
| 170 |
+
|
| 171 |
+
host_.resize(count * 2);
|
| 172 |
+
|
| 173 |
+
// Allocate memory
|
| 174 |
+
Element* device_memory = nullptr;
|
| 175 |
+
if (device_backed_) {
|
| 176 |
+
device_memory = device_memory::allocate<Element>(count * 2);
|
| 177 |
+
}
|
| 178 |
+
device_.reset(device_memory, device_backed_ ? count * 2 : 0);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 182 |
+
/// extent and layout.
|
| 183 |
+
void reset(
|
| 184 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 185 |
+
Layout const &layout, ///< layout object of tensor
|
| 186 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 187 |
+
|
| 188 |
+
extent_ = extent;
|
| 189 |
+
layout_ = layout;
|
| 190 |
+
|
| 191 |
+
reserve(size_t(layout_.capacity(extent_)), device_backed_);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/// Updates the extent and layout of the HostTensor. Allocates memory according to the new
|
| 195 |
+
/// extent and layout. Assumes a packed tensor configuration.
|
| 196 |
+
void reset(
|
| 197 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 198 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 199 |
+
|
| 200 |
+
reset(extent, Layout::packed(extent), device_backed_);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 204 |
+
/// To force allocation, call reset().
|
| 205 |
+
void resize(
|
| 206 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 207 |
+
Layout const &layout, ///< layout object of tensor
|
| 208 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 209 |
+
|
| 210 |
+
extent_ = extent;
|
| 211 |
+
layout_ = layout;
|
| 212 |
+
|
| 213 |
+
LongIndex new_size = size_t(layout_.capacity(extent_));
|
| 214 |
+
|
| 215 |
+
if (static_cast<decltype(host_.size())>(new_size * 2) > host_.size()) {
|
| 216 |
+
reserve(new_size);
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
|
| 221 |
+
/// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
|
| 222 |
+
void resize(
|
| 223 |
+
TensorCoord const &extent, ///< extent of logical tensor
|
| 224 |
+
bool device_backed_ = true) { ///< if true, device memory is also allocated.
|
| 225 |
+
|
| 226 |
+
resize(extent, Layout::packed(extent), device_backed_);
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
/// Returns the number of elements stored in the host tensor
|
| 230 |
+
size_t size() const {
|
| 231 |
+
return host_.size() / 2;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
/// Returns the logical capacity based on extent and layout. May differ from size().
|
| 235 |
+
LongIndex capacity() const {
|
| 236 |
+
return layout_.capacity(extent_);
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
/// Stride between real and imaginary parts
|
| 240 |
+
LongIndex imaginary_stride() const {
|
| 241 |
+
return host_.size() / 2;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/// Gets pointer to host data
|
| 245 |
+
Element * host_data() { return host_.data(); }
|
| 246 |
+
|
| 247 |
+
/// Gets pointer to host data imaginary part
|
| 248 |
+
Element * host_data_imag() { return host_.data() + imaginary_stride(); }
|
| 249 |
+
|
| 250 |
+
/// Gets pointer to host data with a pointer offset
|
| 251 |
+
Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; }
|
| 252 |
+
|
| 253 |
+
/// Gets pointer to host data with a pointer offset
|
| 254 |
+
Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; }
|
| 255 |
+
|
| 256 |
+
/// Gets a reference to an element in host memory
|
| 257 |
+
Reference host_data(LongIndex idx) {
|
| 258 |
+
return PlanarComplexReference<Element>(host_data() + idx, host_data_imag() + idx);
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
/// Gets pointer to host data
|
| 262 |
+
Element const * host_data() const { return host_.data(); }
|
| 263 |
+
|
| 264 |
+
/// Gets pointer to host data imaginary part
|
| 265 |
+
Element const * host_data_imag() const { return host_.data() + imaginary_stride(); }
|
| 266 |
+
|
| 267 |
+
/// Gets a constant reference to an element in host memory
|
| 268 |
+
ConstReference host_data(LongIndex idx) const {
|
| 269 |
+
return PlanarComplexReference<Element const>(host_data() + idx, host_data_imag() + idx);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
/// Gets pointer to device data
|
| 273 |
+
Element * device_data() { return device_.get(); }
|
| 274 |
+
|
| 275 |
+
/// Gets pointer to device data with a pointer offset
|
| 276 |
+
Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; }
|
| 277 |
+
|
| 278 |
+
/// Gets pointer to device data
|
| 279 |
+
Element const * device_data() const { return device_.get(); }
|
| 280 |
+
|
| 281 |
+
/// Gets pointer to device data with a pointer offset
|
| 282 |
+
Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; }
|
| 283 |
+
|
| 284 |
+
/// Gets a pointer to the device data imaginary part
|
| 285 |
+
Element * device_data_imag() { return device_.get() + imaginary_stride(); }
|
| 286 |
+
|
| 287 |
+
/// Accesses the tensor reference pointing to data
|
| 288 |
+
TensorRef host_ref(LongIndex ptr_element_offset=0) {
|
| 289 |
+
return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
/// Returns a tensor reference to the real part of the tensor
|
| 293 |
+
cutlass::TensorRef<Element, Layout> host_ref_real() {
|
| 294 |
+
return cutlass::TensorRef<Element, Layout>(host_data(), layout_);
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
/// Returns a tensor reference to the real part of the tensor
|
| 298 |
+
cutlass::TensorRef<Element, Layout> host_ref_imag() {
|
| 299 |
+
return cutlass::TensorRef<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_);
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// Accesses the tensor reference pointing to data
|
| 303 |
+
ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const {
|
| 304 |
+
return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
/// Accesses the tensor reference pointing to data
|
| 308 |
+
TensorRef device_ref(LongIndex ptr_element_offset=0) {
|
| 309 |
+
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
/// Accesses the tensor reference pointing to data
|
| 313 |
+
ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
|
| 314 |
+
return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
/// Returns a tensor reference to the real part of the tensor
|
| 318 |
+
cutlass::TensorRef<Element, Layout> device_ref_real() {
|
| 319 |
+
return cutlass::TensorRef<Element, Layout>(device_data(), layout_);
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
/// Returns a tensor reference to the real part of the tensor
|
| 323 |
+
cutlass::TensorRef<Element, Layout> device_ref_imag() {
|
| 324 |
+
return cutlass::TensorRef<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_);
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
/// Accesses the tensor reference pointing to data
|
| 328 |
+
TensorView host_view(LongIndex ptr_element_offset=0) {
|
| 329 |
+
return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
/// Accesses the tensor reference pointing to data
|
| 333 |
+
ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
|
| 334 |
+
return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
/// Accesses the tensor reference pointing to data
|
| 338 |
+
cutlass::TensorView<Element, Layout> host_view_real() {
|
| 339 |
+
return cutlass::TensorView<Element, Layout>(host_data(), layout_, extent_);
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
/// Accesses the tensor reference pointing to data
|
| 343 |
+
cutlass::TensorView<Element, Layout> host_view_imag() {
|
| 344 |
+
return cutlass::TensorView<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_, extent_);
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/// Accesses the tensor reference pointing to data
|
| 348 |
+
TensorView device_view(LongIndex ptr_element_offset=0) {
|
| 349 |
+
return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
/// Accesses the tensor reference pointing to data
|
| 353 |
+
ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
|
| 354 |
+
return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
/// Accesses the tensor reference pointing to data
|
| 358 |
+
cutlass::TensorView<Element, Layout> device_view_real() {
|
| 359 |
+
return cutlass::TensorView<Element, Layout>(device_data(), layout_, extent_);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
/// Accesses the tensor reference pointing to data
|
| 363 |
+
cutlass::TensorView<Element, Layout> device_view_imag() {
|
| 364 |
+
return cutlass::TensorView<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_, extent_);
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
/// Returns true if device memory is allocated
|
| 368 |
+
bool device_backed() const {
|
| 369 |
+
return (device_.get() == nullptr) ? false : true;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
/// Returns the layout object
|
| 373 |
+
Layout layout() const {
|
| 374 |
+
return layout_;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
/// Returns the layout object's stride vector
|
| 378 |
+
Stride stride() const {
|
| 379 |
+
return layout_.stride();
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
/// Returns the layout object's stride in a given physical dimension
|
| 383 |
+
Index stride(int dim) const {
|
| 384 |
+
return layout_.stride().at(dim);
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
/// Computes the offset of an index from the origin of the tensor
|
| 388 |
+
LongIndex offset(TensorCoord const& coord) const {
|
| 389 |
+
return layout_(coord);
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
/// Returns a reference to the element at the logical Coord in host memory
|
| 393 |
+
Reference at(TensorCoord const& coord) {
|
| 394 |
+
return host_data(offset(coord));
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
/// Returns a const reference to the element at the logical Coord in host memory
|
| 398 |
+
ConstReference at(TensorCoord const& coord) const {
|
| 399 |
+
return host_data(offset(coord));
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
/// Returns the extent of the tensor
|
| 403 |
+
TensorCoord extent() const {
|
| 404 |
+
return extent_;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
/// Returns the extent of the tensor
|
| 408 |
+
TensorCoord & extent() {
|
| 409 |
+
return extent_;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
/// Copies data from device to host
|
| 413 |
+
void sync_host() {
|
| 414 |
+
if (device_backed()) {
|
| 415 |
+
device_memory::copy_to_host(
|
| 416 |
+
host_data(), device_data(), imaginary_stride() * 2);
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
/// Copies data from host to device
|
| 421 |
+
void sync_device() {
|
| 422 |
+
if (device_backed()) {
|
| 423 |
+
device_memory::copy_to_device(
|
| 424 |
+
device_data(), host_data(), imaginary_stride() * 2);
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 429 |
+
void copy_in_device_to_host(
|
| 430 |
+
Element const* ptr_device_real, ///< source device memory
|
| 431 |
+
Element const* ptr_device_imag, ///< source device memory
|
| 432 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 433 |
+
|
| 434 |
+
if (count < 0) {
|
| 435 |
+
count = capacity();
|
| 436 |
+
}
|
| 437 |
+
else {
|
| 438 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
device_memory::copy_to_host(
|
| 442 |
+
host_data(), ptr_device_real, count);
|
| 443 |
+
|
| 444 |
+
device_memory::copy_to_host(
|
| 445 |
+
host_data_imag(), ptr_device_imag, count);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 449 |
+
void copy_in_device_to_device(
|
| 450 |
+
Element const* ptr_device_real, ///< source device memory
|
| 451 |
+
Element const* ptr_device_imag, ///< source device memory
|
| 452 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 453 |
+
|
| 454 |
+
if (count < 0) {
|
| 455 |
+
count = capacity();
|
| 456 |
+
}
|
| 457 |
+
else {
|
| 458 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
device_memory::copy_device_to_device(
|
| 462 |
+
device_data(), ptr_device_real, count);
|
| 463 |
+
|
| 464 |
+
device_memory::copy_device_to_device(
|
| 465 |
+
device_data_imag(), ptr_device_imag, count);
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 469 |
+
void copy_in_host_to_device(
|
| 470 |
+
Element const* ptr_host_real, ///< source host memory
|
| 471 |
+
Element const* ptr_host_imag, ///< source host memory
|
| 472 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 473 |
+
|
| 474 |
+
if (count < 0) {
|
| 475 |
+
count = capacity();
|
| 476 |
+
}
|
| 477 |
+
else {
|
| 478 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
device_memory::copy_to_device(
|
| 482 |
+
device_data(), ptr_host_real, count);
|
| 483 |
+
|
| 484 |
+
device_memory::copy_to_device(
|
| 485 |
+
device_data_imag(), ptr_host_imag, count);
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 489 |
+
void copy_in_host_to_host(
|
| 490 |
+
Element const* ptr_host_real, ///< source host memory
|
| 491 |
+
Element const* ptr_host_imag, ///< source host memory
|
| 492 |
+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 493 |
+
|
| 494 |
+
if (count < 0) {
|
| 495 |
+
count = capacity();
|
| 496 |
+
}
|
| 497 |
+
else {
|
| 498 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
device_memory::copy_host_to_host(
|
| 502 |
+
host_data(), ptr_host_real, count);
|
| 503 |
+
|
| 504 |
+
device_memory::copy_host_to_host(
|
| 505 |
+
host_data_imag(), ptr_host_imag, count);
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 509 |
+
void copy_out_device_to_host(
|
| 510 |
+
Element * ptr_host_real, ///< source device memory
|
| 511 |
+
Element * ptr_host_imag, ///< source device memory
|
| 512 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 513 |
+
|
| 514 |
+
if (count < 0) {
|
| 515 |
+
count = capacity();
|
| 516 |
+
}
|
| 517 |
+
else {
|
| 518 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
device_memory::copy_to_host(
|
| 522 |
+
ptr_host_real, device_data(), count);
|
| 523 |
+
|
| 524 |
+
device_memory::copy_to_host(
|
| 525 |
+
ptr_host_imag, device_data_imag(), count);
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 529 |
+
void copy_out_device_to_device(
|
| 530 |
+
Element * ptr_device_real, ///< source device memory
|
| 531 |
+
Element * ptr_device_imag, ///< source device memory
|
| 532 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 533 |
+
|
| 534 |
+
if (count < 0) {
|
| 535 |
+
count = capacity();
|
| 536 |
+
}
|
| 537 |
+
else {
|
| 538 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
device_memory::copy_device_to_device(
|
| 542 |
+
ptr_device_real, device_data(), count);
|
| 543 |
+
|
| 544 |
+
device_memory::copy_device_to_device(
|
| 545 |
+
ptr_device_imag, device_data_imag(), count);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 549 |
+
void copy_out_host_to_device(
|
| 550 |
+
Element * ptr_device_real, ///< source device memory
|
| 551 |
+
Element * ptr_device_imag, ///< source device memory
|
| 552 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 553 |
+
|
| 554 |
+
if (count < 0) {
|
| 555 |
+
count = capacity();
|
| 556 |
+
}
|
| 557 |
+
else {
|
| 558 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
device_memory::copy_to_device(
|
| 562 |
+
ptr_device_real, host_data(), count);
|
| 563 |
+
|
| 564 |
+
device_memory::copy_to_device(
|
| 565 |
+
ptr_device_imag, host_data_imag(), count);
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
/// Copy data from a caller-supplied device pointer into host memory.
|
| 569 |
+
void copy_out_host_to_host(
|
| 570 |
+
Element * ptr_host_real, ///< source host memory
|
| 571 |
+
Element * ptr_host_imag, ///< source host memory
|
| 572 |
+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
|
| 573 |
+
|
| 574 |
+
if (count < 0) {
|
| 575 |
+
count = capacity();
|
| 576 |
+
}
|
| 577 |
+
else {
|
| 578 |
+
count = __NV_STD_MIN(capacity(), count);
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
device_memory::copy_host_to_host(
|
| 582 |
+
ptr_host_real, host_data(), count);
|
| 583 |
+
|
| 584 |
+
device_memory::copy_host_to_host(
|
| 585 |
+
ptr_host_imag, host_data_imag(), count);
|
| 586 |
+
}
|
| 587 |
+
};
|
| 588 |
+
|
| 589 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 590 |
+
|
| 591 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief uncompress sparse matrix from the host side
|
| 34 |
+
*/
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/util/host_tensor.h"
|
| 39 |
+
#include "cutlass/tensor_view.h"
|
| 40 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 41 |
+
#include "cutlass/util/reference/host/gemm.h"
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
|
| 45 |
+
// uncompress sparse tensor core A matrix
|
| 46 |
+
template <typename ElementA, typename LayoutA, typename ElementE,
|
| 47 |
+
typename LayoutE>
|
| 48 |
+
void uncompress(TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
|
| 49 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 50 |
+
TensorRef<ElementE, LayoutE> tensor_e, int row, int col) {
|
| 51 |
+
// How many uncompressed data we can get with ElementE meta data
|
| 52 |
+
int DecompressedElementsPerElementE =
|
| 53 |
+
256 / cutlass::sizeof_bits<ElementA>::value;
|
| 54 |
+
|
| 55 |
+
// Process 4bit meta data a time
|
| 56 |
+
int step;
|
| 57 |
+
|
| 58 |
+
// 1:2 or 2:4 or 4:8
|
| 59 |
+
int a, b;
|
| 60 |
+
|
| 61 |
+
if (cutlass::sizeof_bits<ElementA>::value == 4) {
|
| 62 |
+
step = 8;
|
| 63 |
+
a = 4;
|
| 64 |
+
b = 8;
|
| 65 |
+
} else if (cutlass::sizeof_bits<ElementA>::value == 8) {
|
| 66 |
+
step = 4;
|
| 67 |
+
a = 2;
|
| 68 |
+
b = 4;
|
| 69 |
+
} else if (cutlass::sizeof_bits<ElementA>::value == 16) {
|
| 70 |
+
step = 4;
|
| 71 |
+
a = 2;
|
| 72 |
+
b = 4;
|
| 73 |
+
} else if (cutlass::sizeof_bits<ElementA>::value == 32) {
|
| 74 |
+
step = 2;
|
| 75 |
+
a = 1;
|
| 76 |
+
b = 2;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
int ElementsPerE = (cutlass::sizeof_bits<ElementA>::value == 4) ? 2 : 1;
|
| 80 |
+
|
| 81 |
+
for (int r = 0; r < row; ++r) {
|
| 82 |
+
for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) {
|
| 83 |
+
|
| 84 |
+
ElementE meta = tensor_e.at(MatrixCoord(r, c));
|
| 85 |
+
|
| 86 |
+
for (int i = 0; i < DecompressedElementsPerElementE; i += step) {
|
| 87 |
+
int e = (meta >> (i / step * 4)) & 0xf;
|
| 88 |
+
int idx0 = e & 0x3;
|
| 89 |
+
int idx1 = e >> 2;
|
| 90 |
+
|
| 91 |
+
if (a == 1) idx0 = idx0 / 2;
|
| 92 |
+
|
| 93 |
+
for (int ii = 0; ii < step; ii += ElementsPerE) {
|
| 94 |
+
int real_col =
|
| 95 |
+
c * DecompressedElementsPerElementE + i + ii;
|
| 96 |
+
int compressed_col = (real_col / b) * a;
|
| 97 |
+
|
| 98 |
+
if (ii == (idx0 * ElementsPerE)) {
|
| 99 |
+
uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
|
| 100 |
+
tensor_a.at(MatrixCoord(r, compressed_col));
|
| 101 |
+
if (ElementsPerE == 2)
|
| 102 |
+
uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
|
| 103 |
+
tensor_a.at(MatrixCoord(r, compressed_col + 1));
|
| 104 |
+
} else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) {
|
| 105 |
+
uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
|
| 106 |
+
tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE));
|
| 107 |
+
if (ElementsPerE == 2)
|
| 108 |
+
uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
|
| 109 |
+
tensor_a.at(
|
| 110 |
+
MatrixCoord(r, compressed_col + ElementsPerE + 1));
|
| 111 |
+
} else {
|
| 112 |
+
uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
|
| 113 |
+
ElementA(0);
|
| 114 |
+
if (ElementsPerE == 2)
|
| 115 |
+
uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
|
| 116 |
+
ElementA(0);
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// uncompress ELL block sparse matrix
|
| 125 |
+
template <typename ElementA, typename LayoutA,
|
| 126 |
+
typename ElementE, typename LayoutE>
|
| 127 |
+
void uncompress_ell_block_sparse(
|
| 128 |
+
TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
|
| 129 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 130 |
+
TensorRef<ElementE, LayoutE> ell_idx,
|
| 131 |
+
int rows, int cols,
|
| 132 |
+
int ell_num_cols, int ell_blocksize) {
|
| 133 |
+
|
| 134 |
+
for (int r = 0; r < rows / ell_blocksize; ++r) {
|
| 135 |
+
for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) {
|
| 136 |
+
|
| 137 |
+
ElementE idx = ell_idx.at(MatrixCoord(r, c));
|
| 138 |
+
|
| 139 |
+
if (idx != -1) {
|
| 140 |
+
int row_begin = r * ell_blocksize;
|
| 141 |
+
int col_begin_real = idx * ell_blocksize;
|
| 142 |
+
int col_begin = c * ell_blocksize;
|
| 143 |
+
|
| 144 |
+
for (int i = 0; i < ell_blocksize; ++i) {
|
| 145 |
+
for (int j = 0; j < ell_blocksize; ++j) {
|
| 146 |
+
uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) =
|
| 147 |
+
tensor_a.at(
|
| 148 |
+
MatrixCoord(row_begin + i, col_begin +j));
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
} // namespace cutlass
|
| 157 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include "cutlass/cutlass.h"
|
| 35 |
+
#include "cutlass/numeric_types.h"
|
| 36 |
+
|
| 37 |
+
// integer_sequence moved to cutlass/numeric_types.h
|
| 38 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Utilities for mixed input data type kernels.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include <cuda.h>
|
| 38 |
+
#include "cute/layout.hpp"
|
| 39 |
+
#include "cute/tensor.hpp"
|
| 40 |
+
#include "cute/arch/mma_sm90.hpp"
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/util/device_memory.h"
|
| 43 |
+
#include "cutlass/util/reference/device/tensor_fill.h"
|
| 44 |
+
#include "cute/util/type_traits.hpp"
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
|
| 48 |
+
#define CUDA_CHECK(status) \
|
| 49 |
+
{ \
|
| 50 |
+
cudaError_t error = status; \
|
| 51 |
+
if (error != cudaSuccess) { \
|
| 52 |
+
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
|
| 53 |
+
<< " at line: " << __LINE__ << std::endl; \
|
| 54 |
+
exit(EXIT_FAILURE); \
|
| 55 |
+
} \
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <
|
| 59 |
+
class QuantizedElement,
|
| 60 |
+
class DequantizedElement,
|
| 61 |
+
class OperandLayout,
|
| 62 |
+
class ElementScale,
|
| 63 |
+
class ElementZero,
|
| 64 |
+
class ScaleBroadCastLayout,
|
| 65 |
+
class ThrLayout>
|
| 66 |
+
__global__ void dequantize_kernel(DequantizedElement* dq_buffer,
|
| 67 |
+
QuantizedElement const* q_buffer,
|
| 68 |
+
OperandLayout const operand_layout,
|
| 69 |
+
ElementScale const* scale_buffer,
|
| 70 |
+
ElementZero const* zero_buffer,
|
| 71 |
+
ScaleBroadCastLayout const broadcasted_scale_layout,
|
| 72 |
+
ThrLayout thr_layout) {
|
| 73 |
+
using namespace cute;
|
| 74 |
+
|
| 75 |
+
// Represent the full tensors to gmem elements.
|
| 76 |
+
// These are expected to have shape [MN, K, L]
|
| 77 |
+
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
|
| 78 |
+
cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr<QuantizedElement const>(q_buffer), operand_layout);
|
| 79 |
+
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
| 80 |
+
// It is expected that K % G == 0
|
| 81 |
+
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
| 82 |
+
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
| 83 |
+
|
| 84 |
+
// Assign 1 thread per element in the thread block
|
| 85 |
+
auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
| 86 |
+
auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
|
| 87 |
+
|
| 88 |
+
// Tile across the block
|
| 89 |
+
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
|
| 90 |
+
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
| 91 |
+
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
| 92 |
+
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
|
| 93 |
+
|
| 94 |
+
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
|
| 95 |
+
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
|
| 96 |
+
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
|
| 97 |
+
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
|
| 98 |
+
|
| 99 |
+
// Make a fragment of registers to hold gmem loads
|
| 100 |
+
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
| 101 |
+
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
|
| 102 |
+
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
|
| 103 |
+
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
| 104 |
+
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
|
| 105 |
+
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
|
| 106 |
+
|
| 107 |
+
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
|
| 108 |
+
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
|
| 109 |
+
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
|
| 110 |
+
|
| 111 |
+
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
|
| 112 |
+
|
| 113 |
+
for (int ii = 0; ii < num_iters; ++ii) {
|
| 114 |
+
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
|
| 115 |
+
if (thread_offset < cute::size<0>(operand_layout)) {
|
| 116 |
+
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
| 117 |
+
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
| 118 |
+
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
| 119 |
+
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
| 120 |
+
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
| 121 |
+
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{});
|
| 122 |
+
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{});
|
| 123 |
+
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
| 124 |
+
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template <
|
| 130 |
+
class QuantizedElement,
|
| 131 |
+
class DequantizedElement,
|
| 132 |
+
class OperandLayout,
|
| 133 |
+
class ElementScale,
|
| 134 |
+
class ElementZero,
|
| 135 |
+
class ScaleLayout>
|
| 136 |
+
static void dequantize(DequantizedElement* dq_buffer,
|
| 137 |
+
QuantizedElement const* q_buffer,
|
| 138 |
+
OperandLayout const operand_layout,
|
| 139 |
+
ElementScale const* scale_buffer,
|
| 140 |
+
ElementZero const* zero_buffer,
|
| 141 |
+
ScaleLayout const scale_layout,
|
| 142 |
+
int const group_size,
|
| 143 |
+
cudaStream_t &stream) {
|
| 144 |
+
using namespace cute;
|
| 145 |
+
|
| 146 |
+
constexpr int tpb = 128;
|
| 147 |
+
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
| 148 |
+
|
| 149 |
+
const auto num_rows = get<0>(shape(operand_layout));
|
| 150 |
+
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
| 151 |
+
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
| 152 |
+
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
| 153 |
+
|
| 154 |
+
if (num_rows != size<0>(scale_layout)) {
|
| 155 |
+
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
| 156 |
+
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
| 157 |
+
<< std::endl;
|
| 158 |
+
exit(-1);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
const auto scale_stride0 = get<0>(stride(scale_layout));
|
| 162 |
+
const auto scale_stride1 = get<1>(stride(scale_layout));
|
| 163 |
+
const auto scale_stride2 = get<2>(stride(scale_layout));
|
| 164 |
+
|
| 165 |
+
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
| 166 |
+
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
| 167 |
+
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
| 168 |
+
|
| 169 |
+
const auto blocks_x = gemm_k;
|
| 170 |
+
const auto blocks_y = batches;
|
| 171 |
+
|
| 172 |
+
dim3 blocks(blocks_x, blocks_y, 1);
|
| 173 |
+
dequantize_kernel<<<blocks, tpb, 0, stream>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
| 174 |
+
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
template <typename T>
|
| 178 |
+
class packed_scale_t {
|
| 179 |
+
public:
|
| 180 |
+
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
| 181 |
+
cute::is_same_v<T, cutlass::uint8_t> ||
|
| 182 |
+
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
| 183 |
+
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
| 184 |
+
"only 8 bit arithmetic types are supported.");
|
| 185 |
+
CUTLASS_HOST_DEVICE
|
| 186 |
+
explicit packed_scale_t(T val) {
|
| 187 |
+
if constexpr (!cute::is_unsigned_v<T>) {
|
| 188 |
+
// Only pack negative values. The positive values are generated in flight in the mainloop.
|
| 189 |
+
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
|
| 190 |
+
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
|
| 191 |
+
}
|
| 192 |
+
else {
|
| 193 |
+
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
|
| 194 |
+
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
CUTLASS_HOST_DEVICE
|
| 198 |
+
packed_scale_t() = default;
|
| 199 |
+
CUTLASS_HOST_DEVICE
|
| 200 |
+
explicit operator float() const {
|
| 201 |
+
return float(get());
|
| 202 |
+
}
|
| 203 |
+
CUTLASS_HOST_DEVICE
|
| 204 |
+
bool operator==(packed_scale_t const& rhs) const {
|
| 205 |
+
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
|
| 206 |
+
}
|
| 207 |
+
CUTLASS_HOST_DEVICE
|
| 208 |
+
bool operator!=(packed_scale_t const& rhs) const {
|
| 209 |
+
return !(*this == rhs);
|
| 210 |
+
}
|
| 211 |
+
CUTLASS_HOST_DEVICE
|
| 212 |
+
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 213 |
+
return packed_scale_t(lhs.get() + rhs.get());
|
| 214 |
+
}
|
| 215 |
+
CUTLASS_HOST_DEVICE
|
| 216 |
+
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 217 |
+
return packed_scale_t(lhs.get() - rhs.get());
|
| 218 |
+
}
|
| 219 |
+
CUTLASS_HOST_DEVICE
|
| 220 |
+
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 221 |
+
return packed_scale_t(lhs.get() * rhs.get());
|
| 222 |
+
}
|
| 223 |
+
CUTLASS_HOST_DEVICE
|
| 224 |
+
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
| 225 |
+
return packed_scale_t(lhs.get() / rhs.get());
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
private:
|
| 229 |
+
using Storage = uint32_t;
|
| 230 |
+
using Stage = uint8_t;
|
| 231 |
+
|
| 232 |
+
Storage storage[2] {};
|
| 233 |
+
|
| 234 |
+
CUTLASS_HOST_DEVICE
|
| 235 |
+
static Storage pack4(T c1, T c2, T c3, T c4) {
|
| 236 |
+
Storage result = 0;
|
| 237 |
+
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
|
| 238 |
+
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
|
| 239 |
+
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
|
| 240 |
+
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
|
| 241 |
+
return result;
|
| 242 |
+
}
|
| 243 |
+
CUTLASS_HOST_DEVICE
|
| 244 |
+
T get() const {
|
| 245 |
+
auto stage = static_cast<Stage>(storage[0] >> 8);
|
| 246 |
+
#if defined(__CUDA_ARCH__)
|
| 247 |
+
return reinterpret_cast<T const&>(stage);
|
| 248 |
+
#else
|
| 249 |
+
T tmp;
|
| 250 |
+
std::memcpy(&tmp, &stage, sizeof(Stage));
|
| 251 |
+
return tmp;
|
| 252 |
+
#endif
|
| 253 |
+
}
|
| 254 |
+
CUTLASS_HOST_DEVICE
|
| 255 |
+
T get(int idx) const {
|
| 256 |
+
Stage stage;
|
| 257 |
+
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
|
| 258 |
+
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
|
| 259 |
+
#if defined(__CUDA_ARCH__)
|
| 260 |
+
return reinterpret_cast<T const&>(stage);
|
| 261 |
+
#else
|
| 262 |
+
T tmp;
|
| 263 |
+
std::memcpy(&tmp, &stage, sizeof(Stage));
|
| 264 |
+
return tmp;
|
| 265 |
+
#endif
|
| 266 |
+
}
|
| 267 |
+
};
|
| 268 |
+
|
| 269 |
+
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
| 270 |
+
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
| 271 |
+
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
| 272 |
+
static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) {
|
| 273 |
+
|
| 274 |
+
using StorageType = cutlass::int4b_t::Storage;
|
| 275 |
+
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
|
| 276 |
+
const size_t host_buf_size = block_size / pack;
|
| 277 |
+
std::vector<StorageType> host_buf(host_buf_size);
|
| 278 |
+
cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size);
|
| 279 |
+
|
| 280 |
+
for (auto&& d : host_buf) {
|
| 281 |
+
StorageType out = 0;
|
| 282 |
+
StorageType mask = 0x0f;
|
| 283 |
+
for (int i = 0; i < pack; i++) {
|
| 284 |
+
cutlass::int4b_t curr;
|
| 285 |
+
curr.storage = (d >> (i * 4)) & 0x0f;
|
| 286 |
+
switch (curr) {
|
| 287 |
+
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
|
| 288 |
+
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
|
| 289 |
+
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
|
| 290 |
+
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
|
| 291 |
+
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
|
| 292 |
+
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
|
| 293 |
+
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
|
| 294 |
+
default: break;
|
| 295 |
+
}
|
| 296 |
+
out |= (curr.storage << (4 * i)) & mask;
|
| 297 |
+
mask <<= 4;
|
| 298 |
+
}
|
| 299 |
+
d = out;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size);
|
| 303 |
+
return true;
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
template <class ElementScale>
|
| 307 |
+
static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array<ElementScale, 8> *block_out, const size_t block_size) {
|
| 308 |
+
std::vector<ElementScale> data_in(block_size);
|
| 309 |
+
std::vector<cutlass::Array<ElementScale, 8>> data_out(block_size);
|
| 310 |
+
|
| 311 |
+
try {
|
| 312 |
+
cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size);
|
| 313 |
+
}
|
| 314 |
+
catch (cutlass::cuda_exception const& e) {
|
| 315 |
+
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
| 316 |
+
return false;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
for (size_t i = 0; i < block_size; i++) {
|
| 320 |
+
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
| 321 |
+
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
try {
|
| 325 |
+
cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size);
|
| 326 |
+
}
|
| 327 |
+
catch (cutlass::cuda_exception const& e) {
|
| 328 |
+
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
| 329 |
+
return false;
|
| 330 |
+
}
|
| 331 |
+
return true;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
template <class T, class = void>
|
| 335 |
+
struct UnderlyingElement {
|
| 336 |
+
using type = T;
|
| 337 |
+
};
|
| 338 |
+
|
| 339 |
+
template <class T>
|
| 340 |
+
struct UnderlyingElement<T, cute::void_t<typename T::Element>> {
|
| 341 |
+
using type = typename T::Element;
|
| 342 |
+
};
|
| 343 |
+
|
| 344 |
+
// Given a type of MMA instruction, compute a memory reordering atom that places all values
|
| 345 |
+
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
|
| 346 |
+
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
|
| 347 |
+
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
|
| 348 |
+
// In addition, we can reorder the values across several MMA instructions to get even wider
|
| 349 |
+
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
|
| 350 |
+
// more optimal conversion instruction sequences (ValLayout parameter).
|
| 351 |
+
template <class ElementMma,
|
| 352 |
+
class AtomLayout = cute::Layout<cute::_1>,
|
| 353 |
+
class ValLayout = cute::Layout<cute::_1>>
|
| 354 |
+
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
|
| 355 |
+
{
|
| 356 |
+
using namespace cute;
|
| 357 |
+
|
| 358 |
+
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
|
| 359 |
+
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
|
| 360 |
+
|
| 361 |
+
// 1. Choose an MMA atom to access TV layout and MN shape
|
| 362 |
+
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
|
| 363 |
+
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
|
| 364 |
+
using MmaTraits = MMA_Traits<MmaAtom>;
|
| 365 |
+
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
|
| 366 |
+
auto tv_layout_mma = typename MmaTraits::ALayout{};
|
| 367 |
+
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
|
| 368 |
+
|
| 369 |
+
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
|
| 370 |
+
// Note: this assumes A is partitioned between warps along M mode
|
| 371 |
+
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
|
| 372 |
+
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
|
| 373 |
+
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
|
| 374 |
+
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
|
| 375 |
+
|
| 376 |
+
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
|
| 377 |
+
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
|
| 378 |
+
|
| 379 |
+
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
|
| 380 |
+
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
|
| 381 |
+
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
|
| 382 |
+
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
|
| 383 |
+
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
|
| 384 |
+
|
| 385 |
+
return layout_atom;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
template <class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
|
| 389 |
+
__global__ void reorder_tensor_kernel(
|
| 390 |
+
cute::Tensor<EngineSrc, LayoutSrc> S,
|
| 391 |
+
cute::Tensor<EngineDst, LayoutDst> D,
|
| 392 |
+
TiledCopy tiled_copy)
|
| 393 |
+
{
|
| 394 |
+
using namespace cute;
|
| 395 |
+
|
| 396 |
+
using T = typename EngineDst::value_type;
|
| 397 |
+
|
| 398 |
+
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
| 399 |
+
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
| 400 |
+
|
| 401 |
+
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
|
| 402 |
+
Tensor tS = thread_copy.partition_S(gS);
|
| 403 |
+
Tensor tD = thread_copy.partition_D(gD);
|
| 404 |
+
|
| 405 |
+
copy(tiled_copy, tS, tD);
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
template <class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
| 409 |
+
void reorder_tensor(
|
| 410 |
+
cute::Tensor<EngineSrc, LayoutSrc> S,
|
| 411 |
+
cute::Tensor<EngineDst, LayoutDst> D)
|
| 412 |
+
{
|
| 413 |
+
using namespace cute;
|
| 414 |
+
|
| 415 |
+
using T = typename EngineDst::value_type;
|
| 416 |
+
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
|
| 417 |
+
|
| 418 |
+
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
|
| 419 |
+
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
|
| 420 |
+
auto has_major_mode = [](auto s) {
|
| 421 |
+
return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; });
|
| 422 |
+
};
|
| 423 |
+
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
|
| 424 |
+
"Could not find stride-1 mode in destination layout");
|
| 425 |
+
constexpr int N = shape_div(Int<8>{}, Int<sizeof_bits_v<T>>{});
|
| 426 |
+
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
|
| 427 |
+
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
|
| 428 |
+
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
|
| 429 |
+
|
| 430 |
+
// Make a tiled copy with a simple row-major thread order and above layout
|
| 431 |
+
int constexpr NumThreads = 128;
|
| 432 |
+
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
|
| 433 |
+
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
|
| 434 |
+
|
| 435 |
+
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
|
| 436 |
+
using TileShape = Shape<_16>;
|
| 437 |
+
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
|
| 438 |
+
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
|
| 439 |
+
|
| 440 |
+
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
|
| 441 |
+
CUDA_CHECK(cudaDeviceSynchronize());
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
// In-place version
|
| 445 |
+
template <class T, class LayoutSrc, class LayoutDst>
|
| 446 |
+
void reorder_tensor(
|
| 447 |
+
T const* src,
|
| 448 |
+
LayoutSrc const& layout_src,
|
| 449 |
+
T * dst,
|
| 450 |
+
LayoutDst const& layout_dst)
|
| 451 |
+
{
|
| 452 |
+
using namespace cute;
|
| 453 |
+
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
|
| 454 |
+
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
// In-place version
|
| 458 |
+
template <class T, class LayoutSrc, class LayoutDst>
|
| 459 |
+
void reorder_tensor(
|
| 460 |
+
T * data,
|
| 461 |
+
LayoutSrc const& layout_src,
|
| 462 |
+
LayoutDst const& layout_dst)
|
| 463 |
+
{
|
| 464 |
+
using namespace cute;
|
| 465 |
+
cutlass::DeviceAllocation<T> temp(size(layout_src));
|
| 466 |
+
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
| 467 |
+
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
#undef CUDA_CHECK
|
| 471 |
+
|
| 472 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cute/layout.hpp"
|
| 38 |
+
#include "cute/container/array.hpp" // cute::array
|
| 39 |
+
#include "cutlass/conv/convolution.h" // cutlass::conv::Operator
|
| 40 |
+
|
| 41 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
// Strides without batch mode
|
| 48 |
+
|
| 49 |
+
template <class IntT>
|
| 50 |
+
CUTLASS_HOST_DEVICE
|
| 51 |
+
cute::Stride<IntT, cute::Int<1>>
|
| 52 |
+
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,int> shape_MKL) {
|
| 53 |
+
static_assert(std::is_integral_v<IntT>,
|
| 54 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 55 |
+
auto s_copy = s;
|
| 56 |
+
cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
|
| 57 |
+
return s_copy;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <class IntT>
|
| 61 |
+
CUTLASS_HOST_DEVICE
|
| 62 |
+
cute::Stride<cute::Int<1>, IntT>
|
| 63 |
+
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,int> shape_MKL) {
|
| 64 |
+
static_assert(std::is_integral_v<IntT>,
|
| 65 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 66 |
+
auto s_copy = s;
|
| 67 |
+
cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
|
| 68 |
+
return s_copy;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 72 |
+
|
| 73 |
+
// Strides with batch mode
|
| 74 |
+
|
| 75 |
+
template <class IntT>
|
| 76 |
+
CUTLASS_HOST_DEVICE
|
| 77 |
+
cute::Stride<IntT, cute::Int<1>, int64_t>
|
| 78 |
+
make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
|
| 79 |
+
static_assert(std::is_integral_v<IntT>,
|
| 80 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 81 |
+
auto s_copy = s;
|
| 82 |
+
cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
|
| 83 |
+
int batch_count = cute::get<2>(shape_MKL);
|
| 84 |
+
if (batch_count > 1) {
|
| 85 |
+
cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
|
| 86 |
+
}
|
| 87 |
+
else {
|
| 88 |
+
cute::get<2>(s_copy) = static_cast<IntT>(0);
|
| 89 |
+
}
|
| 90 |
+
return s_copy;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
template <class IntT>
|
| 94 |
+
CUTLASS_HOST_DEVICE
|
| 95 |
+
cute::Stride<cute::Int<1>, IntT, int64_t>
|
| 96 |
+
make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
|
| 97 |
+
static_assert(std::is_integral_v<IntT>,
|
| 98 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 99 |
+
auto s_copy = s;
|
| 100 |
+
cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
|
| 101 |
+
int batch_count = cute::get<2>(shape_MKL);
|
| 102 |
+
if (batch_count > 1) {
|
| 103 |
+
cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
|
| 104 |
+
}
|
| 105 |
+
else {
|
| 106 |
+
cute::get<2>(s_copy) = static_cast<IntT>(0);
|
| 107 |
+
}
|
| 108 |
+
return s_copy;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 112 |
+
|
| 113 |
+
// Strides with group mode
|
| 114 |
+
|
| 115 |
+
template <class StrideIntT>
|
| 116 |
+
CUTLASS_HOST_DEVICE
|
| 117 |
+
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
|
| 118 |
+
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
| 119 |
+
static_assert(std::is_integral_v<StrideIntT>,
|
| 120 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 121 |
+
auto s_copy = s;
|
| 122 |
+
cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
|
| 123 |
+
return s_copy;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
template <class StrideIntT>
|
| 127 |
+
CUTLASS_HOST_DEVICE
|
| 128 |
+
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
|
| 129 |
+
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
| 130 |
+
static_assert(std::is_integral_v<StrideIntT>,
|
| 131 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 132 |
+
auto s_copy = s;
|
| 133 |
+
cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
|
| 134 |
+
return s_copy;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 138 |
+
|
| 139 |
+
// Strides for convolutions
|
| 140 |
+
|
| 141 |
+
// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
|
| 142 |
+
// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order
|
| 143 |
+
// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout
|
| 144 |
+
// right in KTRSC order and can be coalesced to just k.
|
| 145 |
+
// We enforce this condition here with asserts.
|
| 146 |
+
template <class IntT, size_t RankT_>
|
| 147 |
+
CUTLASS_HOST_DEVICE
|
| 148 |
+
cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
|
| 149 |
+
make_cute_packed_stride(
|
| 150 |
+
cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
|
| 151 |
+
cute::array<int32_t, RankT_> shape_output,
|
| 152 |
+
cute::array<IntT, RankT_> stride_output,
|
| 153 |
+
cutlass::conv::Operator conv_op) {
|
| 154 |
+
static_assert(std::is_integral_v<IntT>,
|
| 155 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 156 |
+
static_assert(RankT_ >= 3u);
|
| 157 |
+
constexpr static int RankT = static_cast<int>(RankT_);
|
| 158 |
+
|
| 159 |
+
assert(stride_output[RankT-1] == 1);
|
| 160 |
+
cute::for_each(cute::make_seq<RankT-2>{}, [&](auto i) {
|
| 161 |
+
assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]);
|
| 162 |
+
});
|
| 163 |
+
|
| 164 |
+
auto s_copy = s;
|
| 165 |
+
cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ?
|
| 166 |
+
stride_output[0] :
|
| 167 |
+
stride_output[RankT-2];
|
| 168 |
+
return s_copy;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
//
|
| 172 |
+
// Activation tensor ((w, h, d, n), _1) for fprop kernel
|
| 173 |
+
//
|
| 174 |
+
|
| 175 |
+
// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
|
| 176 |
+
template <class IntT>
|
| 177 |
+
CUTLASS_HOST_DEVICE
|
| 178 |
+
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
|
| 179 |
+
make_cute_packed_stride(
|
| 180 |
+
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
|
| 181 |
+
cute::array<IntT, 3> stride_nwc,
|
| 182 |
+
conv::Operator ConvOp) {
|
| 183 |
+
static_assert(std::is_integral_v<IntT>,
|
| 184 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 185 |
+
assert(stride_nwc[2] == 1);
|
| 186 |
+
auto s_copy = s;
|
| 187 |
+
cute::get<0,0>(s_copy) = stride_nwc[1];
|
| 188 |
+
cute::get<0,1>(s_copy) = stride_nwc[0];
|
| 189 |
+
return s_copy;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
|
| 193 |
+
template <class IntT>
|
| 194 |
+
CUTLASS_HOST_DEVICE
|
| 195 |
+
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
|
| 196 |
+
make_cute_packed_stride(
|
| 197 |
+
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
|
| 198 |
+
cute::array<IntT, 4> stride_nhwc,
|
| 199 |
+
conv::Operator ConvOp) {
|
| 200 |
+
static_assert(std::is_integral_v<IntT>,
|
| 201 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 202 |
+
assert(stride_nhwc[3] == 1);
|
| 203 |
+
auto s_copy = s;
|
| 204 |
+
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 205 |
+
cute::get<0,i>(s_copy) = stride_nhwc[2-i];
|
| 206 |
+
});
|
| 207 |
+
return s_copy;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
|
| 211 |
+
template <class IntT>
|
| 212 |
+
CUTLASS_HOST_DEVICE
|
| 213 |
+
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
|
| 214 |
+
make_cute_packed_stride(
|
| 215 |
+
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
|
| 216 |
+
cute::array<IntT, 5> stride_ndhwc,
|
| 217 |
+
conv::Operator ConvOp) {
|
| 218 |
+
static_assert(std::is_integral_v<IntT>,
|
| 219 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 220 |
+
|
| 221 |
+
assert(stride_ndhwc[4] == 1);
|
| 222 |
+
auto s_copy = s;
|
| 223 |
+
cute::for_each(cute::make_seq<4>{}, [&](auto i) {
|
| 224 |
+
cute::get<0,i>(s_copy) = stride_ndhwc[3-i];
|
| 225 |
+
});
|
| 226 |
+
return s_copy;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
//
|
| 230 |
+
// Filter tensor (k, (_1, s, r, t)) for fprop kernel
|
| 231 |
+
//
|
| 232 |
+
|
| 233 |
+
// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
|
| 234 |
+
template <class IntT>
|
| 235 |
+
CUTLASS_HOST_DEVICE
|
| 236 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
|
| 237 |
+
make_cute_packed_stride(
|
| 238 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
|
| 239 |
+
cute::array<IntT, 3> stride_ksc,
|
| 240 |
+
conv::Operator ConvOp) {
|
| 241 |
+
static_assert(std::is_integral_v<IntT>,
|
| 242 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 243 |
+
|
| 244 |
+
assert(stride_ksc[2] == 1);
|
| 245 |
+
auto s_copy = s;
|
| 246 |
+
cute::get<0,0>(s_copy) = stride_ksc[0];
|
| 247 |
+
cute::get<1,1>(s_copy) = stride_ksc[1];
|
| 248 |
+
return s_copy;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
|
| 252 |
+
template <class IntT>
|
| 253 |
+
CUTLASS_HOST_DEVICE
|
| 254 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
|
| 255 |
+
make_cute_packed_stride(
|
| 256 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
|
| 257 |
+
cute::array<IntT, 4> stride_krsc,
|
| 258 |
+
conv::Operator ConvOp) {
|
| 259 |
+
static_assert(std::is_integral_v<IntT>,
|
| 260 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 261 |
+
|
| 262 |
+
assert(stride_krsc[3] == 1);
|
| 263 |
+
auto s_copy = s;
|
| 264 |
+
cute::get<0,0>(s_copy) = stride_krsc[0];
|
| 265 |
+
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 266 |
+
cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
|
| 267 |
+
});
|
| 268 |
+
return s_copy;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
|
| 272 |
+
template <class IntT>
|
| 273 |
+
CUTLASS_HOST_DEVICE
|
| 274 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
|
| 275 |
+
make_cute_packed_stride(
|
| 276 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
|
| 277 |
+
cute::array<IntT, 5> stride_ktrsc,
|
| 278 |
+
conv::Operator ConvOp) {
|
| 279 |
+
static_assert(std::is_integral_v<IntT>,
|
| 280 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 281 |
+
|
| 282 |
+
assert(stride_ktrsc[4] == 1);
|
| 283 |
+
auto s_copy = s;
|
| 284 |
+
cute::get<0,0>(s_copy) = stride_ktrsc[0];
|
| 285 |
+
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 286 |
+
cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
|
| 287 |
+
});
|
| 288 |
+
return s_copy;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
//
|
| 292 |
+
// Activation tensor (_1, (w, h, d, n)) for wgrad kernel
|
| 293 |
+
//
|
| 294 |
+
// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel
|
| 295 |
+
//
|
| 296 |
+
|
| 297 |
+
// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
|
| 298 |
+
// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
|
| 299 |
+
template <class IntT>
|
| 300 |
+
CUTLASS_HOST_DEVICE
|
| 301 |
+
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
|
| 302 |
+
make_cute_packed_stride(
|
| 303 |
+
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
|
| 304 |
+
cute::array<IntT, 3> stride_nwc,
|
| 305 |
+
conv::Operator ConvOp) {
|
| 306 |
+
static_assert(std::is_integral_v<IntT>,
|
| 307 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 308 |
+
|
| 309 |
+
assert(stride_nwc[2] == 1);
|
| 310 |
+
auto s_copy = s;
|
| 311 |
+
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 312 |
+
cute::get<1,0>(s_copy) = stride_nwc[1];
|
| 313 |
+
cute::get<1,1>(s_copy) = stride_nwc[0];
|
| 314 |
+
}
|
| 315 |
+
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 316 |
+
// stride_nwc in dgrad is ksc.
|
| 317 |
+
cute::get<1,0>(s_copy) = stride_nwc[0];
|
| 318 |
+
cute::get<1,1>(s_copy) = stride_nwc[1];
|
| 319 |
+
}
|
| 320 |
+
return s_copy;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
|
| 324 |
+
// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
|
| 325 |
+
template <class IntT>
|
| 326 |
+
CUTLASS_HOST_DEVICE
|
| 327 |
+
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
|
| 328 |
+
make_cute_packed_stride(
|
| 329 |
+
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
|
| 330 |
+
cute::array<IntT, 4> stride_nhwc,
|
| 331 |
+
conv::Operator ConvOp) {
|
| 332 |
+
static_assert(std::is_integral_v<IntT>,
|
| 333 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 334 |
+
|
| 335 |
+
assert(stride_nhwc[3] == 1);
|
| 336 |
+
auto s_copy = s;
|
| 337 |
+
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 338 |
+
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 339 |
+
cute::get<1,i>(s_copy) = stride_nhwc[2-i];
|
| 340 |
+
});
|
| 341 |
+
}
|
| 342 |
+
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 343 |
+
// stride_nhwc in dgrad is krsc.
|
| 344 |
+
cute::get<1,0>(s_copy) = stride_nhwc[0];
|
| 345 |
+
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 346 |
+
cute::get<1,2-i>(s_copy) = stride_nhwc[i+1];
|
| 347 |
+
});
|
| 348 |
+
}
|
| 349 |
+
return s_copy;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
|
| 353 |
+
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
|
| 354 |
+
template <class IntT>
|
| 355 |
+
CUTLASS_HOST_DEVICE
|
| 356 |
+
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
|
| 357 |
+
make_cute_packed_stride(
|
| 358 |
+
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
|
| 359 |
+
cute::array<IntT, 5> stride_ndhwc,
|
| 360 |
+
conv::Operator ConvOp) {
|
| 361 |
+
static_assert(std::is_integral_v<IntT>,
|
| 362 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 363 |
+
|
| 364 |
+
assert(stride_ndhwc[4] == 1);
|
| 365 |
+
auto s_copy = s;
|
| 366 |
+
if (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 367 |
+
cute::for_each(cute::make_seq<4>{}, [&](auto i) {
|
| 368 |
+
cute::get<1,i>(s_copy) = stride_ndhwc[3-i];
|
| 369 |
+
});
|
| 370 |
+
}
|
| 371 |
+
else if (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 372 |
+
// stride_ndhwc in dgrad is ktrsc.
|
| 373 |
+
cute::get<1,0>(s_copy) = stride_ndhwc[0];
|
| 374 |
+
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 375 |
+
cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1];
|
| 376 |
+
});
|
| 377 |
+
}
|
| 378 |
+
return s_copy;
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
//
|
| 382 |
+
// NZPQ tensor (_1, nzpq) for wgrad kernel
|
| 383 |
+
//
|
| 384 |
+
|
| 385 |
+
// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
|
| 386 |
+
template <class IntT>
|
| 387 |
+
CUTLASS_HOST_DEVICE
|
| 388 |
+
cute::Stride<cute::Int<1>, IntT>
|
| 389 |
+
make_cute_packed_stride(
|
| 390 |
+
cute::Stride<cute::Int<1>, IntT> s,
|
| 391 |
+
cute::array<IntT, 3> stride_nqk,
|
| 392 |
+
conv::Operator ConvOp) {
|
| 393 |
+
static_assert(std::is_integral_v<IntT>,
|
| 394 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 395 |
+
|
| 396 |
+
assert(stride_nqk[2] == 1);
|
| 397 |
+
auto s_copy = s;
|
| 398 |
+
cute::get<1>(s_copy) = stride_nqk[1];
|
| 399 |
+
return s_copy;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
|
| 403 |
+
template <class IntT>
|
| 404 |
+
CUTLASS_HOST_DEVICE
|
| 405 |
+
cute::Stride<cute::Int<1>, IntT>
|
| 406 |
+
make_cute_packed_stride(
|
| 407 |
+
cute::Stride<cute::Int<1>, IntT> s,
|
| 408 |
+
cute::array<IntT, 4> stride_npqk,
|
| 409 |
+
conv::Operator ConvOp) {
|
| 410 |
+
static_assert(std::is_integral_v<IntT>,
|
| 411 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 412 |
+
|
| 413 |
+
assert(stride_npqk[3] == 1);
|
| 414 |
+
auto s_copy = s;
|
| 415 |
+
cute::get<1>(s_copy) = stride_npqk[2];
|
| 416 |
+
return s_copy;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
|
| 420 |
+
template <class IntT>
|
| 421 |
+
CUTLASS_HOST_DEVICE
|
| 422 |
+
cute::Stride<cute::Int<1>, IntT>
|
| 423 |
+
make_cute_packed_stride(
|
| 424 |
+
cute::Stride<cute::Int<1>, IntT> s,
|
| 425 |
+
cute::array<IntT, 5> stride_nzpqk,
|
| 426 |
+
conv::Operator ConvOp) {
|
| 427 |
+
static_assert(std::is_integral_v<IntT>,
|
| 428 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 429 |
+
|
| 430 |
+
assert(stride_nzpqk[4] == 1);
|
| 431 |
+
auto s_copy = s;
|
| 432 |
+
cute::get<1>(s_copy) = stride_nzpqk[3];
|
| 433 |
+
return s_copy;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
//
|
| 439 |
+
// Wgrad output tensor (k, (_1, s, r, t), _0)
|
| 440 |
+
//
|
| 441 |
+
|
| 442 |
+
// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
|
| 443 |
+
template <class IntT>
|
| 444 |
+
CUTLASS_HOST_DEVICE
|
| 445 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
|
| 446 |
+
make_cute_packed_stride(
|
| 447 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
|
| 448 |
+
[[maybe_unused]] cute::array<int32_t, 3> shape_output,
|
| 449 |
+
cute::array<IntT, 3> stride_ksc,
|
| 450 |
+
conv::Operator ConvOp) {
|
| 451 |
+
static_assert(std::is_integral_v<IntT>,
|
| 452 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 453 |
+
|
| 454 |
+
assert(stride_ksc[2] == 1);
|
| 455 |
+
auto s_copy = s;
|
| 456 |
+
cute::get<0,0>(s_copy) = stride_ksc[0];
|
| 457 |
+
cute::get<1,1>(s_copy) = stride_ksc[1];
|
| 458 |
+
return s_copy;
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0)
|
| 462 |
+
template <class IntT>
|
| 463 |
+
CUTLASS_HOST_DEVICE
|
| 464 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>>
|
| 465 |
+
make_cute_packed_stride(
|
| 466 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>> s,
|
| 467 |
+
[[maybe_unused]] cute::array<int32_t, 4> shape_output,
|
| 468 |
+
cute::array<IntT, 4> stride_krsc,
|
| 469 |
+
conv::Operator ConvOp) {
|
| 470 |
+
static_assert(std::is_integral_v<IntT>,
|
| 471 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 472 |
+
|
| 473 |
+
assert(stride_krsc[3] == 1);
|
| 474 |
+
auto s_copy = s;
|
| 475 |
+
cute::get<0,0>(s_copy) = stride_krsc[0];
|
| 476 |
+
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 477 |
+
cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
|
| 478 |
+
});
|
| 479 |
+
return s_copy;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
|
| 483 |
+
template <class IntT>
|
| 484 |
+
CUTLASS_HOST_DEVICE
|
| 485 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
|
| 486 |
+
make_cute_packed_stride(
|
| 487 |
+
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,
|
| 488 |
+
[[maybe_unused]] cute::array<int32_t, 5> shape_output,
|
| 489 |
+
cute::array<IntT, 5> stride_ktrsc,
|
| 490 |
+
conv::Operator ConvOp) {
|
| 491 |
+
static_assert(std::is_integral_v<IntT>,
|
| 492 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 493 |
+
|
| 494 |
+
assert(stride_ktrsc[4] == 1);
|
| 495 |
+
auto s_copy = s;
|
| 496 |
+
cute::get<0,0>(s_copy) = stride_ktrsc[0];
|
| 497 |
+
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 498 |
+
cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
|
| 499 |
+
});
|
| 500 |
+
return s_copy;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
//
|
| 505 |
+
// Wgrad output tensor ((_1, s, r, t), k, _0)
|
| 506 |
+
//
|
| 507 |
+
|
| 508 |
+
// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0)
|
| 509 |
+
template <class IntT>
|
| 510 |
+
CUTLASS_HOST_DEVICE
|
| 511 |
+
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>>
|
| 512 |
+
make_cute_packed_stride(
|
| 513 |
+
cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>> s,
|
| 514 |
+
[[maybe_unused]] cute::array<int32_t, 3> shape_output,
|
| 515 |
+
cute::array<IntT, 3> stride_ksc,
|
| 516 |
+
conv::Operator ConvOp) {
|
| 517 |
+
static_assert(std::is_integral_v<IntT>,
|
| 518 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 519 |
+
|
| 520 |
+
assert(stride_ksc[2] == 1);
|
| 521 |
+
auto s_copy = s;
|
| 522 |
+
cute::get<1,0>(s_copy) = stride_ksc[0];
|
| 523 |
+
cute::get<0,1>(s_copy) = stride_ksc[1];
|
| 524 |
+
return s_copy;
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0)
|
| 528 |
+
template <class IntT>
|
| 529 |
+
CUTLASS_HOST_DEVICE
|
| 530 |
+
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>>
|
| 531 |
+
make_cute_packed_stride(
|
| 532 |
+
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>> s,
|
| 533 |
+
[[maybe_unused]] cute::array<int32_t, 4> shape_output,
|
| 534 |
+
cute::array<IntT, 4> stride_krsc,
|
| 535 |
+
conv::Operator ConvOp) {
|
| 536 |
+
static_assert(std::is_integral_v<IntT>,
|
| 537 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 538 |
+
|
| 539 |
+
assert(stride_krsc[3] == 1);
|
| 540 |
+
auto s_copy = s;
|
| 541 |
+
cute::get<1,0>(s_copy) = stride_krsc[0];
|
| 542 |
+
cute::for_each(cute::make_seq<2>{}, [&](auto i) {
|
| 543 |
+
cute::get<0,2-i>(s_copy) = stride_krsc[i+1];
|
| 544 |
+
});
|
| 545 |
+
return s_copy;
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0)
|
| 549 |
+
template <class IntT>
|
| 550 |
+
CUTLASS_HOST_DEVICE
|
| 551 |
+
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>>
|
| 552 |
+
make_cute_packed_stride(
|
| 553 |
+
cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>> s,
|
| 554 |
+
[[maybe_unused]] cute::array<int32_t, 5> shape_output,
|
| 555 |
+
cute::array<IntT, 5> stride_ktrsc,
|
| 556 |
+
conv::Operator ConvOp) {
|
| 557 |
+
static_assert(std::is_integral_v<IntT>,
|
| 558 |
+
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
| 559 |
+
|
| 560 |
+
assert(stride_ktrsc[4] == 1);
|
| 561 |
+
auto s_copy = s;
|
| 562 |
+
cute::get<1,0>(s_copy) = stride_ktrsc[0];
|
| 563 |
+
cute::for_each(cute::make_seq<3>{}, [&](auto i) {
|
| 564 |
+
cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1];
|
| 565 |
+
});
|
| 566 |
+
return s_copy;
|
| 567 |
+
}
|
| 568 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 569 |
+
|
| 570 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include <array>
|
| 35 |
+
#include <cassert>
|
| 36 |
+
#include <cmath>
|
| 37 |
+
#include <iostream>
|
| 38 |
+
#include <type_traits>
|
| 39 |
+
|
| 40 |
+
#include <cute/util/type_traits.hpp>
|
| 41 |
+
#include <cute/tensor.hpp>
|
| 42 |
+
|
| 43 |
+
#include <cute/numeric/numeric_types.hpp>
|
| 44 |
+
#include <cute/numeric/complex.hpp>
|
| 45 |
+
|
| 46 |
+
#include <cutlass/layout/layout.h>
|
| 47 |
+
|
| 48 |
+
// The computed infinity norm does not include
|
| 49 |
+
// any NaN column absolute-value sums.
|
| 50 |
+
struct matrix_inf_norm_result {
|
| 51 |
+
// Accumulate errors in double, as this is generally
|
| 52 |
+
// the highest precision that the examples use.
|
| 53 |
+
double inf_norm = 0.0;
|
| 54 |
+
bool found_nan = false;
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
// In theory, cute::Tensor<ViewEngine<T*>, T> could be treated as a view type,
|
| 58 |
+
// and thus passed by value (as std::span or std::string_view would be).
|
| 59 |
+
// However, generic cute::Tensor are more like containers
|
| 60 |
+
// and thus are best passed by reference or const reference.
|
| 61 |
+
template <typename EngineType, typename LayoutType>
|
| 62 |
+
matrix_inf_norm_result
|
| 63 |
+
matrix_inf_norm(cute::Tensor<EngineType, LayoutType> const& host_matrix)
|
| 64 |
+
{
|
| 65 |
+
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
| 66 |
+
using element_type = typename EngineType::value_type;
|
| 67 |
+
|
| 68 |
+
error_type inf_norm = 0.0;
|
| 69 |
+
bool found_nan = false;
|
| 70 |
+
|
| 71 |
+
// Computing the infinity norm requires that we be able
|
| 72 |
+
// to treat the input as a matrix, with rows and columns.
|
| 73 |
+
const int64_t num_rows = cute::size<0>(host_matrix);
|
| 74 |
+
const int64_t num_cols = cute::size<1>(host_matrix);
|
| 75 |
+
|
| 76 |
+
auto abs_fn = [] (element_type A_ij) {
|
| 77 |
+
if constexpr (not std::is_unsigned_v<element_type>) {
|
| 78 |
+
using std::abs;
|
| 79 |
+
return abs(A_ij);
|
| 80 |
+
}
|
| 81 |
+
else {
|
| 82 |
+
return A_ij;
|
| 83 |
+
}
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
for (int64_t i = 0; i < num_rows; ++i) {
|
| 87 |
+
error_type row_abs_sum = 0.0;
|
| 88 |
+
for(int64_t j = 0; j < num_cols; ++j) {
|
| 89 |
+
row_abs_sum += abs_fn(host_matrix(i, j));
|
| 90 |
+
}
|
| 91 |
+
if (std::isnan(row_abs_sum)) {
|
| 92 |
+
found_nan = true;
|
| 93 |
+
}
|
| 94 |
+
else {
|
| 95 |
+
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return {inf_norm, found_nan};
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// Infinity norm of (X - Y).
|
| 103 |
+
template <typename EngineType, typename LayoutType>
|
| 104 |
+
matrix_inf_norm_result
|
| 105 |
+
matrix_diff_inf_norm(cute::Tensor<EngineType, LayoutType> const& X,
|
| 106 |
+
cute::Tensor<EngineType, LayoutType> const& Y)
|
| 107 |
+
{
|
| 108 |
+
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
| 109 |
+
using element_type = typename EngineType::value_type;
|
| 110 |
+
|
| 111 |
+
auto abs_fn = [] (element_type A_ij) {
|
| 112 |
+
if constexpr (not std::is_unsigned_v<element_type>) {
|
| 113 |
+
using std::abs;
|
| 114 |
+
return abs(A_ij);
|
| 115 |
+
}
|
| 116 |
+
else {
|
| 117 |
+
return A_ij;
|
| 118 |
+
}
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
assert(cute::size<0>(X) == cute::size<0>(Y));
|
| 122 |
+
assert(cute::size<1>(X) == cute::size<1>(Y));
|
| 123 |
+
|
| 124 |
+
// Computing the infinity norm requires that we be able
|
| 125 |
+
// to treat the input as a matrix, with rows and columns.
|
| 126 |
+
const int64_t num_rows = cute::size<0>(X);
|
| 127 |
+
const int64_t num_cols = cute::size<1>(X);
|
| 128 |
+
|
| 129 |
+
error_type inf_norm = 0.0;
|
| 130 |
+
bool found_nan = false;
|
| 131 |
+
|
| 132 |
+
for (int64_t i = 0; i < num_rows; ++i) {
|
| 133 |
+
error_type row_abs_sum = 0.0;
|
| 134 |
+
for (int64_t j = 0; j < num_cols; ++j) {
|
| 135 |
+
row_abs_sum += error_type(abs_fn(element_type(X(i,j)) -
|
| 136 |
+
element_type(Y(i,j))));
|
| 137 |
+
}
|
| 138 |
+
if (std::isnan(row_abs_sum)) {
|
| 139 |
+
found_nan = true;
|
| 140 |
+
}
|
| 141 |
+
else {
|
| 142 |
+
inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
return {inf_norm, found_nan};
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
template <typename EngineType_A, typename LayoutType_A,
|
| 150 |
+
typename EngineType_B, typename LayoutType_B,
|
| 151 |
+
typename EngineType_C, typename LayoutType_C,
|
| 152 |
+
typename EngineType_C_ref, typename LayoutType_C_ref>
|
| 153 |
+
auto
|
| 154 |
+
print_matrix_multiply_mollified_relative_error(
|
| 155 |
+
char const A_value_type_name[],
|
| 156 |
+
cute::Tensor<EngineType_A, LayoutType_A> const& A,
|
| 157 |
+
char const B_value_type_name[],
|
| 158 |
+
cute::Tensor<EngineType_B, LayoutType_B> const& B,
|
| 159 |
+
char const C_value_type_name[],
|
| 160 |
+
cute::Tensor<EngineType_C, LayoutType_C> const& C,
|
| 161 |
+
cute::Tensor<EngineType_C_ref, LayoutType_C_ref> const& C_ref)
|
| 162 |
+
{
|
| 163 |
+
const auto [A_norm, A_has_nan] = matrix_inf_norm(A);
|
| 164 |
+
const auto [B_norm, B_has_nan] = matrix_inf_norm(B);
|
| 165 |
+
const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref);
|
| 166 |
+
const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref);
|
| 167 |
+
|
| 168 |
+
const auto A_norm_times_B_norm = A_norm * B_norm;
|
| 169 |
+
const auto relative_error = A_norm_times_B_norm == 0.0 ?
|
| 170 |
+
diff_norm : (diff_norm / A_norm_times_B_norm);
|
| 171 |
+
|
| 172 |
+
// For expected error bounds, please refer to the LAPACK Users' Guide,
|
| 173 |
+
// in particular https://netlib.org/lapack/lug/node108.html .
|
| 174 |
+
// Printing the infinity norm of C is a way to check
|
| 175 |
+
// that both the function being tested (C)
|
| 176 |
+
// and the reference implementation (C_ref)
|
| 177 |
+
// don't just do nothing (or fill with zeros).
|
| 178 |
+
using std::cout;
|
| 179 |
+
using cute::shape;
|
| 180 |
+
cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n'
|
| 181 |
+
<< "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n'
|
| 182 |
+
<< "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n'
|
| 183 |
+
<< std::scientific
|
| 184 |
+
<< "Infinity norm of A: " << A_norm << '\n'
|
| 185 |
+
<< "Infinity norm of B: " << B_norm << '\n'
|
| 186 |
+
<< "Infinity norm of C: " << C_norm << '\n'
|
| 187 |
+
<< "Infinity norm of (C - C_ref): " << diff_norm << '\n';
|
| 188 |
+
|
| 189 |
+
if(A_norm_times_B_norm == 0.0) {
|
| 190 |
+
cout << "Mollified relative error: " << relative_error << '\n';
|
| 191 |
+
} else {
|
| 192 |
+
cout << "Relative error: " << relative_error << '\n';
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) {
|
| 196 |
+
cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
|
| 197 |
+
<< "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
|
| 198 |
+
<< "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n'
|
| 199 |
+
<< "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n';
|
| 200 |
+
}
|
| 201 |
+
return relative_error;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
template <typename EngineType, typename LayoutType>
|
| 205 |
+
auto
|
| 206 |
+
print_matrix_multiply_mollified_relative_error(
|
| 207 |
+
const char value_type_name[],
|
| 208 |
+
const cute::Tensor<EngineType, LayoutType>& A,
|
| 209 |
+
const cute::Tensor<EngineType, LayoutType>& B,
|
| 210 |
+
const cute::Tensor<EngineType, LayoutType>& C_computed,
|
| 211 |
+
const cute::Tensor<EngineType, LayoutType>& C_expected)
|
| 212 |
+
{
|
| 213 |
+
return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B,
|
| 214 |
+
value_type_name, C_computed, C_expected);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
// Take a CUTLASS HostTensor (or the like) as input,
|
| 218 |
+
// and return a const CuTe Tensor.
|
| 219 |
+
// This is useful for use with the above error printing functions.
|
| 220 |
+
// This implicitly "transposes" if the layout is RowMajor.
|
| 221 |
+
// Note that the HostTensor must be captured by nonconst reference
|
| 222 |
+
// in order for X.host_ref().data() to compile.
|
| 223 |
+
// (CUTLASS is a bit more container-y than CuTe.)
|
| 224 |
+
template<class CutlassHostTensorType>
|
| 225 |
+
auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X)
|
| 226 |
+
{
|
| 227 |
+
// The tensors were created with post-transposed extents.
|
| 228 |
+
const auto extents = X.extent();
|
| 229 |
+
const auto shape = cute::Shape<int, int>{extents[0], extents[1]};
|
| 230 |
+
// Both RowMajor and ColumnMajor only store one stride.
|
| 231 |
+
const int LDX = X.stride(0);
|
| 232 |
+
const auto strides = [&]() {
|
| 233 |
+
using input_layout_type = typename std::decay_t<decltype(X)>::Layout;
|
| 234 |
+
if constexpr (std::is_same_v<input_layout_type, cutlass::layout::ColumnMajor>) {
|
| 235 |
+
return cute::Stride<int, int>{1, LDX};
|
| 236 |
+
}
|
| 237 |
+
else {
|
| 238 |
+
static_assert(std::is_same_v<input_layout_type, cutlass::layout::RowMajor>);
|
| 239 |
+
return cute::Stride<int, int>{LDX, 1};
|
| 240 |
+
}
|
| 241 |
+
}();
|
| 242 |
+
const auto layout = cute::make_layout(shape, strides);
|
| 243 |
+
auto X_data = X.host_ref().data();
|
| 244 |
+
auto X_data_const = const_cast<std::add_const_t< decltype(X_data)> >(X_data);
|
| 245 |
+
return cute::make_tensor(X_data_const, layout);
|
| 246 |
+
};
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
// Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE.
|
| 250 |
+
// This makes the return value suitable as the return value of main().
|
| 251 |
+
template <typename T1, typename T2>
|
| 252 |
+
int
|
| 253 |
+
print_relative_error(
|
| 254 |
+
std::size_t n,
|
| 255 |
+
T1 const& data,
|
| 256 |
+
T2 const& reference,
|
| 257 |
+
bool print_verbose = false,
|
| 258 |
+
bool print_error = true,
|
| 259 |
+
double error_margin = 0.00001) {
|
| 260 |
+
using std::abs; using std::sqrt;
|
| 261 |
+
|
| 262 |
+
// Use either double or complex<double> for error computation
|
| 263 |
+
using value_type = cute::remove_cvref_t<decltype(reference[0])>;
|
| 264 |
+
using error_type = std::conditional_t<cute::is_complex<value_type>::value,
|
| 265 |
+
cute::complex<double>,
|
| 266 |
+
double>;
|
| 267 |
+
|
| 268 |
+
if (print_verbose) {
|
| 269 |
+
std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
double eps = 1e-200;
|
| 273 |
+
|
| 274 |
+
double tot_error_sq = 0;
|
| 275 |
+
double tot_norm_sq = 0;
|
| 276 |
+
double tot_ind_rel_err = 0;
|
| 277 |
+
double max_ind_rel_err = 0;
|
| 278 |
+
double max_diff = 0;
|
| 279 |
+
for (std::size_t i = 0; i < n; ++i) {
|
| 280 |
+
error_type val = data[i];
|
| 281 |
+
error_type ref = reference[i];
|
| 282 |
+
|
| 283 |
+
double aref = abs(ref);
|
| 284 |
+
double diff = abs(ref - val);
|
| 285 |
+
double rel_error = diff / (aref + eps);
|
| 286 |
+
|
| 287 |
+
// Individual relative error
|
| 288 |
+
tot_ind_rel_err += rel_error;
|
| 289 |
+
|
| 290 |
+
// Maximum relative error
|
| 291 |
+
max_ind_rel_err = std::max(max_ind_rel_err, rel_error);
|
| 292 |
+
|
| 293 |
+
// Maximum delta in value error
|
| 294 |
+
max_diff = std::max(max_diff, diff);
|
| 295 |
+
|
| 296 |
+
// Total relative error
|
| 297 |
+
tot_error_sq += diff * diff;
|
| 298 |
+
tot_norm_sq += aref * aref;
|
| 299 |
+
|
| 300 |
+
if (print_verbose) {
|
| 301 |
+
std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl;
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
double ave_rel_err = tot_ind_rel_err / double(n);
|
| 306 |
+
if (print_error) {
|
| 307 |
+
printf("Average relative error: %.3e\n", ave_rel_err);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
if (print_error) {
|
| 311 |
+
printf("Maximum relative error: %.3e\n", max_ind_rel_err);
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
if (print_error) {
|
| 315 |
+
printf("Maximum difference : %.3e\n", max_diff);
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps));
|
| 319 |
+
if (print_error) {
|
| 320 |
+
printf("Vector relative error: %.3e\n", tot_rel_err);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq));
|
| 324 |
+
|
| 325 |
+
return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
// Overload for cute::Tensor<>
|
| 329 |
+
template <class Engine, class Layout>
|
| 330 |
+
int
|
| 331 |
+
print_relative_error(
|
| 332 |
+
cute::Tensor<Engine, Layout> data,
|
| 333 |
+
cute::Tensor<Engine, Layout> reference,
|
| 334 |
+
bool print_verbose = false,
|
| 335 |
+
bool print_error = true,
|
| 336 |
+
double error_margin = 0.00001) {
|
| 337 |
+
assert(size(data) == size(reference));
|
| 338 |
+
return print_relative_error(static_cast<std::size_t>(size(data)),
|
| 339 |
+
data, reference,
|
| 340 |
+
print_verbose, print_error, error_margin);
|
| 341 |
+
}
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "cutlass/array.h"
|
| 38 |
+
|
| 39 |
+
namespace cutlass {
|
| 40 |
+
namespace reference {
|
| 41 |
+
namespace detail {
|
| 42 |
+
|
| 43 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
/// Template function to compute an inner product.
|
| 46 |
+
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a
|
| 47 |
+
// host-only type
|
| 48 |
+
template <typename Atype, typename Btype, typename Ctype>
|
| 49 |
+
CUTLASS_HOST_DEVICE
|
| 50 |
+
Ctype inner_product(Atype a, Btype b, Ctype c) {
|
| 51 |
+
return Ctype(a) * Ctype(b) + c;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
/// Specialization for matrix multiplication with binary operands
|
| 55 |
+
template <>
|
| 56 |
+
CUTLASS_HOST_DEVICE
|
| 57 |
+
int inner_product<Array<bin1_t, 32>, Array<bin1_t, 32>, int>(
|
| 58 |
+
Array<bin1_t, 32> a,
|
| 59 |
+
Array<bin1_t, 32> b,
|
| 60 |
+
int c) {
|
| 61 |
+
|
| 62 |
+
int accum = 0;
|
| 63 |
+
for (int bit = 0; bit < 32; bit++) {
|
| 64 |
+
accum += a[bit] ^ b[bit];
|
| 65 |
+
}
|
| 66 |
+
return accum + c;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
/*
|
| 70 |
+
/// Specialization for matrix multiplication with signed 4-bit integer operands
|
| 71 |
+
template <>
|
| 72 |
+
CUTLASS_HOST_DEVICE
|
| 73 |
+
int inner_product<Array<int4b_t, 8>, Array<int4b_t, 8>, int>(
|
| 74 |
+
Array<int4b_t, 8> a,
|
| 75 |
+
Array<int4b_t, 8> b,
|
| 76 |
+
int c) {
|
| 77 |
+
|
| 78 |
+
int accum = 0;
|
| 79 |
+
for (int k = 0; k < 8; k++) {
|
| 80 |
+
accum += a[k] * b[k];
|
| 81 |
+
}
|
| 82 |
+
return accum + c;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
/// Specialization for matrix multiplication with unsigned 4-bit integer operands
|
| 86 |
+
template <>
|
| 87 |
+
CUTLASS_HOST_DEVICE
|
| 88 |
+
int inner_product<Array<uint4b_t, 8>, Array<uint4b_t, 8>, int>(
|
| 89 |
+
Array<uint4b_t, 8> a,
|
| 90 |
+
Array<uint4b_t, 8> b,
|
| 91 |
+
int c) {
|
| 92 |
+
|
| 93 |
+
int accum = 0;
|
| 94 |
+
for (int k = 0; k < 8; k++) {
|
| 95 |
+
accum += a[k] * b[k];
|
| 96 |
+
}
|
| 97 |
+
return accum + c;
|
| 98 |
+
}
|
| 99 |
+
*/
|
| 100 |
+
|
| 101 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 102 |
+
|
| 103 |
+
template <typename SrcType, typename DstType>
|
| 104 |
+
struct Cast {
|
| 105 |
+
// Default behavior: convert to the destination type
|
| 106 |
+
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
| 107 |
+
// host-only type
|
| 108 |
+
CUTLASS_HOST_DEVICE
|
| 109 |
+
static DstType apply(SrcType src) { return static_cast<DstType>(src); };
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
template <>
|
| 113 |
+
struct Cast<float, int8_t> {
|
| 114 |
+
CUTLASS_HOST_DEVICE
|
| 115 |
+
static int8_t apply(float src) {
|
| 116 |
+
// Clamp to the range of signed 8-bit integers.
|
| 117 |
+
return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
|
| 118 |
+
};
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
template <>
|
| 122 |
+
struct Cast<float, uint8_t> {
|
| 123 |
+
CUTLASS_HOST_DEVICE
|
| 124 |
+
static uint8_t apply(float src) {
|
| 125 |
+
// Clamp to the range of signed 8-bit integers.
|
| 126 |
+
return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
|
| 127 |
+
};
|
| 128 |
+
};
|
| 129 |
+
|
| 130 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 131 |
+
|
| 132 |
+
} // namespace detail
|
| 133 |
+
} // namespace reference
|
| 134 |
+
} // namespace cutlass
|
| 135 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
|
| 39 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace reference {
|
| 43 |
+
namespace detail {
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
template <int Rank, int Index>
|
| 48 |
+
struct LinearToCoordinateHelper {
|
| 49 |
+
|
| 50 |
+
CUTLASS_HOST_DEVICE
|
| 51 |
+
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
|
| 52 |
+
|
| 53 |
+
int64_t prod = 1;
|
| 54 |
+
|
| 55 |
+
CUTLASS_PRAGMA_UNROLL
|
| 56 |
+
for (int i = Rank - Index; i < Rank; ++i) {
|
| 57 |
+
prod *= int64_t(extent[i]);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
coord[Rank - Index - 1] = int(idx / prod);
|
| 61 |
+
|
| 62 |
+
int64_t residual = idx % prod;
|
| 63 |
+
LinearToCoordinateHelper<Rank, Index - 1>()(coord, residual, extent);
|
| 64 |
+
}
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
template <int Rank>
|
| 68 |
+
struct LinearToCoordinateHelper<Rank, 0> {
|
| 69 |
+
|
| 70 |
+
CUTLASS_HOST_DEVICE
|
| 71 |
+
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &) const {
|
| 72 |
+
coord[Rank - 1] = int(idx);
|
| 73 |
+
}
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 77 |
+
|
| 78 |
+
template <int Rank>
|
| 79 |
+
struct LinearToCoordinate {
|
| 80 |
+
|
| 81 |
+
CUTLASS_HOST_DEVICE
|
| 82 |
+
void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
|
| 83 |
+
LinearToCoordinateHelper<Rank, Rank - 1>()(coord, idx, extent);
|
| 84 |
+
}
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
|
| 89 |
+
} // namespace detail
|
| 90 |
+
} // namespace reference
|
| 91 |
+
} // namespace cutlass
|
| 92 |
+
|
| 93 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h
ADDED
|
@@ -0,0 +1,1549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Reference implementation for convolution in device-side code.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/coord.h"
|
| 39 |
+
#include "cutlass/functional.h"
|
| 40 |
+
#include "cutlass/layout/tensor.h"
|
| 41 |
+
#include "cutlass/matrix_shape.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/numeric_types.h"
|
| 44 |
+
#include "cutlass/tensor_ref.h"
|
| 45 |
+
#include "cutlass/conv/convolution.h"
|
| 46 |
+
#include "cutlass/conv/conv2d_problem_size.h"
|
| 47 |
+
#include "cutlass/conv/conv3d_problem_size.h"
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace reference {
|
| 51 |
+
namespace device {
|
| 52 |
+
|
| 53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
namespace kernel {
|
| 56 |
+
|
| 57 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
/// Conv2d device reference kernel
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
// Conv2d Fprop kernel - y = fprop(x, w)
|
| 62 |
+
template <
|
| 63 |
+
typename ElementA,
|
| 64 |
+
typename LayoutA,
|
| 65 |
+
typename ElementB,
|
| 66 |
+
typename LayoutB,
|
| 67 |
+
typename ElementC,
|
| 68 |
+
typename LayoutC,
|
| 69 |
+
typename ElementCompute,
|
| 70 |
+
typename ElementAccumulator = ElementCompute,
|
| 71 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 72 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 73 |
+
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 74 |
+
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 75 |
+
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 76 |
+
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 77 |
+
>
|
| 78 |
+
__global__ void Conv2dFprop(
|
| 79 |
+
conv::Conv2dProblemSize problem_size,
|
| 80 |
+
TensorRef<ElementA, LayoutA> tensor_x,
|
| 81 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 82 |
+
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 83 |
+
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 84 |
+
ElementCompute alpha,
|
| 85 |
+
ElementCompute beta
|
| 86 |
+
) {
|
| 87 |
+
|
| 88 |
+
ConvertOp convert_op;
|
| 89 |
+
InnerProductOp inner_product_op;
|
| 90 |
+
|
| 91 |
+
ElementAccumulator element_A[kThreadM];
|
| 92 |
+
ElementAccumulator element_B[kThreadN];
|
| 93 |
+
ElementAccumulator accum[kThreadM][kThreadN];
|
| 94 |
+
|
| 95 |
+
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 96 |
+
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 97 |
+
|
| 98 |
+
int thread_n[kThreadM];
|
| 99 |
+
int thread_p[kThreadM];
|
| 100 |
+
int thread_q[kThreadM];
|
| 101 |
+
|
| 102 |
+
// Compute N, P, Q coordinates for each row of a thread's tile
|
| 103 |
+
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
|
| 104 |
+
|
| 105 |
+
CUTLASS_PRAGMA_UNROLL
|
| 106 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 107 |
+
|
| 108 |
+
int64_t npq = npq_start + m;
|
| 109 |
+
|
| 110 |
+
thread_n[m] = int(npq / PQ);
|
| 111 |
+
|
| 112 |
+
int64_t residual = npq % PQ;
|
| 113 |
+
thread_p[m] = int(residual / problem_size.Q);
|
| 114 |
+
thread_q[m] = int(residual % problem_size.Q);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// Clear accumulators
|
| 118 |
+
CUTLASS_PRAGMA_UNROLL
|
| 119 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 120 |
+
CUTLASS_PRAGMA_UNROLL
|
| 121 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 122 |
+
accum[m][n] = ElementAccumulator();
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
int c_per_group = problem_size.C / problem_size.groups;
|
| 127 |
+
int k_per_group = problem_size.K / problem_size.groups;
|
| 128 |
+
|
| 129 |
+
// Compute convolution
|
| 130 |
+
for (int R = 0; R < problem_size.R; ++R) {
|
| 131 |
+
for (int S = 0; S < problem_size.S; ++S) {
|
| 132 |
+
for (int C = 0; C < problem_size.C; ++C) {
|
| 133 |
+
|
| 134 |
+
// Get group id of currnet channel
|
| 135 |
+
int c_group_idx = C / c_per_group;
|
| 136 |
+
|
| 137 |
+
// Load from activations tensor
|
| 138 |
+
int filter_r = R;
|
| 139 |
+
int filter_s = S;
|
| 140 |
+
|
| 141 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 142 |
+
filter_r = problem_size.R - 1 - R;
|
| 143 |
+
filter_s = problem_size.S - 1 - S;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
CUTLASS_PRAGMA_UNROLL
|
| 147 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 148 |
+
int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 149 |
+
int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 150 |
+
|
| 151 |
+
if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
|
| 152 |
+
element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C}));
|
| 153 |
+
}
|
| 154 |
+
else {
|
| 155 |
+
element_A[m] = ElementAccumulator();
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// Load from filters tensor
|
| 160 |
+
CUTLASS_PRAGMA_UNROLL
|
| 161 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 162 |
+
int thread_k = k_start + n;
|
| 163 |
+
int k_group_idx = thread_k / k_per_group;
|
| 164 |
+
|
| 165 |
+
if (thread_k < problem_size.K && k_group_idx == c_group_idx) {
|
| 166 |
+
element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group}));
|
| 167 |
+
}
|
| 168 |
+
else {
|
| 169 |
+
element_B[n] = ElementAccumulator();
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// Accumulate matrix product
|
| 174 |
+
CUTLASS_PRAGMA_UNROLL
|
| 175 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 176 |
+
CUTLASS_PRAGMA_UNROLL
|
| 177 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 178 |
+
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
// Write out the results
|
| 186 |
+
CUTLASS_PRAGMA_UNROLL
|
| 187 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 188 |
+
if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
|
| 189 |
+
CUTLASS_PRAGMA_UNROLL
|
| 190 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 191 |
+
int thread_k = k_start + n;
|
| 192 |
+
if (thread_k < problem_size.K) {
|
| 193 |
+
|
| 194 |
+
ElementCompute c_ref = ElementCompute();
|
| 195 |
+
if (beta != ElementCompute()) {
|
| 196 |
+
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
| 200 |
+
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// Conv3d Fprop kernel - y = fprop(x, w)
|
| 208 |
+
template <
|
| 209 |
+
typename ElementA,
|
| 210 |
+
typename LayoutA,
|
| 211 |
+
typename ElementB,
|
| 212 |
+
typename LayoutB,
|
| 213 |
+
typename ElementC,
|
| 214 |
+
typename LayoutC,
|
| 215 |
+
typename ElementCompute,
|
| 216 |
+
typename ElementAccumulator = ElementCompute,
|
| 217 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 218 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 219 |
+
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 220 |
+
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 221 |
+
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 222 |
+
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 223 |
+
>
|
| 224 |
+
__global__ void Conv3dFprop(
|
| 225 |
+
conv::Conv3dProblemSize problem_size,
|
| 226 |
+
TensorRef<ElementA, LayoutA> tensor_x,
|
| 227 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 228 |
+
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 229 |
+
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 230 |
+
ElementCompute alpha,
|
| 231 |
+
ElementCompute beta
|
| 232 |
+
) {
|
| 233 |
+
|
| 234 |
+
ConvertOp convert_op;
|
| 235 |
+
InnerProductOp inner_product_op;
|
| 236 |
+
|
| 237 |
+
ElementAccumulator element_A[kThreadM];
|
| 238 |
+
ElementAccumulator element_B[kThreadN];
|
| 239 |
+
ElementAccumulator accum[kThreadM][kThreadN];
|
| 240 |
+
|
| 241 |
+
int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 242 |
+
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 243 |
+
|
| 244 |
+
int thread_n[kThreadM];
|
| 245 |
+
int thread_z[kThreadM];
|
| 246 |
+
int thread_p[kThreadM];
|
| 247 |
+
int thread_q[kThreadM];
|
| 248 |
+
|
| 249 |
+
// Compute N, Z, P, Q coordinates for each row of a thread's tile
|
| 250 |
+
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
|
| 251 |
+
int64_t ZPQ = PQ * problem_size.Z;
|
| 252 |
+
|
| 253 |
+
CUTLASS_PRAGMA_UNROLL
|
| 254 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 255 |
+
|
| 256 |
+
int64_t nzpq = nzpq_start + m;
|
| 257 |
+
|
| 258 |
+
thread_n[m] = int(nzpq / ZPQ);
|
| 259 |
+
|
| 260 |
+
int64_t residual = nzpq % ZPQ;
|
| 261 |
+
thread_z[m] = int(residual / PQ);
|
| 262 |
+
|
| 263 |
+
residual = residual % PQ;
|
| 264 |
+
thread_p[m] = int(residual / problem_size.Q);
|
| 265 |
+
thread_q[m] = int(residual % problem_size.Q);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// Clear accumulators
|
| 269 |
+
CUTLASS_PRAGMA_UNROLL
|
| 270 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 271 |
+
CUTLASS_PRAGMA_UNROLL
|
| 272 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 273 |
+
accum[m][n] = ElementAccumulator();
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// Compute convolution
|
| 278 |
+
for (int T = 0; T < problem_size.T; ++T) {
|
| 279 |
+
for (int R = 0; R < problem_size.R; ++R) {
|
| 280 |
+
for (int S = 0; S < problem_size.S; ++S) {
|
| 281 |
+
for (int C = 0; C < problem_size.C; ++C) {
|
| 282 |
+
|
| 283 |
+
// Load from activations tensor
|
| 284 |
+
int filter_t = T;
|
| 285 |
+
int filter_r = R;
|
| 286 |
+
int filter_s = S;
|
| 287 |
+
|
| 288 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 289 |
+
filter_t = problem_size.T - 1 - T;
|
| 290 |
+
filter_r = problem_size.R - 1 - R;
|
| 291 |
+
filter_s = problem_size.S - 1 - S;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
CUTLASS_PRAGMA_UNROLL
|
| 295 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 296 |
+
int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
|
| 297 |
+
int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 298 |
+
int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 299 |
+
|
| 300 |
+
if (thread_n[m] < problem_size.N &&
|
| 301 |
+
d >= 0 && d < problem_size.D &&
|
| 302 |
+
h >= 0 && h < problem_size.H &&
|
| 303 |
+
w >= 0 && w < problem_size.W) {
|
| 304 |
+
|
| 305 |
+
element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C}));
|
| 306 |
+
}
|
| 307 |
+
else {
|
| 308 |
+
element_A[m] = ElementAccumulator();
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
// Load from filters tensor
|
| 313 |
+
CUTLASS_PRAGMA_UNROLL
|
| 314 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 315 |
+
int thread_k = k_start + n;
|
| 316 |
+
|
| 317 |
+
if (thread_k < problem_size.K) {
|
| 318 |
+
element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C}));
|
| 319 |
+
}
|
| 320 |
+
else {
|
| 321 |
+
element_B[n] = ElementAccumulator();
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
// Accumulate matrix product
|
| 326 |
+
CUTLASS_PRAGMA_UNROLL
|
| 327 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 328 |
+
CUTLASS_PRAGMA_UNROLL
|
| 329 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 330 |
+
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
} // for (C)
|
| 335 |
+
} // for (S)
|
| 336 |
+
} // for (R)
|
| 337 |
+
} // for (T)
|
| 338 |
+
|
| 339 |
+
// Write out the results
|
| 340 |
+
CUTLASS_PRAGMA_UNROLL
|
| 341 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 342 |
+
|
| 343 |
+
if (thread_n[m] < problem_size.N &&
|
| 344 |
+
thread_z[m] < problem_size.Z &&
|
| 345 |
+
thread_p[m] < problem_size.P &&
|
| 346 |
+
thread_q[m] < problem_size.Q) {
|
| 347 |
+
|
| 348 |
+
CUTLASS_PRAGMA_UNROLL
|
| 349 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 350 |
+
int thread_k = k_start + n;
|
| 351 |
+
if (thread_k < problem_size.K) {
|
| 352 |
+
|
| 353 |
+
ElementCompute c_ref = ElementCompute();
|
| 354 |
+
if (beta != ElementCompute()) {
|
| 355 |
+
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}));
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
| 359 |
+
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 360 |
+
}
|
| 361 |
+
} // for (n)
|
| 362 |
+
|
| 363 |
+
}
|
| 364 |
+
} // for (m)
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 368 |
+
|
| 369 |
+
// Conv2d dgrad kernel - dx = dgrad(dy, w)
|
| 370 |
+
template <
|
| 371 |
+
typename ElementA,
|
| 372 |
+
typename LayoutA,
|
| 373 |
+
typename ElementB,
|
| 374 |
+
typename LayoutB,
|
| 375 |
+
typename ElementC,
|
| 376 |
+
typename LayoutC,
|
| 377 |
+
typename ElementCompute,
|
| 378 |
+
typename ElementAccumulator = ElementCompute,
|
| 379 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 380 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 381 |
+
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 382 |
+
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 383 |
+
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 384 |
+
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 385 |
+
>
|
| 386 |
+
__global__ void Conv2dDgrad(
|
| 387 |
+
conv::Conv2dProblemSize problem_size,
|
| 388 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 389 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 390 |
+
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 391 |
+
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 392 |
+
ElementCompute alpha,
|
| 393 |
+
ElementCompute beta
|
| 394 |
+
) {
|
| 395 |
+
|
| 396 |
+
ConvertOp convert_op;
|
| 397 |
+
InnerProductOp inner_product_op;
|
| 398 |
+
|
| 399 |
+
ElementAccumulator element_A[kThreadM];
|
| 400 |
+
ElementAccumulator element_B[kThreadN];
|
| 401 |
+
ElementAccumulator accum[kThreadM][kThreadN];
|
| 402 |
+
|
| 403 |
+
int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 404 |
+
int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 405 |
+
|
| 406 |
+
int thread_n[kThreadM];
|
| 407 |
+
int thread_h[kThreadM];
|
| 408 |
+
int thread_w[kThreadM];
|
| 409 |
+
|
| 410 |
+
// Compute N, H, W coordinates for each row of a thread's tile
|
| 411 |
+
int64_t HW = int64_t(problem_size.H) * problem_size.W;
|
| 412 |
+
|
| 413 |
+
CUTLASS_PRAGMA_UNROLL
|
| 414 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 415 |
+
|
| 416 |
+
int64_t nhw = nhw_start + m;
|
| 417 |
+
|
| 418 |
+
thread_n[m] = int(nhw / HW);
|
| 419 |
+
|
| 420 |
+
int64_t residual = nhw % HW;
|
| 421 |
+
thread_h[m] = int(residual / problem_size.W);
|
| 422 |
+
thread_w[m] = int(residual % problem_size.W);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
// Clear accumulators
|
| 426 |
+
CUTLASS_PRAGMA_UNROLL
|
| 427 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 428 |
+
CUTLASS_PRAGMA_UNROLL
|
| 429 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 430 |
+
accum[m][n] = ElementAccumulator();
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
// Compute convolution
|
| 435 |
+
for (int R = 0; R < problem_size.R; ++R) {
|
| 436 |
+
for (int S = 0; S < problem_size.S; ++S) {
|
| 437 |
+
for (int K = 0; K < problem_size.K; ++K) {
|
| 438 |
+
|
| 439 |
+
// Load from activations tensor
|
| 440 |
+
int filter_r = R;
|
| 441 |
+
int filter_s = S;
|
| 442 |
+
|
| 443 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 444 |
+
filter_r = problem_size.R - 1 - R;
|
| 445 |
+
filter_s = problem_size.S - 1 - S;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
CUTLASS_PRAGMA_UNROLL
|
| 449 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 450 |
+
|
| 451 |
+
int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 452 |
+
int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 453 |
+
|
| 454 |
+
element_A[m] = ElementAccumulator();
|
| 455 |
+
|
| 456 |
+
if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) {
|
| 457 |
+
|
| 458 |
+
p = p / problem_size.stride_h;
|
| 459 |
+
q = q / problem_size.stride_w;
|
| 460 |
+
|
| 461 |
+
if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) {
|
| 462 |
+
element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K}));
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
// Load from filters tensor
|
| 468 |
+
CUTLASS_PRAGMA_UNROLL
|
| 469 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 470 |
+
int thread_c = c_start + n;
|
| 471 |
+
|
| 472 |
+
if (thread_c < problem_size.C) {
|
| 473 |
+
element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c}));
|
| 474 |
+
}
|
| 475 |
+
else {
|
| 476 |
+
element_B[n] = ElementAccumulator();
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
// Accumulate matrix product
|
| 481 |
+
CUTLASS_PRAGMA_UNROLL
|
| 482 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 483 |
+
CUTLASS_PRAGMA_UNROLL
|
| 484 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 485 |
+
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 486 |
+
}
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
// Write out the results
|
| 493 |
+
CUTLASS_PRAGMA_UNROLL
|
| 494 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 495 |
+
|
| 496 |
+
if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) {
|
| 497 |
+
|
| 498 |
+
CUTLASS_PRAGMA_UNROLL
|
| 499 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 500 |
+
int thread_c = c_start + n;
|
| 501 |
+
if (thread_c < problem_size.C) {
|
| 502 |
+
|
| 503 |
+
ElementCompute c_ref = ElementCompute();
|
| 504 |
+
if (beta != ElementCompute()) {
|
| 505 |
+
c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c}));
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
|
| 509 |
+
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 510 |
+
}
|
| 511 |
+
}
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
// Conv3d dgrad kernel - dx = dgrad(dy, w)
|
| 517 |
+
template <
|
| 518 |
+
typename ElementA,
|
| 519 |
+
typename LayoutA,
|
| 520 |
+
typename ElementB,
|
| 521 |
+
typename LayoutB,
|
| 522 |
+
typename ElementC,
|
| 523 |
+
typename LayoutC,
|
| 524 |
+
typename ElementCompute,
|
| 525 |
+
typename ElementAccumulator = ElementCompute,
|
| 526 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 527 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 528 |
+
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 529 |
+
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 530 |
+
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
| 531 |
+
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
| 532 |
+
>
|
| 533 |
+
__global__ void Conv3dDgrad(
|
| 534 |
+
conv::Conv3dProblemSize problem_size,
|
| 535 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 536 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 537 |
+
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 538 |
+
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 539 |
+
ElementCompute alpha,
|
| 540 |
+
ElementCompute beta
|
| 541 |
+
) {
|
| 542 |
+
|
| 543 |
+
ConvertOp convert_op;
|
| 544 |
+
InnerProductOp inner_product_op;
|
| 545 |
+
|
| 546 |
+
ElementAccumulator element_A[kThreadM];
|
| 547 |
+
ElementAccumulator element_B[kThreadN];
|
| 548 |
+
ElementAccumulator accum[kThreadM][kThreadN];
|
| 549 |
+
|
| 550 |
+
int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 551 |
+
int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 552 |
+
|
| 553 |
+
int thread_n[kThreadM];
|
| 554 |
+
int thread_d[kThreadM];
|
| 555 |
+
int thread_h[kThreadM];
|
| 556 |
+
int thread_w[kThreadM];
|
| 557 |
+
|
| 558 |
+
// Compute N, H, W coordinates for each row of a thread's tile
|
| 559 |
+
int64_t HW = int64_t(problem_size.H) * problem_size.W;
|
| 560 |
+
int64_t DHW = HW * problem_size.D;
|
| 561 |
+
|
| 562 |
+
CUTLASS_PRAGMA_UNROLL
|
| 563 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 564 |
+
|
| 565 |
+
int64_t ndhw = ndhw_start + m;
|
| 566 |
+
|
| 567 |
+
thread_n[m] = int(ndhw / DHW);
|
| 568 |
+
|
| 569 |
+
int64_t residual = ndhw % DHW;
|
| 570 |
+
thread_d[m] = int(residual / HW);
|
| 571 |
+
|
| 572 |
+
residual = residual % HW;
|
| 573 |
+
thread_h[m] = int(residual / problem_size.W);
|
| 574 |
+
thread_w[m] = int(residual % problem_size.W);
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
// Clear accumulators
|
| 578 |
+
CUTLASS_PRAGMA_UNROLL
|
| 579 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 580 |
+
CUTLASS_PRAGMA_UNROLL
|
| 581 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 582 |
+
accum[m][n] = ElementAccumulator();
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
// Compute convolution
|
| 587 |
+
for (int T = 0; T < problem_size.T; ++T) {
|
| 588 |
+
for (int R = 0; R < problem_size.R; ++R) {
|
| 589 |
+
for (int S = 0; S < problem_size.S; ++S) {
|
| 590 |
+
for (int K = 0; K < problem_size.K; ++K) {
|
| 591 |
+
|
| 592 |
+
// Load from activations tensor
|
| 593 |
+
int filter_t = T;
|
| 594 |
+
int filter_r = R;
|
| 595 |
+
int filter_s = S;
|
| 596 |
+
|
| 597 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 598 |
+
filter_t = problem_size.T - 1 - T;
|
| 599 |
+
filter_r = problem_size.R - 1 - R;
|
| 600 |
+
filter_s = problem_size.S - 1 - S;
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
CUTLASS_PRAGMA_UNROLL
|
| 604 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 605 |
+
|
| 606 |
+
int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d;
|
| 607 |
+
int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 608 |
+
int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 609 |
+
|
| 610 |
+
element_A[m] = ElementAccumulator();
|
| 611 |
+
|
| 612 |
+
if (z >= 0 && !(z % problem_size.stride_d) &&
|
| 613 |
+
p >= 0 && !(p % problem_size.stride_h) &&
|
| 614 |
+
q >= 0 && !(q % problem_size.stride_w)) {
|
| 615 |
+
|
| 616 |
+
z = z / problem_size.stride_d;
|
| 617 |
+
p = p / problem_size.stride_h;
|
| 618 |
+
q = q / problem_size.stride_w;
|
| 619 |
+
|
| 620 |
+
if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
|
| 621 |
+
element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K}));
|
| 622 |
+
}
|
| 623 |
+
}
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
// Load from filters tensor
|
| 627 |
+
CUTLASS_PRAGMA_UNROLL
|
| 628 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 629 |
+
int thread_c = c_start + n;
|
| 630 |
+
|
| 631 |
+
if (thread_c < problem_size.C) {
|
| 632 |
+
element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c}));
|
| 633 |
+
}
|
| 634 |
+
else {
|
| 635 |
+
element_B[n] = ElementAccumulator();
|
| 636 |
+
}
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
// Accumulate matrix product
|
| 640 |
+
CUTLASS_PRAGMA_UNROLL
|
| 641 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 642 |
+
CUTLASS_PRAGMA_UNROLL
|
| 643 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 644 |
+
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
} // for (C)
|
| 649 |
+
} // for (S)
|
| 650 |
+
} // for (R)
|
| 651 |
+
} // for (T)
|
| 652 |
+
|
| 653 |
+
// Write out the results
|
| 654 |
+
CUTLASS_PRAGMA_UNROLL
|
| 655 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 656 |
+
|
| 657 |
+
if (thread_n[m] < problem_size.N &&
|
| 658 |
+
thread_d[m] < problem_size.D &&
|
| 659 |
+
thread_h[m] < problem_size.H &&
|
| 660 |
+
thread_w[m] < problem_size.W) {
|
| 661 |
+
|
| 662 |
+
CUTLASS_PRAGMA_UNROLL
|
| 663 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 664 |
+
int thread_c = c_start + n;
|
| 665 |
+
if (thread_c < problem_size.C) {
|
| 666 |
+
|
| 667 |
+
ElementCompute c_ref = ElementCompute();
|
| 668 |
+
if (beta != ElementCompute()) {
|
| 669 |
+
c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}));
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
|
| 673 |
+
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 674 |
+
}
|
| 675 |
+
}
|
| 676 |
+
}
|
| 677 |
+
}
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 681 |
+
|
| 682 |
+
// Conv2d wgrad kernel - dw = wgrad(dy, x)
|
| 683 |
+
template <
|
| 684 |
+
typename ElementA,
|
| 685 |
+
typename LayoutA,
|
| 686 |
+
typename ElementB,
|
| 687 |
+
typename LayoutB,
|
| 688 |
+
typename ElementC,
|
| 689 |
+
typename LayoutC,
|
| 690 |
+
typename ElementCompute,
|
| 691 |
+
typename ElementAccumulator = ElementCompute,
|
| 692 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 693 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 694 |
+
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 695 |
+
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 696 |
+
int kCtaShapeM = 8, // shape of a threadblock in units of threads
|
| 697 |
+
int kCtaShapeN = 16 // shape of a threadblock in units of threads
|
| 698 |
+
>
|
| 699 |
+
__global__ void Conv2dWgrad(
|
| 700 |
+
conv::Conv2dProblemSize problem_size,
|
| 701 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 702 |
+
TensorRef<ElementB, LayoutB> tensor_x,
|
| 703 |
+
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 704 |
+
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 705 |
+
ElementCompute alpha,
|
| 706 |
+
ElementCompute beta
|
| 707 |
+
) {
|
| 708 |
+
|
| 709 |
+
ConvertOp convert_op;
|
| 710 |
+
InnerProductOp inner_product_op;
|
| 711 |
+
|
| 712 |
+
ElementAccumulator element_A[kThreadM];
|
| 713 |
+
ElementAccumulator element_B[kThreadN];
|
| 714 |
+
ElementAccumulator accum[kThreadM][kThreadN];
|
| 715 |
+
|
| 716 |
+
int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 717 |
+
int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 718 |
+
|
| 719 |
+
int thread_r[kThreadN];
|
| 720 |
+
int thread_s[kThreadN];
|
| 721 |
+
int thread_c[kThreadN];
|
| 722 |
+
|
| 723 |
+
// Compute R, S, C coordinates for each row of a thread's tile
|
| 724 |
+
int64_t SC = int64_t(problem_size.S) * problem_size.C;
|
| 725 |
+
|
| 726 |
+
CUTLASS_PRAGMA_UNROLL
|
| 727 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 728 |
+
|
| 729 |
+
int64_t rsc = rsc_start + n;
|
| 730 |
+
int64_t residual = rsc % SC;
|
| 731 |
+
|
| 732 |
+
thread_r[n] = int(rsc / SC);
|
| 733 |
+
thread_s[n] = int(residual / problem_size.C);
|
| 734 |
+
thread_c[n] = int(residual % problem_size.C);
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
// Clear accumulators
|
| 738 |
+
CUTLASS_PRAGMA_UNROLL
|
| 739 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 740 |
+
CUTLASS_PRAGMA_UNROLL
|
| 741 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 742 |
+
accum[m][n] = ElementAccumulator();
|
| 743 |
+
}
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
// Compute convolution
|
| 747 |
+
for (int N = 0; N < problem_size.N; ++N) {
|
| 748 |
+
for (int P = 0; P < problem_size.P; ++P) {
|
| 749 |
+
for (int Q = 0; Q < problem_size.Q; ++Q) {
|
| 750 |
+
|
| 751 |
+
CUTLASS_PRAGMA_UNROLL
|
| 752 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 753 |
+
int thread_k = k_start + m;
|
| 754 |
+
|
| 755 |
+
element_A[m] = ElementAccumulator();
|
| 756 |
+
|
| 757 |
+
if (thread_k < problem_size.K) {
|
| 758 |
+
element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k}));
|
| 759 |
+
}
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
// Load from filters tensor
|
| 763 |
+
CUTLASS_PRAGMA_UNROLL
|
| 764 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 765 |
+
|
| 766 |
+
// Load from activations tensor
|
| 767 |
+
int filter_r = thread_r[n];
|
| 768 |
+
int filter_s = thread_s[n];
|
| 769 |
+
|
| 770 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 771 |
+
filter_r = problem_size.R - 1 - filter_r;
|
| 772 |
+
filter_s = problem_size.S - 1 - filter_s;
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 776 |
+
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 777 |
+
|
| 778 |
+
element_B[n] = ElementAccumulator();
|
| 779 |
+
|
| 780 |
+
if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) {
|
| 781 |
+
element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]}));
|
| 782 |
+
}
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
// Accumulate matrix product
|
| 786 |
+
CUTLASS_PRAGMA_UNROLL
|
| 787 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 788 |
+
CUTLASS_PRAGMA_UNROLL
|
| 789 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 790 |
+
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 791 |
+
}
|
| 792 |
+
}
|
| 793 |
+
}
|
| 794 |
+
}
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
// Write out the results
|
| 798 |
+
CUTLASS_PRAGMA_UNROLL
|
| 799 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 800 |
+
int thread_k = k_start + m;
|
| 801 |
+
|
| 802 |
+
if (thread_k < problem_size.K) {
|
| 803 |
+
|
| 804 |
+
CUTLASS_PRAGMA_UNROLL
|
| 805 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 806 |
+
|
| 807 |
+
if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) {
|
| 808 |
+
|
| 809 |
+
ElementCompute c_ref = ElementCompute();
|
| 810 |
+
|
| 811 |
+
if (beta != ElementCompute()) {
|
| 812 |
+
c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}));
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
|
| 816 |
+
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 817 |
+
}
|
| 818 |
+
}
|
| 819 |
+
}
|
| 820 |
+
}
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
// Conv3d wgrad kernel - dw = wgrad(dy, x)
|
| 824 |
+
template <
|
| 825 |
+
typename ElementA,
|
| 826 |
+
typename LayoutA,
|
| 827 |
+
typename ElementB,
|
| 828 |
+
typename LayoutB,
|
| 829 |
+
typename ElementC,
|
| 830 |
+
typename LayoutC,
|
| 831 |
+
typename ElementCompute,
|
| 832 |
+
typename ElementAccumulator = ElementCompute,
|
| 833 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 834 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>,
|
| 835 |
+
int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
|
| 836 |
+
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
| 837 |
+
int kCtaShapeM = 8, // shape of a threadblock in units of threads
|
| 838 |
+
int kCtaShapeN = 16 // shape of a threadblock in units of threads
|
| 839 |
+
>
|
| 840 |
+
__global__ void Conv3dWgrad(
|
| 841 |
+
conv::Conv3dProblemSize problem_size,
|
| 842 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 843 |
+
TensorRef<ElementB, LayoutB> tensor_x,
|
| 844 |
+
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 845 |
+
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 846 |
+
ElementCompute alpha,
|
| 847 |
+
ElementCompute beta
|
| 848 |
+
) {
|
| 849 |
+
|
| 850 |
+
ConvertOp convert_op;
|
| 851 |
+
InnerProductOp inner_product_op;
|
| 852 |
+
|
| 853 |
+
ElementAccumulator element_A[kThreadM];
|
| 854 |
+
ElementAccumulator element_B[kThreadN];
|
| 855 |
+
ElementAccumulator accum[kThreadM][kThreadN];
|
| 856 |
+
|
| 857 |
+
int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
| 858 |
+
int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
| 859 |
+
|
| 860 |
+
int thread_t[kThreadN];
|
| 861 |
+
int thread_r[kThreadN];
|
| 862 |
+
int thread_s[kThreadN];
|
| 863 |
+
int thread_c[kThreadN];
|
| 864 |
+
|
| 865 |
+
// Compute R, S, C coordinates for each row of a thread's tile
|
| 866 |
+
int64_t SC = int64_t(problem_size.S) * problem_size.C;
|
| 867 |
+
int64_t RSC = SC * problem_size.R;
|
| 868 |
+
|
| 869 |
+
CUTLASS_PRAGMA_UNROLL
|
| 870 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 871 |
+
|
| 872 |
+
int64_t trsc = trsc_start + n;
|
| 873 |
+
|
| 874 |
+
thread_t[n] = int(trsc / RSC);
|
| 875 |
+
|
| 876 |
+
int64_t residual = trsc % RSC;
|
| 877 |
+
thread_r[n] = int(residual / SC);
|
| 878 |
+
|
| 879 |
+
residual = residual % SC;
|
| 880 |
+
thread_s[n] = int(residual / problem_size.C);
|
| 881 |
+
thread_c[n] = int(residual % problem_size.C);
|
| 882 |
+
}
|
| 883 |
+
|
| 884 |
+
// Clear accumulators
|
| 885 |
+
CUTLASS_PRAGMA_UNROLL
|
| 886 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 887 |
+
CUTLASS_PRAGMA_UNROLL
|
| 888 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 889 |
+
accum[m][n] = ElementAccumulator();
|
| 890 |
+
}
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
// Compute convolution
|
| 894 |
+
for (int N = 0; N < problem_size.N; ++N) {
|
| 895 |
+
for (int Z = 0; Z < problem_size.Z; ++Z) {
|
| 896 |
+
for (int P = 0; P < problem_size.P; ++P) {
|
| 897 |
+
for (int Q = 0; Q < problem_size.Q; ++Q) {
|
| 898 |
+
|
| 899 |
+
CUTLASS_PRAGMA_UNROLL
|
| 900 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 901 |
+
int thread_k = k_start + m;
|
| 902 |
+
|
| 903 |
+
element_A[m] = ElementAccumulator();
|
| 904 |
+
|
| 905 |
+
if (thread_k < problem_size.K) {
|
| 906 |
+
element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k}));
|
| 907 |
+
}
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
// Load from filters tensor
|
| 911 |
+
CUTLASS_PRAGMA_UNROLL
|
| 912 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 913 |
+
|
| 914 |
+
// Load from activations tensor
|
| 915 |
+
int filter_t = thread_t[n];
|
| 916 |
+
int filter_r = thread_r[n];
|
| 917 |
+
int filter_s = thread_s[n];
|
| 918 |
+
|
| 919 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 920 |
+
filter_t = problem_size.T - 1 - filter_t;
|
| 921 |
+
filter_r = problem_size.R - 1 - filter_r;
|
| 922 |
+
filter_s = problem_size.S - 1 - filter_s;
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
|
| 926 |
+
int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 927 |
+
int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 928 |
+
|
| 929 |
+
element_B[n] = ElementAccumulator();
|
| 930 |
+
|
| 931 |
+
if (d >= 0 && d < problem_size.D &&
|
| 932 |
+
h >= 0 && h < problem_size.H &&
|
| 933 |
+
w >= 0 && w < problem_size.W &&
|
| 934 |
+
thread_c[n] < problem_size.C) {
|
| 935 |
+
|
| 936 |
+
element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]}));
|
| 937 |
+
}
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
// Accumulate matrix product
|
| 941 |
+
CUTLASS_PRAGMA_UNROLL
|
| 942 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 943 |
+
CUTLASS_PRAGMA_UNROLL
|
| 944 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 945 |
+
accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
|
| 946 |
+
}
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
} // for (Q)
|
| 950 |
+
} // for (P)
|
| 951 |
+
} // for (Z)
|
| 952 |
+
} // for (N)
|
| 953 |
+
|
| 954 |
+
// Write out the results
|
| 955 |
+
CUTLASS_PRAGMA_UNROLL
|
| 956 |
+
for (int m = 0; m < kThreadM; ++m) {
|
| 957 |
+
int thread_k = k_start + m;
|
| 958 |
+
|
| 959 |
+
if (thread_k < problem_size.K) {
|
| 960 |
+
|
| 961 |
+
CUTLASS_PRAGMA_UNROLL
|
| 962 |
+
for (int n = 0; n < kThreadN; ++n) {
|
| 963 |
+
|
| 964 |
+
if (thread_t[n] < problem_size.T &&
|
| 965 |
+
thread_r[n] < problem_size.R &&
|
| 966 |
+
thread_s[n] < problem_size.S &&
|
| 967 |
+
thread_c[n] < problem_size.C) {
|
| 968 |
+
|
| 969 |
+
ElementCompute c_ref = ElementCompute();
|
| 970 |
+
|
| 971 |
+
if (beta != ElementCompute()) {
|
| 972 |
+
c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}));
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
|
| 976 |
+
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
| 977 |
+
}
|
| 978 |
+
}
|
| 979 |
+
}
|
| 980 |
+
}
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 984 |
+
|
| 985 |
+
} // namespace kernel
|
| 986 |
+
|
| 987 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 988 |
+
|
| 989 |
+
/// Conv2d Fprop dispatcher - y = fprop(x, w)
|
| 990 |
+
template <
|
| 991 |
+
typename ElementA,
|
| 992 |
+
typename LayoutA,
|
| 993 |
+
typename ElementB,
|
| 994 |
+
typename LayoutB,
|
| 995 |
+
typename ElementC,
|
| 996 |
+
typename LayoutC,
|
| 997 |
+
typename ElementCompute,
|
| 998 |
+
typename ElementAccumulator = ElementCompute,
|
| 999 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1000 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1001 |
+
>
|
| 1002 |
+
Status Conv2dFprop(
|
| 1003 |
+
conv::Conv2dProblemSize problem_size,
|
| 1004 |
+
TensorRef<ElementA, LayoutA> tensor_x,
|
| 1005 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1006 |
+
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 1007 |
+
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 1008 |
+
ElementCompute alpha,
|
| 1009 |
+
ElementCompute beta,
|
| 1010 |
+
cudaStream_t stream = nullptr) {
|
| 1011 |
+
|
| 1012 |
+
//
|
| 1013 |
+
// Blocking factors improve performance of reference implementation
|
| 1014 |
+
//
|
| 1015 |
+
|
| 1016 |
+
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
|
| 1017 |
+
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1018 |
+
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1019 |
+
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1020 |
+
|
| 1021 |
+
int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
|
| 1022 |
+
int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1023 |
+
|
| 1024 |
+
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1025 |
+
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1026 |
+
|
| 1027 |
+
kernel::Conv2dFprop<
|
| 1028 |
+
ElementA,
|
| 1029 |
+
LayoutA,
|
| 1030 |
+
ElementB,
|
| 1031 |
+
LayoutB,
|
| 1032 |
+
ElementC,
|
| 1033 |
+
LayoutC,
|
| 1034 |
+
ElementCompute,
|
| 1035 |
+
ElementAccumulator,
|
| 1036 |
+
ConvertOp,
|
| 1037 |
+
InnerProductOp,
|
| 1038 |
+
kThreadM,
|
| 1039 |
+
kThreadN,
|
| 1040 |
+
kCtaShapeM,
|
| 1041 |
+
kCtaShapeN
|
| 1042 |
+
><<< grid, block, 0, stream >>>(
|
| 1043 |
+
problem_size,
|
| 1044 |
+
tensor_x,
|
| 1045 |
+
tensor_w,
|
| 1046 |
+
tensor_y_in,
|
| 1047 |
+
tensor_y_out,
|
| 1048 |
+
alpha,
|
| 1049 |
+
beta
|
| 1050 |
+
);
|
| 1051 |
+
|
| 1052 |
+
cudaError_t result = cudaPeekAtLastError();
|
| 1053 |
+
if (result != cudaSuccess) {
|
| 1054 |
+
return Status::kErrorInternal;
|
| 1055 |
+
}
|
| 1056 |
+
|
| 1057 |
+
return Status::kSuccess;
|
| 1058 |
+
}
|
| 1059 |
+
|
| 1060 |
+
/// Conv3d Fprop dispatcher - y = fprop(x, w)
|
| 1061 |
+
template <
|
| 1062 |
+
typename ElementA,
|
| 1063 |
+
typename LayoutA,
|
| 1064 |
+
typename ElementB,
|
| 1065 |
+
typename LayoutB,
|
| 1066 |
+
typename ElementC,
|
| 1067 |
+
typename LayoutC,
|
| 1068 |
+
typename ElementCompute,
|
| 1069 |
+
typename ElementAccumulator = ElementCompute,
|
| 1070 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1071 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1072 |
+
>
|
| 1073 |
+
Status Conv3dFprop(
|
| 1074 |
+
conv::Conv3dProblemSize problem_size,
|
| 1075 |
+
TensorRef<ElementA, LayoutA> tensor_x,
|
| 1076 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1077 |
+
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 1078 |
+
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 1079 |
+
ElementCompute alpha,
|
| 1080 |
+
ElementCompute beta,
|
| 1081 |
+
cudaStream_t stream = nullptr) {
|
| 1082 |
+
|
| 1083 |
+
//
|
| 1084 |
+
// Blocking factors improve performance of reference implementation
|
| 1085 |
+
//
|
| 1086 |
+
|
| 1087 |
+
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
|
| 1088 |
+
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1089 |
+
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1090 |
+
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1091 |
+
|
| 1092 |
+
int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q;
|
| 1093 |
+
int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1094 |
+
|
| 1095 |
+
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1096 |
+
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1097 |
+
|
| 1098 |
+
kernel::Conv3dFprop<
|
| 1099 |
+
ElementA,
|
| 1100 |
+
LayoutA,
|
| 1101 |
+
ElementB,
|
| 1102 |
+
LayoutB,
|
| 1103 |
+
ElementC,
|
| 1104 |
+
LayoutC,
|
| 1105 |
+
ElementCompute,
|
| 1106 |
+
ElementAccumulator,
|
| 1107 |
+
ConvertOp,
|
| 1108 |
+
InnerProductOp,
|
| 1109 |
+
kThreadM,
|
| 1110 |
+
kThreadN,
|
| 1111 |
+
kCtaShapeM,
|
| 1112 |
+
kCtaShapeN
|
| 1113 |
+
><<< grid, block, 0, stream >>>(
|
| 1114 |
+
problem_size,
|
| 1115 |
+
tensor_x,
|
| 1116 |
+
tensor_w,
|
| 1117 |
+
tensor_y_in,
|
| 1118 |
+
tensor_y_out,
|
| 1119 |
+
alpha,
|
| 1120 |
+
beta
|
| 1121 |
+
);
|
| 1122 |
+
|
| 1123 |
+
cudaError_t result = cudaPeekAtLastError();
|
| 1124 |
+
if (result != cudaSuccess) {
|
| 1125 |
+
return Status::kErrorInternal;
|
| 1126 |
+
}
|
| 1127 |
+
|
| 1128 |
+
return Status::kSuccess;
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w)
|
| 1132 |
+
template <
|
| 1133 |
+
typename ElementA,
|
| 1134 |
+
typename LayoutA,
|
| 1135 |
+
typename ElementB,
|
| 1136 |
+
typename LayoutB,
|
| 1137 |
+
typename ElementC,
|
| 1138 |
+
typename LayoutC,
|
| 1139 |
+
typename ElementCompute,
|
| 1140 |
+
typename ElementAccumulator = ElementCompute,
|
| 1141 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1142 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1143 |
+
>
|
| 1144 |
+
Status Conv2dDgrad(
|
| 1145 |
+
conv::Conv2dProblemSize problem_size,
|
| 1146 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1147 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1148 |
+
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 1149 |
+
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 1150 |
+
ElementCompute alpha,
|
| 1151 |
+
ElementCompute beta,
|
| 1152 |
+
cudaStream_t stream = nullptr) {
|
| 1153 |
+
|
| 1154 |
+
//
|
| 1155 |
+
// Blocking factors improve performance of reference implementation
|
| 1156 |
+
//
|
| 1157 |
+
|
| 1158 |
+
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1159 |
+
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1160 |
+
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1161 |
+
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1162 |
+
|
| 1163 |
+
int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W;
|
| 1164 |
+
int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1165 |
+
|
| 1166 |
+
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1167 |
+
dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1168 |
+
|
| 1169 |
+
kernel::Conv2dDgrad<
|
| 1170 |
+
ElementA,
|
| 1171 |
+
LayoutA,
|
| 1172 |
+
ElementB,
|
| 1173 |
+
LayoutB,
|
| 1174 |
+
ElementC,
|
| 1175 |
+
LayoutC,
|
| 1176 |
+
ElementCompute,
|
| 1177 |
+
ElementAccumulator,
|
| 1178 |
+
ConvertOp,
|
| 1179 |
+
InnerProductOp,
|
| 1180 |
+
kThreadM,
|
| 1181 |
+
kThreadN,
|
| 1182 |
+
kCtaShapeM,
|
| 1183 |
+
kCtaShapeN
|
| 1184 |
+
><<< grid, block, 0, stream >>>(
|
| 1185 |
+
problem_size,
|
| 1186 |
+
tensor_dy,
|
| 1187 |
+
tensor_w,
|
| 1188 |
+
tensor_dx_in,
|
| 1189 |
+
tensor_dx_out,
|
| 1190 |
+
alpha,
|
| 1191 |
+
beta
|
| 1192 |
+
);
|
| 1193 |
+
|
| 1194 |
+
cudaError_t result = cudaPeekAtLastError();
|
| 1195 |
+
if (result != cudaSuccess) {
|
| 1196 |
+
return Status::kErrorInternal;
|
| 1197 |
+
}
|
| 1198 |
+
|
| 1199 |
+
return Status::kSuccess;
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w)
|
| 1203 |
+
template <
|
| 1204 |
+
typename ElementA,
|
| 1205 |
+
typename LayoutA,
|
| 1206 |
+
typename ElementB,
|
| 1207 |
+
typename LayoutB,
|
| 1208 |
+
typename ElementC,
|
| 1209 |
+
typename LayoutC,
|
| 1210 |
+
typename ElementCompute,
|
| 1211 |
+
typename ElementAccumulator = ElementCompute,
|
| 1212 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1213 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1214 |
+
>
|
| 1215 |
+
Status Conv3dDgrad(
|
| 1216 |
+
conv::Conv3dProblemSize problem_size,
|
| 1217 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1218 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 1219 |
+
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 1220 |
+
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 1221 |
+
ElementCompute alpha,
|
| 1222 |
+
ElementCompute beta,
|
| 1223 |
+
cudaStream_t stream = nullptr) {
|
| 1224 |
+
|
| 1225 |
+
//
|
| 1226 |
+
// Blocking factors improve performance of reference implementation
|
| 1227 |
+
//
|
| 1228 |
+
|
| 1229 |
+
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1230 |
+
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1231 |
+
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
| 1232 |
+
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
| 1233 |
+
|
| 1234 |
+
int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W;
|
| 1235 |
+
int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
| 1236 |
+
|
| 1237 |
+
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1238 |
+
dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
| 1239 |
+
|
| 1240 |
+
kernel::Conv3dDgrad<
|
| 1241 |
+
ElementA,
|
| 1242 |
+
LayoutA,
|
| 1243 |
+
ElementB,
|
| 1244 |
+
LayoutB,
|
| 1245 |
+
ElementC,
|
| 1246 |
+
LayoutC,
|
| 1247 |
+
ElementCompute,
|
| 1248 |
+
ElementAccumulator,
|
| 1249 |
+
ConvertOp,
|
| 1250 |
+
InnerProductOp,
|
| 1251 |
+
kThreadM,
|
| 1252 |
+
kThreadN,
|
| 1253 |
+
kCtaShapeM,
|
| 1254 |
+
kCtaShapeN
|
| 1255 |
+
><<< grid, block, 0, stream >>>(
|
| 1256 |
+
problem_size,
|
| 1257 |
+
tensor_dy,
|
| 1258 |
+
tensor_w,
|
| 1259 |
+
tensor_dx_in,
|
| 1260 |
+
tensor_dx_out,
|
| 1261 |
+
alpha,
|
| 1262 |
+
beta
|
| 1263 |
+
);
|
| 1264 |
+
|
| 1265 |
+
cudaError_t result = cudaPeekAtLastError();
|
| 1266 |
+
if (result != cudaSuccess) {
|
| 1267 |
+
return Status::kErrorInternal;
|
| 1268 |
+
}
|
| 1269 |
+
|
| 1270 |
+
return Status::kSuccess;
|
| 1271 |
+
}
|
| 1272 |
+
|
| 1273 |
+
/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x)
|
| 1274 |
+
template <
|
| 1275 |
+
typename ElementA,
|
| 1276 |
+
typename LayoutA,
|
| 1277 |
+
typename ElementB,
|
| 1278 |
+
typename LayoutB,
|
| 1279 |
+
typename ElementC,
|
| 1280 |
+
typename LayoutC,
|
| 1281 |
+
typename ElementCompute,
|
| 1282 |
+
typename ElementAccumulator = ElementCompute,
|
| 1283 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1284 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1285 |
+
>
|
| 1286 |
+
Status Conv2dWgrad(
|
| 1287 |
+
conv::Conv2dProblemSize problem_size,
|
| 1288 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1289 |
+
TensorRef<ElementB, LayoutB> tensor_x,
|
| 1290 |
+
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 1291 |
+
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 1292 |
+
ElementCompute alpha,
|
| 1293 |
+
ElementCompute beta,
|
| 1294 |
+
cudaStream_t stream = nullptr) {
|
| 1295 |
+
|
| 1296 |
+
//
|
| 1297 |
+
// Blocking factors improve performance of reference implementation
|
| 1298 |
+
//
|
| 1299 |
+
|
| 1300 |
+
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1301 |
+
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1302 |
+
int const kCtaShapeM = 8; // shape of a threadblock in units of threads
|
| 1303 |
+
int const kCtaShapeN = 16; // shape of a threadblock in units of threads
|
| 1304 |
+
|
| 1305 |
+
int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C;
|
| 1306 |
+
int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
|
| 1307 |
+
|
| 1308 |
+
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1309 |
+
dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
|
| 1310 |
+
|
| 1311 |
+
kernel::Conv2dWgrad<
|
| 1312 |
+
ElementA,
|
| 1313 |
+
LayoutA,
|
| 1314 |
+
ElementB,
|
| 1315 |
+
LayoutB,
|
| 1316 |
+
ElementC,
|
| 1317 |
+
LayoutC,
|
| 1318 |
+
ElementCompute,
|
| 1319 |
+
ElementAccumulator,
|
| 1320 |
+
ConvertOp,
|
| 1321 |
+
InnerProductOp,
|
| 1322 |
+
kThreadM,
|
| 1323 |
+
kThreadN,
|
| 1324 |
+
kCtaShapeM,
|
| 1325 |
+
kCtaShapeN
|
| 1326 |
+
><<< grid, block, 0, stream >>>(
|
| 1327 |
+
problem_size,
|
| 1328 |
+
tensor_dy,
|
| 1329 |
+
tensor_x,
|
| 1330 |
+
tensor_dw_in,
|
| 1331 |
+
tensor_dw_out,
|
| 1332 |
+
alpha,
|
| 1333 |
+
beta
|
| 1334 |
+
);
|
| 1335 |
+
|
| 1336 |
+
cudaError_t result = cudaPeekAtLastError();
|
| 1337 |
+
if (result != cudaSuccess) {
|
| 1338 |
+
return Status::kErrorInternal;
|
| 1339 |
+
}
|
| 1340 |
+
|
| 1341 |
+
return Status::kSuccess;
|
| 1342 |
+
}
|
| 1343 |
+
|
| 1344 |
+
/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x)
|
| 1345 |
+
template <
|
| 1346 |
+
typename ElementA,
|
| 1347 |
+
typename LayoutA,
|
| 1348 |
+
typename ElementB,
|
| 1349 |
+
typename LayoutB,
|
| 1350 |
+
typename ElementC,
|
| 1351 |
+
typename LayoutC,
|
| 1352 |
+
typename ElementCompute,
|
| 1353 |
+
typename ElementAccumulator = ElementCompute,
|
| 1354 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1355 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1356 |
+
>
|
| 1357 |
+
Status Conv3dWgrad(
|
| 1358 |
+
conv::Conv3dProblemSize problem_size,
|
| 1359 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 1360 |
+
TensorRef<ElementB, LayoutB> tensor_x,
|
| 1361 |
+
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 1362 |
+
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 1363 |
+
ElementCompute alpha,
|
| 1364 |
+
ElementCompute beta,
|
| 1365 |
+
cudaStream_t stream = nullptr) {
|
| 1366 |
+
|
| 1367 |
+
//
|
| 1368 |
+
// Blocking factors improve performance of reference implementation
|
| 1369 |
+
//
|
| 1370 |
+
|
| 1371 |
+
int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
|
| 1372 |
+
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
| 1373 |
+
int const kCtaShapeM = 8; // shape of a threadblock in units of threads
|
| 1374 |
+
int const kCtaShapeN = 16; // shape of a threadblock in units of threads
|
| 1375 |
+
|
| 1376 |
+
int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C;
|
| 1377 |
+
int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
|
| 1378 |
+
|
| 1379 |
+
dim3 block(kCtaShapeM, kCtaShapeN);
|
| 1380 |
+
dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
|
| 1381 |
+
|
| 1382 |
+
kernel::Conv3dWgrad<
|
| 1383 |
+
ElementA,
|
| 1384 |
+
LayoutA,
|
| 1385 |
+
ElementB,
|
| 1386 |
+
LayoutB,
|
| 1387 |
+
ElementC,
|
| 1388 |
+
LayoutC,
|
| 1389 |
+
ElementCompute,
|
| 1390 |
+
ElementAccumulator,
|
| 1391 |
+
ConvertOp,
|
| 1392 |
+
InnerProductOp,
|
| 1393 |
+
kThreadM,
|
| 1394 |
+
kThreadN,
|
| 1395 |
+
kCtaShapeM,
|
| 1396 |
+
kCtaShapeN
|
| 1397 |
+
><<< grid, block, 0, stream >>>(
|
| 1398 |
+
problem_size,
|
| 1399 |
+
tensor_dy,
|
| 1400 |
+
tensor_x,
|
| 1401 |
+
tensor_dw_in,
|
| 1402 |
+
tensor_dw_out,
|
| 1403 |
+
alpha,
|
| 1404 |
+
beta
|
| 1405 |
+
);
|
| 1406 |
+
|
| 1407 |
+
cudaError_t result = cudaPeekAtLastError();
|
| 1408 |
+
if (result != cudaSuccess) {
|
| 1409 |
+
return Status::kErrorInternal;
|
| 1410 |
+
}
|
| 1411 |
+
|
| 1412 |
+
return Status::kSuccess;
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1416 |
+
|
| 1417 |
+
/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
|
| 1418 |
+
template <
|
| 1419 |
+
typename ElementA,
|
| 1420 |
+
typename LayoutA,
|
| 1421 |
+
typename ElementB,
|
| 1422 |
+
typename LayoutB,
|
| 1423 |
+
typename ElementC,
|
| 1424 |
+
typename LayoutC,
|
| 1425 |
+
typename ElementCompute,
|
| 1426 |
+
typename ElementAccumulator = ElementCompute,
|
| 1427 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1428 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1429 |
+
>
|
| 1430 |
+
Status Conv2d(
|
| 1431 |
+
conv::Operator convolutional_operator,
|
| 1432 |
+
conv::Conv2dProblemSize problem_size,
|
| 1433 |
+
TensorRef<ElementA, LayoutA> tensor_A,
|
| 1434 |
+
TensorRef<ElementB, LayoutB> tensor_B,
|
| 1435 |
+
TensorRef<ElementC, LayoutC> tensor_C,
|
| 1436 |
+
TensorRef<ElementC, LayoutC> tensor_D,
|
| 1437 |
+
ElementCompute alpha,
|
| 1438 |
+
ElementCompute beta,
|
| 1439 |
+
cudaStream_t stream = nullptr) {
|
| 1440 |
+
|
| 1441 |
+
switch (convolutional_operator) {
|
| 1442 |
+
case conv::Operator::kFprop:
|
| 1443 |
+
return Conv2dFprop<
|
| 1444 |
+
ElementA, LayoutA,
|
| 1445 |
+
ElementB, LayoutB,
|
| 1446 |
+
ElementC, LayoutC,
|
| 1447 |
+
ElementCompute,
|
| 1448 |
+
ElementAccumulator,
|
| 1449 |
+
ConvertOp, InnerProductOp
|
| 1450 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1451 |
+
break;
|
| 1452 |
+
|
| 1453 |
+
case conv::Operator::kDgrad:
|
| 1454 |
+
return Conv2dDgrad<
|
| 1455 |
+
ElementA, LayoutA,
|
| 1456 |
+
ElementB, LayoutB,
|
| 1457 |
+
ElementC, LayoutC,
|
| 1458 |
+
ElementCompute,
|
| 1459 |
+
ElementAccumulator,
|
| 1460 |
+
ConvertOp, InnerProductOp
|
| 1461 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1462 |
+
break;
|
| 1463 |
+
|
| 1464 |
+
case conv::Operator::kWgrad:
|
| 1465 |
+
return Conv2dWgrad<
|
| 1466 |
+
ElementA, LayoutA,
|
| 1467 |
+
ElementB, LayoutB,
|
| 1468 |
+
ElementC, LayoutC,
|
| 1469 |
+
ElementCompute,
|
| 1470 |
+
ElementAccumulator,
|
| 1471 |
+
ConvertOp, InnerProductOp
|
| 1472 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1473 |
+
break;
|
| 1474 |
+
|
| 1475 |
+
default: break;
|
| 1476 |
+
}
|
| 1477 |
+
|
| 1478 |
+
return Status::kErrorNotSupported;
|
| 1479 |
+
}
|
| 1480 |
+
|
| 1481 |
+
/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad.
|
| 1482 |
+
template <
|
| 1483 |
+
typename ElementA,
|
| 1484 |
+
typename LayoutA,
|
| 1485 |
+
typename ElementB,
|
| 1486 |
+
typename LayoutB,
|
| 1487 |
+
typename ElementC,
|
| 1488 |
+
typename LayoutC,
|
| 1489 |
+
typename ElementCompute,
|
| 1490 |
+
typename ElementAccumulator = ElementCompute,
|
| 1491 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 1492 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 1493 |
+
>
|
| 1494 |
+
Status Conv3d(
|
| 1495 |
+
conv::Operator convolutional_operator,
|
| 1496 |
+
conv::Conv3dProblemSize problem_size,
|
| 1497 |
+
TensorRef<ElementA, LayoutA> tensor_A,
|
| 1498 |
+
TensorRef<ElementB, LayoutB> tensor_B,
|
| 1499 |
+
TensorRef<ElementC, LayoutC> tensor_C,
|
| 1500 |
+
TensorRef<ElementC, LayoutC> tensor_D,
|
| 1501 |
+
ElementCompute alpha,
|
| 1502 |
+
ElementCompute beta,
|
| 1503 |
+
cudaStream_t stream = nullptr) {
|
| 1504 |
+
|
| 1505 |
+
switch (convolutional_operator) {
|
| 1506 |
+
case conv::Operator::kFprop:
|
| 1507 |
+
return Conv3dFprop<
|
| 1508 |
+
ElementA, LayoutA,
|
| 1509 |
+
ElementB, LayoutB,
|
| 1510 |
+
ElementC, LayoutC,
|
| 1511 |
+
ElementCompute,
|
| 1512 |
+
ElementAccumulator,
|
| 1513 |
+
ConvertOp, InnerProductOp
|
| 1514 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1515 |
+
|
| 1516 |
+
case conv::Operator::kDgrad:
|
| 1517 |
+
return Conv3dDgrad<
|
| 1518 |
+
ElementA, LayoutA,
|
| 1519 |
+
ElementB, LayoutB,
|
| 1520 |
+
ElementC, LayoutC,
|
| 1521 |
+
ElementCompute,
|
| 1522 |
+
ElementAccumulator,
|
| 1523 |
+
ConvertOp, InnerProductOp
|
| 1524 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1525 |
+
|
| 1526 |
+
case conv::Operator::kWgrad:
|
| 1527 |
+
return Conv3dWgrad<
|
| 1528 |
+
ElementA, LayoutA,
|
| 1529 |
+
ElementB, LayoutB,
|
| 1530 |
+
ElementC, LayoutC,
|
| 1531 |
+
ElementCompute,
|
| 1532 |
+
ElementAccumulator,
|
| 1533 |
+
ConvertOp, InnerProductOp
|
| 1534 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
| 1535 |
+
|
| 1536 |
+
default: break;
|
| 1537 |
+
}
|
| 1538 |
+
|
| 1539 |
+
return Status::kErrorNotSupported;
|
| 1540 |
+
}
|
| 1541 |
+
|
| 1542 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1543 |
+
|
| 1544 |
+
} // namespace device
|
| 1545 |
+
} // namespace reference
|
| 1546 |
+
} // namespace cutlass
|
| 1547 |
+
|
| 1548 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1549 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for GEMM in device-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/tensor_view.h"
|
| 44 |
+
#include "cutlass/gemm/gemm.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/util/reference/device/kernel/gemm.h"
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reference {
|
| 50 |
+
namespace device {
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 55 |
+
/// objects.
|
| 56 |
+
///
|
| 57 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 58 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 59 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 60 |
+
/// arguments explicitly.
|
| 61 |
+
template <
|
| 62 |
+
typename ElementA,
|
| 63 |
+
typename LayoutA,
|
| 64 |
+
typename ElementB,
|
| 65 |
+
typename LayoutB,
|
| 66 |
+
typename ElementC,
|
| 67 |
+
typename LayoutC,
|
| 68 |
+
typename ScalarType,
|
| 69 |
+
typename AccumulatorType,
|
| 70 |
+
typename InnerProductOp = multiply_add<AccumulatorType>,
|
| 71 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 72 |
+
>
|
| 73 |
+
void compute_gemm(
|
| 74 |
+
gemm::GemmCoord problem_size,
|
| 75 |
+
ScalarType alpha,
|
| 76 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 77 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 78 |
+
ScalarType beta,
|
| 79 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 80 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 81 |
+
AccumulatorType initial_accum) {
|
| 82 |
+
|
| 83 |
+
static_assert(
|
| 84 |
+
LayoutA::kRank == 2 &&
|
| 85 |
+
LayoutB::kRank == 2 &&
|
| 86 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 87 |
+
|
| 88 |
+
// Blocking structure potentially improves performance of reference implementation
|
| 89 |
+
// with a minor increase in complexity.
|
| 90 |
+
//
|
| 91 |
+
// Note, this reference implementation is NOT expected to approach peak performance.
|
| 92 |
+
using OutputTile = MatrixShape<4, 4>;
|
| 93 |
+
|
| 94 |
+
dim3 block(16, 8);
|
| 95 |
+
|
| 96 |
+
dim3 grid(
|
| 97 |
+
(problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
|
| 98 |
+
(problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
|
| 99 |
+
);
|
| 100 |
+
|
| 101 |
+
// Launch a GEMM kernel
|
| 102 |
+
kernel::Gemm<
|
| 103 |
+
TensorRef<ElementA, LayoutA>,
|
| 104 |
+
TensorRef<ElementB, LayoutB>,
|
| 105 |
+
TensorRef<ElementC, LayoutC>,
|
| 106 |
+
ScalarType,
|
| 107 |
+
AccumulatorType,
|
| 108 |
+
OutputTile,
|
| 109 |
+
InnerProductOp,
|
| 110 |
+
ConvertOp
|
| 111 |
+
><<< grid, block >>>(
|
| 112 |
+
problem_size,
|
| 113 |
+
alpha,
|
| 114 |
+
tensor_a,
|
| 115 |
+
tensor_b,
|
| 116 |
+
beta,
|
| 117 |
+
tensor_c,
|
| 118 |
+
tensor_d,
|
| 119 |
+
initial_accum
|
| 120 |
+
);
|
| 121 |
+
}
|
| 122 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 123 |
+
|
| 124 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 125 |
+
/// objects.
|
| 126 |
+
///
|
| 127 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 128 |
+
template <
|
| 129 |
+
typename ElementA,
|
| 130 |
+
typename LayoutA,
|
| 131 |
+
typename ElementB,
|
| 132 |
+
typename LayoutB,
|
| 133 |
+
typename ElementC,
|
| 134 |
+
typename LayoutC,
|
| 135 |
+
typename ScalarType,
|
| 136 |
+
typename AccumulatorType,
|
| 137 |
+
typename InnerProductOp = multiply_add<AccumulatorType>,
|
| 138 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 139 |
+
>
|
| 140 |
+
void compute_gemm(
|
| 141 |
+
gemm::GemmCoord problem_size,
|
| 142 |
+
ScalarType alpha,
|
| 143 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 144 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 145 |
+
ScalarType beta,
|
| 146 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 147 |
+
AccumulatorType initial_accum) {
|
| 148 |
+
|
| 149 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 150 |
+
ScalarType, AccumulatorType, InnerProductOp, ConvertOp>(
|
| 151 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
| 152 |
+
initial_accum);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template <
|
| 156 |
+
typename ElementA,
|
| 157 |
+
typename LayoutA,
|
| 158 |
+
typename ElementB,
|
| 159 |
+
typename LayoutB,
|
| 160 |
+
typename ElementC,
|
| 161 |
+
typename LayoutC,
|
| 162 |
+
typename ScalarType,
|
| 163 |
+
typename AccumulatorType,
|
| 164 |
+
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
| 165 |
+
>
|
| 166 |
+
struct Gemm;
|
| 167 |
+
|
| 168 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 169 |
+
|
| 170 |
+
/// Partial specialization for multiply-add
|
| 171 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 172 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 173 |
+
typename ScalarType, typename AccumulatorType>
|
| 174 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 175 |
+
ScalarType, AccumulatorType, arch::OpMultiplyAdd> {
|
| 176 |
+
|
| 177 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 178 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 179 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 180 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 181 |
+
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 182 |
+
|
| 183 |
+
static_assert(
|
| 184 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 185 |
+
"Tensors must be of rank 2");
|
| 186 |
+
|
| 187 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 188 |
+
ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
|
| 189 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 193 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 194 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 195 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 196 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 197 |
+
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 198 |
+
static_assert(
|
| 199 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 200 |
+
"Tensors must be of rank 2");
|
| 201 |
+
|
| 202 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 203 |
+
ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
|
| 204 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 205 |
+
}
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 209 |
+
|
| 210 |
+
/// Partial specialization for multiply-add-saturate
|
| 211 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 212 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 213 |
+
typename ScalarType, typename AccumulatorType>
|
| 214 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 215 |
+
AccumulatorType, arch::OpMultiplyAddSaturate> {
|
| 216 |
+
|
| 217 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 218 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 219 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 220 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 221 |
+
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 222 |
+
static_assert(
|
| 223 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 224 |
+
"Tensors must be of rank 2");
|
| 225 |
+
|
| 226 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 227 |
+
ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
|
| 228 |
+
NumericConverterClamp<ElementC, ScalarType>>(
|
| 229 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 233 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 234 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 235 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 236 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 237 |
+
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 238 |
+
static_assert(
|
| 239 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 240 |
+
"Tensors must be of rank 2");
|
| 241 |
+
|
| 242 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 243 |
+
ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
|
| 244 |
+
NumericConverterClamp<ElementC, ScalarType>>(
|
| 245 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 246 |
+
}
|
| 247 |
+
};
|
| 248 |
+
|
| 249 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 250 |
+
|
| 251 |
+
/// Partial specialization for XOR-popc
|
| 252 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 253 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 254 |
+
typename ScalarType, typename AccumulatorType>
|
| 255 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 256 |
+
AccumulatorType, arch::OpXorPopc> {
|
| 257 |
+
|
| 258 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 259 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 260 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 261 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 262 |
+
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 263 |
+
static_assert(
|
| 264 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 265 |
+
"Tensors must be of rank 2");
|
| 266 |
+
|
| 267 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 268 |
+
ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
|
| 269 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 273 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 274 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 275 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 276 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 277 |
+
AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 278 |
+
static_assert(
|
| 279 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 280 |
+
"Tensors must be of rank 2");
|
| 281 |
+
|
| 282 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 283 |
+
ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
|
| 284 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 285 |
+
}
|
| 286 |
+
};
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 290 |
+
//
|
| 291 |
+
// Batched GEMM
|
| 292 |
+
//
|
| 293 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 294 |
+
|
| 295 |
+
/// Computes a batch of GEMMs over a set of matrices of common dimension.
|
| 296 |
+
//
|
| 297 |
+
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 298 |
+
//
|
| 299 |
+
template <
|
| 300 |
+
typename TensorRefCollectionA,
|
| 301 |
+
typename TensorRefCollectionB,
|
| 302 |
+
typename TensorRefCollectionC,
|
| 303 |
+
typename ScalarType,
|
| 304 |
+
typename AccumulatorType,
|
| 305 |
+
typename InnerProductOp,
|
| 306 |
+
typename ConvertOp
|
| 307 |
+
>
|
| 308 |
+
void BatchedGemm(
|
| 309 |
+
gemm::GemmCoord problem_size,
|
| 310 |
+
int batch_count,
|
| 311 |
+
ScalarType alpha,
|
| 312 |
+
TensorRefCollectionA const& tensor_a,
|
| 313 |
+
TensorRefCollectionB const& tensor_b,
|
| 314 |
+
ScalarType beta,
|
| 315 |
+
TensorRefCollectionC &tensor_c,
|
| 316 |
+
AccumulatorType initial_accum) {
|
| 317 |
+
|
| 318 |
+
static_assert(
|
| 319 |
+
TensorRefCollectionA::kRank == 2 &&
|
| 320 |
+
TensorRefCollectionB::kRank == 2 &&
|
| 321 |
+
TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2");
|
| 322 |
+
|
| 323 |
+
// Blocking structure potentially improves performance of reference implementation
|
| 324 |
+
// with a minor increase in complexity.
|
| 325 |
+
//
|
| 326 |
+
// Note, this reference implementation is NOT expected to approach peak performance.
|
| 327 |
+
using OutputTile = MatrixShape<4, 4>;
|
| 328 |
+
|
| 329 |
+
dim3 block(16, 8);
|
| 330 |
+
dim3 grid(
|
| 331 |
+
(problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
|
| 332 |
+
(problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn),
|
| 333 |
+
batch_count
|
| 334 |
+
);
|
| 335 |
+
|
| 336 |
+
// Launch a GEMM kernel
|
| 337 |
+
kernel::BatchedGemm<
|
| 338 |
+
TensorRefCollectionA,
|
| 339 |
+
TensorRefCollectionB,
|
| 340 |
+
TensorRefCollectionC,
|
| 341 |
+
ScalarType,
|
| 342 |
+
AccumulatorType,
|
| 343 |
+
OutputTile,
|
| 344 |
+
InnerProductOp,
|
| 345 |
+
ConvertOp
|
| 346 |
+
><<< grid, block >>>(
|
| 347 |
+
problem_size,
|
| 348 |
+
alpha,
|
| 349 |
+
tensor_a,
|
| 350 |
+
tensor_b,
|
| 351 |
+
beta,
|
| 352 |
+
tensor_c,
|
| 353 |
+
initial_accum
|
| 354 |
+
);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 358 |
+
/// objects.
|
| 359 |
+
//
|
| 360 |
+
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 361 |
+
//
|
| 362 |
+
template <
|
| 363 |
+
typename TensorRefCollectionA,
|
| 364 |
+
typename TensorRefCollectionB,
|
| 365 |
+
typename TensorRefCollectionC,
|
| 366 |
+
typename ScalarType,
|
| 367 |
+
typename AccumulatorType
|
| 368 |
+
>
|
| 369 |
+
void BatchedGemm(
|
| 370 |
+
gemm::GemmCoord problem_size,
|
| 371 |
+
int batch_count,
|
| 372 |
+
ScalarType alpha,
|
| 373 |
+
TensorRefCollectionA const& tensor_a,
|
| 374 |
+
TensorRefCollectionB const& tensor_b,
|
| 375 |
+
ScalarType beta,
|
| 376 |
+
TensorRefCollectionC &tensor_c) {
|
| 377 |
+
|
| 378 |
+
BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 382 |
+
|
| 383 |
+
} // namespace device
|
| 384 |
+
} // namespace reference
|
| 385 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued GEMM in device-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/tensor_view.h"
|
| 44 |
+
#include "cutlass/gemm/gemm.h"
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace reference {
|
| 48 |
+
namespace device {
|
| 49 |
+
|
| 50 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
namespace kernel {
|
| 53 |
+
|
| 54 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 55 |
+
/// objects.
|
| 56 |
+
///
|
| 57 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 58 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 59 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 60 |
+
/// arguments explicitly.
|
| 61 |
+
template <
|
| 62 |
+
typename ElementA,
|
| 63 |
+
typename LayoutA,
|
| 64 |
+
typename ElementB,
|
| 65 |
+
typename LayoutB,
|
| 66 |
+
typename ElementC,
|
| 67 |
+
typename LayoutC,
|
| 68 |
+
typename ScalarType,
|
| 69 |
+
typename ComputeType,
|
| 70 |
+
typename ElementD = ElementC,
|
| 71 |
+
typename ConvertOp = NumericConverter<ElementD, ScalarType>,
|
| 72 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 73 |
+
int kMblock = 4,
|
| 74 |
+
int kNblock = 4
|
| 75 |
+
>
|
| 76 |
+
__global__ void GemmComplex(
|
| 77 |
+
gemm::GemmCoord problem_size,
|
| 78 |
+
ScalarType alpha,
|
| 79 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 80 |
+
ComplexTransform transform_a,
|
| 81 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 82 |
+
ComplexTransform transform_b,
|
| 83 |
+
ScalarType beta,
|
| 84 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 85 |
+
TensorRef<ElementD, LayoutC> tensor_d,
|
| 86 |
+
ComputeType initial_accum,
|
| 87 |
+
int batch_count = 1,
|
| 88 |
+
int64_t batch_stride_A = 0,
|
| 89 |
+
int64_t batch_stride_B = 0,
|
| 90 |
+
int64_t batch_stride_C = 0,
|
| 91 |
+
int64_t batch_stride_D = 0) {
|
| 92 |
+
|
| 93 |
+
static_assert(
|
| 94 |
+
LayoutA::kRank == 2 &&
|
| 95 |
+
LayoutB::kRank == 2 &&
|
| 96 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 97 |
+
|
| 98 |
+
int const M = problem_size.m();
|
| 99 |
+
int const N = problem_size.n();
|
| 100 |
+
int const K = problem_size.k();
|
| 101 |
+
|
| 102 |
+
ConvertOp convert_op;
|
| 103 |
+
InnerProductOp inner_product_op;
|
| 104 |
+
|
| 105 |
+
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
| 106 |
+
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
| 107 |
+
int batch_idx = blockIdx.z;
|
| 108 |
+
|
| 109 |
+
tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
|
| 110 |
+
tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
|
| 111 |
+
tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
|
| 112 |
+
tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
|
| 113 |
+
|
| 114 |
+
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
| 115 |
+
|
| 116 |
+
// Compute matrix product using blocks
|
| 117 |
+
ComputeType accum[kMblock][kNblock];
|
| 118 |
+
|
| 119 |
+
CUTLASS_PRAGMA_UNROLL
|
| 120 |
+
for (int j = 0; j < kNblock; j++) {
|
| 121 |
+
CUTLASS_PRAGMA_UNROLL
|
| 122 |
+
for (int i = 0; i < kMblock; i++) {
|
| 123 |
+
accum[i][j] = initial_accum;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 128 |
+
CUTLASS_PRAGMA_UNROLL
|
| 129 |
+
for (int j = 0; j < kNblock; j++) {
|
| 130 |
+
CUTLASS_PRAGMA_UNROLL
|
| 131 |
+
for (int i = 0; i < kMblock; i++) {
|
| 132 |
+
int row = row_block + i;
|
| 133 |
+
int col = col_block + j;
|
| 134 |
+
|
| 135 |
+
if (row < M && col < N) {
|
| 136 |
+
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 137 |
+
ElementB b = tensor_b.at(MatrixCoord(k_block, col));
|
| 138 |
+
|
| 139 |
+
ComputeType a_ik = ComputeType(a);
|
| 140 |
+
ComputeType b_kj = ComputeType(b);
|
| 141 |
+
|
| 142 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 143 |
+
a_ik = conj(a_ik);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 147 |
+
b_kj = conj(b_kj);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
CUTLASS_PRAGMA_UNROLL
|
| 157 |
+
for (int j = 0; j < kNblock; j++) {
|
| 158 |
+
CUTLASS_PRAGMA_UNROLL
|
| 159 |
+
for (int i = 0; i < kMblock; i++) {
|
| 160 |
+
int row = row_block + i;
|
| 161 |
+
int col = col_block + j;
|
| 162 |
+
|
| 163 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 164 |
+
|
| 165 |
+
if (row < M && col < N) {
|
| 166 |
+
|
| 167 |
+
tensor_d.at(coord) = convert_op(
|
| 168 |
+
alpha * ScalarType(accum[i][j]) +
|
| 169 |
+
beta * ScalarType(tensor_c.at(coord)));
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
|
| 175 |
+
tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
|
| 176 |
+
tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
|
| 177 |
+
tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
|
| 178 |
+
|
| 179 |
+
} // for (batch_idx)
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
} // namespace kernel
|
| 183 |
+
|
| 184 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 185 |
+
|
| 186 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 187 |
+
/// objects.
|
| 188 |
+
///
|
| 189 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 190 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 191 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 192 |
+
/// arguments explicitly.
|
| 193 |
+
template <
|
| 194 |
+
typename ElementA,
|
| 195 |
+
typename LayoutA,
|
| 196 |
+
typename ElementB,
|
| 197 |
+
typename LayoutB,
|
| 198 |
+
typename ElementC,
|
| 199 |
+
typename LayoutC,
|
| 200 |
+
typename ScalarType,
|
| 201 |
+
typename ComputeType,
|
| 202 |
+
typename ElementD = ElementC,
|
| 203 |
+
typename ConvertOp = NumericConverter<ElementD, ScalarType>,
|
| 204 |
+
typename InnerProductOp = multiply_add<ComputeType>
|
| 205 |
+
>
|
| 206 |
+
void GemmComplex(
|
| 207 |
+
gemm::GemmCoord problem_size,
|
| 208 |
+
ScalarType alpha,
|
| 209 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 210 |
+
ComplexTransform transform_a,
|
| 211 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 212 |
+
ComplexTransform transform_b,
|
| 213 |
+
ScalarType beta,
|
| 214 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 215 |
+
TensorRef<ElementD, LayoutC> tensor_d,
|
| 216 |
+
ComputeType initial_accum,
|
| 217 |
+
int batch_count = 1,
|
| 218 |
+
int64_t batch_stride_A = 0,
|
| 219 |
+
int64_t batch_stride_B = 0,
|
| 220 |
+
int64_t batch_stride_C = 0,
|
| 221 |
+
int64_t batch_stride_D = 0) {
|
| 222 |
+
|
| 223 |
+
static_assert(
|
| 224 |
+
LayoutA::kRank == 2 &&
|
| 225 |
+
LayoutB::kRank == 2 &&
|
| 226 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 227 |
+
|
| 228 |
+
int const kMblock = 4;
|
| 229 |
+
int const kNblock = 4;
|
| 230 |
+
|
| 231 |
+
dim3 block(16, 8);
|
| 232 |
+
dim3 grid(
|
| 233 |
+
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
| 234 |
+
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
| 235 |
+
batch_count % std::numeric_limits<uint16_t>::max()
|
| 236 |
+
);
|
| 237 |
+
|
| 238 |
+
if (grid.y <= std::numeric_limits<uint16_t>::max()) {
|
| 239 |
+
kernel::GemmComplex<
|
| 240 |
+
ElementA,
|
| 241 |
+
LayoutA,
|
| 242 |
+
ElementB,
|
| 243 |
+
LayoutB,
|
| 244 |
+
ElementC,
|
| 245 |
+
LayoutC,
|
| 246 |
+
ScalarType,
|
| 247 |
+
ComputeType,
|
| 248 |
+
ElementD,
|
| 249 |
+
ConvertOp,
|
| 250 |
+
InnerProductOp,
|
| 251 |
+
kMblock,
|
| 252 |
+
kNblock
|
| 253 |
+
><<< grid, block >>>(
|
| 254 |
+
problem_size,
|
| 255 |
+
alpha,
|
| 256 |
+
tensor_a,
|
| 257 |
+
transform_a,
|
| 258 |
+
tensor_b,
|
| 259 |
+
transform_b,
|
| 260 |
+
beta,
|
| 261 |
+
tensor_c,
|
| 262 |
+
tensor_d,
|
| 263 |
+
initial_accum,
|
| 264 |
+
batch_count,
|
| 265 |
+
batch_stride_A,
|
| 266 |
+
batch_stride_B,
|
| 267 |
+
batch_stride_C,
|
| 268 |
+
batch_stride_D
|
| 269 |
+
);
|
| 270 |
+
} else {
|
| 271 |
+
// Using bigger thread tile size
|
| 272 |
+
int const kBigMblock = 4;
|
| 273 |
+
int const kBigNblock = 16;
|
| 274 |
+
|
| 275 |
+
dim3 Bigblock(16, 8);
|
| 276 |
+
dim3 Biggrid(
|
| 277 |
+
(problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock),
|
| 278 |
+
(problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock),
|
| 279 |
+
batch_count % std::numeric_limits<uint16_t>::max()
|
| 280 |
+
);
|
| 281 |
+
|
| 282 |
+
kernel::GemmComplex<
|
| 283 |
+
ElementA,
|
| 284 |
+
LayoutA,
|
| 285 |
+
ElementB,
|
| 286 |
+
LayoutB,
|
| 287 |
+
ElementC,
|
| 288 |
+
LayoutC,
|
| 289 |
+
ScalarType,
|
| 290 |
+
ComputeType,
|
| 291 |
+
ElementD,
|
| 292 |
+
ConvertOp,
|
| 293 |
+
InnerProductOp,
|
| 294 |
+
kBigMblock,
|
| 295 |
+
kBigNblock
|
| 296 |
+
><<< Biggrid, Bigblock >>>(
|
| 297 |
+
problem_size,
|
| 298 |
+
alpha,
|
| 299 |
+
tensor_a,
|
| 300 |
+
transform_a,
|
| 301 |
+
tensor_b,
|
| 302 |
+
transform_b,
|
| 303 |
+
beta,
|
| 304 |
+
tensor_c,
|
| 305 |
+
tensor_d,
|
| 306 |
+
initial_accum,
|
| 307 |
+
batch_count,
|
| 308 |
+
batch_stride_A,
|
| 309 |
+
batch_stride_B,
|
| 310 |
+
batch_stride_C,
|
| 311 |
+
batch_stride_D
|
| 312 |
+
);
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 317 |
+
|
| 318 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 319 |
+
/// objects.
|
| 320 |
+
///
|
| 321 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 322 |
+
template <
|
| 323 |
+
typename ElementA,
|
| 324 |
+
typename LayoutA,
|
| 325 |
+
typename ElementB,
|
| 326 |
+
typename LayoutB,
|
| 327 |
+
typename ElementC,
|
| 328 |
+
typename LayoutC,
|
| 329 |
+
typename ScalarType,
|
| 330 |
+
typename ElementD = ElementC
|
| 331 |
+
>
|
| 332 |
+
void GemmComplex(
|
| 333 |
+
gemm::GemmCoord problem_size,
|
| 334 |
+
ScalarType alpha,
|
| 335 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 336 |
+
ComplexTransform transform_a,
|
| 337 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 338 |
+
ComplexTransform transform_b,
|
| 339 |
+
ScalarType beta,
|
| 340 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 341 |
+
TensorRef<ElementD, LayoutC> tensor_d) {
|
| 342 |
+
|
| 343 |
+
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 347 |
+
|
| 348 |
+
} // namespace device
|
| 349 |
+
} // namespace reference
|
| 350 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued GEMM in device code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/matrix_coord.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/tensor_ref_planar_complex.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/tensor_view.h"
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reference {
|
| 50 |
+
namespace device {
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace kernel {
|
| 55 |
+
|
| 56 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
static int const kGemmPlanarComplexBlockSize = 4;
|
| 59 |
+
|
| 60 |
+
template <
|
| 61 |
+
typename ElementA,
|
| 62 |
+
typename LayoutA,
|
| 63 |
+
typename ElementB,
|
| 64 |
+
typename LayoutB,
|
| 65 |
+
typename ElementC,
|
| 66 |
+
typename LayoutC,
|
| 67 |
+
typename ScalarType,
|
| 68 |
+
typename ComputeType,
|
| 69 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 70 |
+
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
| 71 |
+
>
|
| 72 |
+
__global__ void GemmPlanarComplex(
|
| 73 |
+
gemm::GemmCoord problem_size,
|
| 74 |
+
complex<ScalarType> alpha,
|
| 75 |
+
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 76 |
+
ComplexTransform transform_a,
|
| 77 |
+
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 78 |
+
ComplexTransform transform_b,
|
| 79 |
+
complex<ScalarType> beta,
|
| 80 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 81 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
| 82 |
+
complex<ComputeType> initial_accum) {
|
| 83 |
+
|
| 84 |
+
int const kMblock = kGemmPlanarComplexBlockSize;
|
| 85 |
+
int const kNblock = kGemmPlanarComplexBlockSize;
|
| 86 |
+
|
| 87 |
+
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
|
| 88 |
+
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
|
| 89 |
+
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
|
| 90 |
+
|
| 91 |
+
// Note: batch is ignored.
|
| 92 |
+
int const M = problem_size.m();
|
| 93 |
+
int const N = problem_size.n();
|
| 94 |
+
int const K = problem_size.k();
|
| 95 |
+
|
| 96 |
+
ConvertOp convert_op;
|
| 97 |
+
InnerProductOp inner_product_op;
|
| 98 |
+
|
| 99 |
+
complex<ComputeType> accum[kMblock][kNblock];
|
| 100 |
+
|
| 101 |
+
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
| 102 |
+
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
| 103 |
+
|
| 104 |
+
CUTLASS_PRAGMA_UNROLL
|
| 105 |
+
for (int j = 0; j < kNblock; j++) {
|
| 106 |
+
CUTLASS_PRAGMA_UNROLL
|
| 107 |
+
for (int i = 0; i < kMblock; i++) {
|
| 108 |
+
accum[i][j] = initial_accum;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 113 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 114 |
+
|
| 115 |
+
CUTLASS_PRAGMA_UNROLL
|
| 116 |
+
for (int j = 0; j < kNblock; j++) {
|
| 117 |
+
|
| 118 |
+
CUTLASS_PRAGMA_UNROLL
|
| 119 |
+
for (int i = 0; i < kMblock; i++) {
|
| 120 |
+
|
| 121 |
+
int row = row_block + i;
|
| 122 |
+
int col = col_block + j;
|
| 123 |
+
|
| 124 |
+
if (row < M && col < N) {
|
| 125 |
+
|
| 126 |
+
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
|
| 127 |
+
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
|
| 128 |
+
|
| 129 |
+
complex<ComputeType> a = complex<ComputeType>{
|
| 130 |
+
ComputeType(a_ik.real()),
|
| 131 |
+
ComputeType(a_ik.imag())
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
complex<ComputeType> b = complex<ComputeType>{
|
| 135 |
+
ComputeType(b_kj.real()),
|
| 136 |
+
ComputeType(b_kj.imag())
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 140 |
+
a = conj(a);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 144 |
+
b = conj(b);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
accum[i][j] = inner_product_op(a, b, accum[i][j]);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
CUTLASS_PRAGMA_UNROLL
|
| 154 |
+
for (int j = 0; j < kNblock; j++) {
|
| 155 |
+
CUTLASS_PRAGMA_UNROLL
|
| 156 |
+
for (int i = 0; i < kMblock; i++) {
|
| 157 |
+
|
| 158 |
+
int row = row_block + i;
|
| 159 |
+
int col = col_block + j;
|
| 160 |
+
|
| 161 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 162 |
+
|
| 163 |
+
if (row < M && col < N) {
|
| 164 |
+
|
| 165 |
+
complex<ScalarType> acc{
|
| 166 |
+
ScalarType(accum[i][j].real()),
|
| 167 |
+
ScalarType(accum[i][j].imag())
|
| 168 |
+
};
|
| 169 |
+
|
| 170 |
+
ComplexC c_ij = ComplexC();
|
| 171 |
+
|
| 172 |
+
if (beta.real() != ScalarType() || beta.imag() != ScalarType()) {
|
| 173 |
+
c_ij = tensor_c.at(coord);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
complex<ScalarType> src{
|
| 177 |
+
ScalarType(c_ij.real()),
|
| 178 |
+
ScalarType(c_ij.imag())
|
| 179 |
+
};
|
| 180 |
+
|
| 181 |
+
complex<ScalarType> result = alpha * acc + beta * src;
|
| 182 |
+
|
| 183 |
+
ComplexC d_ij;
|
| 184 |
+
|
| 185 |
+
d_ij.real() = convert_op(result.real());
|
| 186 |
+
d_ij.imag() = convert_op(result.imag());
|
| 187 |
+
|
| 188 |
+
tensor_d.at(coord) = d_ij;
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 195 |
+
|
| 196 |
+
} // namespace kernel
|
| 197 |
+
|
| 198 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 199 |
+
|
| 200 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 201 |
+
/// objects.
|
| 202 |
+
///
|
| 203 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 204 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 205 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 206 |
+
/// arguments explicitly.
|
| 207 |
+
template <
|
| 208 |
+
typename ElementA,
|
| 209 |
+
typename LayoutA,
|
| 210 |
+
typename ElementB,
|
| 211 |
+
typename LayoutB,
|
| 212 |
+
typename ElementC,
|
| 213 |
+
typename LayoutC,
|
| 214 |
+
typename ScalarType,
|
| 215 |
+
typename ComputeType,
|
| 216 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 217 |
+
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
| 218 |
+
>
|
| 219 |
+
void GemmPlanarComplex(
|
| 220 |
+
gemm::GemmCoord problem_size,
|
| 221 |
+
complex<ScalarType> alpha,
|
| 222 |
+
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 223 |
+
ComplexTransform transform_a,
|
| 224 |
+
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 225 |
+
ComplexTransform transform_b,
|
| 226 |
+
complex<ScalarType> beta,
|
| 227 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 228 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
| 229 |
+
complex<ComputeType> initial_accum) {
|
| 230 |
+
|
| 231 |
+
static_assert(
|
| 232 |
+
LayoutA::kRank == 2 &&
|
| 233 |
+
LayoutB::kRank == 2 &&
|
| 234 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 235 |
+
|
| 236 |
+
int const kMblock = kernel::kGemmPlanarComplexBlockSize;
|
| 237 |
+
int const kNblock = kernel::kGemmPlanarComplexBlockSize;
|
| 238 |
+
|
| 239 |
+
dim3 block(16, 8);
|
| 240 |
+
|
| 241 |
+
dim3 grid(
|
| 242 |
+
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
| 243 |
+
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
| 244 |
+
1);
|
| 245 |
+
|
| 246 |
+
kernel::GemmPlanarComplex<
|
| 247 |
+
ElementA, LayoutA,
|
| 248 |
+
ElementB, LayoutB,
|
| 249 |
+
ElementC, LayoutC,
|
| 250 |
+
ScalarType,
|
| 251 |
+
ComputeType,
|
| 252 |
+
ConvertOp,
|
| 253 |
+
InnerProductOp
|
| 254 |
+
><<< grid, block >>>(
|
| 255 |
+
problem_size,
|
| 256 |
+
alpha,
|
| 257 |
+
tensor_a,
|
| 258 |
+
transform_a,
|
| 259 |
+
tensor_b,
|
| 260 |
+
transform_b,
|
| 261 |
+
beta,
|
| 262 |
+
tensor_c,
|
| 263 |
+
tensor_d,
|
| 264 |
+
initial_accum
|
| 265 |
+
);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 269 |
+
|
| 270 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 271 |
+
/// objects.
|
| 272 |
+
///
|
| 273 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 274 |
+
template <
|
| 275 |
+
typename ElementA,
|
| 276 |
+
typename LayoutA,
|
| 277 |
+
typename ElementB,
|
| 278 |
+
typename LayoutB,
|
| 279 |
+
typename ElementC,
|
| 280 |
+
typename LayoutC,
|
| 281 |
+
typename ScalarType
|
| 282 |
+
>
|
| 283 |
+
void GemmPlanarComplex(
|
| 284 |
+
gemm::GemmCoord problem_size,
|
| 285 |
+
complex<ScalarType> alpha,
|
| 286 |
+
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 287 |
+
ComplexTransform transform_a,
|
| 288 |
+
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 289 |
+
ComplexTransform transform_b,
|
| 290 |
+
complex<ScalarType> beta,
|
| 291 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 292 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
|
| 293 |
+
|
| 294 |
+
GemmPlanarComplex(
|
| 295 |
+
problem_size,
|
| 296 |
+
alpha,
|
| 297 |
+
tensor_a, transform_a,
|
| 298 |
+
tensor_b, transform_b,
|
| 299 |
+
beta,
|
| 300 |
+
tensor_c,
|
| 301 |
+
tensor_d,
|
| 302 |
+
complex<ScalarType>());
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 306 |
+
|
| 307 |
+
} // namespace device
|
| 308 |
+
} // namespace reference
|
| 309 |
+
} // namespace cutlass
|
| 310 |
+
|
| 311 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief GETT device reference code
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <cute/tensor.hpp>
|
| 37 |
+
|
| 38 |
+
namespace cutlass::reference::device {
|
| 39 |
+
|
| 40 |
+
template <
|
| 41 |
+
class ATensor,
|
| 42 |
+
class BTensor,
|
| 43 |
+
class CTensor,
|
| 44 |
+
class DTensor,
|
| 45 |
+
class ElementAccumulator,
|
| 46 |
+
class ElementEpilogue>
|
| 47 |
+
__global__ static
|
| 48 |
+
void
|
| 49 |
+
gett_kernel(
|
| 50 |
+
DTensor D,
|
| 51 |
+
ATensor const A,
|
| 52 |
+
BTensor const B,
|
| 53 |
+
CTensor const C,
|
| 54 |
+
ElementEpilogue alpha, ElementEpilogue beta,
|
| 55 |
+
ElementAccumulator acc_init)
|
| 56 |
+
{
|
| 57 |
+
using namespace cute;
|
| 58 |
+
|
| 59 |
+
static_assert(DTensor::rank == 3, "(M,N,L)");
|
| 60 |
+
static_assert(ATensor::rank == 3, "(M,K,L)");
|
| 61 |
+
static_assert(BTensor::rank == 3, "(N,K,L)");
|
| 62 |
+
static_assert(CTensor::rank == 3, "(M,N,L)");
|
| 63 |
+
|
| 64 |
+
assert(size<0>(A) == size<0>(D)); // M
|
| 65 |
+
assert(size<0>(C) == size<0>(D)); // M
|
| 66 |
+
assert(size<0>(B) == size<1>(D)); // N
|
| 67 |
+
assert(size<1>(C) == size<1>(D)); // N
|
| 68 |
+
assert(size<1>(A) == size<1>(B)); // K
|
| 69 |
+
assert(size<2>(A) == size<2>(D)); // L
|
| 70 |
+
assert(size<2>(B) == size<2>(D)); // L
|
| 71 |
+
assert(size<2>(C) == size<2>(D)); // L
|
| 72 |
+
|
| 73 |
+
NumericConverter<ElementAccumulator, typename ATensor::value_type> a_converter;
|
| 74 |
+
NumericConverter<ElementAccumulator, typename BTensor::value_type> b_converter;
|
| 75 |
+
NumericConverter<ElementEpilogue, ElementAccumulator> acc_converter;
|
| 76 |
+
NumericConverter<ElementEpilogue, typename CTensor::value_type> source_converter;
|
| 77 |
+
NumericConverter<typename DTensor::value_type, ElementEpilogue> output_converter;
|
| 78 |
+
|
| 79 |
+
// Thread id to each element of D
|
| 80 |
+
for (int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
| 81 |
+
tid < size(D);
|
| 82 |
+
tid += blockDim.x * gridDim.x) {
|
| 83 |
+
// (m,n,l) coordinate
|
| 84 |
+
auto mnl_coord = idx2crd(tid, product_each(shape(D)));
|
| 85 |
+
auto m = get<0>(mnl_coord);
|
| 86 |
+
auto n = get<1>(mnl_coord);
|
| 87 |
+
auto l = get<2>(mnl_coord);
|
| 88 |
+
|
| 89 |
+
auto A_ml = A(m,_,l);
|
| 90 |
+
auto B_nl = B(n,_,l);
|
| 91 |
+
|
| 92 |
+
ElementAccumulator accum = ElementAccumulator(0);
|
| 93 |
+
for (int k = 0; k < size<1>(A); ++k) {
|
| 94 |
+
ElementAccumulator a = a_converter(A_ml(k));
|
| 95 |
+
ElementAccumulator b = b_converter(B_nl(k));
|
| 96 |
+
accum += a * b;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l)));
|
| 100 |
+
D(m,n,l) = output_converter(scaled_output);
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Most general version
|
| 105 |
+
template <
|
| 106 |
+
class ProblemShapeMNKL,
|
| 107 |
+
class ElementA,
|
| 108 |
+
class StrideA,
|
| 109 |
+
class ElementB,
|
| 110 |
+
class StrideB,
|
| 111 |
+
class ElementAccumulator,
|
| 112 |
+
class ElementC,
|
| 113 |
+
class StrideC,
|
| 114 |
+
class ElementD,
|
| 115 |
+
class StrideD,
|
| 116 |
+
class ElementEpilogue>
|
| 117 |
+
void
|
| 118 |
+
gett(
|
| 119 |
+
ProblemShapeMNKL problem_shape_mnkl,
|
| 120 |
+
ElementA const* ptr_A, StrideA stride_a_mkl,
|
| 121 |
+
ElementB const* ptr_B, StrideB stride_b_nkl,
|
| 122 |
+
ElementAccumulator _,
|
| 123 |
+
ElementC const* ptr_C, StrideC stride_c_mnl,
|
| 124 |
+
ElementD * ptr_D, StrideD stride_d_mnl,
|
| 125 |
+
ElementEpilogue alpha, ElementEpilogue beta,
|
| 126 |
+
cudaStream_t stream = 0) {
|
| 127 |
+
using namespace cute;
|
| 128 |
+
|
| 129 |
+
static_assert(cute::rank(ProblemShapeMNKL{}) == 4);
|
| 130 |
+
auto M = get<0>(problem_shape_mnkl);
|
| 131 |
+
auto N = get<1>(problem_shape_mnkl);
|
| 132 |
+
auto K = get<2>(problem_shape_mnkl);
|
| 133 |
+
auto L = get<3>(problem_shape_mnkl);
|
| 134 |
+
|
| 135 |
+
// Represent the full tensors
|
| 136 |
+
auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L)
|
| 137 |
+
auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L)
|
| 138 |
+
auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L)
|
| 139 |
+
auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L)
|
| 140 |
+
|
| 141 |
+
dim3 dimBlock(256);
|
| 142 |
+
dim3 dimGrid(240);
|
| 143 |
+
gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0));
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
} // namespace cutlass::reference::device
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/tensor_view.h"
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/util/reference/device/thread/gemm.h"
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace reference {
|
| 45 |
+
namespace device {
|
| 46 |
+
namespace kernel {
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 51 |
+
/// objects.
|
| 52 |
+
template <
|
| 53 |
+
typename TensorRefA,
|
| 54 |
+
typename TensorRefB,
|
| 55 |
+
typename TensorRefC,
|
| 56 |
+
typename ScalarType,
|
| 57 |
+
typename AccumulatorType,
|
| 58 |
+
typename OutputTile,
|
| 59 |
+
typename InnerProductOp,
|
| 60 |
+
typename ConvertOp
|
| 61 |
+
>
|
| 62 |
+
__global__ void Gemm(
|
| 63 |
+
gemm::GemmCoord problem_size,
|
| 64 |
+
ScalarType alpha,
|
| 65 |
+
TensorRefA tensor_a,
|
| 66 |
+
TensorRefB tensor_b,
|
| 67 |
+
ScalarType beta,
|
| 68 |
+
TensorRefC tensor_c,
|
| 69 |
+
TensorRefC tensor_d,
|
| 70 |
+
AccumulatorType initial_accum) {
|
| 71 |
+
|
| 72 |
+
// Map each thread to a unique tile of the output matrix
|
| 73 |
+
MatrixCoord output_coord(
|
| 74 |
+
MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow),
|
| 75 |
+
MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn)
|
| 76 |
+
);
|
| 77 |
+
|
| 78 |
+
// Compute the general matrix product
|
| 79 |
+
thread::Gemm<
|
| 80 |
+
TensorRefA,
|
| 81 |
+
TensorRefB,
|
| 82 |
+
TensorRefC,
|
| 83 |
+
ScalarType,
|
| 84 |
+
AccumulatorType,
|
| 85 |
+
OutputTile,
|
| 86 |
+
InnerProductOp,
|
| 87 |
+
ConvertOp
|
| 88 |
+
> gemm(initial_accum);
|
| 89 |
+
|
| 90 |
+
gemm.multiply_add(
|
| 91 |
+
problem_size,
|
| 92 |
+
tensor_a,
|
| 93 |
+
tensor_b,
|
| 94 |
+
output_coord);
|
| 95 |
+
|
| 96 |
+
gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 100 |
+
|
| 101 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 102 |
+
/// objects.
|
| 103 |
+
template <
|
| 104 |
+
typename TensorRefCollectionA,
|
| 105 |
+
typename TensorRefCollectionB,
|
| 106 |
+
typename TensorRefCollectionC,
|
| 107 |
+
typename ScalarType,
|
| 108 |
+
typename AccumulatorType,
|
| 109 |
+
typename OutputTile,
|
| 110 |
+
typename InnerProductOp,
|
| 111 |
+
typename ConvertOp
|
| 112 |
+
>
|
| 113 |
+
__global__ void BatchedGemm(
|
| 114 |
+
gemm::GemmCoord problem_size,
|
| 115 |
+
ScalarType alpha,
|
| 116 |
+
TensorRefCollectionA tensor_collection_a,
|
| 117 |
+
TensorRefCollectionB tensor_collection_b,
|
| 118 |
+
ScalarType beta,
|
| 119 |
+
TensorRefCollectionC tensor_collection_c,
|
| 120 |
+
AccumulatorType initial_accum) {
|
| 121 |
+
|
| 122 |
+
// Obtain batch ID
|
| 123 |
+
int batch_id = blockIdx.z;
|
| 124 |
+
|
| 125 |
+
// Dereference based on batch_id
|
| 126 |
+
typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id);
|
| 127 |
+
typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id);
|
| 128 |
+
typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id);
|
| 129 |
+
|
| 130 |
+
// Map each thread to a unique tile of the output matrix
|
| 131 |
+
MatrixCoord output_coord(
|
| 132 |
+
(threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn,
|
| 133 |
+
(threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow
|
| 134 |
+
);
|
| 135 |
+
|
| 136 |
+
// Compute the general matrix product
|
| 137 |
+
thread::Gemm<
|
| 138 |
+
typename TensorRefCollectionA::TensorRef,
|
| 139 |
+
typename TensorRefCollectionB::TensorRef,
|
| 140 |
+
typename TensorRefCollectionC::TensorRef,
|
| 141 |
+
ScalarType,
|
| 142 |
+
AccumulatorType,
|
| 143 |
+
OutputTile,
|
| 144 |
+
InnerProductOp,
|
| 145 |
+
ConvertOp
|
| 146 |
+
> gemm(initial_accum);
|
| 147 |
+
|
| 148 |
+
gemm.multiply_add(
|
| 149 |
+
problem_size,
|
| 150 |
+
tensor_a,
|
| 151 |
+
tensor_b,
|
| 152 |
+
output_coord);
|
| 153 |
+
|
| 154 |
+
gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 158 |
+
|
| 159 |
+
} // namespace kernel
|
| 160 |
+
} // namespace device
|
| 161 |
+
} // namespace reference
|
| 162 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include <curand_kernel.h>
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
|
| 38 |
+
namespace cutlass {
|
| 39 |
+
namespace reference {
|
| 40 |
+
namespace device {
|
| 41 |
+
namespace kernel {
|
| 42 |
+
|
| 43 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
/// Kernel to initialize tensor to uniform random distribution
|
| 46 |
+
template <typename T>
|
| 47 |
+
__global__ void TensorInitializeUniform(
|
| 48 |
+
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 49 |
+
__shared__ curandState_t rng_state[1024];
|
| 50 |
+
|
| 51 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 52 |
+
|
| 53 |
+
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 54 |
+
|
| 55 |
+
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 56 |
+
int s_idx = blockIdx.y * blockDim.x;
|
| 57 |
+
|
| 58 |
+
tensor += s_idx * ldm + c_idx;
|
| 59 |
+
|
| 60 |
+
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 61 |
+
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 62 |
+
double range = dist.uniform.max - dist.uniform.min;
|
| 63 |
+
|
| 64 |
+
double rnd = curand_uniform(&rng_state[threadIdx.x]);
|
| 65 |
+
|
| 66 |
+
rnd = dist.uniform.min + range * rnd;
|
| 67 |
+
|
| 68 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 69 |
+
// testing
|
| 70 |
+
if (dist.int_scale >= 0) {
|
| 71 |
+
rnd = double(int(rnd * double(1 << dist.int_scale)));
|
| 72 |
+
*tensor = T(rnd / double(1 << dist.int_scale));
|
| 73 |
+
} else {
|
| 74 |
+
*tensor = T(rnd);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
tensor += ldm;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 83 |
+
|
| 84 |
+
/// Kernel to initialize tensor to uniform distribution
|
| 85 |
+
template <typename T>
|
| 86 |
+
__global__ void TensorInitializeGaussian(
|
| 87 |
+
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 88 |
+
__shared__ curandState_t rng_state[1024];
|
| 89 |
+
|
| 90 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 91 |
+
|
| 92 |
+
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 93 |
+
|
| 94 |
+
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 95 |
+
int s_idx = blockIdx.y * blockDim.x;
|
| 96 |
+
|
| 97 |
+
tensor += s_idx * ldm + c_idx;
|
| 98 |
+
|
| 99 |
+
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 100 |
+
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 101 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 102 |
+
// testing
|
| 103 |
+
|
| 104 |
+
double rnd = curand_normal(&rng_state[threadIdx.x]);
|
| 105 |
+
|
| 106 |
+
rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd;
|
| 107 |
+
|
| 108 |
+
if (dist.int_scale >= 0) {
|
| 109 |
+
rnd = double(int(rnd * double(1 << dist.int_scale)));
|
| 110 |
+
*tensor = T(rnd / double(1 << dist.int_scale));
|
| 111 |
+
} else {
|
| 112 |
+
*tensor = T(rnd);
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/// Kernel to initialize tensor to an identity matrix
|
| 119 |
+
template <typename T>
|
| 120 |
+
__global__ void TensorInitializeLinear(
|
| 121 |
+
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 122 |
+
__shared__ curandState_t rng_state[1024];
|
| 123 |
+
|
| 124 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 125 |
+
|
| 126 |
+
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 127 |
+
|
| 128 |
+
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 129 |
+
int s_idx = blockIdx.y * blockDim.x;
|
| 130 |
+
|
| 131 |
+
tensor += s_idx * ldm + c_idx;
|
| 132 |
+
|
| 133 |
+
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 134 |
+
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 135 |
+
*tensor =
|
| 136 |
+
dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx;
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Kernel to initialize tensor to an identity matrix
|
| 142 |
+
template <typename T>
|
| 143 |
+
__global__ void TensorInitializeIdentity(
|
| 144 |
+
Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
|
| 145 |
+
__shared__ curandState_t rng_state[1024];
|
| 146 |
+
|
| 147 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
|
| 148 |
+
|
| 149 |
+
curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
|
| 150 |
+
|
| 151 |
+
int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 152 |
+
int s_idx = blockIdx.y * blockDim.x;
|
| 153 |
+
|
| 154 |
+
tensor += s_idx * ldm + c_idx;
|
| 155 |
+
|
| 156 |
+
for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
|
| 157 |
+
if (s_idx < dim_strided && c_idx < dim_contiguous) {
|
| 158 |
+
*tensor = (c_idx == s_idx ? T(1) : T(0));
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 164 |
+
|
| 165 |
+
} // namespace kernel
|
| 166 |
+
} // namespace device
|
| 167 |
+
} // namespace reference
|
| 168 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include "cutlass/cutlass.h"
|
| 35 |
+
#include "cutlass/coord.h"
|
| 36 |
+
#include "cutlass/subbyte_reference.h"
|
| 37 |
+
#include "cutlass/fast_math.h"
|
| 38 |
+
|
| 39 |
+
namespace cutlass {
|
| 40 |
+
namespace reference {
|
| 41 |
+
namespace device {
|
| 42 |
+
namespace kernel {
|
| 43 |
+
|
| 44 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
/// Defines several helpers
|
| 47 |
+
namespace detail {
|
| 48 |
+
|
| 49 |
+
/// Helper to perform for-each operation
|
| 50 |
+
template <typename Func, int Rank, int RankRemaining>
|
| 51 |
+
struct TensorForEachHelper {
|
| 52 |
+
|
| 53 |
+
/// Constructor for general rank
|
| 54 |
+
__inline__ __device__
|
| 55 |
+
TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
|
| 56 |
+
|
| 57 |
+
int64_t product = 1;
|
| 58 |
+
|
| 59 |
+
CUTLASS_PRAGMA_UNROLL
|
| 60 |
+
for (int i = Rank - RankRemaining; i < Rank; ++i) {
|
| 61 |
+
product *= size[i];
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
coord[Rank - 1 - RankRemaining] = index / product;
|
| 65 |
+
int64_t remaining = index % product;
|
| 66 |
+
|
| 67 |
+
TensorForEachHelper<Func, Rank, RankRemaining-1>(func, size, coord, remaining);
|
| 68 |
+
}
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
/// Helper to perform for-each operation
|
| 72 |
+
template <typename Func, int Rank>
|
| 73 |
+
struct TensorForEachHelper<Func, Rank, 0> {
|
| 74 |
+
|
| 75 |
+
/// Constructor for fastest changing rank
|
| 76 |
+
__inline__ __device__
|
| 77 |
+
TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
|
| 78 |
+
|
| 79 |
+
coord[Rank - 1] = index;
|
| 80 |
+
|
| 81 |
+
if (coord < size) {
|
| 82 |
+
func(coord);
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
} // namespace detail
|
| 88 |
+
|
| 89 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 90 |
+
|
| 91 |
+
/// Kernel calls a functor for each element in a tensor's index space
|
| 92 |
+
template <typename Func, int Rank, typename Params>
|
| 93 |
+
__global__ void TensorForEach(Coord<Rank> size, Params params = Params()) {
|
| 94 |
+
|
| 95 |
+
Func func(params);
|
| 96 |
+
|
| 97 |
+
int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
| 98 |
+
int64_t max_index = 1;
|
| 99 |
+
|
| 100 |
+
CUTLASS_PRAGMA_UNROLL
|
| 101 |
+
for (int i = 0; i < Rank; ++i) {
|
| 102 |
+
max_index *= size[i];
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 106 |
+
while (index < max_index) {
|
| 107 |
+
Coord<Rank> coord;
|
| 108 |
+
|
| 109 |
+
detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, size, coord, index);
|
| 110 |
+
index += blockDim.x * gridDim.x;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 115 |
+
|
| 116 |
+
/// Kernel calls a functor for each element along a tensor's diagonal
|
| 117 |
+
template <typename Func, int Rank, typename Params>
|
| 118 |
+
__global__ void TensorDiagonalForEach(Coord<Rank> size, Params params, int start, int end) {
|
| 119 |
+
|
| 120 |
+
Func func(params);
|
| 121 |
+
|
| 122 |
+
int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start;
|
| 123 |
+
|
| 124 |
+
if (index < end) {
|
| 125 |
+
Coord<Rank> coord;
|
| 126 |
+
|
| 127 |
+
CUTLASS_PRAGMA_UNROLL
|
| 128 |
+
for (int i = 0; i < Rank; ++i) {
|
| 129 |
+
coord[i] = index;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
func(coord);
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 137 |
+
|
| 138 |
+
template <typename Element, typename Func>
|
| 139 |
+
__global__ void BlockForEach(
|
| 140 |
+
Element *ptr,
|
| 141 |
+
size_t capacity,
|
| 142 |
+
typename Func::Params params) {
|
| 143 |
+
|
| 144 |
+
Func func(params);
|
| 145 |
+
|
| 146 |
+
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
| 147 |
+
|
| 148 |
+
for (; index < capacity; index += blockDim.x * gridDim.x) {
|
| 149 |
+
ReferenceFactory<Element>::get(ptr, index) = func();
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 154 |
+
|
| 155 |
+
} // namespace kernel
|
| 156 |
+
} // namespace device
|
| 157 |
+
} // namespace reference
|
| 158 |
+
} // namespace cutlass
|
| 159 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued GEMM in device-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/blas3.h"
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/numeric_conversion.h"
|
| 40 |
+
#include "cutlass/tensor_view.h"
|
| 41 |
+
#include "cutlass/gemm/gemm.h"
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace reference {
|
| 45 |
+
namespace device {
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace kernel {
|
| 50 |
+
|
| 51 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 52 |
+
/// objects.
|
| 53 |
+
///
|
| 54 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 55 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 56 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 57 |
+
/// arguments explicitly.
|
| 58 |
+
template <
|
| 59 |
+
typename ElementA,
|
| 60 |
+
typename LayoutA,
|
| 61 |
+
typename ElementB,
|
| 62 |
+
typename LayoutB,
|
| 63 |
+
typename ElementC,
|
| 64 |
+
typename LayoutC,
|
| 65 |
+
typename ScalarType,
|
| 66 |
+
typename ComputeType,
|
| 67 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 68 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 69 |
+
int kMblock = 4,
|
| 70 |
+
int kNblock = 4
|
| 71 |
+
>
|
| 72 |
+
__global__ void Rank2KComplex(
|
| 73 |
+
gemm::GemmCoord problem_size,
|
| 74 |
+
ScalarType alpha,
|
| 75 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 76 |
+
ComplexTransform transform_a,
|
| 77 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 78 |
+
ComplexTransform transform_b,
|
| 79 |
+
ScalarType beta,
|
| 80 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 81 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 82 |
+
ComputeType initial_accum,
|
| 83 |
+
FillMode fill_mode_c,
|
| 84 |
+
BlasMode blas_mode,
|
| 85 |
+
int batch_count = 1,
|
| 86 |
+
int64_t batch_stride_A = 0,
|
| 87 |
+
int64_t batch_stride_B = 0,
|
| 88 |
+
int64_t batch_stride_C = 0,
|
| 89 |
+
int64_t batch_stride_D = 0) {
|
| 90 |
+
|
| 91 |
+
static_assert(
|
| 92 |
+
LayoutA::kRank == 2 &&
|
| 93 |
+
LayoutB::kRank == 2 &&
|
| 94 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 95 |
+
|
| 96 |
+
int const M = problem_size.m();
|
| 97 |
+
int const N = problem_size.n();
|
| 98 |
+
int const K = problem_size.k();
|
| 99 |
+
|
| 100 |
+
assert(M=N);
|
| 101 |
+
|
| 102 |
+
ConvertOp convert_op;
|
| 103 |
+
InnerProductOp inner_product_op;
|
| 104 |
+
|
| 105 |
+
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
| 106 |
+
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
| 107 |
+
int batch_idx = blockIdx.z;
|
| 108 |
+
|
| 109 |
+
tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
|
| 110 |
+
tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
|
| 111 |
+
tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
|
| 112 |
+
tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
|
| 113 |
+
|
| 114 |
+
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
| 115 |
+
|
| 116 |
+
// Compute matrix product using blocks
|
| 117 |
+
ComputeType accum[kMblock][kNblock];
|
| 118 |
+
|
| 119 |
+
CUTLASS_PRAGMA_UNROLL
|
| 120 |
+
for (int j = 0; j < kNblock; j++) {
|
| 121 |
+
CUTLASS_PRAGMA_UNROLL
|
| 122 |
+
for (int i = 0; i < kMblock; i++) {
|
| 123 |
+
accum[i][j] = initial_accum;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 128 |
+
CUTLASS_PRAGMA_UNROLL
|
| 129 |
+
for (int j = 0; j < kNblock; j++) {
|
| 130 |
+
CUTLASS_PRAGMA_UNROLL
|
| 131 |
+
for (int i = 0; i < kMblock; i++) {
|
| 132 |
+
int row = row_block + i;
|
| 133 |
+
int col = col_block + j;
|
| 134 |
+
|
| 135 |
+
if (row < M && col < N &&
|
| 136 |
+
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
| 137 |
+
(fill_mode_c == FillMode::kUpper && row <= col) )
|
| 138 |
+
) {
|
| 139 |
+
|
| 140 |
+
// A x B^T (Symmetric) or A x B^H (Hermitian)
|
| 141 |
+
// complex conjugation on operandB (b_t) is function of blas3 computation
|
| 142 |
+
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 143 |
+
ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
|
| 144 |
+
conj(tensor_b.at(MatrixCoord(col, k_block))) :
|
| 145 |
+
tensor_b.at(MatrixCoord(col, k_block));
|
| 146 |
+
|
| 147 |
+
ComputeType a_ik = ComputeType(a);
|
| 148 |
+
ComputeType b_jk = ComputeType(b_t);
|
| 149 |
+
|
| 150 |
+
// complex conjugation is a function of operand layouts
|
| 151 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 152 |
+
a_ik = conj(a_ik);
|
| 153 |
+
}
|
| 154 |
+
// complex conjugation is a function of operand layouts
|
| 155 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 156 |
+
b_jk = conj(b_jk);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
|
| 160 |
+
|
| 161 |
+
// B x A^T (Symmetric) or B x A^H (Hermitian)
|
| 162 |
+
// complex conjugation on operandB (a_t) is function of blas3 computation
|
| 163 |
+
ElementB b = tensor_b.at(MatrixCoord(row, k_block));
|
| 164 |
+
ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
|
| 165 |
+
conj(tensor_a.at(MatrixCoord(col, k_block))):
|
| 166 |
+
tensor_a.at(MatrixCoord(col, k_block));
|
| 167 |
+
|
| 168 |
+
ComputeType b_ik = ComputeType(b);
|
| 169 |
+
ComputeType a_jk = ComputeType(a_t);
|
| 170 |
+
|
| 171 |
+
// complex conjugation here is a function of operand layouts
|
| 172 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 173 |
+
b_ik = conj(b_ik);
|
| 174 |
+
}
|
| 175 |
+
// complex conjugation here is a function of operand layouts
|
| 176 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 177 |
+
a_jk = conj(a_jk);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
CUTLASS_PRAGMA_UNROLL
|
| 187 |
+
for (int j = 0; j < kNblock; j++) {
|
| 188 |
+
CUTLASS_PRAGMA_UNROLL
|
| 189 |
+
for (int i = 0; i < kMblock; i++) {
|
| 190 |
+
int row = row_block + i;
|
| 191 |
+
int col = col_block + j;
|
| 192 |
+
|
| 193 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 194 |
+
|
| 195 |
+
if (row < M && col < N &&
|
| 196 |
+
((fill_mode_c == FillMode::kLower && row >= col) ||
|
| 197 |
+
(fill_mode_c == FillMode::kUpper && row <= col))
|
| 198 |
+
) {
|
| 199 |
+
|
| 200 |
+
ScalarType c = tensor_c.at(coord);
|
| 201 |
+
// The imaginary parts of the diagonal elements of
|
| 202 |
+
// a complex data type are assumed and set to zero
|
| 203 |
+
if (blas_mode == BlasMode::kHermitian) {
|
| 204 |
+
c = (row == col) ? real(c) : c;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
tensor_d.at(coord) = convert_op(
|
| 208 |
+
alpha * ScalarType(accum[i][j]) +
|
| 209 |
+
beta * c);
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
|
| 215 |
+
tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
|
| 216 |
+
tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
|
| 217 |
+
tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
|
| 218 |
+
|
| 219 |
+
} // for (batch_idx)
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
} // namespace kernel
|
| 223 |
+
|
| 224 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 225 |
+
|
| 226 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 227 |
+
/// objects.
|
| 228 |
+
///
|
| 229 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 230 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 231 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 232 |
+
/// arguments explicitly.
|
| 233 |
+
template <
|
| 234 |
+
typename ElementA,
|
| 235 |
+
typename LayoutA,
|
| 236 |
+
typename ElementB,
|
| 237 |
+
typename LayoutB,
|
| 238 |
+
typename ElementC,
|
| 239 |
+
typename LayoutC,
|
| 240 |
+
typename ScalarType,
|
| 241 |
+
typename ComputeType,
|
| 242 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 243 |
+
typename InnerProductOp = multiply_add<ComputeType>
|
| 244 |
+
>
|
| 245 |
+
void Rank2KComplex(
|
| 246 |
+
gemm::GemmCoord problem_size,
|
| 247 |
+
ScalarType alpha,
|
| 248 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 249 |
+
ComplexTransform transform_a,
|
| 250 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 251 |
+
ComplexTransform transform_b,
|
| 252 |
+
ScalarType beta,
|
| 253 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 254 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 255 |
+
ComputeType initial_accum,
|
| 256 |
+
FillMode fill_mode_c,
|
| 257 |
+
BlasMode blas_mode,
|
| 258 |
+
int batch_count = 1,
|
| 259 |
+
int64_t batch_stride_A = 0,
|
| 260 |
+
int64_t batch_stride_B = 0,
|
| 261 |
+
int64_t batch_stride_C = 0,
|
| 262 |
+
int64_t batch_stride_D = 0) {
|
| 263 |
+
|
| 264 |
+
static_assert(
|
| 265 |
+
LayoutA::kRank == 2 &&
|
| 266 |
+
LayoutB::kRank == 2 &&
|
| 267 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 268 |
+
|
| 269 |
+
int const kMblock = 4;
|
| 270 |
+
int const kNblock = 4;
|
| 271 |
+
|
| 272 |
+
dim3 block(16, 8);
|
| 273 |
+
dim3 grid(
|
| 274 |
+
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
| 275 |
+
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
| 276 |
+
batch_count % std::numeric_limits<uint16_t>::max()
|
| 277 |
+
);
|
| 278 |
+
|
| 279 |
+
kernel::Rank2KComplex<
|
| 280 |
+
ElementA,
|
| 281 |
+
LayoutA,
|
| 282 |
+
ElementB,
|
| 283 |
+
LayoutB,
|
| 284 |
+
ElementC,
|
| 285 |
+
LayoutC,
|
| 286 |
+
ScalarType,
|
| 287 |
+
ComputeType,
|
| 288 |
+
ConvertOp,
|
| 289 |
+
InnerProductOp,
|
| 290 |
+
kMblock,
|
| 291 |
+
kNblock
|
| 292 |
+
><<< grid, block >>>(
|
| 293 |
+
problem_size,
|
| 294 |
+
alpha,
|
| 295 |
+
tensor_a,
|
| 296 |
+
transform_a,
|
| 297 |
+
tensor_b,
|
| 298 |
+
transform_b,
|
| 299 |
+
beta,
|
| 300 |
+
tensor_c,
|
| 301 |
+
tensor_d,
|
| 302 |
+
initial_accum,
|
| 303 |
+
fill_mode_c,
|
| 304 |
+
blas_mode,
|
| 305 |
+
batch_count,
|
| 306 |
+
batch_stride_A,
|
| 307 |
+
batch_stride_B,
|
| 308 |
+
batch_stride_C,
|
| 309 |
+
batch_stride_D
|
| 310 |
+
);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 314 |
+
|
| 315 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 316 |
+
/// objects.
|
| 317 |
+
///
|
| 318 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 319 |
+
template <
|
| 320 |
+
typename ElementA,
|
| 321 |
+
typename LayoutA,
|
| 322 |
+
typename ElementB,
|
| 323 |
+
typename LayoutB,
|
| 324 |
+
typename ElementC,
|
| 325 |
+
typename LayoutC,
|
| 326 |
+
typename ScalarType
|
| 327 |
+
>
|
| 328 |
+
void Rank2KComplex(
|
| 329 |
+
gemm::GemmCoord problem_size,
|
| 330 |
+
ScalarType alpha,
|
| 331 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 332 |
+
ComplexTransform transform_a,
|
| 333 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 334 |
+
ComplexTransform transform_b,
|
| 335 |
+
ScalarType beta,
|
| 336 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 337 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 338 |
+
FillMode fill_mode_c,
|
| 339 |
+
BlasMode blas_mode) {
|
| 340 |
+
|
| 341 |
+
Rank2KComplex(
|
| 342 |
+
problem_size, alpha,
|
| 343 |
+
tensor_a, transform_a,
|
| 344 |
+
tensor_b, transform_b,
|
| 345 |
+
beta, tensor_c, tensor_d,
|
| 346 |
+
ScalarType(0),
|
| 347 |
+
fill_mode_c,
|
| 348 |
+
blas_mode);
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 352 |
+
|
| 353 |
+
} // namespace device
|
| 354 |
+
} // namespace reference
|
| 355 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Defines host-side elementwise operations on TensorView.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
// Standard Library includes
|
| 37 |
+
#include <utility>
|
| 38 |
+
|
| 39 |
+
// Cutlass includes
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/relatively_equal.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/util/distribution.h"
|
| 44 |
+
|
| 45 |
+
#include "tensor_foreach.h"
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace reference {
|
| 49 |
+
namespace device {
|
| 50 |
+
|
| 51 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
namespace kernel {
|
| 54 |
+
|
| 55 |
+
template <typename Element>
|
| 56 |
+
__global__ void BlockCompareEqual(
|
| 57 |
+
int *equal,
|
| 58 |
+
Element const *ptr_A,
|
| 59 |
+
Element const *ptr_B,
|
| 60 |
+
size_t capacity) {
|
| 61 |
+
|
| 62 |
+
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
|
| 63 |
+
|
| 64 |
+
for (; idx < capacity; idx += gridDim.x * blockDim.x) {
|
| 65 |
+
|
| 66 |
+
Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
|
| 67 |
+
Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
|
| 68 |
+
|
| 69 |
+
if (a != b) {
|
| 70 |
+
*equal = 0;
|
| 71 |
+
|
| 72 |
+
return;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template <typename Element>
|
| 78 |
+
__global__ void BlockCompareRelativelyEqual(
|
| 79 |
+
int *equal,
|
| 80 |
+
Element const *ptr_A,
|
| 81 |
+
Element const *ptr_B,
|
| 82 |
+
size_t capacity,
|
| 83 |
+
Element epsilon,
|
| 84 |
+
Element nonzero_floor) {
|
| 85 |
+
|
| 86 |
+
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
|
| 87 |
+
|
| 88 |
+
for (; idx < capacity; idx += gridDim.x * blockDim.x) {
|
| 89 |
+
|
| 90 |
+
Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
|
| 91 |
+
Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
|
| 92 |
+
|
| 93 |
+
if (!relatively_equal(a, b, epsilon, nonzero_floor)) {
|
| 94 |
+
*equal = 0;
|
| 95 |
+
return;
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
} // namespace kernel
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 104 |
+
|
| 105 |
+
/// Performs a bit-level equality check between two blocks
|
| 106 |
+
template <typename Element>
|
| 107 |
+
bool BlockCompareEqual(
|
| 108 |
+
Element const *ptr_A,
|
| 109 |
+
Element const *ptr_B,
|
| 110 |
+
size_t capacity,
|
| 111 |
+
int grid_size = 0,
|
| 112 |
+
int block_size = 0,
|
| 113 |
+
cudaStream_t stream = nullptr) {
|
| 114 |
+
|
| 115 |
+
int equal_flag = 1;
|
| 116 |
+
int *device_equal_flag = nullptr;
|
| 117 |
+
|
| 118 |
+
if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
|
| 119 |
+
throw std::runtime_error("Failed to allocate device flag.");
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
if (cudaMemcpy(
|
| 123 |
+
device_equal_flag,
|
| 124 |
+
&equal_flag,
|
| 125 |
+
sizeof(int),
|
| 126 |
+
cudaMemcpyHostToDevice) != cudaSuccess) {
|
| 127 |
+
|
| 128 |
+
throw std::runtime_error("Failed to copy equality flag to device.");
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
if (!grid_size || !block_size) {
|
| 132 |
+
|
| 133 |
+
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 134 |
+
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 135 |
+
&grid_size,
|
| 136 |
+
&block_size,
|
| 137 |
+
reinterpret_cast<void const *>(kernel::BlockCompareEqual<Element>));
|
| 138 |
+
|
| 139 |
+
if (result != cudaSuccess) {
|
| 140 |
+
throw std::runtime_error("Failed to query occupancy.");
|
| 141 |
+
}
|
| 142 |
+
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 143 |
+
// single thread and reduces the impact of initialization overhead.
|
| 144 |
+
block_size = (block_size < 128 ? block_size : 128);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
dim3 grid(grid_size, 1, 1);
|
| 148 |
+
dim3 block(block_size, 1, 1);
|
| 149 |
+
|
| 150 |
+
kernel::BlockCompareEqual<Element><<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity);
|
| 151 |
+
|
| 152 |
+
cudaStreamSynchronize(stream);
|
| 153 |
+
|
| 154 |
+
if (cudaMemcpy(
|
| 155 |
+
&equal_flag,
|
| 156 |
+
device_equal_flag,
|
| 157 |
+
sizeof(int),
|
| 158 |
+
cudaMemcpyDeviceToHost) != cudaSuccess) {
|
| 159 |
+
|
| 160 |
+
cudaFree(device_equal_flag);
|
| 161 |
+
|
| 162 |
+
throw std::runtime_error("Failed to copy equality flag from device.");
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
cudaFree(device_equal_flag);
|
| 166 |
+
|
| 167 |
+
return equal_flag;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 171 |
+
|
| 172 |
+
/// Performs a bit-level equality check between two blocks
|
| 173 |
+
template <typename Element>
|
| 174 |
+
bool BlockCompareRelativelyEqual(
|
| 175 |
+
Element const *ptr_A,
|
| 176 |
+
Element const *ptr_B,
|
| 177 |
+
size_t capacity,
|
| 178 |
+
Element epsilon,
|
| 179 |
+
Element nonzero_floor,
|
| 180 |
+
int grid_size = 0,
|
| 181 |
+
int block_size = 0,
|
| 182 |
+
cudaStream_t stream = nullptr) {
|
| 183 |
+
|
| 184 |
+
int equal_flag = 1;
|
| 185 |
+
int *device_equal_flag = nullptr;
|
| 186 |
+
|
| 187 |
+
if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
|
| 188 |
+
throw std::runtime_error("Failed to allocate device flag.");
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
if (cudaMemcpy(
|
| 192 |
+
device_equal_flag,
|
| 193 |
+
&equal_flag,
|
| 194 |
+
sizeof(int),
|
| 195 |
+
cudaMemcpyHostToDevice) != cudaSuccess) {
|
| 196 |
+
|
| 197 |
+
throw std::runtime_error("Failed to copy equality flag to device.");
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
if (!grid_size || !block_size) {
|
| 201 |
+
|
| 202 |
+
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 203 |
+
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 204 |
+
&grid_size,
|
| 205 |
+
&block_size,
|
| 206 |
+
reinterpret_cast<void const *>(kernel::BlockCompareRelativelyEqual<Element>));
|
| 207 |
+
|
| 208 |
+
if (result != cudaSuccess) {
|
| 209 |
+
throw std::runtime_error("Failed to query occupancy.");
|
| 210 |
+
}
|
| 211 |
+
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 212 |
+
// single thread and reduces the impact of initialization overhead.
|
| 213 |
+
block_size = (block_size < 128 ? block_size : 128);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
dim3 grid(grid_size, 1, 1);
|
| 217 |
+
dim3 block(block_size, 1, 1);
|
| 218 |
+
|
| 219 |
+
kernel::BlockCompareRelativelyEqual<Element><<< grid, block, 0, stream >>>(
|
| 220 |
+
device_equal_flag,
|
| 221 |
+
ptr_A,
|
| 222 |
+
ptr_B,
|
| 223 |
+
capacity,
|
| 224 |
+
epsilon,
|
| 225 |
+
nonzero_floor
|
| 226 |
+
);
|
| 227 |
+
|
| 228 |
+
cudaStreamSynchronize(stream);
|
| 229 |
+
|
| 230 |
+
if (cudaMemcpy(
|
| 231 |
+
&equal_flag,
|
| 232 |
+
device_equal_flag,
|
| 233 |
+
sizeof(int),
|
| 234 |
+
cudaMemcpyDeviceToHost) != cudaSuccess) {
|
| 235 |
+
|
| 236 |
+
cudaFree(device_equal_flag);
|
| 237 |
+
|
| 238 |
+
throw std::runtime_error("Failed to copy equality flag from device.");
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
cudaFree(device_equal_flag);
|
| 242 |
+
|
| 243 |
+
return equal_flag;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 247 |
+
|
| 248 |
+
} // device
|
| 249 |
+
} // reference
|
| 250 |
+
} // cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h
ADDED
|
@@ -0,0 +1,2075 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Defines device-side elementwise operations on TensorView. Note, the operations defined
|
| 33 |
+
in this header are not specialized for any particular data layout and are therefore not
|
| 34 |
+
intended to offer the best possible performance. Rather, they are intended to be generic
|
| 35 |
+
reference implementations to support the CUTLASS unit tests.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#if !defined(__CUDACC_RTC__)
|
| 41 |
+
|
| 42 |
+
// Standard Library includes
|
| 43 |
+
#include <utility>
|
| 44 |
+
#include <cstdlib>
|
| 45 |
+
#include <cmath>
|
| 46 |
+
#include <type_traits>
|
| 47 |
+
#include <cstdint>
|
| 48 |
+
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
// CUDA includes
|
| 52 |
+
#include <curand_kernel.h>
|
| 53 |
+
|
| 54 |
+
// Cutlass includes
|
| 55 |
+
#include "cutlass/cutlass.h"
|
| 56 |
+
#include "cutlass/array.h"
|
| 57 |
+
#include "cutlass/complex.h"
|
| 58 |
+
#include "cutlass/tensor_view.h"
|
| 59 |
+
#include "cutlass/blas3.h"
|
| 60 |
+
#include "cutlass/numeric_types.h"
|
| 61 |
+
|
| 62 |
+
#include "cutlass/layout/vector.h"
|
| 63 |
+
|
| 64 |
+
#include "cutlass/util/reference/device/tensor_foreach.h"
|
| 65 |
+
#include "cutlass/util/distribution.h"
|
| 66 |
+
|
| 67 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 68 |
+
|
| 69 |
+
namespace cutlass {
|
| 70 |
+
namespace reference {
|
| 71 |
+
namespace device {
|
| 72 |
+
|
| 73 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 74 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 75 |
+
|
| 76 |
+
namespace detail {
|
| 77 |
+
|
| 78 |
+
template <typename FloatType>
|
| 79 |
+
CUTLASS_DEVICE
|
| 80 |
+
FloatType random_normal_float(curandState_t *state) {
|
| 81 |
+
return curand_normal(state);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
template <>
|
| 85 |
+
CUTLASS_DEVICE
|
| 86 |
+
double random_normal_float<double>(curandState_t *state) {
|
| 87 |
+
return curand_normal_double(state);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
template <typename FloatType>
|
| 91 |
+
CUTLASS_DEVICE
|
| 92 |
+
FloatType random_uniform_float(curandState_t *state) {
|
| 93 |
+
return curand_uniform(state);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <>
|
| 97 |
+
CUTLASS_DEVICE
|
| 98 |
+
double random_uniform_float<double>(curandState_t *state) {
|
| 99 |
+
return curand_uniform_double(state);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
template <typename Element>
|
| 103 |
+
struct RandomGaussianFunc {
|
| 104 |
+
|
| 105 |
+
using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type;
|
| 106 |
+
using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type;
|
| 107 |
+
|
| 108 |
+
/// Parameters structure
|
| 109 |
+
struct Params {
|
| 110 |
+
|
| 111 |
+
//
|
| 112 |
+
// Data members
|
| 113 |
+
//
|
| 114 |
+
|
| 115 |
+
uint64_t seed;
|
| 116 |
+
FloatType mean;
|
| 117 |
+
FloatType stddev;
|
| 118 |
+
int int_scale;
|
| 119 |
+
FloatType float_scale_up;
|
| 120 |
+
FloatType float_scale_down;
|
| 121 |
+
int exclude_zero; ///< If non-negative, excludes zeros
|
| 122 |
+
|
| 123 |
+
//
|
| 124 |
+
// Methods
|
| 125 |
+
//
|
| 126 |
+
|
| 127 |
+
/// Construction of Gaussian RNG functor.
|
| 128 |
+
Params(
|
| 129 |
+
uint64_t seed_ = 0,
|
| 130 |
+
Element mean_ = 0,
|
| 131 |
+
Element stddev_ = 1,
|
| 132 |
+
int int_scale_ = -1,
|
| 133 |
+
int exclude_zero_ = -1
|
| 134 |
+
):
|
| 135 |
+
seed(seed_),
|
| 136 |
+
mean(static_cast<FloatType>(mean_)),
|
| 137 |
+
stddev(static_cast<FloatType>(stddev_)),
|
| 138 |
+
int_scale(int_scale_),
|
| 139 |
+
exclude_zero(exclude_zero_) {
|
| 140 |
+
|
| 141 |
+
float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
|
| 142 |
+
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 143 |
+
}
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
//
|
| 147 |
+
// Data members
|
| 148 |
+
//
|
| 149 |
+
|
| 150 |
+
/// Parameters object
|
| 151 |
+
Params params;
|
| 152 |
+
|
| 153 |
+
/// RNG state object
|
| 154 |
+
curandState_t rng_state;
|
| 155 |
+
|
| 156 |
+
//
|
| 157 |
+
// Methods
|
| 158 |
+
//
|
| 159 |
+
|
| 160 |
+
/// Device-side initialization of RNG
|
| 161 |
+
CUTLASS_DEVICE
|
| 162 |
+
RandomGaussianFunc(Params const ¶ms): params(params) {
|
| 163 |
+
|
| 164 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 165 |
+
|
| 166 |
+
curand_init(params.seed, gtid, 0, &rng_state);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/// Compute random value and update RNG state
|
| 170 |
+
CUTLASS_DEVICE
|
| 171 |
+
Element operator()() {
|
| 172 |
+
|
| 173 |
+
FloatType rnd = random_normal_float<FloatType>(&rng_state);
|
| 174 |
+
rnd = params.mean + params.stddev * rnd;
|
| 175 |
+
|
| 176 |
+
Element result;
|
| 177 |
+
if (params.int_scale >= 0) {
|
| 178 |
+
rnd = FloatType(std::llround(rnd * params.float_scale_up));
|
| 179 |
+
result = Element(rnd * params.float_scale_down);
|
| 180 |
+
}
|
| 181 |
+
else {
|
| 182 |
+
result = Element(rnd);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
if (params.exclude_zero >=0 && result == Element(0.0)) {
|
| 186 |
+
if (rnd > FloatType(0)) {
|
| 187 |
+
rnd += FloatType(1);
|
| 188 |
+
} else {
|
| 189 |
+
rnd -= FloatType(1);
|
| 190 |
+
}
|
| 191 |
+
result = Element(rnd);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
return result;
|
| 195 |
+
}
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
template <typename Real>
|
| 200 |
+
struct RandomGaussianFunc<complex<Real>> {
|
| 201 |
+
|
| 202 |
+
using Element = complex<Real>;
|
| 203 |
+
using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type;
|
| 204 |
+
using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type;
|
| 205 |
+
|
| 206 |
+
/// Parameters structure
|
| 207 |
+
struct Params {
|
| 208 |
+
|
| 209 |
+
//
|
| 210 |
+
// Data members
|
| 211 |
+
//
|
| 212 |
+
|
| 213 |
+
uint64_t seed;
|
| 214 |
+
FloatType mean;
|
| 215 |
+
FloatType stddev;
|
| 216 |
+
int int_scale;
|
| 217 |
+
FloatType float_scale_up;
|
| 218 |
+
FloatType float_scale_down;
|
| 219 |
+
int exclude_zero; ///< If non-negative, excludes zeros
|
| 220 |
+
|
| 221 |
+
//
|
| 222 |
+
// Methods
|
| 223 |
+
//
|
| 224 |
+
|
| 225 |
+
/// Construction of Gaussian RNG functor.
|
| 226 |
+
Params(
|
| 227 |
+
uint64_t seed_ = 0,
|
| 228 |
+
Real mean_ = 0,
|
| 229 |
+
Real stddev_ = 1,
|
| 230 |
+
int int_scale_ = -1,
|
| 231 |
+
int exclude_zero_ = -1
|
| 232 |
+
):
|
| 233 |
+
seed(seed_),
|
| 234 |
+
mean(static_cast<FloatType>(mean_)),
|
| 235 |
+
stddev(static_cast<FloatType>(stddev_)),
|
| 236 |
+
int_scale(int_scale_),
|
| 237 |
+
exclude_zero(exclude_zero_) {
|
| 238 |
+
|
| 239 |
+
float_scale_up = FloatType(IntType(1) << int_scale);
|
| 240 |
+
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 241 |
+
}
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
//
|
| 245 |
+
// Data members
|
| 246 |
+
//
|
| 247 |
+
|
| 248 |
+
/// Parameters object
|
| 249 |
+
Params params;
|
| 250 |
+
|
| 251 |
+
/// RNG state object
|
| 252 |
+
curandState_t rng_state;
|
| 253 |
+
|
| 254 |
+
//
|
| 255 |
+
// Methods
|
| 256 |
+
//
|
| 257 |
+
|
| 258 |
+
/// Device-side initialization of RNG
|
| 259 |
+
CUTLASS_DEVICE
|
| 260 |
+
RandomGaussianFunc(Params const ¶ms): params(params) {
|
| 261 |
+
|
| 262 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 263 |
+
|
| 264 |
+
curand_init(params.seed, gtid, 0, &rng_state);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/// Compute random value and update RNG state
|
| 268 |
+
CUTLASS_DEVICE
|
| 269 |
+
Element operator()() {
|
| 270 |
+
|
| 271 |
+
FloatType rnd_r = random_normal_float<FloatType>(&rng_state);
|
| 272 |
+
FloatType rnd_i = random_normal_float<FloatType>(&rng_state);
|
| 273 |
+
rnd_r = params.mean + params.stddev * rnd_r;
|
| 274 |
+
rnd_i = params.mean + params.stddev * rnd_i;
|
| 275 |
+
|
| 276 |
+
Element result;
|
| 277 |
+
if (params.int_scale >= 0) {
|
| 278 |
+
rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
|
| 279 |
+
rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
|
| 280 |
+
|
| 281 |
+
result = {
|
| 282 |
+
Real(rnd_r * params.float_scale_down),
|
| 283 |
+
Real(rnd_i * params.float_scale_down)
|
| 284 |
+
};
|
| 285 |
+
}
|
| 286 |
+
else {
|
| 287 |
+
result = Element(Real(rnd_r), Real(rnd_i));
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
if (params.exclude_zero >= 0 &&
|
| 291 |
+
result.real() == Real(0.0) &&
|
| 292 |
+
result.imag() == Real(0.0)) {
|
| 293 |
+
|
| 294 |
+
if (rnd_r > FloatType(0)) {
|
| 295 |
+
rnd_r += FloatType(1);
|
| 296 |
+
} else {
|
| 297 |
+
rnd_r -= FloatType(1);
|
| 298 |
+
}
|
| 299 |
+
result = Element(Real(rnd_r), Real(rnd_i));
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
return result;
|
| 303 |
+
}
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
/// Computes a random Gaussian distribution
|
| 307 |
+
template <
|
| 308 |
+
typename Element, ///< Element type
|
| 309 |
+
typename Layout> ///< Layout function
|
| 310 |
+
struct TensorFillRandomGaussianFunc {
|
| 311 |
+
|
| 312 |
+
/// View type
|
| 313 |
+
using TensorView = TensorView<Element, Layout>;
|
| 314 |
+
|
| 315 |
+
/// Scalar type
|
| 316 |
+
typedef typename TensorView::Element T;
|
| 317 |
+
|
| 318 |
+
/// Coordinate in tensor's index space
|
| 319 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 320 |
+
|
| 321 |
+
using RandomFunc = RandomGaussianFunc<Element>;
|
| 322 |
+
|
| 323 |
+
/// Parameters structure
|
| 324 |
+
struct Params {
|
| 325 |
+
|
| 326 |
+
//
|
| 327 |
+
// Data members
|
| 328 |
+
//
|
| 329 |
+
|
| 330 |
+
TensorView view;
|
| 331 |
+
typename RandomFunc::Params random;
|
| 332 |
+
|
| 333 |
+
//
|
| 334 |
+
// Methods
|
| 335 |
+
//
|
| 336 |
+
|
| 337 |
+
/// Construction of Gaussian RNG functor.
|
| 338 |
+
Params(
|
| 339 |
+
TensorView view_ = TensorView(),
|
| 340 |
+
typename RandomFunc::Params random_ = typename RandomFunc::Params()
|
| 341 |
+
):
|
| 342 |
+
view(view_), random(random_) {
|
| 343 |
+
|
| 344 |
+
}
|
| 345 |
+
};
|
| 346 |
+
|
| 347 |
+
//
|
| 348 |
+
// Data members
|
| 349 |
+
//
|
| 350 |
+
|
| 351 |
+
Params params;
|
| 352 |
+
RandomFunc random;
|
| 353 |
+
|
| 354 |
+
//
|
| 355 |
+
// Methods
|
| 356 |
+
//
|
| 357 |
+
|
| 358 |
+
/// Device-side initialization of RNG
|
| 359 |
+
CUTLASS_DEVICE
|
| 360 |
+
TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) {
|
| 361 |
+
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
/// Compute random value and update RNG state
|
| 365 |
+
CUTLASS_DEVICE
|
| 366 |
+
void operator()(TensorCoord const &coord) {
|
| 367 |
+
|
| 368 |
+
params.view.at(coord) = random();
|
| 369 |
+
}
|
| 370 |
+
};
|
| 371 |
+
|
| 372 |
+
} // namespace detail
|
| 373 |
+
|
| 374 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 375 |
+
|
| 376 |
+
/// Fills a tensor with random values with a Gaussian distribution.
|
| 377 |
+
template <
|
| 378 |
+
typename Element, ///< Element type
|
| 379 |
+
typename Layout> ///< Layout function
|
| 380 |
+
void TensorFillRandomGaussian(
|
| 381 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 382 |
+
uint64_t seed, ///< seed for RNG
|
| 383 |
+
typename RealType<Element>::Type mean = Element(0), ///< Gaussian distribution's mean
|
| 384 |
+
typename RealType<Element>::Type stddev = Element(1), ///< Gaussian distribution's standard deviation
|
| 385 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 386 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 387 |
+
/// data.
|
| 388 |
+
int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
|
| 389 |
+
cudaStream_t stream = nullptr) {
|
| 390 |
+
|
| 391 |
+
using RandomFunc = detail::RandomGaussianFunc<Element>;
|
| 392 |
+
using Func = detail::TensorFillRandomGaussianFunc<Element, Layout>;
|
| 393 |
+
using Params = typename Func::Params;
|
| 394 |
+
|
| 395 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 396 |
+
view.extent(),
|
| 397 |
+
Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)),
|
| 398 |
+
/*grid_size*/0, /*block_size*/0,
|
| 399 |
+
stream
|
| 400 |
+
);
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 404 |
+
|
| 405 |
+
/// Fills a tensor with random values with a Gaussian distribution.
|
| 406 |
+
template <typename Element> ///< Element type
|
| 407 |
+
void BlockFillRandomGaussian(
|
| 408 |
+
Element *ptr,
|
| 409 |
+
size_t capacity,
|
| 410 |
+
uint64_t seed, ///< seed for RNG
|
| 411 |
+
typename RealType<Element>::Type mean, ///< Gaussian distribution's mean
|
| 412 |
+
typename RealType<Element>::Type stddev, ///< Gaussian distribution's standard deviation
|
| 413 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 414 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 415 |
+
/// data.
|
| 416 |
+
cudaStream_t stream = nullptr) {
|
| 417 |
+
|
| 418 |
+
using RandomFunc = detail::RandomGaussianFunc<Element>;
|
| 419 |
+
|
| 420 |
+
typename RandomFunc::Params params(seed, mean, stddev, bits);
|
| 421 |
+
|
| 422 |
+
BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 426 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 427 |
+
|
| 428 |
+
namespace detail {
|
| 429 |
+
|
| 430 |
+
/// Computes a random uniform distribution
|
| 431 |
+
template <typename Element> ///< Element type
|
| 432 |
+
struct RandomUniformFunc {
|
| 433 |
+
|
| 434 |
+
using FloatType = typename std::conditional<
|
| 435 |
+
(sizeof(Element) > 4),
|
| 436 |
+
double,
|
| 437 |
+
float>::type;
|
| 438 |
+
|
| 439 |
+
using IntType = typename std::conditional<
|
| 440 |
+
(sizeof(Element) > 4),
|
| 441 |
+
int64_t,
|
| 442 |
+
int>::type;
|
| 443 |
+
|
| 444 |
+
/// Parameters structure
|
| 445 |
+
struct Params {
|
| 446 |
+
|
| 447 |
+
//
|
| 448 |
+
// Data members
|
| 449 |
+
//
|
| 450 |
+
|
| 451 |
+
uint64_t seed;
|
| 452 |
+
FloatType range;
|
| 453 |
+
FloatType max;
|
| 454 |
+
int int_scale;
|
| 455 |
+
double pnan;
|
| 456 |
+
FloatType float_scale_up;
|
| 457 |
+
FloatType float_scale_down;
|
| 458 |
+
int exclude_zero; ///< If non-negative, excludes zeros
|
| 459 |
+
|
| 460 |
+
/// Default ctor
|
| 461 |
+
CUTLASS_HOST_DEVICE
|
| 462 |
+
Params() { }
|
| 463 |
+
|
| 464 |
+
//
|
| 465 |
+
// Methods
|
| 466 |
+
//
|
| 467 |
+
|
| 468 |
+
/// Construction of Gaussian RNG functor.
|
| 469 |
+
Params(
|
| 470 |
+
uint64_t seed_ = 0,
|
| 471 |
+
Element max_ = 1,
|
| 472 |
+
Element min = 0,
|
| 473 |
+
int int_scale_ = -1,
|
| 474 |
+
double pnan_ = 0,
|
| 475 |
+
int exclude_zero_ = -1
|
| 476 |
+
):
|
| 477 |
+
seed(seed_),
|
| 478 |
+
range(static_cast<FloatType>(max_) - static_cast<FloatType>(min)),
|
| 479 |
+
max(static_cast<FloatType>(max_)),
|
| 480 |
+
int_scale(int_scale_),
|
| 481 |
+
pnan(pnan_),
|
| 482 |
+
exclude_zero(exclude_zero_) {
|
| 483 |
+
|
| 484 |
+
float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
|
| 485 |
+
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 486 |
+
|
| 487 |
+
// Handle cases where min = 0 or max = 0 for excluding zeros
|
| 488 |
+
if (exclude_zero >= 0) {
|
| 489 |
+
range = (min == Element(0)) ? range - FloatType(1): range;
|
| 490 |
+
max = (max_ == Element(0)) ? max - FloatType(1): max;
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
};
|
| 494 |
+
|
| 495 |
+
//
|
| 496 |
+
// Data members
|
| 497 |
+
//
|
| 498 |
+
|
| 499 |
+
/// Parameters object
|
| 500 |
+
Params params;
|
| 501 |
+
|
| 502 |
+
/// RNG state object
|
| 503 |
+
curandState_t rng_state;
|
| 504 |
+
|
| 505 |
+
//
|
| 506 |
+
// Methods
|
| 507 |
+
//
|
| 508 |
+
|
| 509 |
+
/// Device-side initialization of RNG
|
| 510 |
+
CUTLASS_DEVICE
|
| 511 |
+
RandomUniformFunc(Params const ¶ms): params(params) {
|
| 512 |
+
|
| 513 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 514 |
+
|
| 515 |
+
curand_init(params.seed, gtid, 0, &rng_state);
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
/// Compute random value and update RNG state
|
| 519 |
+
CUTLASS_DEVICE
|
| 520 |
+
Element operator()() {
|
| 521 |
+
|
| 522 |
+
// Draw random float in [0.0, 1.0] to determine if element should be NaN.
|
| 523 |
+
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
| 524 |
+
if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
|
| 525 |
+
return Element(NAN);
|
| 526 |
+
}
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
FloatType rnd = random_uniform_float<FloatType>(&rng_state);
|
| 530 |
+
rnd = params.max - params.range * rnd;
|
| 531 |
+
|
| 532 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 533 |
+
// testing
|
| 534 |
+
Element result;
|
| 535 |
+
|
| 536 |
+
if (params.int_scale >= 0) {
|
| 537 |
+
rnd = FloatType(std::llround(rnd * params.float_scale_up));
|
| 538 |
+
result = Element(rnd * params.float_scale_down);
|
| 539 |
+
}
|
| 540 |
+
else {
|
| 541 |
+
result = Element(rnd);
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
if (params.exclude_zero >=0 && result == Element(0.0)) {
|
| 545 |
+
if (rnd > FloatType(0)) {
|
| 546 |
+
rnd = std::min(params.max, rnd + FloatType(1));
|
| 547 |
+
} else {
|
| 548 |
+
rnd = std::max((params.max - params.range), rnd - FloatType(1));
|
| 549 |
+
}
|
| 550 |
+
result = Element(rnd);
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
return result;
|
| 554 |
+
}
|
| 555 |
+
};
|
| 556 |
+
|
| 557 |
+
/// Computes a random Gaussian distribution
|
| 558 |
+
template <typename Real>
|
| 559 |
+
struct RandomUniformFunc<complex<Real>> {
|
| 560 |
+
|
| 561 |
+
using Element = complex<Real>;
|
| 562 |
+
|
| 563 |
+
using FloatType = typename std::conditional<
|
| 564 |
+
(sizeof(Real) > 4),
|
| 565 |
+
double,
|
| 566 |
+
float>::type;
|
| 567 |
+
|
| 568 |
+
using IntType = typename std::conditional<
|
| 569 |
+
(sizeof(Real) > 4),
|
| 570 |
+
int64_t,
|
| 571 |
+
int>::type;
|
| 572 |
+
|
| 573 |
+
/// Parameters structure
|
| 574 |
+
struct Params {
|
| 575 |
+
|
| 576 |
+
//
|
| 577 |
+
// Data members
|
| 578 |
+
//
|
| 579 |
+
|
| 580 |
+
uint64_t seed;
|
| 581 |
+
FloatType range;
|
| 582 |
+
FloatType min;
|
| 583 |
+
int int_scale;
|
| 584 |
+
double pnan;
|
| 585 |
+
FloatType float_scale_up;
|
| 586 |
+
FloatType float_scale_down;
|
| 587 |
+
int exclude_zero; ///< If non-negative, excludes zeros
|
| 588 |
+
|
| 589 |
+
/// Default ctor
|
| 590 |
+
CUTLASS_HOST_DEVICE
|
| 591 |
+
Params() { }
|
| 592 |
+
|
| 593 |
+
//
|
| 594 |
+
// Methods
|
| 595 |
+
//
|
| 596 |
+
|
| 597 |
+
/// Construction of Gaussian RNG functor.
|
| 598 |
+
Params(
|
| 599 |
+
uint64_t seed_ = 0,
|
| 600 |
+
FloatType max = 1,
|
| 601 |
+
FloatType min_ = 0,
|
| 602 |
+
int int_scale_ = -1,
|
| 603 |
+
double pnan_ = 0,
|
| 604 |
+
int exclude_zero_ = -1
|
| 605 |
+
):
|
| 606 |
+
seed(seed_),
|
| 607 |
+
range(static_cast<FloatType>(max - min_)),
|
| 608 |
+
min(static_cast<FloatType>(min_)),
|
| 609 |
+
int_scale(int_scale_),
|
| 610 |
+
pnan(pnan_),
|
| 611 |
+
exclude_zero(exclude_zero_) {
|
| 612 |
+
|
| 613 |
+
float_scale_up = FloatType(IntType(1) << int_scale);
|
| 614 |
+
float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
|
| 615 |
+
|
| 616 |
+
// Handle cases where min = 0 or max = 0 for excluding zeros
|
| 617 |
+
if (exclude_zero >= 0) {
|
| 618 |
+
min = (min == FloatType(0)) ? min + FloatType(1): min;
|
| 619 |
+
range = (max == FloatType(0)) ? range - FloatType(1): range;
|
| 620 |
+
}
|
| 621 |
+
}
|
| 622 |
+
};
|
| 623 |
+
|
| 624 |
+
//
|
| 625 |
+
// Data members
|
| 626 |
+
//
|
| 627 |
+
|
| 628 |
+
/// Parameters object
|
| 629 |
+
Params params;
|
| 630 |
+
|
| 631 |
+
/// RNG state object
|
| 632 |
+
curandState_t rng_state;
|
| 633 |
+
|
| 634 |
+
//
|
| 635 |
+
// Methods
|
| 636 |
+
//
|
| 637 |
+
|
| 638 |
+
/// Device-side initialization of RNG
|
| 639 |
+
CUTLASS_DEVICE
|
| 640 |
+
RandomUniformFunc(Params const ¶ms): params(params) {
|
| 641 |
+
|
| 642 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 643 |
+
|
| 644 |
+
curand_init(params.seed, gtid, 0, &rng_state);
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
/// Compute random value and update RNG state
|
| 648 |
+
CUTLASS_DEVICE
|
| 649 |
+
Element operator()() {
|
| 650 |
+
|
| 651 |
+
// Draw random float in [0.0, 1.0] to determine if element should be NaN.
|
| 652 |
+
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
| 653 |
+
if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
|
| 654 |
+
return Element(Real(NAN), Real(NAN));
|
| 655 |
+
}
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
FloatType rnd_r = random_uniform_float<FloatType>(&rng_state);
|
| 659 |
+
FloatType rnd_i = random_uniform_float<FloatType>(&rng_state);
|
| 660 |
+
|
| 661 |
+
rnd_r = params.min + params.range * rnd_r;
|
| 662 |
+
rnd_i = params.min + params.range * rnd_i;
|
| 663 |
+
|
| 664 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 665 |
+
// testing
|
| 666 |
+
Element result;
|
| 667 |
+
|
| 668 |
+
if (params.int_scale >= 0) {
|
| 669 |
+
rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
|
| 670 |
+
rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
|
| 671 |
+
|
| 672 |
+
result = {
|
| 673 |
+
Real(rnd_r * params.float_scale_down),
|
| 674 |
+
Real(rnd_i * params.float_scale_down)
|
| 675 |
+
};
|
| 676 |
+
}
|
| 677 |
+
else {
|
| 678 |
+
result = Element(Real(rnd_r), Real(rnd_i));
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
if (params.exclude_zero >= 0 &&
|
| 682 |
+
result.real() == Real(0.0) &&
|
| 683 |
+
result.imag() == Real(0.0)) {
|
| 684 |
+
|
| 685 |
+
if (rnd_r > FloatType(0)) {
|
| 686 |
+
rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1));
|
| 687 |
+
} else {
|
| 688 |
+
rnd_r = std::max((params.min), rnd_r - FloatType(1));
|
| 689 |
+
}
|
| 690 |
+
result = Element(Real(rnd_r), Real(rnd_i));
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
return result;
|
| 694 |
+
}
|
| 695 |
+
};
|
| 696 |
+
|
| 697 |
+
/// Computes a random uniform distribution
|
| 698 |
+
template <
|
| 699 |
+
typename Element, ///< Element type
|
| 700 |
+
typename Layout> ///< Layout function
|
| 701 |
+
struct TensorFillRandomUniformFunc {
|
| 702 |
+
|
| 703 |
+
/// View type
|
| 704 |
+
using TensorView = TensorView<Element, Layout>;
|
| 705 |
+
|
| 706 |
+
/// Scalar type
|
| 707 |
+
typedef typename TensorView::Element T;
|
| 708 |
+
|
| 709 |
+
/// Coordinate in tensor's index space
|
| 710 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 711 |
+
|
| 712 |
+
using RandomFunc = RandomUniformFunc<Element>;
|
| 713 |
+
|
| 714 |
+
/// Parameters structure
|
| 715 |
+
struct Params {
|
| 716 |
+
|
| 717 |
+
//
|
| 718 |
+
// Data members
|
| 719 |
+
//
|
| 720 |
+
|
| 721 |
+
TensorView view;
|
| 722 |
+
typename RandomFunc::Params random;
|
| 723 |
+
|
| 724 |
+
/// Default ctor
|
| 725 |
+
CUTLASS_HOST_DEVICE
|
| 726 |
+
Params() { }
|
| 727 |
+
|
| 728 |
+
//
|
| 729 |
+
// Methods
|
| 730 |
+
//
|
| 731 |
+
|
| 732 |
+
/// Construction of Gaussian RNG functor.
|
| 733 |
+
Params(
|
| 734 |
+
TensorView view_ = TensorView(),
|
| 735 |
+
typename RandomFunc::Params random_ = RandomFunc::Params()
|
| 736 |
+
):
|
| 737 |
+
view(view_), random(random_) {
|
| 738 |
+
|
| 739 |
+
}
|
| 740 |
+
};
|
| 741 |
+
|
| 742 |
+
//
|
| 743 |
+
// Data members
|
| 744 |
+
//
|
| 745 |
+
|
| 746 |
+
Params params;
|
| 747 |
+
RandomFunc random;
|
| 748 |
+
|
| 749 |
+
//
|
| 750 |
+
// Methods
|
| 751 |
+
//
|
| 752 |
+
|
| 753 |
+
/// Device-side initialization of RNG
|
| 754 |
+
CUTLASS_DEVICE
|
| 755 |
+
TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) {
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
/// Compute random value and update RNG state
|
| 759 |
+
CUTLASS_DEVICE
|
| 760 |
+
void operator()(TensorCoord const &coord) {
|
| 761 |
+
|
| 762 |
+
params.view.at(coord) = random();
|
| 763 |
+
}
|
| 764 |
+
};
|
| 765 |
+
|
| 766 |
+
} // namespace detail
|
| 767 |
+
|
| 768 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 769 |
+
|
| 770 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 771 |
+
template <
|
| 772 |
+
typename Element, ///< Element type
|
| 773 |
+
typename Layout> ///< Layout function
|
| 774 |
+
void TensorFillRandomUniform(
|
| 775 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 776 |
+
uint64_t seed, ///< seed for RNG
|
| 777 |
+
typename RealType<Element>::Type max = Element(1), ///< upper bound of distribution
|
| 778 |
+
typename RealType<Element>::Type min = Element(0), ///< lower bound for distribution
|
| 779 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 780 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 781 |
+
/// data.
|
| 782 |
+
double pnan = 0, ///< Percentage of NaN elements.
|
| 783 |
+
int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
|
| 784 |
+
cudaStream_t stream = nullptr) {
|
| 785 |
+
|
| 786 |
+
using RandomFunc = detail::RandomUniformFunc<Element>;
|
| 787 |
+
using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
|
| 788 |
+
using Params = typename Func::Params;
|
| 789 |
+
|
| 790 |
+
typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero);
|
| 791 |
+
|
| 792 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 793 |
+
view.extent(),
|
| 794 |
+
Params(view, random),
|
| 795 |
+
/*grid_size*/0, /*block_size*/0,
|
| 796 |
+
stream
|
| 797 |
+
);
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 801 |
+
|
| 802 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 803 |
+
template <typename Element>
|
| 804 |
+
void BlockFillRandomUniform(
|
| 805 |
+
Element *ptr,
|
| 806 |
+
size_t capacity,
|
| 807 |
+
uint64_t seed, ///< seed for RNG
|
| 808 |
+
typename RealType<Element>::Type max, ///< upper bound of distribution
|
| 809 |
+
typename RealType<Element>::Type min, ///< lower bound for distribution
|
| 810 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 811 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 812 |
+
/// data.
|
| 813 |
+
double pnan = 0, ///< Percentage of NaN elements.
|
| 814 |
+
cudaStream_t stream = nullptr) {
|
| 815 |
+
|
| 816 |
+
using RandomFunc = detail::RandomUniformFunc<Element>;
|
| 817 |
+
|
| 818 |
+
typename RandomFunc::Params params(seed, max, min, bits, pnan);
|
| 819 |
+
|
| 820 |
+
BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 824 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 825 |
+
|
| 826 |
+
namespace detail {
|
| 827 |
+
|
| 828 |
+
/// Computes a random sparse meta
|
| 829 |
+
template <typename Element> ///< Element type
|
| 830 |
+
struct RandomSparseMetaFunc {
|
| 831 |
+
|
| 832 |
+
using FloatType = float;
|
| 833 |
+
|
| 834 |
+
using IntType = int32_t;
|
| 835 |
+
|
| 836 |
+
/// Parameters structure
|
| 837 |
+
struct Params {
|
| 838 |
+
|
| 839 |
+
//
|
| 840 |
+
// Data members
|
| 841 |
+
//
|
| 842 |
+
|
| 843 |
+
uint64_t seed;
|
| 844 |
+
FloatType range;
|
| 845 |
+
int MetaSizeInBits;
|
| 846 |
+
|
| 847 |
+
/// Default ctor
|
| 848 |
+
CUTLASS_HOST_DEVICE
|
| 849 |
+
Params() { }
|
| 850 |
+
|
| 851 |
+
//
|
| 852 |
+
// Methods
|
| 853 |
+
//
|
| 854 |
+
|
| 855 |
+
/// Construction of Gaussian RNG functor.
|
| 856 |
+
Params(
|
| 857 |
+
uint64_t seed_ = 0,
|
| 858 |
+
int MetaSizeInBits_ = 2
|
| 859 |
+
):
|
| 860 |
+
seed(seed_),
|
| 861 |
+
MetaSizeInBits(MetaSizeInBits_) {
|
| 862 |
+
if (MetaSizeInBits_ == 2) {
|
| 863 |
+
range = 6;
|
| 864 |
+
}
|
| 865 |
+
else if (MetaSizeInBits_ == 4) {
|
| 866 |
+
range = 2;
|
| 867 |
+
}
|
| 868 |
+
else {
|
| 869 |
+
throw std::invalid_argument("Invalid MetaSizeInBits");
|
| 870 |
+
}
|
| 871 |
+
}
|
| 872 |
+
};
|
| 873 |
+
|
| 874 |
+
//
|
| 875 |
+
// Data members
|
| 876 |
+
//
|
| 877 |
+
|
| 878 |
+
/// Parameters object
|
| 879 |
+
Params params;
|
| 880 |
+
|
| 881 |
+
/// RNG state object
|
| 882 |
+
curandState_t rng_state;
|
| 883 |
+
|
| 884 |
+
//
|
| 885 |
+
// Methods
|
| 886 |
+
//
|
| 887 |
+
|
| 888 |
+
/// Device-side initialization of RNG
|
| 889 |
+
CUTLASS_DEVICE
|
| 890 |
+
RandomSparseMetaFunc(Params const ¶ms): params(params) {
|
| 891 |
+
|
| 892 |
+
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
| 893 |
+
|
| 894 |
+
curand_init(params.seed, gtid, 0, &rng_state);
|
| 895 |
+
}
|
| 896 |
+
|
| 897 |
+
/// Compute random value and update RNG state
|
| 898 |
+
CUTLASS_DEVICE
|
| 899 |
+
Element operator()() {
|
| 900 |
+
Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe};
|
| 901 |
+
Element TwoToOneMeta[2] = {0x4, 0xe};
|
| 902 |
+
|
| 903 |
+
Element *MetaArray =
|
| 904 |
+
(params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta;
|
| 905 |
+
|
| 906 |
+
Element result = 0x0;
|
| 907 |
+
|
| 908 |
+
CUTLASS_PRAGMA_UNROLL
|
| 909 |
+
for (int i = 0; i < cutlass::sizeof_bits<Element>::value / 4; ++i) {
|
| 910 |
+
FloatType rnd = random_uniform_float<FloatType>(&rng_state);
|
| 911 |
+
rnd = params.range * rnd;
|
| 912 |
+
Element meta = MetaArray[(int)rnd];
|
| 913 |
+
|
| 914 |
+
result = (Element)(result | ((Element)(meta << (i * 4))));
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
return result;
|
| 918 |
+
}
|
| 919 |
+
};
|
| 920 |
+
|
| 921 |
+
/// Computes a random Gaussian distribution
|
| 922 |
+
template <
|
| 923 |
+
typename Element, ///< Element type
|
| 924 |
+
typename Layout> ///< Layout function
|
| 925 |
+
struct TensorFillRandomSparseMetaFunc {
|
| 926 |
+
|
| 927 |
+
/// View type
|
| 928 |
+
using TensorView = TensorView<Element, Layout>;
|
| 929 |
+
|
| 930 |
+
/// Scalar type
|
| 931 |
+
typedef typename TensorView::Element T;
|
| 932 |
+
|
| 933 |
+
/// Coordinate in tensor's index space
|
| 934 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 935 |
+
|
| 936 |
+
using RandomFunc = RandomSparseMetaFunc<Element>;
|
| 937 |
+
|
| 938 |
+
/// Parameters structure
|
| 939 |
+
struct Params {
|
| 940 |
+
|
| 941 |
+
//
|
| 942 |
+
// Data members
|
| 943 |
+
//
|
| 944 |
+
|
| 945 |
+
TensorView view;
|
| 946 |
+
typename RandomFunc::Params random;
|
| 947 |
+
|
| 948 |
+
/// Default ctor
|
| 949 |
+
CUTLASS_HOST_DEVICE
|
| 950 |
+
Params() { }
|
| 951 |
+
|
| 952 |
+
//
|
| 953 |
+
// Methods
|
| 954 |
+
//
|
| 955 |
+
|
| 956 |
+
/// Construction of Gaussian RNG functor.
|
| 957 |
+
Params(
|
| 958 |
+
TensorView view_ = TensorView(),
|
| 959 |
+
typename RandomFunc::Params random_ = RandomFunc::Params()
|
| 960 |
+
):
|
| 961 |
+
view(view_), random(random_) {
|
| 962 |
+
|
| 963 |
+
}
|
| 964 |
+
};
|
| 965 |
+
|
| 966 |
+
//
|
| 967 |
+
// Data members
|
| 968 |
+
//
|
| 969 |
+
|
| 970 |
+
Params params;
|
| 971 |
+
RandomFunc random;
|
| 972 |
+
|
| 973 |
+
//
|
| 974 |
+
// Methods
|
| 975 |
+
//
|
| 976 |
+
|
| 977 |
+
/// Device-side initialization of RNG
|
| 978 |
+
CUTLASS_DEVICE
|
| 979 |
+
TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) {
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
/// Compute random value and update RNG state
|
| 983 |
+
CUTLASS_DEVICE
|
| 984 |
+
void operator()(TensorCoord const &coord) {
|
| 985 |
+
|
| 986 |
+
params.view.at(coord) = random();
|
| 987 |
+
}
|
| 988 |
+
};
|
| 989 |
+
|
| 990 |
+
} // namespace detail
|
| 991 |
+
|
| 992 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 993 |
+
|
| 994 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 995 |
+
template <
|
| 996 |
+
typename Element, ///< Element type
|
| 997 |
+
typename Layout> ///< Layout function
|
| 998 |
+
void TensorFillRandomSparseMeta(
|
| 999 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1000 |
+
uint64_t seed, ///< seed for RNG
|
| 1001 |
+
int MetaSizeInBits = 2, ///< meta data size
|
| 1002 |
+
cudaStream_t stream = nullptr) {
|
| 1003 |
+
|
| 1004 |
+
using RandomFunc = detail::RandomSparseMetaFunc<Element>;
|
| 1005 |
+
using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
|
| 1006 |
+
using Params = typename Func::Params;
|
| 1007 |
+
|
| 1008 |
+
typename RandomFunc::Params random(seed, MetaSizeInBits);
|
| 1009 |
+
|
| 1010 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1011 |
+
view.extent(),
|
| 1012 |
+
Params(view, random),
|
| 1013 |
+
/*grid_size*/0, /*block_size*/0,
|
| 1014 |
+
stream
|
| 1015 |
+
);
|
| 1016 |
+
}
|
| 1017 |
+
|
| 1018 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1019 |
+
|
| 1020 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 1021 |
+
template <typename Element>
|
| 1022 |
+
void BlockFillRandomSparseMeta(
|
| 1023 |
+
Element *ptr,
|
| 1024 |
+
size_t capacity,
|
| 1025 |
+
uint64_t seed, ///< seed for RNG
|
| 1026 |
+
int MetaSizeInBits = 2, ///< meta data size
|
| 1027 |
+
cudaStream_t stream = nullptr) {
|
| 1028 |
+
|
| 1029 |
+
using RandomFunc = detail::RandomSparseMetaFunc<Element>;
|
| 1030 |
+
|
| 1031 |
+
typename RandomFunc::Params params(seed, MetaSizeInBits);
|
| 1032 |
+
|
| 1033 |
+
BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1037 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1038 |
+
|
| 1039 |
+
namespace detail {
|
| 1040 |
+
|
| 1041 |
+
/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal.
|
| 1042 |
+
template <
|
| 1043 |
+
typename Element, ///< Element type
|
| 1044 |
+
typename Layout> ///< Layout function
|
| 1045 |
+
struct TensorFillDiagonalFunc {
|
| 1046 |
+
|
| 1047 |
+
/// View type
|
| 1048 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1049 |
+
|
| 1050 |
+
/// Scalar type
|
| 1051 |
+
typedef typename TensorView::Element T;
|
| 1052 |
+
|
| 1053 |
+
/// Coordinate in tensor's index space
|
| 1054 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1055 |
+
|
| 1056 |
+
/// Parameters structure
|
| 1057 |
+
struct Params {
|
| 1058 |
+
|
| 1059 |
+
//
|
| 1060 |
+
// Data members
|
| 1061 |
+
//
|
| 1062 |
+
|
| 1063 |
+
TensorView view;
|
| 1064 |
+
Element diag;
|
| 1065 |
+
Element other;
|
| 1066 |
+
|
| 1067 |
+
/// Default ctor
|
| 1068 |
+
CUTLASS_HOST_DEVICE
|
| 1069 |
+
Params() { }
|
| 1070 |
+
|
| 1071 |
+
//
|
| 1072 |
+
// Methods
|
| 1073 |
+
//
|
| 1074 |
+
|
| 1075 |
+
Params(
|
| 1076 |
+
TensorView view_ = TensorView(),
|
| 1077 |
+
Element diag_ = Element(1),
|
| 1078 |
+
Element other_ = Element(0)
|
| 1079 |
+
):
|
| 1080 |
+
view(view_), diag(diag_), other(other_) {
|
| 1081 |
+
|
| 1082 |
+
}
|
| 1083 |
+
};
|
| 1084 |
+
|
| 1085 |
+
//
|
| 1086 |
+
// Data members
|
| 1087 |
+
//
|
| 1088 |
+
|
| 1089 |
+
/// Parameters object
|
| 1090 |
+
Params params;
|
| 1091 |
+
|
| 1092 |
+
//
|
| 1093 |
+
// Methods
|
| 1094 |
+
//
|
| 1095 |
+
|
| 1096 |
+
/// Device-side initialization of RNG
|
| 1097 |
+
CUTLASS_DEVICE
|
| 1098 |
+
TensorFillDiagonalFunc(Params const ¶ms): params(params) {
|
| 1099 |
+
|
| 1100 |
+
}
|
| 1101 |
+
|
| 1102 |
+
/// Updates the tensor
|
| 1103 |
+
CUTLASS_DEVICE
|
| 1104 |
+
void operator()(TensorCoord const &coord) {
|
| 1105 |
+
|
| 1106 |
+
bool is_diag = true;
|
| 1107 |
+
|
| 1108 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1109 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1110 |
+
if (coord[i] != coord[i - 1]) {
|
| 1111 |
+
is_diag = false;
|
| 1112 |
+
break;
|
| 1113 |
+
}
|
| 1114 |
+
}
|
| 1115 |
+
|
| 1116 |
+
params.view.at(coord) = (is_diag ? params.diag : params.other);
|
| 1117 |
+
}
|
| 1118 |
+
};
|
| 1119 |
+
|
| 1120 |
+
// Overwrites the elements of a tensor with a uniform value depending on fill mode
|
| 1121 |
+
template <
|
| 1122 |
+
typename Element, ///< Element type
|
| 1123 |
+
typename Layout> ///< Layout function
|
| 1124 |
+
struct TensorFillPartialFunc {
|
| 1125 |
+
|
| 1126 |
+
/// View type
|
| 1127 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1128 |
+
|
| 1129 |
+
/// Scalar type
|
| 1130 |
+
typedef typename TensorView::Element T;
|
| 1131 |
+
|
| 1132 |
+
/// Coordinate in tensor's index space
|
| 1133 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1134 |
+
|
| 1135 |
+
/// Parameters structure
|
| 1136 |
+
struct Params {
|
| 1137 |
+
|
| 1138 |
+
//
|
| 1139 |
+
// Data members
|
| 1140 |
+
//
|
| 1141 |
+
|
| 1142 |
+
TensorView view;
|
| 1143 |
+
Element element;
|
| 1144 |
+
FillMode fill_mode;
|
| 1145 |
+
|
| 1146 |
+
/// Default ctor
|
| 1147 |
+
CUTLASS_HOST_DEVICE
|
| 1148 |
+
Params(): fill_mode(FillMode::kNone) { }
|
| 1149 |
+
|
| 1150 |
+
//
|
| 1151 |
+
// Methods
|
| 1152 |
+
//
|
| 1153 |
+
|
| 1154 |
+
/// Construction of Gaussian RNG functor.
|
| 1155 |
+
Params(
|
| 1156 |
+
TensorView view_,
|
| 1157 |
+
Element element_,
|
| 1158 |
+
FillMode fill_mode_
|
| 1159 |
+
):
|
| 1160 |
+
view(view_), element(element_), fill_mode(fill_mode_) {
|
| 1161 |
+
|
| 1162 |
+
}
|
| 1163 |
+
};
|
| 1164 |
+
|
| 1165 |
+
//
|
| 1166 |
+
// Data members
|
| 1167 |
+
//
|
| 1168 |
+
|
| 1169 |
+
/// Parameters object
|
| 1170 |
+
Params params;
|
| 1171 |
+
|
| 1172 |
+
//
|
| 1173 |
+
// Methods
|
| 1174 |
+
//
|
| 1175 |
+
|
| 1176 |
+
CUTLASS_DEVICE
|
| 1177 |
+
TensorFillPartialFunc(Params const ¶ms): params(params) {
|
| 1178 |
+
|
| 1179 |
+
}
|
| 1180 |
+
|
| 1181 |
+
/// Overwrites the element if it is within the covered region.
|
| 1182 |
+
CUTLASS_DEVICE
|
| 1183 |
+
void operator()(TensorCoord const &coord) {
|
| 1184 |
+
|
| 1185 |
+
bool predicate = true;
|
| 1186 |
+
|
| 1187 |
+
switch (params.fill_mode) {
|
| 1188 |
+
case FillMode::kFull:
|
| 1189 |
+
predicate = true;
|
| 1190 |
+
break;
|
| 1191 |
+
|
| 1192 |
+
case FillMode::kLower:
|
| 1193 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1194 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1195 |
+
if (coord[i - 1] < coord[i]) {
|
| 1196 |
+
predicate = false;
|
| 1197 |
+
break;
|
| 1198 |
+
}
|
| 1199 |
+
}
|
| 1200 |
+
break;
|
| 1201 |
+
|
| 1202 |
+
case FillMode::kUpper:
|
| 1203 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1204 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1205 |
+
if (coord[i - 1] > coord[i]) {
|
| 1206 |
+
predicate = false;
|
| 1207 |
+
break;
|
| 1208 |
+
}
|
| 1209 |
+
}
|
| 1210 |
+
break;
|
| 1211 |
+
|
| 1212 |
+
case FillMode::kDiagonal:
|
| 1213 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1214 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1215 |
+
if (coord[i - 1] != coord[i]) {
|
| 1216 |
+
predicate = false;
|
| 1217 |
+
break;
|
| 1218 |
+
}
|
| 1219 |
+
}
|
| 1220 |
+
break;
|
| 1221 |
+
|
| 1222 |
+
case FillMode::kNone: // fall-through
|
| 1223 |
+
|
| 1224 |
+
default:
|
| 1225 |
+
predicate = false;
|
| 1226 |
+
break;
|
| 1227 |
+
}
|
| 1228 |
+
|
| 1229 |
+
if (predicate) {
|
| 1230 |
+
params.view.at(coord) = params.element;
|
| 1231 |
+
}
|
| 1232 |
+
}
|
| 1233 |
+
};
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
template <
|
| 1237 |
+
typename Element, ///< Element type
|
| 1238 |
+
typename Layout> ///< Layout function
|
| 1239 |
+
struct TensorClearPartialFunc {
|
| 1240 |
+
|
| 1241 |
+
/// View type
|
| 1242 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1243 |
+
|
| 1244 |
+
/// Scalar type
|
| 1245 |
+
typedef typename TensorView::Element T;
|
| 1246 |
+
|
| 1247 |
+
/// Coordinate in tensor's index space
|
| 1248 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1249 |
+
|
| 1250 |
+
///
|
| 1251 |
+
static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices");
|
| 1252 |
+
|
| 1253 |
+
/// Parameters structure
|
| 1254 |
+
struct Params {
|
| 1255 |
+
TensorView view{};
|
| 1256 |
+
Element element{};
|
| 1257 |
+
FillMode fill_mode{FillMode::kNone};
|
| 1258 |
+
int alignment{0};
|
| 1259 |
+
};
|
| 1260 |
+
|
| 1261 |
+
//
|
| 1262 |
+
// Data members
|
| 1263 |
+
//
|
| 1264 |
+
|
| 1265 |
+
/// Parameters object
|
| 1266 |
+
Params params;
|
| 1267 |
+
|
| 1268 |
+
//
|
| 1269 |
+
// Methods
|
| 1270 |
+
//
|
| 1271 |
+
|
| 1272 |
+
CUTLASS_DEVICE
|
| 1273 |
+
TensorClearPartialFunc(Params const ¶ms): params(params) {
|
| 1274 |
+
|
| 1275 |
+
}
|
| 1276 |
+
|
| 1277 |
+
/// Overwrites the element if it is within the covered region.
|
| 1278 |
+
CUTLASS_DEVICE
|
| 1279 |
+
void operator()(TensorCoord const &coord) {
|
| 1280 |
+
|
| 1281 |
+
bool predicate = true;
|
| 1282 |
+
|
| 1283 |
+
switch (params.fill_mode) {
|
| 1284 |
+
|
| 1285 |
+
case FillMode::kLower:
|
| 1286 |
+
if ((coord[0] >= coord[1]) ||
|
| 1287 |
+
((coord[1] - coord[0]) >= params.alignment)) {
|
| 1288 |
+
predicate = false;
|
| 1289 |
+
break;
|
| 1290 |
+
}
|
| 1291 |
+
break;
|
| 1292 |
+
|
| 1293 |
+
case FillMode::kUpper:
|
| 1294 |
+
if ((coord[0] <= coord[1]) ||
|
| 1295 |
+
((coord[0] - coord[1]) >= params.alignment)) {
|
| 1296 |
+
predicate = false;
|
| 1297 |
+
break;
|
| 1298 |
+
}
|
| 1299 |
+
break;
|
| 1300 |
+
|
| 1301 |
+
case FillMode::kNone: // fall-through
|
| 1302 |
+
|
| 1303 |
+
default:
|
| 1304 |
+
predicate = false;
|
| 1305 |
+
break;
|
| 1306 |
+
}
|
| 1307 |
+
|
| 1308 |
+
if (predicate) {
|
| 1309 |
+
params.view.at(coord) = params.element;
|
| 1310 |
+
}
|
| 1311 |
+
}
|
| 1312 |
+
};
|
| 1313 |
+
|
| 1314 |
+
} // namespace detail
|
| 1315 |
+
|
| 1316 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1317 |
+
|
| 1318 |
+
/// Fills a tensor everywhere with a unique value for its diagonal.
|
| 1319 |
+
template <
|
| 1320 |
+
typename Element, ///< Element type
|
| 1321 |
+
typename Layout> ///< Layout function
|
| 1322 |
+
void TensorFillDiagonal(
|
| 1323 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1324 |
+
Element diag = Element(1), ///< value to write in the diagonal
|
| 1325 |
+
Element other = Element(0), ///< value to write off the diagonal
|
| 1326 |
+
cudaStream_t stream = nullptr) {
|
| 1327 |
+
|
| 1328 |
+
typedef detail::TensorFillDiagonalFunc<Element, Layout> Func;
|
| 1329 |
+
typedef typename Func::Params Params;
|
| 1330 |
+
|
| 1331 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1332 |
+
view.extent(),
|
| 1333 |
+
Params(view, diag, other),
|
| 1334 |
+
/*grid_size*/0, /*block_size*/0,
|
| 1335 |
+
stream
|
| 1336 |
+
);
|
| 1337 |
+
}
|
| 1338 |
+
|
| 1339 |
+
/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are
|
| 1340 |
+
/// not written.
|
| 1341 |
+
template <
|
| 1342 |
+
typename Element, ///< Element type
|
| 1343 |
+
typename Layout> ///< Layout function
|
| 1344 |
+
void TensorFillPartial(
|
| 1345 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1346 |
+
Element element,
|
| 1347 |
+
FillMode fill_mode,
|
| 1348 |
+
cudaStream_t stream = nullptr) {
|
| 1349 |
+
|
| 1350 |
+
typedef detail::TensorFillPartialFunc<Element, Layout> Func;
|
| 1351 |
+
typedef typename Func::Params Params;
|
| 1352 |
+
|
| 1353 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1354 |
+
view.extent(),
|
| 1355 |
+
Params(view, element, fill_mode),
|
| 1356 |
+
stream
|
| 1357 |
+
);
|
| 1358 |
+
}
|
| 1359 |
+
|
| 1360 |
+
/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side
|
| 1361 |
+
/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros)
|
| 1362 |
+
template <
|
| 1363 |
+
typename Element, ///< Element type
|
| 1364 |
+
typename Layout> ///< Layout function
|
| 1365 |
+
void TensorClearPartial(
|
| 1366 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1367 |
+
Element element,
|
| 1368 |
+
FillMode fill_mode,
|
| 1369 |
+
int alignment,
|
| 1370 |
+
cudaStream_t stream = nullptr) {
|
| 1371 |
+
|
| 1372 |
+
typedef detail::TensorClearPartialFunc<Element, Layout> Func;
|
| 1373 |
+
typedef typename Func::Params Params;
|
| 1374 |
+
|
| 1375 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1376 |
+
view.extent(),
|
| 1377 |
+
Params{view, element, fill_mode, alignment},
|
| 1378 |
+
/*grid_size*/0, /*block_size*/0,
|
| 1379 |
+
stream
|
| 1380 |
+
);
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1384 |
+
|
| 1385 |
+
/// Fills a tensor with a uniform value
|
| 1386 |
+
template <
|
| 1387 |
+
typename Element, ///< Element type
|
| 1388 |
+
typename Layout> ///< Layout function
|
| 1389 |
+
void TensorFill(
|
| 1390 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1391 |
+
Element val = Element(0), ///< value to uniformly fill it with
|
| 1392 |
+
cudaStream_t stream = nullptr) {
|
| 1393 |
+
|
| 1394 |
+
TensorFillDiagonal(view, val, val, stream);
|
| 1395 |
+
}
|
| 1396 |
+
|
| 1397 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1398 |
+
|
| 1399 |
+
/// Fills a tensor's diagonal with 1 and 0 everywhere else.
|
| 1400 |
+
template <
|
| 1401 |
+
typename Element, ///< Element type
|
| 1402 |
+
typename Layout> ///< Layout function
|
| 1403 |
+
void TensorFillIdentity(
|
| 1404 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1405 |
+
cudaStream_t stream = nullptr) {
|
| 1406 |
+
|
| 1407 |
+
TensorFillDiagonal(view, Element(1), Element(0), stream);
|
| 1408 |
+
}
|
| 1409 |
+
|
| 1410 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1411 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1412 |
+
|
| 1413 |
+
namespace detail {
|
| 1414 |
+
|
| 1415 |
+
/// Computes a random Gaussian distribution
|
| 1416 |
+
template <
|
| 1417 |
+
typename Element, ///< Element type
|
| 1418 |
+
typename Layout> ///< Layout function
|
| 1419 |
+
struct TensorUpdateDiagonalFunc {
|
| 1420 |
+
|
| 1421 |
+
/// View type
|
| 1422 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1423 |
+
|
| 1424 |
+
/// Scalar type
|
| 1425 |
+
typedef typename TensorView::Element T;
|
| 1426 |
+
|
| 1427 |
+
/// Coordinate in tensor's index space
|
| 1428 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1429 |
+
|
| 1430 |
+
/// Parameters structure
|
| 1431 |
+
struct Params {
|
| 1432 |
+
|
| 1433 |
+
//
|
| 1434 |
+
// Data members
|
| 1435 |
+
//
|
| 1436 |
+
|
| 1437 |
+
TensorView view;
|
| 1438 |
+
Element diag;
|
| 1439 |
+
|
| 1440 |
+
/// Default ctor
|
| 1441 |
+
CUTLASS_HOST_DEVICE
|
| 1442 |
+
Params() { }
|
| 1443 |
+
|
| 1444 |
+
//
|
| 1445 |
+
// Methods
|
| 1446 |
+
//
|
| 1447 |
+
|
| 1448 |
+
/// Construction of Gaussian RNG functor.
|
| 1449 |
+
Params(
|
| 1450 |
+
TensorView view_ = TensorView(),
|
| 1451 |
+
Element diag_ = Element(1)
|
| 1452 |
+
):
|
| 1453 |
+
view(view_), diag(diag_) {
|
| 1454 |
+
|
| 1455 |
+
}
|
| 1456 |
+
};
|
| 1457 |
+
|
| 1458 |
+
//
|
| 1459 |
+
// Data members
|
| 1460 |
+
//
|
| 1461 |
+
|
| 1462 |
+
/// Parameters object
|
| 1463 |
+
Params params;
|
| 1464 |
+
|
| 1465 |
+
//
|
| 1466 |
+
// Methods
|
| 1467 |
+
//
|
| 1468 |
+
|
| 1469 |
+
/// Device-side initialization of RNG
|
| 1470 |
+
CUTLASS_DEVICE
|
| 1471 |
+
TensorUpdateDiagonalFunc(Params const ¶ms): params(params) {
|
| 1472 |
+
|
| 1473 |
+
}
|
| 1474 |
+
|
| 1475 |
+
/// Compute random value and update RNG state
|
| 1476 |
+
CUTLASS_DEVICE
|
| 1477 |
+
void operator()(TensorCoord const &coord) {
|
| 1478 |
+
|
| 1479 |
+
bool is_diag = true;
|
| 1480 |
+
|
| 1481 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1482 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1483 |
+
if (coord[i] != coord[i - 1]) {
|
| 1484 |
+
is_diag = false;
|
| 1485 |
+
break;
|
| 1486 |
+
}
|
| 1487 |
+
}
|
| 1488 |
+
|
| 1489 |
+
if (is_diag) {
|
| 1490 |
+
params.view.at(coord) = params.diag;
|
| 1491 |
+
}
|
| 1492 |
+
}
|
| 1493 |
+
};
|
| 1494 |
+
|
| 1495 |
+
} // namespace detail
|
| 1496 |
+
|
| 1497 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1498 |
+
|
| 1499 |
+
/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements.
|
| 1500 |
+
template <
|
| 1501 |
+
typename Element, ///< Element type
|
| 1502 |
+
typename Layout> ///< Layout function
|
| 1503 |
+
void TensorUpdateDiagonal(
|
| 1504 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1505 |
+
Element diag = Element(1),
|
| 1506 |
+
cudaStream_t stream = nullptr) {
|
| 1507 |
+
|
| 1508 |
+
typedef detail::TensorUpdateDiagonalFunc<Element, Layout> Func;
|
| 1509 |
+
typedef typename Func::Params Params;
|
| 1510 |
+
|
| 1511 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1512 |
+
view.extent(),
|
| 1513 |
+
Params(view, diag),
|
| 1514 |
+
/*grid_size*/0, /*block_size*/0,
|
| 1515 |
+
stream
|
| 1516 |
+
);
|
| 1517 |
+
}
|
| 1518 |
+
|
| 1519 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1520 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1521 |
+
|
| 1522 |
+
namespace detail {
|
| 1523 |
+
|
| 1524 |
+
/// Computes a random Gaussian distribution
|
| 1525 |
+
template <
|
| 1526 |
+
typename Element, ///< Element type
|
| 1527 |
+
typename Layout> ///< Layout function
|
| 1528 |
+
struct TensorUpdateOffDiagonalFunc {
|
| 1529 |
+
|
| 1530 |
+
/// View type
|
| 1531 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1532 |
+
|
| 1533 |
+
/// Scalar type
|
| 1534 |
+
typedef typename TensorView::Element T;
|
| 1535 |
+
|
| 1536 |
+
/// Coordinate in tensor's index space
|
| 1537 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1538 |
+
|
| 1539 |
+
/// Parameters structure
|
| 1540 |
+
struct Params {
|
| 1541 |
+
|
| 1542 |
+
//
|
| 1543 |
+
// Data members
|
| 1544 |
+
//
|
| 1545 |
+
|
| 1546 |
+
TensorView view;
|
| 1547 |
+
Element other;
|
| 1548 |
+
|
| 1549 |
+
/// Default ctor
|
| 1550 |
+
CUTLASS_HOST_DEVICE
|
| 1551 |
+
Params() { }
|
| 1552 |
+
|
| 1553 |
+
//
|
| 1554 |
+
// Methods
|
| 1555 |
+
//
|
| 1556 |
+
|
| 1557 |
+
/// Construction of Gaussian RNG functor.
|
| 1558 |
+
Params(
|
| 1559 |
+
TensorView view_ = TensorView(),
|
| 1560 |
+
Element other_ = Element(0)
|
| 1561 |
+
):
|
| 1562 |
+
view(view_), other(other_) {
|
| 1563 |
+
|
| 1564 |
+
}
|
| 1565 |
+
};
|
| 1566 |
+
|
| 1567 |
+
//
|
| 1568 |
+
// Data members
|
| 1569 |
+
//
|
| 1570 |
+
|
| 1571 |
+
/// Parameters object
|
| 1572 |
+
Params params;
|
| 1573 |
+
|
| 1574 |
+
//
|
| 1575 |
+
// Methods
|
| 1576 |
+
//
|
| 1577 |
+
|
| 1578 |
+
/// Device-side initialization of RNG
|
| 1579 |
+
CUTLASS_DEVICE
|
| 1580 |
+
TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) {
|
| 1581 |
+
|
| 1582 |
+
}
|
| 1583 |
+
|
| 1584 |
+
/// Compute random value and update RNG state
|
| 1585 |
+
CUTLASS_DEVICE
|
| 1586 |
+
void operator()(TensorCoord const &coord) {
|
| 1587 |
+
|
| 1588 |
+
bool is_diag = true;
|
| 1589 |
+
|
| 1590 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1591 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1592 |
+
if (coord[i] != coord[i - 1]) {
|
| 1593 |
+
is_diag = false;
|
| 1594 |
+
break;
|
| 1595 |
+
}
|
| 1596 |
+
}
|
| 1597 |
+
|
| 1598 |
+
if (!is_diag) {
|
| 1599 |
+
params.view.at(coord) = params.other;
|
| 1600 |
+
}
|
| 1601 |
+
}
|
| 1602 |
+
};
|
| 1603 |
+
|
| 1604 |
+
} // namespace detail
|
| 1605 |
+
|
| 1606 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1607 |
+
|
| 1608 |
+
/// Writes a uniform value to all elements in the tensor without modifying diagonal elements.
|
| 1609 |
+
template <
|
| 1610 |
+
typename Element, ///< Element type
|
| 1611 |
+
typename Layout> ///< Layout function
|
| 1612 |
+
void TensorUpdateOffDiagonal(
|
| 1613 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1614 |
+
Element other = Element(1),
|
| 1615 |
+
cudaStream_t stream = nullptr) {
|
| 1616 |
+
|
| 1617 |
+
typedef detail::TensorUpdateOffDiagonalFunc<Element, Layout> Func;
|
| 1618 |
+
typedef typename Func::Params Params;
|
| 1619 |
+
|
| 1620 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1621 |
+
view.extent(),
|
| 1622 |
+
Params(view, other),
|
| 1623 |
+
/*grid_size*/0, /*block_size*/0,
|
| 1624 |
+
stream
|
| 1625 |
+
);
|
| 1626 |
+
}
|
| 1627 |
+
|
| 1628 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1629 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1630 |
+
|
| 1631 |
+
namespace detail {
|
| 1632 |
+
|
| 1633 |
+
/// Computes a random Gaussian distribution
|
| 1634 |
+
template <
|
| 1635 |
+
typename Element, ///< Element type
|
| 1636 |
+
typename Layout> ///< Layout function
|
| 1637 |
+
struct TensorFillLinearFunc {
|
| 1638 |
+
|
| 1639 |
+
/// View type
|
| 1640 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1641 |
+
|
| 1642 |
+
/// Scalar type
|
| 1643 |
+
typedef typename TensorView::Element T;
|
| 1644 |
+
|
| 1645 |
+
/// Coordinate in tensor's index space
|
| 1646 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1647 |
+
|
| 1648 |
+
/// Parameters structure
|
| 1649 |
+
struct Params {
|
| 1650 |
+
|
| 1651 |
+
//
|
| 1652 |
+
// Data members
|
| 1653 |
+
//
|
| 1654 |
+
|
| 1655 |
+
TensorView view;
|
| 1656 |
+
Array<Element, Layout::kRank> v;
|
| 1657 |
+
Element s;
|
| 1658 |
+
|
| 1659 |
+
/// Default ctor
|
| 1660 |
+
CUTLASS_HOST_DEVICE
|
| 1661 |
+
Params() { }
|
| 1662 |
+
|
| 1663 |
+
//
|
| 1664 |
+
// Methods
|
| 1665 |
+
//
|
| 1666 |
+
|
| 1667 |
+
/// Construction of Gaussian RNG functor.
|
| 1668 |
+
Params(
|
| 1669 |
+
TensorView view_, ///< destination tensor
|
| 1670 |
+
Array<Element, Layout::kRank> const & v_,
|
| 1671 |
+
Element s_ = Element(0)
|
| 1672 |
+
):
|
| 1673 |
+
view(view_), v(v_), s(s_) {
|
| 1674 |
+
|
| 1675 |
+
}
|
| 1676 |
+
};
|
| 1677 |
+
|
| 1678 |
+
//
|
| 1679 |
+
// Data members
|
| 1680 |
+
//
|
| 1681 |
+
|
| 1682 |
+
/// Parameters object
|
| 1683 |
+
Params params;
|
| 1684 |
+
|
| 1685 |
+
//
|
| 1686 |
+
// Methods
|
| 1687 |
+
//
|
| 1688 |
+
|
| 1689 |
+
/// Device-side initialization of RNG
|
| 1690 |
+
CUTLASS_DEVICE
|
| 1691 |
+
TensorFillLinearFunc(Params const ¶ms): params(params) {
|
| 1692 |
+
|
| 1693 |
+
}
|
| 1694 |
+
|
| 1695 |
+
/// Compute random value and update RNG state
|
| 1696 |
+
CUTLASS_DEVICE
|
| 1697 |
+
void operator()(TensorCoord const &coord) {
|
| 1698 |
+
|
| 1699 |
+
Element sum = params.s;
|
| 1700 |
+
|
| 1701 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1702 |
+
for (int i = 0; i < Layout::kRank; ++i) {
|
| 1703 |
+
if constexpr (is_complex<Element>::value) {
|
| 1704 |
+
if constexpr (sizeof_bits<Element>::value <= 32) {
|
| 1705 |
+
sum = Element(static_cast<complex<float>>(sum) +
|
| 1706 |
+
static_cast<complex<float>>(params.v[i]) * static_cast<complex<float>>(coord[i]));
|
| 1707 |
+
}
|
| 1708 |
+
}
|
| 1709 |
+
else if constexpr (sizeof_bits<Element>::value <= 32) {
|
| 1710 |
+
if constexpr (std::numeric_limits<Element>::is_integer) {
|
| 1711 |
+
sum = Element(static_cast<int32_t>(sum) +
|
| 1712 |
+
static_cast<int32_t>(params.v[i]) * static_cast<int32_t>(coord[i]));
|
| 1713 |
+
}
|
| 1714 |
+
else {
|
| 1715 |
+
sum = Element(static_cast<float>(sum) +
|
| 1716 |
+
static_cast<float>(params.v[i]) * static_cast<float>(coord[i]));
|
| 1717 |
+
}
|
| 1718 |
+
}
|
| 1719 |
+
else {
|
| 1720 |
+
sum += params.v[i] * coord[i];
|
| 1721 |
+
}
|
| 1722 |
+
}
|
| 1723 |
+
|
| 1724 |
+
params.view.at(coord) = sum;
|
| 1725 |
+
}
|
| 1726 |
+
};
|
| 1727 |
+
|
| 1728 |
+
} // namespace detail
|
| 1729 |
+
|
| 1730 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1731 |
+
|
| 1732 |
+
/// Fills tensor with a linear combination of its coordinate and another vector
|
| 1733 |
+
template <
|
| 1734 |
+
typename Element, ///< Element type
|
| 1735 |
+
typename Layout> ///< Layout function
|
| 1736 |
+
void TensorFillLinear(
|
| 1737 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1738 |
+
Array<Element, Layout::kRank> const & v,
|
| 1739 |
+
Element s = Element(0),
|
| 1740 |
+
cudaStream_t stream = nullptr) {
|
| 1741 |
+
|
| 1742 |
+
using Func = detail::TensorFillLinearFunc<Element, Layout>;
|
| 1743 |
+
using Params = typename Func::Params;
|
| 1744 |
+
|
| 1745 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1746 |
+
view.extent(),
|
| 1747 |
+
Params(view, v, s),
|
| 1748 |
+
/*grid_size*/0, /*block_size*/0,
|
| 1749 |
+
stream
|
| 1750 |
+
);
|
| 1751 |
+
}
|
| 1752 |
+
|
| 1753 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1754 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1755 |
+
|
| 1756 |
+
/// Fills a tensor with random values from a distribution.
|
| 1757 |
+
template <
|
| 1758 |
+
typename Element, ///< Element type
|
| 1759 |
+
typename Layout> ///< Layout function
|
| 1760 |
+
void TensorFillRandom(
|
| 1761 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1762 |
+
uint64_t seed,
|
| 1763 |
+
Distribution dist,
|
| 1764 |
+
cudaStream_t stream = nullptr,
|
| 1765 |
+
int exclude_zero = -1 ///< If non-negative, excludes 0.
|
| 1766 |
+
/// Note that setting this flag will result in more 1's,
|
| 1767 |
+
/// as we use a simple mechanism to replace 0's by adding/subtracting 1's.
|
| 1768 |
+
) {
|
| 1769 |
+
|
| 1770 |
+
using Real = typename RealType<Element>::Type;
|
| 1771 |
+
|
| 1772 |
+
if (dist.kind == Distribution::Gaussian) {
|
| 1773 |
+
TensorFillRandomGaussian<Element, Layout>(
|
| 1774 |
+
view,
|
| 1775 |
+
seed,
|
| 1776 |
+
static_cast<Real>(dist.gaussian.mean),
|
| 1777 |
+
static_cast<Real>(dist.gaussian.stddev),
|
| 1778 |
+
dist.int_scale,
|
| 1779 |
+
exclude_zero,
|
| 1780 |
+
stream);
|
| 1781 |
+
} else if (dist.kind == Distribution::Uniform) {
|
| 1782 |
+
TensorFillRandomUniform<Element, Layout>(
|
| 1783 |
+
view,
|
| 1784 |
+
seed,
|
| 1785 |
+
static_cast<Real>(dist.uniform.max),
|
| 1786 |
+
static_cast<Real>(dist.uniform.min),
|
| 1787 |
+
dist.int_scale,
|
| 1788 |
+
dist.uniform.pnan,
|
| 1789 |
+
exclude_zero,
|
| 1790 |
+
stream);
|
| 1791 |
+
}
|
| 1792 |
+
}
|
| 1793 |
+
|
| 1794 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1795 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1796 |
+
|
| 1797 |
+
/// Fills a block of data with sequential elements
|
| 1798 |
+
template <
|
| 1799 |
+
typename Element
|
| 1800 |
+
>
|
| 1801 |
+
void BlockFillSequential(
|
| 1802 |
+
Element *ptr,
|
| 1803 |
+
int64_t capacity,
|
| 1804 |
+
Element v = Element(1),
|
| 1805 |
+
Element s = Element(0)) {
|
| 1806 |
+
|
| 1807 |
+
using Layout = layout::PackedVectorLayout;
|
| 1808 |
+
Layout::TensorCoord size(static_cast<Layout::Index>(capacity)); // -Wconversion
|
| 1809 |
+
Layout layout = Layout::packed(size);
|
| 1810 |
+
TensorView<Element, Layout> view(ptr, layout, size);
|
| 1811 |
+
|
| 1812 |
+
Array<Element, Layout::kRank> c{};
|
| 1813 |
+
c[0] = v;
|
| 1814 |
+
|
| 1815 |
+
TensorFillLinear(view, c, s);
|
| 1816 |
+
}
|
| 1817 |
+
|
| 1818 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1819 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1820 |
+
|
| 1821 |
+
/// Fills a block of data with sequential elements
|
| 1822 |
+
template <
|
| 1823 |
+
typename Element
|
| 1824 |
+
>
|
| 1825 |
+
void BlockFillRandom(
|
| 1826 |
+
Element *ptr,
|
| 1827 |
+
size_t capacity,
|
| 1828 |
+
uint64_t seed,
|
| 1829 |
+
Distribution dist,
|
| 1830 |
+
cudaStream_t stream = nullptr) {
|
| 1831 |
+
|
| 1832 |
+
using Real = typename RealType<Element>::Type;
|
| 1833 |
+
|
| 1834 |
+
if (dist.kind == Distribution::Gaussian) {
|
| 1835 |
+
BlockFillRandomGaussian<Element>(
|
| 1836 |
+
ptr,
|
| 1837 |
+
capacity,
|
| 1838 |
+
seed,
|
| 1839 |
+
static_cast<Real>(dist.gaussian.mean),
|
| 1840 |
+
static_cast<Real>(dist.gaussian.stddev),
|
| 1841 |
+
dist.int_scale,
|
| 1842 |
+
stream);
|
| 1843 |
+
}
|
| 1844 |
+
else if (dist.kind == Distribution::Uniform) {
|
| 1845 |
+
BlockFillRandomUniform<Element>(
|
| 1846 |
+
ptr,
|
| 1847 |
+
capacity,
|
| 1848 |
+
seed,
|
| 1849 |
+
static_cast<Real>(dist.uniform.max),
|
| 1850 |
+
static_cast<Real>(dist.uniform.min),
|
| 1851 |
+
dist.int_scale,
|
| 1852 |
+
dist.uniform.pnan,
|
| 1853 |
+
stream);
|
| 1854 |
+
}
|
| 1855 |
+
}
|
| 1856 |
+
|
| 1857 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1858 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1859 |
+
|
| 1860 |
+
namespace detail {
|
| 1861 |
+
|
| 1862 |
+
/// Computes a random Gaussian distribution
|
| 1863 |
+
template <
|
| 1864 |
+
typename Element, ///< Element type
|
| 1865 |
+
typename Layout> ///< Layout function
|
| 1866 |
+
struct TensorCopyDiagonalInFunc {
|
| 1867 |
+
|
| 1868 |
+
/// View type
|
| 1869 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1870 |
+
|
| 1871 |
+
/// Scalar type
|
| 1872 |
+
typedef typename TensorView::Element T;
|
| 1873 |
+
|
| 1874 |
+
/// Coordinate in tensor's index space
|
| 1875 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1876 |
+
|
| 1877 |
+
/// Parameters structure
|
| 1878 |
+
struct Params {
|
| 1879 |
+
|
| 1880 |
+
//
|
| 1881 |
+
// Data members
|
| 1882 |
+
//
|
| 1883 |
+
|
| 1884 |
+
TensorView view;
|
| 1885 |
+
Element const *ptr;
|
| 1886 |
+
|
| 1887 |
+
/// Default ctor
|
| 1888 |
+
CUTLASS_HOST_DEVICE
|
| 1889 |
+
Params() { }
|
| 1890 |
+
|
| 1891 |
+
//
|
| 1892 |
+
// Methods
|
| 1893 |
+
//
|
| 1894 |
+
|
| 1895 |
+
/// Construction of Gaussian RNG functor.
|
| 1896 |
+
Params(
|
| 1897 |
+
TensorView view_, ///< destination tensor
|
| 1898 |
+
Element const *ptr_
|
| 1899 |
+
):
|
| 1900 |
+
view(view_), ptr(ptr_) {
|
| 1901 |
+
|
| 1902 |
+
}
|
| 1903 |
+
};
|
| 1904 |
+
|
| 1905 |
+
//
|
| 1906 |
+
// Data members
|
| 1907 |
+
//
|
| 1908 |
+
|
| 1909 |
+
/// Parameters object
|
| 1910 |
+
Params params;
|
| 1911 |
+
|
| 1912 |
+
//
|
| 1913 |
+
// Methods
|
| 1914 |
+
//
|
| 1915 |
+
|
| 1916 |
+
/// Device-side initialization of RNG
|
| 1917 |
+
CUTLASS_DEVICE
|
| 1918 |
+
TensorCopyDiagonalInFunc(Params const ¶ms): params(params) {
|
| 1919 |
+
|
| 1920 |
+
}
|
| 1921 |
+
|
| 1922 |
+
/// Only update the diagonal element
|
| 1923 |
+
CUTLASS_DEVICE
|
| 1924 |
+
void operator()(TensorCoord const &coord) {
|
| 1925 |
+
bool is_diagonal = true;
|
| 1926 |
+
|
| 1927 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1928 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1929 |
+
if (coord[i] != coord[0]) {
|
| 1930 |
+
is_diagonal = false;
|
| 1931 |
+
}
|
| 1932 |
+
}
|
| 1933 |
+
if (is_diagonal) {
|
| 1934 |
+
params.view.at(coord) = params.ptr[coord[0]];
|
| 1935 |
+
}
|
| 1936 |
+
}
|
| 1937 |
+
};
|
| 1938 |
+
|
| 1939 |
+
} // namespace detail
|
| 1940 |
+
|
| 1941 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1942 |
+
|
| 1943 |
+
/// Copies a diagonal in from host memory without modifying off-diagonal elements.
|
| 1944 |
+
template <
|
| 1945 |
+
typename Element, ///< Element type
|
| 1946 |
+
typename Layout> ///< Layout function
|
| 1947 |
+
void TensorCopyDiagonalIn(
|
| 1948 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1949 |
+
Element const *ptr, ///< dense buffer of elements
|
| 1950 |
+
cudaStream_t stream = nullptr) {
|
| 1951 |
+
|
| 1952 |
+
using Func = detail::TensorCopyDiagonalInFunc<Element, Layout>;
|
| 1953 |
+
using Params = typename Func::Params;
|
| 1954 |
+
|
| 1955 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 1956 |
+
view.extent(),
|
| 1957 |
+
Params(view, ptr),
|
| 1958 |
+
/*grid_size*/0, /*block_size*/0,
|
| 1959 |
+
stream
|
| 1960 |
+
);
|
| 1961 |
+
}
|
| 1962 |
+
|
| 1963 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1964 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1965 |
+
|
| 1966 |
+
|
| 1967 |
+
namespace detail {
|
| 1968 |
+
|
| 1969 |
+
/// Computes a random Gaussian distribution
|
| 1970 |
+
template <
|
| 1971 |
+
typename Element, ///< Element type
|
| 1972 |
+
typename Layout> ///< Layout function
|
| 1973 |
+
struct TensorCopyDiagonalOutFunc {
|
| 1974 |
+
|
| 1975 |
+
/// View type
|
| 1976 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1977 |
+
|
| 1978 |
+
/// Scalar type
|
| 1979 |
+
typedef typename TensorView::Element T;
|
| 1980 |
+
|
| 1981 |
+
/// Coordinate in tensor's index space
|
| 1982 |
+
typedef typename TensorView::TensorCoord TensorCoord;
|
| 1983 |
+
|
| 1984 |
+
/// Parameters structure
|
| 1985 |
+
struct Params {
|
| 1986 |
+
|
| 1987 |
+
//
|
| 1988 |
+
// Data members
|
| 1989 |
+
//
|
| 1990 |
+
|
| 1991 |
+
TensorView view;
|
| 1992 |
+
Element *ptr;
|
| 1993 |
+
|
| 1994 |
+
/// Default ctor
|
| 1995 |
+
CUTLASS_HOST_DEVICE
|
| 1996 |
+
Params() { }
|
| 1997 |
+
|
| 1998 |
+
//
|
| 1999 |
+
// Methods
|
| 2000 |
+
//
|
| 2001 |
+
|
| 2002 |
+
/// Construction of Gaussian RNG functor.
|
| 2003 |
+
Params(
|
| 2004 |
+
TensorView view_, ///< destination tensor
|
| 2005 |
+
Element *ptr_
|
| 2006 |
+
):
|
| 2007 |
+
view(view_), ptr(ptr_) {
|
| 2008 |
+
|
| 2009 |
+
}
|
| 2010 |
+
};
|
| 2011 |
+
|
| 2012 |
+
//
|
| 2013 |
+
// Data members
|
| 2014 |
+
//
|
| 2015 |
+
|
| 2016 |
+
/// Parameters object
|
| 2017 |
+
Params params;
|
| 2018 |
+
|
| 2019 |
+
//
|
| 2020 |
+
// Methods
|
| 2021 |
+
//
|
| 2022 |
+
|
| 2023 |
+
/// Device-side initialization of RNG
|
| 2024 |
+
CUTLASS_DEVICE
|
| 2025 |
+
TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) {
|
| 2026 |
+
|
| 2027 |
+
}
|
| 2028 |
+
|
| 2029 |
+
/// Compute random value and update RNG state
|
| 2030 |
+
CUTLASS_DEVICE
|
| 2031 |
+
void operator()(TensorCoord const &coord) {
|
| 2032 |
+
bool is_diagonal = true;
|
| 2033 |
+
|
| 2034 |
+
CUTLASS_PRAGMA_UNROLL
|
| 2035 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 2036 |
+
if (coord[i] != coord[0]) {
|
| 2037 |
+
is_diagonal = false;
|
| 2038 |
+
}
|
| 2039 |
+
}
|
| 2040 |
+
if (is_diagonal) {
|
| 2041 |
+
params.ptr[coord[0]] = params.view.at(coord);
|
| 2042 |
+
}
|
| 2043 |
+
}
|
| 2044 |
+
};
|
| 2045 |
+
|
| 2046 |
+
} // namespace detail
|
| 2047 |
+
|
| 2048 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2049 |
+
|
| 2050 |
+
/// Copies the diagonal of a tensor into a dense buffer in host memory.
|
| 2051 |
+
template <
|
| 2052 |
+
typename Element, ///< Element type
|
| 2053 |
+
typename Layout> ///< Layout function
|
| 2054 |
+
void TensorCopyDiagonalOut(
|
| 2055 |
+
Element *ptr, ///< dense buffer of elements
|
| 2056 |
+
TensorView<Element, Layout> view, ///< source tensor
|
| 2057 |
+
cudaStream_t stream = nullptr) {
|
| 2058 |
+
|
| 2059 |
+
using Func = detail::TensorCopyDiagonalOutFunc<Element, Layout>;
|
| 2060 |
+
using Params = typename Func::Params;
|
| 2061 |
+
|
| 2062 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 2063 |
+
view.extent(),
|
| 2064 |
+
Params(view, ptr),
|
| 2065 |
+
/*grid_size*/0, /*block_size*/0,
|
| 2066 |
+
stream
|
| 2067 |
+
);
|
| 2068 |
+
}
|
| 2069 |
+
|
| 2070 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2071 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 2072 |
+
|
| 2073 |
+
} // namespace device
|
| 2074 |
+
} // namespace reference
|
| 2075 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include <stdexcept>
|
| 34 |
+
#include "cutlass/cutlass.h"
|
| 35 |
+
#include "cutlass/util/reference/device/kernel/tensor_foreach.h"
|
| 36 |
+
|
| 37 |
+
namespace cutlass {
|
| 38 |
+
namespace reference {
|
| 39 |
+
namespace device {
|
| 40 |
+
|
| 41 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
/// Launches a kernel calling a functor for each element in a tensor's index space.
|
| 44 |
+
template <typename Func, int Rank, typename Params>
|
| 45 |
+
struct TensorForEach {
|
| 46 |
+
|
| 47 |
+
/// Constructor performs the operation.
|
| 48 |
+
TensorForEach(
|
| 49 |
+
Coord<Rank> size, Params params = Params(),
|
| 50 |
+
int grid_size = 0, int block_size = 0,
|
| 51 |
+
cudaStream_t stream = nullptr) {
|
| 52 |
+
|
| 53 |
+
if (!grid_size || !block_size) {
|
| 54 |
+
|
| 55 |
+
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 56 |
+
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 57 |
+
&grid_size,
|
| 58 |
+
&block_size,
|
| 59 |
+
reinterpret_cast<void const *>(kernel::TensorForEach<Func, Rank, Params>));
|
| 60 |
+
|
| 61 |
+
if (result != cudaSuccess) {
|
| 62 |
+
throw std::runtime_error("Failed to query occupancy.");
|
| 63 |
+
}
|
| 64 |
+
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 65 |
+
// single thread and reduces the impact of initialization overhead.
|
| 66 |
+
block_size = (block_size < 128 ? block_size : 128);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
dim3 grid(grid_size, 1, 1);
|
| 70 |
+
dim3 block(block_size, 1, 1);
|
| 71 |
+
|
| 72 |
+
kernel::TensorForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(size, params);
|
| 73 |
+
}
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 77 |
+
|
| 78 |
+
/// Launches a kernel calling a functor for each element along a tensor's diagonal
|
| 79 |
+
template <typename Func, int Rank, typename Params>
|
| 80 |
+
struct TensorDiagonalForEach {
|
| 81 |
+
|
| 82 |
+
/// Constructor performs the operation
|
| 83 |
+
TensorDiagonalForEach(
|
| 84 |
+
Coord<Rank> size, Params params = Params(),
|
| 85 |
+
int start = 0, int end = -1,
|
| 86 |
+
int block_size = 128, cudaStream_t stream = nullptr) {
|
| 87 |
+
|
| 88 |
+
if (end < 0) {
|
| 89 |
+
end = size.min();
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
dim3 block(block_size, 1, 1);
|
| 93 |
+
dim3 grid((end - start + block_size - 1) / block_size, 1, 1);
|
| 94 |
+
|
| 95 |
+
kernel::TensorDiagonalForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(
|
| 96 |
+
size, params, start, end);
|
| 97 |
+
}
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 102 |
+
|
| 103 |
+
template <typename Element, typename Func>
|
| 104 |
+
struct BlockForEach {
|
| 105 |
+
|
| 106 |
+
/// Constructor performs the operation.
|
| 107 |
+
BlockForEach(
|
| 108 |
+
Element *ptr,
|
| 109 |
+
size_t capacity,
|
| 110 |
+
typename Func::Params params = typename Func::Params(),
|
| 111 |
+
int grid_size = 0,
|
| 112 |
+
int block_size = 0,
|
| 113 |
+
cudaStream_t stream = nullptr) {
|
| 114 |
+
|
| 115 |
+
if (!grid_size || !block_size) {
|
| 116 |
+
|
| 117 |
+
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
|
| 118 |
+
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
|
| 119 |
+
&grid_size,
|
| 120 |
+
&block_size,
|
| 121 |
+
reinterpret_cast<void const *>(kernel::BlockForEach<Element, Func>));
|
| 122 |
+
|
| 123 |
+
if (result != cudaSuccess) {
|
| 124 |
+
throw std::runtime_error("Failed to query occupancy.");
|
| 125 |
+
}
|
| 126 |
+
// Limit block size. This has the effect of increasing the number of items processed by a
|
| 127 |
+
// single thread and reduces the impact of initialization overhead.
|
| 128 |
+
block_size = (block_size < 128 ? block_size : 128);
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
dim3 grid(grid_size, 1, 1);
|
| 132 |
+
dim3 block(block_size, 1, 1);
|
| 133 |
+
|
| 134 |
+
kernel::BlockForEach<Element, Func><<< grid, block, 0, stream >>>(ptr, capacity, params);
|
| 135 |
+
}
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 139 |
+
|
| 140 |
+
} // namespace device
|
| 141 |
+
} // namespace reference
|
| 142 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include <cmath>
|
| 34 |
+
|
| 35 |
+
#include "cutlass/cutlass.h"
|
| 36 |
+
#include "cutlass/complex.h"
|
| 37 |
+
#include "cutlass/functional.h"
|
| 38 |
+
#include "cutlass/numeric_conversion.h"
|
| 39 |
+
#include "cutlass/tensor_view.h"
|
| 40 |
+
#include "cutlass/util/device_memory.h"
|
| 41 |
+
#include "cutlass/util/reference/detail/linear_to_coordinate.h"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace reference {
|
| 47 |
+
namespace device {
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace kernel {
|
| 52 |
+
|
| 53 |
+
template <
|
| 54 |
+
typename Element,
|
| 55 |
+
typename Layout,
|
| 56 |
+
typename ComputeType,
|
| 57 |
+
typename ReduceOp,
|
| 58 |
+
typename TransformOp,
|
| 59 |
+
int kBlockSize = 128
|
| 60 |
+
>
|
| 61 |
+
__global__ void TensorTransformReducePartial(
|
| 62 |
+
TensorView<Element, Layout> view, /// View of the tensor to reduce over
|
| 63 |
+
ComputeType identity, /// Identity element of the reduction operation
|
| 64 |
+
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 65 |
+
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 66 |
+
ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 67 |
+
|
| 68 |
+
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 69 |
+
int64_t size = view.size();
|
| 70 |
+
|
| 71 |
+
__shared__ ComputeType scratchpad[kBlockSize];
|
| 72 |
+
|
| 73 |
+
for (; idx < size; idx += blockDim.x * gridDim.x) {
|
| 74 |
+
|
| 75 |
+
// Map linear thread ID onto tensor coordinate
|
| 76 |
+
typename Layout::TensorCoord coord;
|
| 77 |
+
|
| 78 |
+
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view.extent());
|
| 79 |
+
|
| 80 |
+
if (view.contains(coord)) {
|
| 81 |
+
|
| 82 |
+
// Fetch element
|
| 83 |
+
Element x = view.at(coord);
|
| 84 |
+
|
| 85 |
+
// Transform
|
| 86 |
+
identity = reduce(identity, transform(x));
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
scratchpad[threadIdx.x] = identity;
|
| 91 |
+
|
| 92 |
+
__syncthreads();
|
| 93 |
+
|
| 94 |
+
// One thread performs the final reduction and stores out. This could be enhanced via
|
| 95 |
+
// a tree reduction and pipelining.
|
| 96 |
+
if (threadIdx.x == 0) {
|
| 97 |
+
|
| 98 |
+
for (int i = 1; i < kBlockSize; ++i) {
|
| 99 |
+
identity = reduce(identity, scratchpad[i]);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
workspace[blockIdx.x] = identity;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template <
|
| 107 |
+
typename Element,
|
| 108 |
+
typename Layout,
|
| 109 |
+
typename ComputeType,
|
| 110 |
+
typename ReduceOp,
|
| 111 |
+
typename TransformOp,
|
| 112 |
+
int kBlockSize = 128
|
| 113 |
+
>
|
| 114 |
+
__global__ void TensorTransformReducePartial(
|
| 115 |
+
TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
|
| 116 |
+
TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
|
| 117 |
+
ComputeType identity, /// Identity element of the reduction operation
|
| 118 |
+
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 119 |
+
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 120 |
+
ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 121 |
+
|
| 122 |
+
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
|
| 123 |
+
auto size = static_cast<int64_t>(view_A.size());
|
| 124 |
+
|
| 125 |
+
__shared__ ComputeType scratchpad[kBlockSize];
|
| 126 |
+
|
| 127 |
+
for (; idx < size; idx += blockDim.x * gridDim.x) {
|
| 128 |
+
|
| 129 |
+
// Map linear thread ID onto tensor coordinate
|
| 130 |
+
typename Layout::TensorCoord coord;
|
| 131 |
+
|
| 132 |
+
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view_A.extent());
|
| 133 |
+
|
| 134 |
+
if (view_A.contains(coord)) {
|
| 135 |
+
|
| 136 |
+
// Fetch element
|
| 137 |
+
Element a = view_A.at(coord);
|
| 138 |
+
Element b = view_B.at(coord);
|
| 139 |
+
|
| 140 |
+
// Transform
|
| 141 |
+
identity = reduce(identity, transform(a, b));
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
scratchpad[threadIdx.x] = identity;
|
| 146 |
+
|
| 147 |
+
__syncthreads();
|
| 148 |
+
|
| 149 |
+
// One thread performs the final reduction and stores out. This could be enhanced via
|
| 150 |
+
// a tree reduction and pipelining.
|
| 151 |
+
if (threadIdx.x == 0) {
|
| 152 |
+
|
| 153 |
+
for (int i = 1; i < kBlockSize; ++i) {
|
| 154 |
+
identity = reduce(identity, scratchpad[i]);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
workspace[blockIdx.x] = identity;
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
template <
|
| 163 |
+
typename ComputeType,
|
| 164 |
+
typename ReduceOp,
|
| 165 |
+
int kBlockSize = 32
|
| 166 |
+
>
|
| 167 |
+
__global__ void TensorTransformReduceFinalize(
|
| 168 |
+
ComputeType *workspace,
|
| 169 |
+
ComputeType identity,
|
| 170 |
+
int workspace_size,
|
| 171 |
+
ReduceOp reduce) {
|
| 172 |
+
|
| 173 |
+
__shared__ ComputeType scratchpad[kBlockSize];
|
| 174 |
+
|
| 175 |
+
for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) {
|
| 176 |
+
identity = reduce(identity, workspace[idx]);
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
scratchpad[threadIdx.x] = identity;
|
| 180 |
+
|
| 181 |
+
__syncthreads();
|
| 182 |
+
|
| 183 |
+
if (threadIdx.x == 0) {
|
| 184 |
+
|
| 185 |
+
for (int i = 1; i < kBlockSize; ++i) {
|
| 186 |
+
identity = reduce(identity, scratchpad[i]);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
workspace[0] = identity;
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
} // namespace kernel
|
| 194 |
+
|
| 195 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 196 |
+
|
| 197 |
+
/// Transform-reduce operation over the elements of a tensor
|
| 198 |
+
template <
|
| 199 |
+
typename Element,
|
| 200 |
+
typename Layout,
|
| 201 |
+
typename ComputeType,
|
| 202 |
+
typename ReduceOp,
|
| 203 |
+
typename TransformOp
|
| 204 |
+
>
|
| 205 |
+
ComputeType TensorTransformReduce(
|
| 206 |
+
TensorView<Element, Layout> view, /// View of the tensor to reduce over
|
| 207 |
+
ComputeType identity, /// Identity element of the reduction operation
|
| 208 |
+
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 209 |
+
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 210 |
+
ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 211 |
+
int workspace_size, /// Number of elements in workspace
|
| 212 |
+
cudaStream_t stream = nullptr, /// CUDA stream to launch into
|
| 213 |
+
bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
|
| 214 |
+
) {
|
| 215 |
+
|
| 216 |
+
int const kBlockSize = 128;
|
| 217 |
+
|
| 218 |
+
dim3 block(kBlockSize, 1);
|
| 219 |
+
dim3 grid(workspace_size, 1);
|
| 220 |
+
|
| 221 |
+
kernel::TensorTransformReducePartial<
|
| 222 |
+
Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
|
| 223 |
+
><<< grid, block, 0, stream >>>(
|
| 224 |
+
view, identity, reduce, transform, workspace
|
| 225 |
+
);
|
| 226 |
+
|
| 227 |
+
int const kFinalizeBlockSize = 32;
|
| 228 |
+
|
| 229 |
+
kernel::TensorTransformReduceFinalize<
|
| 230 |
+
ComputeType, ReduceOp, kFinalizeBlockSize
|
| 231 |
+
><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
|
| 232 |
+
workspace, identity, workspace_size, reduce
|
| 233 |
+
);
|
| 234 |
+
|
| 235 |
+
cudaStreamSynchronize(stream);
|
| 236 |
+
|
| 237 |
+
if (copy_out) {
|
| 238 |
+
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
|
| 239 |
+
if (result != cudaSuccess) {
|
| 240 |
+
throw std::runtime_error("cudaMemcpy() failed");
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
return identity;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/// Transform-reduce operation over the elements of two tensors, zipped together
|
| 248 |
+
template <
|
| 249 |
+
typename Element,
|
| 250 |
+
typename Layout,
|
| 251 |
+
typename ComputeType,
|
| 252 |
+
typename ReduceOp,
|
| 253 |
+
typename TransformOp
|
| 254 |
+
>
|
| 255 |
+
ComputeType TensorTransformReduce(
|
| 256 |
+
TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
|
| 257 |
+
TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
|
| 258 |
+
ComputeType identity, /// Identity element of the reduction operation
|
| 259 |
+
ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
|
| 260 |
+
TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
|
| 261 |
+
ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
|
| 262 |
+
int workspace_size, /// Number of elements in workspace
|
| 263 |
+
cudaStream_t stream = nullptr, /// CUDA stream to launch into
|
| 264 |
+
bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
|
| 265 |
+
) {
|
| 266 |
+
|
| 267 |
+
if (view_A.extent() != view_B.extent()) {
|
| 268 |
+
throw std::runtime_error("Extents must be equal.");
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
int const kBlockSize = 128;
|
| 272 |
+
|
| 273 |
+
dim3 block(kBlockSize, 1);
|
| 274 |
+
dim3 grid(workspace_size, 1);
|
| 275 |
+
|
| 276 |
+
kernel::TensorTransformReducePartial<
|
| 277 |
+
Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
|
| 278 |
+
><<< grid, block, 0, stream >>>(
|
| 279 |
+
view_A, view_B, identity, reduce, transform, workspace
|
| 280 |
+
);
|
| 281 |
+
|
| 282 |
+
int const kFinalizeBlockSize = 32;
|
| 283 |
+
|
| 284 |
+
kernel::TensorTransformReduceFinalize<
|
| 285 |
+
ComputeType, ReduceOp, kFinalizeBlockSize
|
| 286 |
+
><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
|
| 287 |
+
workspace, identity, workspace_size, reduce
|
| 288 |
+
);
|
| 289 |
+
|
| 290 |
+
cudaStreamSynchronize(stream);
|
| 291 |
+
|
| 292 |
+
if (copy_out) {
|
| 293 |
+
cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
|
| 294 |
+
if (result != cudaSuccess) {
|
| 295 |
+
throw std::runtime_error("cudaMemcpy() failed");
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
return identity;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 303 |
+
/// workspace
|
| 304 |
+
template <
|
| 305 |
+
typename Element,
|
| 306 |
+
typename Layout,
|
| 307 |
+
typename ComputeType,
|
| 308 |
+
typename ReduceOp,
|
| 309 |
+
typename TransformOp
|
| 310 |
+
>
|
| 311 |
+
ComputeType TensorTransformReduce(
|
| 312 |
+
TensorView<Element, Layout> view,
|
| 313 |
+
ComputeType identity,
|
| 314 |
+
ReduceOp reduce,
|
| 315 |
+
TransformOp transform,
|
| 316 |
+
cudaStream_t stream = nullptr,
|
| 317 |
+
int workspace_size = 0
|
| 318 |
+
) {
|
| 319 |
+
|
| 320 |
+
// Optionally query for the SM count to size the workspace.
|
| 321 |
+
if (!workspace_size) {
|
| 322 |
+
|
| 323 |
+
int device_idx = 0;
|
| 324 |
+
cudaDeviceProp prop;
|
| 325 |
+
|
| 326 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 327 |
+
if (result != cudaSuccess) {
|
| 328 |
+
throw std::runtime_error("cudaGetDevice() failed");
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
result = cudaGetDeviceProperties(&prop, device_idx);
|
| 332 |
+
if (result != cudaSuccess) {
|
| 333 |
+
throw std::runtime_error("cudaGetDeviceProp() failed");
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
workspace_size = int(prop.multiProcessorCount);
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
DeviceAllocation<ComputeType> workspace(workspace_size);
|
| 340 |
+
|
| 341 |
+
ComputeType output = TensorTransformReduce(
|
| 342 |
+
view,
|
| 343 |
+
identity,
|
| 344 |
+
reduce,
|
| 345 |
+
transform,
|
| 346 |
+
workspace.get(),
|
| 347 |
+
workspace_size,
|
| 348 |
+
stream,
|
| 349 |
+
true);
|
| 350 |
+
|
| 351 |
+
return output;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 356 |
+
/// workspace
|
| 357 |
+
template <
|
| 358 |
+
typename Element,
|
| 359 |
+
typename Layout,
|
| 360 |
+
typename ComputeType,
|
| 361 |
+
typename ReduceOp,
|
| 362 |
+
typename TransformOp
|
| 363 |
+
>
|
| 364 |
+
ComputeType TensorTransformReduce(
|
| 365 |
+
TensorView<Element, Layout> view_A,
|
| 366 |
+
TensorView<Element, Layout> view_B,
|
| 367 |
+
ComputeType identity,
|
| 368 |
+
ReduceOp reduce,
|
| 369 |
+
TransformOp transform,
|
| 370 |
+
cudaStream_t stream = nullptr,
|
| 371 |
+
int workspace_size = 0
|
| 372 |
+
) {
|
| 373 |
+
|
| 374 |
+
// Optionally query for the SM count to size the workspace.
|
| 375 |
+
if (!workspace_size) {
|
| 376 |
+
|
| 377 |
+
int device_idx = 0;
|
| 378 |
+
cudaDeviceProp prop;
|
| 379 |
+
|
| 380 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 381 |
+
if (result != cudaSuccess) {
|
| 382 |
+
throw std::runtime_error("cudaGetDevice() failed");
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
result = cudaGetDeviceProperties(&prop, device_idx);
|
| 386 |
+
if (result != cudaSuccess) {
|
| 387 |
+
throw std::runtime_error("cudaGetDeviceProp() failed");
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
workspace_size = int(prop.multiProcessorCount);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
DeviceAllocation<ComputeType> workspace(workspace_size);
|
| 394 |
+
|
| 395 |
+
ComputeType output = TensorTransformReduce(
|
| 396 |
+
view_A,
|
| 397 |
+
view_B,
|
| 398 |
+
identity,
|
| 399 |
+
reduce,
|
| 400 |
+
transform,
|
| 401 |
+
workspace.get(),
|
| 402 |
+
workspace_size,
|
| 403 |
+
stream,
|
| 404 |
+
true);
|
| 405 |
+
|
| 406 |
+
return output;
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 410 |
+
|
| 411 |
+
/// Helper to compute the sum of the elements of a tensor
|
| 412 |
+
template <
|
| 413 |
+
typename Element,
|
| 414 |
+
typename Layout,
|
| 415 |
+
typename ComputeType = Element
|
| 416 |
+
>
|
| 417 |
+
ComputeType TensorSum(
|
| 418 |
+
TensorView<Element, Layout> view,
|
| 419 |
+
ComputeType identity = ComputeType(),
|
| 420 |
+
cudaStream_t stream = nullptr,
|
| 421 |
+
int workspace_size = 0
|
| 422 |
+
) {
|
| 423 |
+
|
| 424 |
+
plus<ComputeType> reduce;
|
| 425 |
+
NumericConverter<ComputeType, Element> transform;
|
| 426 |
+
|
| 427 |
+
return TensorTransformReduce(
|
| 428 |
+
view, identity, reduce, transform, stream, workspace_size);
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// Helper to compute the sum of the squares of the elements of a tensor
|
| 432 |
+
template <
|
| 433 |
+
typename Element,
|
| 434 |
+
typename Layout,
|
| 435 |
+
typename ComputeType = Element
|
| 436 |
+
>
|
| 437 |
+
ComputeType TensorSumSq(
|
| 438 |
+
TensorView<Element, Layout> view,
|
| 439 |
+
ComputeType identity = ComputeType(),
|
| 440 |
+
cudaStream_t stream = nullptr,
|
| 441 |
+
int workspace_size = 0
|
| 442 |
+
) {
|
| 443 |
+
|
| 444 |
+
plus<ComputeType> reduce;
|
| 445 |
+
magnitude_squared<Element, ComputeType> transform;
|
| 446 |
+
|
| 447 |
+
return TensorTransformReduce(
|
| 448 |
+
view, identity, reduce, transform, stream, workspace_size);
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
/// Helper to compute the norm of the elements of a tensor.
|
| 452 |
+
template <
|
| 453 |
+
typename Element,
|
| 454 |
+
typename Layout,
|
| 455 |
+
typename ComputeType = double
|
| 456 |
+
>
|
| 457 |
+
ComputeType TensorNorm(
|
| 458 |
+
TensorView<Element, Layout> view,
|
| 459 |
+
ComputeType identity = ComputeType(),
|
| 460 |
+
cudaStream_t stream = nullptr,
|
| 461 |
+
int workspace_size = 0
|
| 462 |
+
) {
|
| 463 |
+
|
| 464 |
+
return std::sqrt(TensorSumSq(view, identity, stream, workspace_size));
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 468 |
+
|
| 469 |
+
/// Helper to compute the sum of the squares of the differences of two tensors
|
| 470 |
+
template <
|
| 471 |
+
typename Element,
|
| 472 |
+
typename Layout,
|
| 473 |
+
typename ComputeType = double
|
| 474 |
+
>
|
| 475 |
+
ComputeType TensorSumSqDiff(
|
| 476 |
+
TensorView<Element, Layout> view_A,
|
| 477 |
+
TensorView<Element, Layout> view_B,
|
| 478 |
+
ComputeType identity = ComputeType(),
|
| 479 |
+
cudaStream_t stream = nullptr,
|
| 480 |
+
int workspace_size = 0
|
| 481 |
+
) {
|
| 482 |
+
|
| 483 |
+
plus<ComputeType> reduce;
|
| 484 |
+
magnitude_squared_difference<Element, ComputeType> transform;
|
| 485 |
+
|
| 486 |
+
return TensorTransformReduce(
|
| 487 |
+
view_A, view_B, identity, reduce, transform, stream, workspace_size);
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
|
| 492 |
+
template <
|
| 493 |
+
typename Element,
|
| 494 |
+
typename Layout,
|
| 495 |
+
typename ComputeType = double
|
| 496 |
+
>
|
| 497 |
+
ComputeType TensorNormDiff(
|
| 498 |
+
TensorView<Element, Layout> view_A,
|
| 499 |
+
TensorView<Element, Layout> view_B,
|
| 500 |
+
ComputeType identity = ComputeType(),
|
| 501 |
+
cudaStream_t stream = nullptr,
|
| 502 |
+
int workspace_size = 0
|
| 503 |
+
) {
|
| 504 |
+
|
| 505 |
+
return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size));
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 509 |
+
|
| 510 |
+
} // namespace device
|
| 511 |
+
} // namespace reference
|
| 512 |
+
} // namespace cutlass
|
| 513 |
+
|
| 514 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Defines device-side elementwise operations on TensorView. Note, the operations defined
|
| 33 |
+
in this header are not specialized for any particular data layout and are therefore not
|
| 34 |
+
intended to offer the best possible performance. Rather, they are intended to be generic
|
| 35 |
+
reference implementations to support the CUTLASS unit tests.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
// Cutlass includes
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/tensor_view.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/util/reference/device/tensor_foreach.h"
|
| 45 |
+
|
| 46 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reference {
|
| 50 |
+
namespace device {
|
| 51 |
+
|
| 52 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
namespace detail {
|
| 56 |
+
|
| 57 |
+
template <
|
| 58 |
+
typename Element, ///< Element type
|
| 59 |
+
typename Layout> ///< Layout function
|
| 60 |
+
struct TensorReLuFunc {
|
| 61 |
+
|
| 62 |
+
/// View type
|
| 63 |
+
using TensorView = TensorView<Element, Layout>;
|
| 64 |
+
|
| 65 |
+
/// Coordinate in tensor's index space
|
| 66 |
+
using TensorCoord = typename TensorView::TensorCoord;
|
| 67 |
+
|
| 68 |
+
/// Parameters structure
|
| 69 |
+
struct Params {
|
| 70 |
+
|
| 71 |
+
//
|
| 72 |
+
// Data members
|
| 73 |
+
//
|
| 74 |
+
|
| 75 |
+
TensorView view;
|
| 76 |
+
Element threshold;
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
//
|
| 80 |
+
// Methods
|
| 81 |
+
//
|
| 82 |
+
|
| 83 |
+
Params(
|
| 84 |
+
TensorView view_ = TensorView(),
|
| 85 |
+
Element threshold_ = Element(0)
|
| 86 |
+
):
|
| 87 |
+
view(view_), threshold(threshold_) {
|
| 88 |
+
|
| 89 |
+
}
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
//
|
| 93 |
+
// Data members
|
| 94 |
+
//
|
| 95 |
+
|
| 96 |
+
Params params;
|
| 97 |
+
|
| 98 |
+
//
|
| 99 |
+
// Methods
|
| 100 |
+
//
|
| 101 |
+
|
| 102 |
+
CUTLASS_DEVICE
|
| 103 |
+
TensorReLuFunc(Params const ¶ms): params(params) {
|
| 104 |
+
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
CUTLASS_DEVICE
|
| 108 |
+
void operator()(TensorCoord const &coord) {
|
| 109 |
+
|
| 110 |
+
Element const & value = params.view.at(coord);
|
| 111 |
+
params.view.at(coord) = (value < params.threshold) ? params.threshold : value;
|
| 112 |
+
}
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
} // namespace detail
|
| 116 |
+
|
| 117 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 118 |
+
|
| 119 |
+
/// Apply ReLu on a tensor
|
| 120 |
+
template <
|
| 121 |
+
typename Element, ///< Element type
|
| 122 |
+
typename Layout> ///< Layout function
|
| 123 |
+
void TensorReLu(
|
| 124 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 125 |
+
Element threshold = Element(0)) { ///< ReLu threshold
|
| 126 |
+
|
| 127 |
+
using Func = detail::TensorReLuFunc<Element, Layout>;
|
| 128 |
+
using Params = typename Func::Params;
|
| 129 |
+
|
| 130 |
+
TensorForEach<Func, Layout::kRank, Params>(
|
| 131 |
+
view.extent(),
|
| 132 |
+
Params(view, threshold)
|
| 133 |
+
);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 137 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 138 |
+
|
| 139 |
+
} // namespace device
|
| 140 |
+
} // namespace reference
|
| 141 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/tensor_view.h"
|
| 39 |
+
#include "cutlass/gemm/gemm.h"
|
| 40 |
+
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace reference {
|
| 43 |
+
namespace device {
|
| 44 |
+
namespace thread {
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
/// Thread-level blocked general matrix product.
|
| 49 |
+
//
|
| 50 |
+
// Note, this is a reference implementation. Performance is not expected to approach peak.
|
| 51 |
+
//
|
| 52 |
+
template <
|
| 53 |
+
typename TensorRefA,
|
| 54 |
+
typename TensorRefB,
|
| 55 |
+
typename TensorRefC,
|
| 56 |
+
typename ScalarType,
|
| 57 |
+
typename AccumulatorType,
|
| 58 |
+
typename OutputTile,
|
| 59 |
+
typename InnerProductOp = multiply_add<AccumulatorType>,
|
| 60 |
+
typename ConvertOp = NumericConverter<typename TensorRefC::Element, ScalarType>
|
| 61 |
+
>
|
| 62 |
+
struct Gemm {
|
| 63 |
+
|
| 64 |
+
using ElementA = typename TensorRefA::Element;
|
| 65 |
+
using ElementB = typename TensorRefB::Element;
|
| 66 |
+
using ElementC = typename TensorRefC::Element;
|
| 67 |
+
|
| 68 |
+
//
|
| 69 |
+
// Data members
|
| 70 |
+
//
|
| 71 |
+
|
| 72 |
+
/// Tile for A operand
|
| 73 |
+
ElementA A_tile[OutputTile::kColumn];
|
| 74 |
+
|
| 75 |
+
/// Tile for B operand
|
| 76 |
+
ElementB B_tile[OutputTile::kRow];
|
| 77 |
+
|
| 78 |
+
/// Tile for Accumulator
|
| 79 |
+
AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow];
|
| 80 |
+
|
| 81 |
+
//
|
| 82 |
+
// Methods
|
| 83 |
+
//
|
| 84 |
+
|
| 85 |
+
/// Constructor
|
| 86 |
+
CUTLASS_HOST_DEVICE
|
| 87 |
+
Gemm(AccumulatorType initial_accum = AccumulatorType(0)) {
|
| 88 |
+
|
| 89 |
+
// Clear fetch registers
|
| 90 |
+
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 91 |
+
A_tile[i] = ElementA(0);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 95 |
+
B_tile[j] = ElementB(0);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
// Clear accumulators
|
| 99 |
+
CUTLASS_PRAGMA_UNROLL
|
| 100 |
+
for (int j = 0; j < OutputTile::kColumn; ++j) {
|
| 101 |
+
CUTLASS_PRAGMA_UNROLL
|
| 102 |
+
for (int i = 0; i < OutputTile::kRow; ++i) {
|
| 103 |
+
accum[j][i] = initial_accum;
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
/// Computes a matrix product
|
| 109 |
+
CUTLASS_HOST_DEVICE
|
| 110 |
+
Gemm & multiply_add(
|
| 111 |
+
gemm::GemmCoord problem_size,
|
| 112 |
+
TensorRefA tensor_a,
|
| 113 |
+
TensorRefB tensor_b,
|
| 114 |
+
MatrixCoord output_coord = MatrixCoord()) {
|
| 115 |
+
|
| 116 |
+
InnerProductOp inner_product_op;
|
| 117 |
+
|
| 118 |
+
// Loop over the GEMM K dimension
|
| 119 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 120 |
+
for (int k = 0; k < problem_size.k(); ++k) {
|
| 121 |
+
|
| 122 |
+
// Fetch a slice of the A matrix
|
| 123 |
+
CUTLASS_PRAGMA_UNROLL
|
| 124 |
+
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 125 |
+
if (output_coord.row() + i < problem_size.m()) {
|
| 126 |
+
A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k));
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
// Fetch a slice of the B matrix
|
| 131 |
+
CUTLASS_PRAGMA_UNROLL
|
| 132 |
+
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 133 |
+
if (output_coord.column() + j < problem_size.n()) {
|
| 134 |
+
B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j));
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Compute an accumulated matrix product
|
| 139 |
+
CUTLASS_PRAGMA_UNROLL
|
| 140 |
+
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 141 |
+
CUTLASS_PRAGMA_UNROLL
|
| 142 |
+
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 143 |
+
accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
return *this;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/// Performs linear scaling of matrix product and updates output tensor
|
| 152 |
+
CUTLASS_HOST_DEVICE
|
| 153 |
+
Gemm & epilogue(
|
| 154 |
+
gemm::GemmCoord problem_size,
|
| 155 |
+
ScalarType alpha,
|
| 156 |
+
ScalarType beta,
|
| 157 |
+
TensorRefC tensor_c,
|
| 158 |
+
TensorRefC tensor_d,
|
| 159 |
+
MatrixCoord output_coord = MatrixCoord()) {
|
| 160 |
+
|
| 161 |
+
ConvertOp convert_op;
|
| 162 |
+
|
| 163 |
+
// Update the output tensor
|
| 164 |
+
for (int j = 0; j < OutputTile::kRow; ++j) {
|
| 165 |
+
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
| 166 |
+
MatrixCoord coord = output_coord + MatrixCoord(i, j);
|
| 167 |
+
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
| 168 |
+
|
| 169 |
+
tensor_d.at(coord) = convert_op(
|
| 170 |
+
alpha * ScalarType(accum[j][i]) +
|
| 171 |
+
beta * ScalarType(tensor_c.at(coord))
|
| 172 |
+
);
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
return *this;
|
| 178 |
+
}
|
| 179 |
+
};
|
| 180 |
+
|
| 181 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 182 |
+
|
| 183 |
+
} // namespace thread
|
| 184 |
+
} // namespace device
|
| 185 |
+
} // namespace reference
|
| 186 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for CONV in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 37 |
+
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/numeric_conversion.h"
|
| 40 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 41 |
+
|
| 42 |
+
#include "cute/tensor.hpp"
|
| 43 |
+
|
| 44 |
+
#include <cuda_runtime.h>
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass::reference::host {
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
namespace detail {
|
| 53 |
+
|
| 54 |
+
template<class EngineAct, class LayoutAct>
|
| 55 |
+
bool
|
| 56 |
+
is_activation_in_bounds(
|
| 57 |
+
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
| 58 |
+
int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
|
| 59 |
+
return ((g_ >= 0 && g_ < size<5>(activation)) &&
|
| 60 |
+
(n_ >= 0 && n_ < size<4>(activation)) &&
|
| 61 |
+
(d_ >= 0 && d_ < size<3>(activation)) &&
|
| 62 |
+
(h_ >= 0 && h_ < size<2>(activation)) &&
|
| 63 |
+
(w_ >= 0 && w_ < size<1>(activation)) &&
|
| 64 |
+
(c_ >= 0 && c_ < size<0>(activation)));
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template<class EngineAct, class LayoutAct>
|
| 68 |
+
bool
|
| 69 |
+
is_activation_in_bounds(
|
| 70 |
+
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
| 71 |
+
int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
|
| 72 |
+
return ((g_ >= 0 && g_ < size<4>(activation)) &&
|
| 73 |
+
(n_ >= 0 && n_ < size<3>(activation)) &&
|
| 74 |
+
(h_ >= 0 && h_ < size<2>(activation)) &&
|
| 75 |
+
(w_ >= 0 && w_ < size<1>(activation)) &&
|
| 76 |
+
(c_ >= 0 && c_ < size<0>(activation)));
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
template<class EngineAct, class LayoutAct>
|
| 80 |
+
bool
|
| 81 |
+
is_activation_in_bounds(
|
| 82 |
+
cute::Tensor<EngineAct, LayoutAct> const& activation,
|
| 83 |
+
int32_t n_, int32_t w_, int32_t c_, int32_t g_) {
|
| 84 |
+
return ((g_ >= 0 && g_ < size<3>(activation)) &&
|
| 85 |
+
(n_ >= 0 && n_ < size<2>(activation)) &&
|
| 86 |
+
(w_ >= 0 && w_ < size<1>(activation)) &&
|
| 87 |
+
(c_ >= 0 && c_ < size<0>(activation)));
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
} // namespace detail
|
| 91 |
+
|
| 92 |
+
template<
|
| 93 |
+
class ElementAcc_,
|
| 94 |
+
class ElementScalar_,
|
| 95 |
+
class ElementCompute_,
|
| 96 |
+
class ElementC_,
|
| 97 |
+
class ElementOut_,
|
| 98 |
+
bool ResidualAdd_,
|
| 99 |
+
class TensorAlpha_,
|
| 100 |
+
class TensorBeta_,
|
| 101 |
+
class TensorBias_,
|
| 102 |
+
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>
|
| 103 |
+
>
|
| 104 |
+
struct ConvEpilogueFusionParams {
|
| 105 |
+
using ElementAcc = ElementAcc_;
|
| 106 |
+
using ElementScalar = ElementScalar_;
|
| 107 |
+
using ElementCompute = ElementCompute_;
|
| 108 |
+
using ElementC = ElementC_;
|
| 109 |
+
using ElementOut = ElementOut_;
|
| 110 |
+
using TensorAlpha = TensorAlpha_;
|
| 111 |
+
using TensorBeta = TensorBeta_;
|
| 112 |
+
using TensorBias = TensorBias_;
|
| 113 |
+
using ActivationFunctor = ActivationFunctor_;
|
| 114 |
+
static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation
|
| 115 |
+
|
| 116 |
+
ElementScalar alpha = ElementScalar(1);
|
| 117 |
+
ElementScalar beta = ElementScalar(0);
|
| 118 |
+
|
| 119 |
+
TensorAlpha tensor_alpha{};
|
| 120 |
+
TensorBeta tensor_beta{};
|
| 121 |
+
TensorBias tensor_bias{};
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
template<
|
| 125 |
+
cutlass::conv::Operator ConvOp,
|
| 126 |
+
int NumSpatialDims,
|
| 127 |
+
class TensorA,
|
| 128 |
+
class TensorB,
|
| 129 |
+
class TensorC,
|
| 130 |
+
class TensorD,
|
| 131 |
+
class ShapePadding,
|
| 132 |
+
class StrideTraversal,
|
| 133 |
+
class ShapeDilation,
|
| 134 |
+
class EpilogueFusionParams
|
| 135 |
+
>
|
| 136 |
+
struct ConvReferenceImpl {
|
| 137 |
+
// Hard code accumlulator type to float to avoid data lost in accumulating add.
|
| 138 |
+
using ElementAcc = cutlass::platform::conditional_t<cutlass::platform::is_same_v<typename EpilogueFusionParams::ElementAcc, double>, double, float>;
|
| 139 |
+
using ElementC = typename EpilogueFusionParams::ElementC;
|
| 140 |
+
using ElementOut = typename EpilogueFusionParams::ElementOut;
|
| 141 |
+
using ElementScalar = typename EpilogueFusionParams::ElementScalar;
|
| 142 |
+
using ElementCompute = typename EpilogueFusionParams::ElementCompute;
|
| 143 |
+
using ElementBias = typename EpilogueFusionParams::TensorBias::value_type;
|
| 144 |
+
using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor;
|
| 145 |
+
|
| 146 |
+
// Input related converter
|
| 147 |
+
NumericConverter<ElementCompute, ElementAcc> acc_converter;
|
| 148 |
+
NumericConverter<ElementCompute, ElementC> residual_converter;
|
| 149 |
+
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
| 150 |
+
// Scale related converter
|
| 151 |
+
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
| 152 |
+
// Output related converter
|
| 153 |
+
NumericConverter<ElementOut, ElementCompute> output_converter;
|
| 154 |
+
|
| 155 |
+
EpilogueFusionParams& epi_fusion_params_;
|
| 156 |
+
TensorA const& tensor_a_;
|
| 157 |
+
TensorB const& tensor_b_;
|
| 158 |
+
TensorC const& tensor_c_;
|
| 159 |
+
TensorD& tensor_d_;
|
| 160 |
+
|
| 161 |
+
ShapePadding const& padding_;
|
| 162 |
+
StrideTraversal const& tstride_;
|
| 163 |
+
ShapeDilation const& dilation_;
|
| 164 |
+
|
| 165 |
+
// Epilogue activation operation
|
| 166 |
+
ActivationFunctor epi_activation;
|
| 167 |
+
|
| 168 |
+
ConvReferenceImpl(
|
| 169 |
+
TensorA const& tensor_a,
|
| 170 |
+
TensorB const& tensor_b,
|
| 171 |
+
TensorC const& tensor_c,
|
| 172 |
+
TensorD& tensor_d,
|
| 173 |
+
ShapePadding const& padding,
|
| 174 |
+
StrideTraversal const& tstride,
|
| 175 |
+
ShapeDilation const& dilation,
|
| 176 |
+
EpilogueFusionParams& epi_fusion_params)
|
| 177 |
+
: tensor_a_(tensor_a),
|
| 178 |
+
tensor_b_(tensor_b),
|
| 179 |
+
tensor_c_(tensor_c),
|
| 180 |
+
tensor_d_(tensor_d),
|
| 181 |
+
padding_(padding),
|
| 182 |
+
tstride_(tstride),
|
| 183 |
+
dilation_(dilation),
|
| 184 |
+
epi_fusion_params_(epi_fusion_params)
|
| 185 |
+
{
|
| 186 |
+
static_assert(rank(ShapePadding{}) == rank(ShapeDilation{}));
|
| 187 |
+
static_assert(rank(ShapePadding{}) == rank(StrideTraversal{}));
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
void compute_reference() {
|
| 191 |
+
if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
|
| 192 |
+
fprop_reference(cute::Int<NumSpatialDims>{});
|
| 193 |
+
}
|
| 194 |
+
else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 195 |
+
dgrad_reference(cute::Int<NumSpatialDims>{});
|
| 196 |
+
}
|
| 197 |
+
else {
|
| 198 |
+
wgrad_reference(cute::Int<NumSpatialDims>{});
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
private:
|
| 203 |
+
// Specialization for 1D fprop kernel
|
| 204 |
+
void fprop_reference(cute::Int<1> spatial_dims) {
|
| 205 |
+
int32_t G = size<3>(tensor_d_);
|
| 206 |
+
int32_t N = size<2>(tensor_d_);
|
| 207 |
+
int32_t Q = size<1>(tensor_d_);
|
| 208 |
+
int32_t K = size<0>(tensor_d_);
|
| 209 |
+
int32_t S = size<1>(tensor_b_);
|
| 210 |
+
int32_t C = size<0>(tensor_b_);
|
| 211 |
+
|
| 212 |
+
#if defined(_OPENMP)
|
| 213 |
+
#pragma omp parallel for collapse(2)
|
| 214 |
+
#endif
|
| 215 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 216 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 217 |
+
for (int32_t q = 0; q < Q; ++q) {
|
| 218 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 219 |
+
auto accumulator = ElementAcc(0);
|
| 220 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 221 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 222 |
+
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 223 |
+
if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) {
|
| 224 |
+
auto a = tensor_a_(c, w, n, g);
|
| 225 |
+
auto b = tensor_b_(c, s, k, g);
|
| 226 |
+
accumulator += ElementAcc(a * b);
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 231 |
+
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
| 232 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 233 |
+
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
| 234 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 235 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 236 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
|
| 237 |
+
}
|
| 238 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 239 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
| 240 |
+
}
|
| 241 |
+
output = epi_activation(output);
|
| 242 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 243 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
|
| 244 |
+
}
|
| 245 |
+
tensor_d_(k, q, n, g) = output_converter(output);
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
// Specialization for 2D fprop kernel
|
| 254 |
+
void fprop_reference(cute::Int<2> spatial_dims) {
|
| 255 |
+
int32_t G = size<4>(tensor_d_);
|
| 256 |
+
int32_t N = size<3>(tensor_d_);
|
| 257 |
+
int32_t P = size<2>(tensor_d_);
|
| 258 |
+
int32_t Q = size<1>(tensor_d_);
|
| 259 |
+
int32_t K = size<0>(tensor_d_);
|
| 260 |
+
int32_t R = size<2>(tensor_b_);
|
| 261 |
+
int32_t S = size<1>(tensor_b_);
|
| 262 |
+
int32_t C = size<0>(tensor_b_);
|
| 263 |
+
|
| 264 |
+
#if defined(_OPENMP)
|
| 265 |
+
#pragma omp parallel for collapse(3)
|
| 266 |
+
#endif
|
| 267 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 268 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 269 |
+
for (int32_t p = 0; p < P; ++p) {
|
| 270 |
+
for (int32_t q = 0; q < Q; ++q) {
|
| 271 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 272 |
+
auto accumulator = ElementAcc(0);
|
| 273 |
+
for (int32_t r = 0; r < R; ++r) {
|
| 274 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 275 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 276 |
+
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 277 |
+
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 278 |
+
if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) {
|
| 279 |
+
auto a = tensor_a_(c, w, h, n, g);
|
| 280 |
+
auto b = tensor_b_(c, s, r, k, g);
|
| 281 |
+
accumulator += ElementAcc(a * b);
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 287 |
+
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
| 288 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 289 |
+
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
| 290 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 291 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 292 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
|
| 293 |
+
}
|
| 294 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 295 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
| 296 |
+
}
|
| 297 |
+
output = epi_activation(output);
|
| 298 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 299 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
|
| 300 |
+
}
|
| 301 |
+
tensor_d_(k, q, p, n, g) = output_converter(output);
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// Specialization for 3D fprop kernel
|
| 311 |
+
void fprop_reference(cute::Int<3> spatial_dims) {
|
| 312 |
+
int32_t G = size<5>(tensor_d_);
|
| 313 |
+
int32_t N = size<4>(tensor_d_);
|
| 314 |
+
int32_t Z = size<3>(tensor_d_);
|
| 315 |
+
int32_t P = size<2>(tensor_d_);
|
| 316 |
+
int32_t Q = size<1>(tensor_d_);
|
| 317 |
+
int32_t K = size<0>(tensor_d_);
|
| 318 |
+
int32_t T = size<3>(tensor_b_);
|
| 319 |
+
int32_t R = size<2>(tensor_b_);
|
| 320 |
+
int32_t S = size<1>(tensor_b_);
|
| 321 |
+
int32_t C = size<0>(tensor_b_);
|
| 322 |
+
|
| 323 |
+
#if defined(_OPENMP)
|
| 324 |
+
#pragma omp parallel for collapse(3)
|
| 325 |
+
#endif
|
| 326 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 327 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 328 |
+
for (int32_t z = 0; z < Z; ++z) {
|
| 329 |
+
for (int32_t p = 0; p < P; ++p) {
|
| 330 |
+
for (int32_t q = 0; q < Q; ++q) {
|
| 331 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 332 |
+
auto accumulator = ElementAcc(0);
|
| 333 |
+
for (int32_t t = 0; t < T; ++t) {
|
| 334 |
+
for (int32_t r = 0; r < R; ++r) {
|
| 335 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 336 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 337 |
+
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 338 |
+
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 339 |
+
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
| 340 |
+
if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) {
|
| 341 |
+
auto a = tensor_a_(c, w, h, d, n, g);
|
| 342 |
+
auto b = tensor_b_(c, s, r, t, k, g);
|
| 343 |
+
accumulator += ElementAcc(a * b);
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 350 |
+
epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
|
| 351 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 352 |
+
epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
|
| 353 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 354 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 355 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
|
| 356 |
+
}
|
| 357 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 358 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[k]);
|
| 359 |
+
}
|
| 360 |
+
output = epi_activation(output);
|
| 361 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 362 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
|
| 363 |
+
}
|
| 364 |
+
tensor_d_(k, q, p, z, n, g) = output_converter(output);
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
}
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
// Specialization for 1D dgrad kernel
|
| 375 |
+
void dgrad_reference(cute::Int<1> spatial_dims) {
|
| 376 |
+
int32_t G = size<3>(tensor_d_);
|
| 377 |
+
int32_t N = size<2>(tensor_d_);
|
| 378 |
+
int32_t W = size<1>(tensor_d_);
|
| 379 |
+
int32_t C = size<0>(tensor_d_);
|
| 380 |
+
int32_t K = size<2>(tensor_b_);
|
| 381 |
+
int32_t S = size<1>(tensor_b_);
|
| 382 |
+
|
| 383 |
+
#if defined(_OPENMP)
|
| 384 |
+
#pragma omp parallel for collapse(2)
|
| 385 |
+
#endif
|
| 386 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 387 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 388 |
+
for (int32_t w = 0; w < W; ++w) {
|
| 389 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 390 |
+
auto accumulator = ElementAcc(0);
|
| 391 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 392 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 393 |
+
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
| 394 |
+
|
| 395 |
+
if (q % cute::get<0>(tstride_) == 0) {
|
| 396 |
+
q /= cute::get<0>(tstride_);
|
| 397 |
+
} else {
|
| 398 |
+
continue;
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) {
|
| 402 |
+
accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g));
|
| 403 |
+
}
|
| 404 |
+
}
|
| 405 |
+
}
|
| 406 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
| 407 |
+
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 408 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
| 409 |
+
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 410 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 411 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 412 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
|
| 413 |
+
}
|
| 414 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 415 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 416 |
+
}
|
| 417 |
+
output = epi_activation(output);
|
| 418 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 419 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
|
| 420 |
+
}
|
| 421 |
+
tensor_d_(c, w, n, g) = output_converter(output);
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
// Specialization for 2D dgrad kernel
|
| 430 |
+
void dgrad_reference(cute::Int<2> spatial_dims) {
|
| 431 |
+
int32_t G = size<4>(tensor_d_);
|
| 432 |
+
int32_t N = size<3>(tensor_d_);
|
| 433 |
+
int32_t H = size<2>(tensor_d_);
|
| 434 |
+
int32_t W = size<1>(tensor_d_);
|
| 435 |
+
int32_t C = size<0>(tensor_d_);
|
| 436 |
+
int32_t K = size<3>(tensor_b_);
|
| 437 |
+
int32_t R = size<2>(tensor_b_);
|
| 438 |
+
int32_t S = size<1>(tensor_b_);
|
| 439 |
+
|
| 440 |
+
#if defined(_OPENMP)
|
| 441 |
+
#pragma omp parallel for collapse(3)
|
| 442 |
+
#endif
|
| 443 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 444 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 445 |
+
for (int32_t h = 0; h < H; ++h) {
|
| 446 |
+
for (int32_t w = 0; w < W; ++w) {
|
| 447 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 448 |
+
auto accumulator = ElementAcc(0);
|
| 449 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 450 |
+
for (int32_t r = 0; r < R; ++r) {
|
| 451 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 452 |
+
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
| 453 |
+
int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
|
| 454 |
+
|
| 455 |
+
if (q % cute::get<0>(tstride_) == 0) {
|
| 456 |
+
q /= cute::get<0>(tstride_);
|
| 457 |
+
} else {
|
| 458 |
+
continue;
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
if (p % cute::get<1>(tstride_) == 0) {
|
| 462 |
+
p /= cute::get<1>(tstride_);
|
| 463 |
+
} else {
|
| 464 |
+
continue;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) {
|
| 468 |
+
accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g));
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
}
|
| 473 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
| 474 |
+
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 475 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
| 476 |
+
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 477 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 478 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 479 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
|
| 480 |
+
}
|
| 481 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 482 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 483 |
+
}
|
| 484 |
+
output = epi_activation(output);
|
| 485 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 486 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
tensor_d_(c, w, h, n, g) = output_converter(output);
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
// Specialization for 3D dgrad kernel
|
| 499 |
+
void dgrad_reference(cute::Int<3> spatial_dims) {
|
| 500 |
+
int32_t G = size<5>(tensor_d_);
|
| 501 |
+
int32_t N = size<4>(tensor_d_);
|
| 502 |
+
int32_t D = size<3>(tensor_d_);
|
| 503 |
+
int32_t H = size<2>(tensor_d_);
|
| 504 |
+
int32_t W = size<1>(tensor_d_);
|
| 505 |
+
int32_t C = size<0>(tensor_d_);
|
| 506 |
+
int32_t K = size<4>(tensor_b_);
|
| 507 |
+
int32_t T = size<3>(tensor_b_);
|
| 508 |
+
int32_t R = size<2>(tensor_b_);
|
| 509 |
+
int32_t S = size<1>(tensor_b_);
|
| 510 |
+
|
| 511 |
+
#if defined(_OPENMP)
|
| 512 |
+
#pragma omp parallel for collapse(3)
|
| 513 |
+
#endif
|
| 514 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 515 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 516 |
+
for (int32_t d = 0; d < D; ++d) {
|
| 517 |
+
for (int32_t h = 0; h < H; ++h) {
|
| 518 |
+
for (int32_t w = 0; w < W; ++w) {
|
| 519 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 520 |
+
auto accumulator = ElementAcc(0);
|
| 521 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 522 |
+
for (int32_t t = 0; t < T; ++t) {
|
| 523 |
+
for (int32_t r = 0; r < R; ++r) {
|
| 524 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 525 |
+
int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
|
| 526 |
+
int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
|
| 527 |
+
int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_);
|
| 528 |
+
|
| 529 |
+
if (q % cute::get<0>(tstride_) == 0) {
|
| 530 |
+
q /= cute::get<0>(tstride_);
|
| 531 |
+
} else {
|
| 532 |
+
continue;
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
if (p % cute::get<1>(tstride_) == 0) {
|
| 536 |
+
p /= cute::get<1>(tstride_);
|
| 537 |
+
} else {
|
| 538 |
+
continue;
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
if (z % cute::get<2>(tstride_) == 0) {
|
| 542 |
+
z /= cute::get<2>(tstride_);
|
| 543 |
+
} else {
|
| 544 |
+
continue;
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) {
|
| 548 |
+
accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g));
|
| 549 |
+
}
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
}
|
| 554 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
|
| 555 |
+
? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 556 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
|
| 557 |
+
? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 558 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 559 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 560 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
|
| 561 |
+
}
|
| 562 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 563 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 564 |
+
}
|
| 565 |
+
output = epi_activation(output);
|
| 566 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 567 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
|
| 568 |
+
}
|
| 569 |
+
tensor_d_(c, w, h, d, n, g) = output_converter(output);
|
| 570 |
+
}
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
}
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
// Specialization for 1D wgrad kernel
|
| 580 |
+
void wgrad_reference(cute::Int<1> spatial_dims) {
|
| 581 |
+
int32_t G = size<3>(tensor_d_);
|
| 582 |
+
int32_t N =
|
| 583 |
+
size<2>(tensor_a_);
|
| 584 |
+
int32_t Q =
|
| 585 |
+
size<1>(tensor_a_);
|
| 586 |
+
int32_t K =
|
| 587 |
+
size<0>(tensor_a_);
|
| 588 |
+
int32_t S = size<1>(tensor_d_);
|
| 589 |
+
int32_t C = size<0>(tensor_d_);
|
| 590 |
+
|
| 591 |
+
#if defined(_OPENMP)
|
| 592 |
+
#pragma omp parallel for collapse(2)
|
| 593 |
+
#endif
|
| 594 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 595 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 596 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 597 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 598 |
+
auto accumulator = ElementAcc(0);
|
| 599 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 600 |
+
for (int32_t q = 0; q < Q; ++q) {
|
| 601 |
+
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 602 |
+
bool is_in_bounds =
|
| 603 |
+
detail::is_activation_in_bounds(tensor_b_, n, w, c, g);
|
| 604 |
+
if (is_in_bounds) {
|
| 605 |
+
auto act =
|
| 606 |
+
tensor_b_(c, w, n, g);
|
| 607 |
+
auto xformed_act =
|
| 608 |
+
tensor_a_(k, q, n, g);
|
| 609 |
+
accumulator += ElementAcc(act * xformed_act);
|
| 610 |
+
}
|
| 611 |
+
}
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 615 |
+
epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 616 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 617 |
+
epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 618 |
+
|
| 619 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 620 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 621 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
|
| 622 |
+
}
|
| 623 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 624 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 625 |
+
}
|
| 626 |
+
output = epi_activation(output);
|
| 627 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 628 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
|
| 629 |
+
}
|
| 630 |
+
tensor_d_(c, s, k, g) = output_converter(output);
|
| 631 |
+
}
|
| 632 |
+
}
|
| 633 |
+
}
|
| 634 |
+
}
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
// Specialization for 2D wgrad kernel
|
| 638 |
+
void wgrad_reference(cute::Int<2> spatial_dims) {
|
| 639 |
+
int32_t G = size<4>(tensor_d_);
|
| 640 |
+
int32_t N =
|
| 641 |
+
size<3>(tensor_a_);
|
| 642 |
+
int32_t P =
|
| 643 |
+
size<2>(tensor_a_);
|
| 644 |
+
int32_t Q =
|
| 645 |
+
size<1>(tensor_a_);
|
| 646 |
+
int32_t K =
|
| 647 |
+
size<0>(tensor_a_);
|
| 648 |
+
int32_t R = size<2>(tensor_d_);
|
| 649 |
+
int32_t S = size<1>(tensor_d_);
|
| 650 |
+
int32_t C = size<0>(tensor_d_);
|
| 651 |
+
|
| 652 |
+
#if defined(_OPENMP)
|
| 653 |
+
#pragma omp parallel for collapse(3)
|
| 654 |
+
#endif
|
| 655 |
+
for (int32_t g = 0; g < G; ++g) {
|
| 656 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 657 |
+
for (int32_t r = 0; r < R; ++r) {
|
| 658 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 659 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 660 |
+
auto accumulator = ElementAcc(0);
|
| 661 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 662 |
+
for (int32_t p = 0; p < P; ++p) {
|
| 663 |
+
for (int32_t q = 0; q < Q; ++q) {
|
| 664 |
+
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 665 |
+
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 666 |
+
bool is_in_bounds =
|
| 667 |
+
detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g);
|
| 668 |
+
if (is_in_bounds) {
|
| 669 |
+
auto act =
|
| 670 |
+
tensor_b_(c, w, h, n, g);
|
| 671 |
+
auto xformed_act =
|
| 672 |
+
tensor_a_(k, q, p, n, g);
|
| 673 |
+
accumulator += ElementAcc(act * xformed_act);
|
| 674 |
+
}
|
| 675 |
+
}
|
| 676 |
+
}
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 680 |
+
epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 681 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 682 |
+
epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 683 |
+
|
| 684 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 685 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 686 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
|
| 687 |
+
}
|
| 688 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 689 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 690 |
+
}
|
| 691 |
+
output = epi_activation(output);
|
| 692 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 693 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
|
| 694 |
+
}
|
| 695 |
+
tensor_d_(c, s, r, k, g) = output_converter(output);
|
| 696 |
+
}
|
| 697 |
+
}
|
| 698 |
+
}
|
| 699 |
+
}
|
| 700 |
+
}
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
// Specialization for 3D wgrad kernel
|
| 704 |
+
void wgrad_reference(cute::Int<3> spatial_dims) {
|
| 705 |
+
int32_t G = size<5>(tensor_d_);
|
| 706 |
+
int32_t N =
|
| 707 |
+
size<4>(tensor_a_);
|
| 708 |
+
int32_t Z =
|
| 709 |
+
size<3>(tensor_a_);
|
| 710 |
+
int32_t P =
|
| 711 |
+
size<2>(tensor_a_);
|
| 712 |
+
int32_t Q =
|
| 713 |
+
size<1>(tensor_a_);
|
| 714 |
+
int32_t K =
|
| 715 |
+
size<0>(tensor_a_);
|
| 716 |
+
int32_t T = size<3>(tensor_d_);
|
| 717 |
+
int32_t R = size<2>(tensor_d_);
|
| 718 |
+
int32_t S = size<1>(tensor_d_);
|
| 719 |
+
int32_t C = size<0>(tensor_d_);
|
| 720 |
+
|
| 721 |
+
#if defined(_OPENMP)
|
| 722 |
+
#pragma omp parallel for collapse(3)
|
| 723 |
+
#endif
|
| 724 |
+
for (int32_t g = 0 ; g < G; ++g) {
|
| 725 |
+
for (int32_t k = 0; k < K; ++k) {
|
| 726 |
+
for (int32_t t = 0; t < T; ++t) {
|
| 727 |
+
for (int32_t r = 0; r < R; ++r) {
|
| 728 |
+
for (int32_t s = 0; s < S; ++s) {
|
| 729 |
+
for (int32_t c = 0; c < C; ++c) {
|
| 730 |
+
auto accumulator = ElementAcc(0);
|
| 731 |
+
for (int32_t n = 0; n < N; ++n) {
|
| 732 |
+
for (int32_t z = 0; z < Z; ++z) {
|
| 733 |
+
for (int32_t p = 0; p < P; ++p) {
|
| 734 |
+
for (int32_t q = 0; q < Q; ++q) {
|
| 735 |
+
int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
|
| 736 |
+
int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
|
| 737 |
+
int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
|
| 738 |
+
bool is_in_bounds =
|
| 739 |
+
detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g);
|
| 740 |
+
if (is_in_bounds) {
|
| 741 |
+
auto act =
|
| 742 |
+
tensor_b_(c, w, h, d, n, g);
|
| 743 |
+
auto xformed_act =
|
| 744 |
+
tensor_a_(k, q, p, z, n, g);
|
| 745 |
+
accumulator += ElementAcc(act * xformed_act);
|
| 746 |
+
}
|
| 747 |
+
}
|
| 748 |
+
}
|
| 749 |
+
}
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
|
| 753 |
+
epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
|
| 754 |
+
ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
|
| 755 |
+
epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
|
| 756 |
+
|
| 757 |
+
ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
|
| 758 |
+
if (not EpilogueFusionParams::ResidualAdd) {
|
| 759 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
|
| 760 |
+
}
|
| 761 |
+
if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
|
| 762 |
+
output += bias_converter(epi_fusion_params_.tensor_bias[c]);
|
| 763 |
+
}
|
| 764 |
+
output = epi_activation(output);
|
| 765 |
+
if (EpilogueFusionParams::ResidualAdd) {
|
| 766 |
+
output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
|
| 767 |
+
}
|
| 768 |
+
tensor_d_(c, s, r, t, k, g) = output_converter(output);
|
| 769 |
+
}
|
| 770 |
+
}
|
| 771 |
+
}
|
| 772 |
+
}
|
| 773 |
+
}
|
| 774 |
+
}
|
| 775 |
+
}
|
| 776 |
+
};
|
| 777 |
+
|
| 778 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 779 |
+
|
| 780 |
+
} // cutlass::reference::host
|
| 781 |
+
|
| 782 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Reference implementation for convolution in host-side code.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/coord.h"
|
| 39 |
+
#include "cutlass/functional.h"
|
| 40 |
+
#include "cutlass/layout/tensor.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/numeric_types.h"
|
| 43 |
+
#include "cutlass/tensor_ref.h"
|
| 44 |
+
#include "cutlass/tensor_view.h"
|
| 45 |
+
#include "cutlass/conv/convolution.h"
|
| 46 |
+
#include "cutlass/conv/conv2d_problem_size.h"
|
| 47 |
+
#include "cutlass/conv/conv3d_problem_size.h"
|
| 48 |
+
#include <iostream>
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace reference {
|
| 52 |
+
namespace host {
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
/// Forward propagation
|
| 56 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
/// y = conv2d(x, w)
|
| 59 |
+
template <
|
| 60 |
+
typename ElementA,
|
| 61 |
+
typename LayoutA,
|
| 62 |
+
typename ElementB,
|
| 63 |
+
typename LayoutB,
|
| 64 |
+
typename ElementC,
|
| 65 |
+
typename LayoutC,
|
| 66 |
+
typename ElementCompute,
|
| 67 |
+
typename ElementAccumulator = ElementCompute,
|
| 68 |
+
typename ElementD = ElementC,
|
| 69 |
+
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 70 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 71 |
+
>
|
| 72 |
+
void Conv2dFprop(
|
| 73 |
+
conv::Conv2dProblemSize problem_size,
|
| 74 |
+
TensorRef<ElementA, LayoutA> tensor_x,
|
| 75 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 76 |
+
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 77 |
+
TensorRef<ElementD, LayoutC> tensor_y_out,
|
| 78 |
+
ElementCompute alpha,
|
| 79 |
+
ElementCompute beta) {
|
| 80 |
+
|
| 81 |
+
ConvertOp convert_op;
|
| 82 |
+
InnerProductOp inner_product_op;
|
| 83 |
+
|
| 84 |
+
// Apply MMA and accumulate ElementAccumulator
|
| 85 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 86 |
+
for (int p = 0; p < problem_size.P; ++p) {
|
| 87 |
+
for (int q = 0; q < problem_size.Q; ++q) {
|
| 88 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 89 |
+
|
| 90 |
+
int group_idx = k / (problem_size.K / problem_size.groups);
|
| 91 |
+
int channels_per_group = problem_size.C / problem_size.groups;
|
| 92 |
+
|
| 93 |
+
ElementAccumulator acc = ElementAccumulator();
|
| 94 |
+
|
| 95 |
+
for (int r = 0; r < problem_size.R; ++r) {
|
| 96 |
+
for (int s = 0; s < problem_size.S; ++s) {
|
| 97 |
+
for (int c = 0; c < channels_per_group; ++c) {
|
| 98 |
+
|
| 99 |
+
int filter_r = r;
|
| 100 |
+
int filter_s = s;
|
| 101 |
+
|
| 102 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 103 |
+
filter_r = problem_size.R - 1 - r;
|
| 104 |
+
filter_s = problem_size.S - 1 - s;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 108 |
+
int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 109 |
+
|
| 110 |
+
if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
|
| 111 |
+
|
| 112 |
+
ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group});
|
| 113 |
+
ElementB b = tensor_w.at({k, r, s, c});
|
| 114 |
+
|
| 115 |
+
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 116 |
+
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 123 |
+
ElementC c_ref = ElementC();
|
| 124 |
+
|
| 125 |
+
if (beta != ElementCompute()) {
|
| 126 |
+
c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k));
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) =
|
| 130 |
+
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
/// Depthwise-separable convolution
|
| 138 |
+
template <typename ElementA,
|
| 139 |
+
typename LayoutA,
|
| 140 |
+
typename ElementB,
|
| 141 |
+
typename LayoutB,
|
| 142 |
+
typename ElementC,
|
| 143 |
+
typename LayoutC,
|
| 144 |
+
typename ElementCompute,
|
| 145 |
+
typename ElementAccumulator = ElementCompute,
|
| 146 |
+
typename ElementD = ElementC,
|
| 147 |
+
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 148 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>>
|
| 149 |
+
void Depsep_Fprop(cutlass::TensorView<ElementA, LayoutA> tensor_A,
|
| 150 |
+
cutlass::TensorView<ElementB, LayoutB> tensor_B,
|
| 151 |
+
cutlass::TensorView<ElementC, LayoutC> tensor_C,
|
| 152 |
+
cutlass::TensorView<ElementD, LayoutC> tensor_D,
|
| 153 |
+
ElementCompute alpha,
|
| 154 |
+
ElementCompute beta,
|
| 155 |
+
cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(),
|
| 156 |
+
cutlass::Coord<2> conv_stride = cutlass::Coord<2>(),
|
| 157 |
+
cutlass::Coord<2> dilation = cutlass::Coord<2>(),
|
| 158 |
+
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) {
|
| 159 |
+
|
| 160 |
+
ConvertOp convert_op;
|
| 161 |
+
InnerProductOp inner_product_op;
|
| 162 |
+
|
| 163 |
+
// Apply MMA and accumulate ElementAccumulator
|
| 164 |
+
for (int n = 0; n < tensor_C.extent().n(); ++n) {
|
| 165 |
+
for (int p = 0; p < tensor_C.extent().h(); ++p) {
|
| 166 |
+
for (int q = 0; q < tensor_C.extent().w(); ++q) {
|
| 167 |
+
for (int g = 0; g < tensor_C.extent().c(); ++g) {
|
| 168 |
+
ElementAccumulator acc = ElementAccumulator();
|
| 169 |
+
for (int r = 0; r < tensor_B.extent().h(); ++r) {
|
| 170 |
+
for (int s = 0; s < tensor_B.extent().w(); ++s) {
|
| 171 |
+
|
| 172 |
+
// input activation H and W
|
| 173 |
+
int h = p * conv_stride[0] - padding[0] + r * dilation[0];
|
| 174 |
+
int w = q * conv_stride[1] - padding[2] + s * dilation[1];
|
| 175 |
+
|
| 176 |
+
if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) {
|
| 177 |
+
ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g));
|
| 178 |
+
|
| 179 |
+
ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation)
|
| 180 |
+
? tensor_B.at(cutlass::make_Coord(g, r, s, 0))
|
| 181 |
+
: tensor_B.at(cutlass::make_Coord(
|
| 182 |
+
g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0));
|
| 183 |
+
|
| 184 |
+
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 190 |
+
ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g));
|
| 191 |
+
tensor_D.at(cutlass::make_Coord(n, p, q, g)) =
|
| 192 |
+
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 200 |
+
/// Dgrad / Deconv
|
| 201 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 202 |
+
|
| 203 |
+
/// dx = dgrad(dy, w)
|
| 204 |
+
template <
|
| 205 |
+
typename ElementA,
|
| 206 |
+
typename LayoutA,
|
| 207 |
+
typename ElementB,
|
| 208 |
+
typename LayoutB,
|
| 209 |
+
typename ElementC,
|
| 210 |
+
typename LayoutC,
|
| 211 |
+
typename ElementCompute,
|
| 212 |
+
typename ElementAccumulator = ElementCompute,
|
| 213 |
+
typename ElementD = ElementC,
|
| 214 |
+
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 215 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 216 |
+
>
|
| 217 |
+
void Conv2dDgrad(
|
| 218 |
+
cutlass::conv::Conv2dProblemSize problem_size,
|
| 219 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 220 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 221 |
+
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 222 |
+
TensorRef<ElementD, LayoutC> tensor_dx_out,
|
| 223 |
+
ElementCompute alpha,
|
| 224 |
+
ElementCompute beta,
|
| 225 |
+
bool is_deconv = false) {
|
| 226 |
+
|
| 227 |
+
ConvertOp convert_op;
|
| 228 |
+
InnerProductOp inner_product_op;
|
| 229 |
+
|
| 230 |
+
// Apply MMA and accumulate ElementAccumulator
|
| 231 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 232 |
+
for (int h = 0; h < problem_size.H; ++h) {
|
| 233 |
+
for (int w = 0; w < problem_size.W; ++w) {
|
| 234 |
+
for (int c = 0; c < problem_size.C; ++c) {
|
| 235 |
+
|
| 236 |
+
ElementAccumulator acc = ElementAccumulator();
|
| 237 |
+
|
| 238 |
+
for (int r = 0; r < problem_size.R; ++r) {
|
| 239 |
+
for (int s = 0; s < problem_size.S; ++s) {
|
| 240 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 241 |
+
|
| 242 |
+
int filter_r = r;
|
| 243 |
+
int filter_s = s;
|
| 244 |
+
|
| 245 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 246 |
+
filter_r = problem_size.R - 1 - r;
|
| 247 |
+
filter_s = problem_size.S - 1 - s;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 251 |
+
int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 252 |
+
|
| 253 |
+
if (p >= 0 && (p % problem_size.stride_h) == 0 &&
|
| 254 |
+
q >= 0 && (q % problem_size.stride_w) == 0) {
|
| 255 |
+
|
| 256 |
+
p = p / problem_size.stride_h;
|
| 257 |
+
q = q / problem_size.stride_w;
|
| 258 |
+
#if 0
|
| 259 |
+
std::cout << "row:"
|
| 260 |
+
<< n * problem_size.H * problem_size.W +
|
| 261 |
+
h * problem_size.W +
|
| 262 |
+
w << " "
|
| 263 |
+
<< "n, p, q: ("
|
| 264 |
+
<< n << ", "
|
| 265 |
+
<< p << ", "
|
| 266 |
+
<< q << ") * "
|
| 267 |
+
<< "r, s: ("
|
| 268 |
+
<< r << ", "
|
| 269 |
+
<< s << ") ["
|
| 270 |
+
<< ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]"
|
| 271 |
+
<< std::endl;
|
| 272 |
+
#endif
|
| 273 |
+
if (p < problem_size.P && q < problem_size.Q) {
|
| 274 |
+
|
| 275 |
+
ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k));
|
| 276 |
+
ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k))
|
| 277 |
+
: tensor_w.at(cutlass::make_Coord(k, r, s, c));
|
| 278 |
+
|
| 279 |
+
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
} // for (K)
|
| 284 |
+
} // for (S)
|
| 285 |
+
} // for (R)
|
| 286 |
+
|
| 287 |
+
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 288 |
+
ElementC c_ref = ElementC();
|
| 289 |
+
|
| 290 |
+
if (beta != ElementCompute()) {
|
| 291 |
+
c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c));
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) =
|
| 295 |
+
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 296 |
+
|
| 297 |
+
} // for (C)
|
| 298 |
+
} // for (W)
|
| 299 |
+
} // for (H)
|
| 300 |
+
} // for (N)
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 304 |
+
/// Wgrad
|
| 305 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 306 |
+
|
| 307 |
+
/// dw = wgrad(dy, x)
|
| 308 |
+
template <
|
| 309 |
+
typename ElementA,
|
| 310 |
+
typename LayoutA,
|
| 311 |
+
typename ElementB,
|
| 312 |
+
typename LayoutB,
|
| 313 |
+
typename ElementC,
|
| 314 |
+
typename LayoutC,
|
| 315 |
+
typename ElementCompute,
|
| 316 |
+
typename ElementAccumulator = ElementCompute,
|
| 317 |
+
typename ElementD = ElementC,
|
| 318 |
+
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 319 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 320 |
+
>
|
| 321 |
+
void Conv2dWgrad(
|
| 322 |
+
cutlass::conv::Conv2dProblemSize problem_size,
|
| 323 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 324 |
+
TensorRef<ElementB, LayoutB> tensor_x,
|
| 325 |
+
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 326 |
+
TensorRef<ElementD, LayoutC> tensor_dw_out,
|
| 327 |
+
ElementCompute alpha,
|
| 328 |
+
ElementCompute beta) {
|
| 329 |
+
|
| 330 |
+
InnerProductOp inner_product_op;
|
| 331 |
+
ConvertOp convert_op;
|
| 332 |
+
|
| 333 |
+
// Apply MMA and accumulate ElementAccumulator
|
| 334 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 335 |
+
for (int r = 0; r < problem_size.R; ++r) {
|
| 336 |
+
for (int s = 0; s < problem_size.S; ++s) {
|
| 337 |
+
for (int c = 0; c < problem_size.C; ++c) {
|
| 338 |
+
|
| 339 |
+
ElementAccumulator acc = ElementAccumulator();
|
| 340 |
+
|
| 341 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 342 |
+
for (int p = 0; p < problem_size.P; ++p) {
|
| 343 |
+
for (int q = 0; q < problem_size.Q; ++q) {
|
| 344 |
+
|
| 345 |
+
cutlass::Tensor4DCoord b_coord;
|
| 346 |
+
|
| 347 |
+
int filter_r = r;
|
| 348 |
+
int filter_s = s;
|
| 349 |
+
|
| 350 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 351 |
+
filter_r = problem_size.R - 1 - r;
|
| 352 |
+
filter_s = problem_size.S - 1 - s;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
b_coord = make_Coord(
|
| 356 |
+
n,
|
| 357 |
+
p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
|
| 358 |
+
q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
|
| 359 |
+
c);
|
| 360 |
+
|
| 361 |
+
if (b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
|
| 362 |
+
b_coord.w() < problem_size.W && b_coord.w() >= 0) {
|
| 363 |
+
|
| 364 |
+
ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k)));
|
| 365 |
+
ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
|
| 366 |
+
acc = inner_product_op(a, b, acc);
|
| 367 |
+
}
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 373 |
+
ElementC c_ref = ElementC();
|
| 374 |
+
|
| 375 |
+
if (beta != ElementCompute()) {
|
| 376 |
+
c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c));
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) =
|
| 380 |
+
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 381 |
+
|
| 382 |
+
} // for (C)
|
| 383 |
+
} // for (S)
|
| 384 |
+
} // for (R)
|
| 385 |
+
} // for (K)
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
|
| 389 |
+
template <
|
| 390 |
+
typename ElementA,
|
| 391 |
+
typename LayoutA,
|
| 392 |
+
typename ElementB,
|
| 393 |
+
typename LayoutB,
|
| 394 |
+
typename ElementC,
|
| 395 |
+
typename LayoutC,
|
| 396 |
+
typename ElementCompute,
|
| 397 |
+
typename ElementAccumulator = ElementCompute,
|
| 398 |
+
typename ElementD = ElementC,
|
| 399 |
+
typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
|
| 400 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 401 |
+
>
|
| 402 |
+
void Conv2d(
|
| 403 |
+
conv::Operator convolutional_operator,
|
| 404 |
+
conv::Conv2dProblemSize problem_size,
|
| 405 |
+
TensorRef<ElementA, LayoutA> tensor_A,
|
| 406 |
+
TensorRef<ElementB, LayoutB> tensor_B,
|
| 407 |
+
TensorRef<ElementC, LayoutC> tensor_C,
|
| 408 |
+
TensorRef<ElementD, LayoutC> tensor_D,
|
| 409 |
+
ElementCompute alpha,
|
| 410 |
+
ElementCompute beta) {
|
| 411 |
+
|
| 412 |
+
switch (convolutional_operator) {
|
| 413 |
+
case conv::Operator::kFprop:
|
| 414 |
+
Conv2dFprop<
|
| 415 |
+
ElementA, LayoutA,
|
| 416 |
+
ElementB, LayoutB,
|
| 417 |
+
ElementC, LayoutC,
|
| 418 |
+
ElementCompute,
|
| 419 |
+
ElementAccumulator,
|
| 420 |
+
ElementD,
|
| 421 |
+
ConvertOp, InnerProductOp
|
| 422 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 423 |
+
break;
|
| 424 |
+
|
| 425 |
+
case conv::Operator::kDeconv:
|
| 426 |
+
case conv::Operator::kDgrad:
|
| 427 |
+
Conv2dDgrad<
|
| 428 |
+
ElementA, LayoutA,
|
| 429 |
+
ElementB, LayoutB,
|
| 430 |
+
ElementC, LayoutC,
|
| 431 |
+
ElementCompute,
|
| 432 |
+
ElementAccumulator,
|
| 433 |
+
ElementD,
|
| 434 |
+
ConvertOp, InnerProductOp
|
| 435 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
|
| 436 |
+
break;
|
| 437 |
+
|
| 438 |
+
case conv::Operator::kWgrad:
|
| 439 |
+
Conv2dWgrad<
|
| 440 |
+
ElementA, LayoutA,
|
| 441 |
+
ElementB, LayoutB,
|
| 442 |
+
ElementC, LayoutC,
|
| 443 |
+
ElementCompute,
|
| 444 |
+
ElementAccumulator,
|
| 445 |
+
ElementD,
|
| 446 |
+
ConvertOp, InnerProductOp
|
| 447 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 448 |
+
break;
|
| 449 |
+
|
| 450 |
+
default:
|
| 451 |
+
break;
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 456 |
+
/// 3D convolution
|
| 457 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 458 |
+
|
| 459 |
+
/// y = conv3d(x, w)
|
| 460 |
+
template <
|
| 461 |
+
typename ElementA,
|
| 462 |
+
typename LayoutA,
|
| 463 |
+
typename ElementB,
|
| 464 |
+
typename LayoutB,
|
| 465 |
+
typename ElementC,
|
| 466 |
+
typename LayoutC,
|
| 467 |
+
typename ElementCompute,
|
| 468 |
+
typename ElementAccumulator = ElementCompute,
|
| 469 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 470 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 471 |
+
>
|
| 472 |
+
void Conv3dFprop(
|
| 473 |
+
conv::Conv3dProblemSize problem_size,
|
| 474 |
+
TensorRef<ElementA, LayoutA> tensor_x,
|
| 475 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 476 |
+
TensorRef<ElementC, LayoutC> tensor_y_in,
|
| 477 |
+
TensorRef<ElementC, LayoutC> tensor_y_out,
|
| 478 |
+
ElementCompute alpha,
|
| 479 |
+
ElementCompute beta) {
|
| 480 |
+
|
| 481 |
+
ConvertOp convert_op;
|
| 482 |
+
InnerProductOp inner_product_op;
|
| 483 |
+
|
| 484 |
+
// Apply MMA and accumulate ElementAccumulator
|
| 485 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 486 |
+
for (int z = 0; z < problem_size.Z; ++z) {
|
| 487 |
+
for (int p = 0; p < problem_size.P; ++p) {
|
| 488 |
+
for (int q = 0; q < problem_size.Q; ++q) {
|
| 489 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 490 |
+
|
| 491 |
+
ElementAccumulator acc = ElementAccumulator();
|
| 492 |
+
|
| 493 |
+
for (int t = 0; t < problem_size.T; ++t) {
|
| 494 |
+
for (int r = 0; r < problem_size.R; ++r) {
|
| 495 |
+
for (int s = 0; s < problem_size.S; ++s) {
|
| 496 |
+
for (int c = 0; c < problem_size.C; ++c) {
|
| 497 |
+
|
| 498 |
+
int filter_t = t;
|
| 499 |
+
int filter_r = r;
|
| 500 |
+
int filter_s = s;
|
| 501 |
+
|
| 502 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 503 |
+
filter_t = problem_size.T - 1 - t;
|
| 504 |
+
filter_r = problem_size.R - 1 - r;
|
| 505 |
+
filter_s = problem_size.S - 1 - s;
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
|
| 509 |
+
int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
|
| 510 |
+
int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
|
| 511 |
+
|
| 512 |
+
if (d >= 0 && d < problem_size.D &&
|
| 513 |
+
h >=0 && h < problem_size.H &&
|
| 514 |
+
w >= 0 && w < problem_size.W) {
|
| 515 |
+
|
| 516 |
+
ElementA a = tensor_x.at({n, d, h, w, c});
|
| 517 |
+
ElementB b = tensor_w.at({k, t, r, s, c});
|
| 518 |
+
|
| 519 |
+
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 520 |
+
}
|
| 521 |
+
}
|
| 522 |
+
}
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 527 |
+
ElementC c_ref = ElementC();
|
| 528 |
+
|
| 529 |
+
if (beta != ElementCompute()) {
|
| 530 |
+
c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k));
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) =
|
| 534 |
+
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
}
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 543 |
+
/// Dgrad / Deconv
|
| 544 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 545 |
+
|
| 546 |
+
/// dx = dgrad(dy, w)
|
| 547 |
+
template <
|
| 548 |
+
typename ElementA,
|
| 549 |
+
typename LayoutA,
|
| 550 |
+
typename ElementB,
|
| 551 |
+
typename LayoutB,
|
| 552 |
+
typename ElementC,
|
| 553 |
+
typename LayoutC,
|
| 554 |
+
typename ElementCompute,
|
| 555 |
+
typename ElementAccumulator = ElementCompute,
|
| 556 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 557 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 558 |
+
>
|
| 559 |
+
void Conv3dDgrad(
|
| 560 |
+
cutlass::conv::Conv3dProblemSize problem_size,
|
| 561 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 562 |
+
TensorRef<ElementB, LayoutB> tensor_w,
|
| 563 |
+
TensorRef<ElementC, LayoutC> tensor_dx_in,
|
| 564 |
+
TensorRef<ElementC, LayoutC> tensor_dx_out,
|
| 565 |
+
ElementCompute alpha,
|
| 566 |
+
ElementCompute beta,
|
| 567 |
+
bool is_deconv = false) {
|
| 568 |
+
|
| 569 |
+
ConvertOp convert_op;
|
| 570 |
+
InnerProductOp inner_product_op;
|
| 571 |
+
|
| 572 |
+
// Apply MMA and accumulate ElementAccumulator
|
| 573 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 574 |
+
for (int d = 0; d < problem_size.D; ++d) {
|
| 575 |
+
for (int h = 0; h < problem_size.H; ++h) {
|
| 576 |
+
for (int w = 0; w < problem_size.W; ++w) {
|
| 577 |
+
for (int c = 0; c < problem_size.C; ++c) {
|
| 578 |
+
|
| 579 |
+
ElementAccumulator acc = ElementAccumulator();
|
| 580 |
+
|
| 581 |
+
for (int t = 0; t < problem_size.T; ++t) {
|
| 582 |
+
for (int r = 0; r < problem_size.R; ++r) {
|
| 583 |
+
for (int s = 0; s < problem_size.S; ++s) {
|
| 584 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 585 |
+
|
| 586 |
+
int filter_t = t;
|
| 587 |
+
int filter_r = r;
|
| 588 |
+
int filter_s = s;
|
| 589 |
+
|
| 590 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 591 |
+
filter_t = problem_size.T - 1 - t;
|
| 592 |
+
filter_r = problem_size.R - 1 - r;
|
| 593 |
+
filter_s = problem_size.S - 1 - s;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d;
|
| 597 |
+
int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
|
| 598 |
+
int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
|
| 599 |
+
|
| 600 |
+
if (z >= 0 && (z % problem_size.stride_d) == 0 &&
|
| 601 |
+
p >= 0 && (p % problem_size.stride_h) == 0 &&
|
| 602 |
+
q >= 0 && (q % problem_size.stride_w) == 0) {
|
| 603 |
+
|
| 604 |
+
z = z / problem_size.stride_d;
|
| 605 |
+
p = p / problem_size.stride_h;
|
| 606 |
+
q = q / problem_size.stride_w;
|
| 607 |
+
|
| 608 |
+
if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
|
| 609 |
+
|
| 610 |
+
ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k));
|
| 611 |
+
ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k))
|
| 612 |
+
: tensor_w.at(cutlass::make_Coord(k, t, r, s, c));
|
| 613 |
+
acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
} // for (K)
|
| 618 |
+
} // for (S)
|
| 619 |
+
} // for (R)
|
| 620 |
+
} // for (T)
|
| 621 |
+
|
| 622 |
+
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 623 |
+
ElementC c_ref = ElementC();
|
| 624 |
+
|
| 625 |
+
if (beta != ElementCompute()) {
|
| 626 |
+
c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c));
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) =
|
| 630 |
+
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 631 |
+
|
| 632 |
+
} // for (C)
|
| 633 |
+
} // for (W)
|
| 634 |
+
} // for (H)
|
| 635 |
+
} // for (D)
|
| 636 |
+
} // for (N)
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 640 |
+
/// Wgrad
|
| 641 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 642 |
+
|
| 643 |
+
/// dw = wgrad(dy, x)
|
| 644 |
+
template <
|
| 645 |
+
typename ElementA,
|
| 646 |
+
typename LayoutA,
|
| 647 |
+
typename ElementB,
|
| 648 |
+
typename LayoutB,
|
| 649 |
+
typename ElementC,
|
| 650 |
+
typename LayoutC,
|
| 651 |
+
typename ElementCompute,
|
| 652 |
+
typename ElementAccumulator = ElementCompute,
|
| 653 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 654 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 655 |
+
>
|
| 656 |
+
void Conv3dWgrad(
|
| 657 |
+
cutlass::conv::Conv3dProblemSize problem_size,
|
| 658 |
+
TensorRef<ElementA, LayoutA> tensor_dy,
|
| 659 |
+
TensorRef<ElementB, LayoutB> tensor_x,
|
| 660 |
+
TensorRef<ElementC, LayoutC> tensor_dw_in,
|
| 661 |
+
TensorRef<ElementC, LayoutC> tensor_dw_out,
|
| 662 |
+
ElementCompute alpha,
|
| 663 |
+
ElementCompute beta) {
|
| 664 |
+
|
| 665 |
+
InnerProductOp inner_product_op;
|
| 666 |
+
ConvertOp convert_op;
|
| 667 |
+
|
| 668 |
+
// Apply MMA and accumulate ElementAccumulator
|
| 669 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 670 |
+
for (int t = 0; t < problem_size.T; ++t) {
|
| 671 |
+
for (int r = 0; r < problem_size.R; ++r) {
|
| 672 |
+
for (int s = 0; s < problem_size.S; ++s) {
|
| 673 |
+
for (int c = 0; c < problem_size.C; ++c) {
|
| 674 |
+
|
| 675 |
+
ElementAccumulator acc = ElementAccumulator();
|
| 676 |
+
|
| 677 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 678 |
+
for (int z = 0; z < problem_size.Z; ++z) {
|
| 679 |
+
for (int p = 0; p < problem_size.P; ++p) {
|
| 680 |
+
for (int q = 0; q < problem_size.Q; ++q) {
|
| 681 |
+
|
| 682 |
+
int filter_t = t;
|
| 683 |
+
int filter_r = r;
|
| 684 |
+
int filter_s = s;
|
| 685 |
+
|
| 686 |
+
if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
|
| 687 |
+
filter_t = problem_size.T - 1 - t;
|
| 688 |
+
filter_r = problem_size.R - 1 - r;
|
| 689 |
+
filter_s = problem_size.S - 1 - s;
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
Tensor5DCoord b_coord = make_Coord(
|
| 693 |
+
n,
|
| 694 |
+
z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d,
|
| 695 |
+
p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
|
| 696 |
+
q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
|
| 697 |
+
c);
|
| 698 |
+
|
| 699 |
+
if (b_coord.d() < problem_size.D && b_coord.d() >= 0 &&
|
| 700 |
+
b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
|
| 701 |
+
b_coord.w() < problem_size.W && b_coord.w() >= 0) {
|
| 702 |
+
|
| 703 |
+
ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)));
|
| 704 |
+
ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
|
| 705 |
+
|
| 706 |
+
acc = inner_product_op(a, b, acc);
|
| 707 |
+
}
|
| 708 |
+
}
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
// Apply Epilogue, compute ElementCompute, convert and store ElementC
|
| 714 |
+
ElementC c_ref = ElementC();
|
| 715 |
+
|
| 716 |
+
if (beta != ElementCompute()) {
|
| 717 |
+
c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c));
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) =
|
| 721 |
+
convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
|
| 722 |
+
|
| 723 |
+
} // for (C)
|
| 724 |
+
} // for (S)
|
| 725 |
+
} // for (R)
|
| 726 |
+
} // for (T)
|
| 727 |
+
} // for (K)
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 731 |
+
|
| 732 |
+
/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
|
| 733 |
+
template <
|
| 734 |
+
typename ElementA,
|
| 735 |
+
typename LayoutA,
|
| 736 |
+
typename ElementB,
|
| 737 |
+
typename LayoutB,
|
| 738 |
+
typename ElementC,
|
| 739 |
+
typename LayoutC,
|
| 740 |
+
typename ElementCompute,
|
| 741 |
+
typename ElementAccumulator = ElementCompute,
|
| 742 |
+
typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
|
| 743 |
+
typename InnerProductOp = multiply_add<ElementAccumulator>
|
| 744 |
+
>
|
| 745 |
+
void Conv3d(
|
| 746 |
+
conv::Operator convolutional_operator,
|
| 747 |
+
conv::Conv3dProblemSize problem_size,
|
| 748 |
+
TensorRef<ElementA, LayoutA> tensor_A,
|
| 749 |
+
TensorRef<ElementB, LayoutB> tensor_B,
|
| 750 |
+
TensorRef<ElementC, LayoutC> tensor_C,
|
| 751 |
+
TensorRef<ElementC, LayoutC> tensor_D,
|
| 752 |
+
ElementCompute alpha,
|
| 753 |
+
ElementCompute beta) {
|
| 754 |
+
|
| 755 |
+
switch (convolutional_operator) {
|
| 756 |
+
case conv::Operator::kFprop:
|
| 757 |
+
Conv3dFprop<
|
| 758 |
+
ElementA, LayoutA,
|
| 759 |
+
ElementB, LayoutB,
|
| 760 |
+
ElementC, LayoutC,
|
| 761 |
+
ElementCompute,
|
| 762 |
+
ElementAccumulator,
|
| 763 |
+
ConvertOp, InnerProductOp
|
| 764 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 765 |
+
break;
|
| 766 |
+
|
| 767 |
+
case conv::Operator::kDeconv:
|
| 768 |
+
case conv::Operator::kDgrad:
|
| 769 |
+
Conv3dDgrad<
|
| 770 |
+
ElementA, LayoutA,
|
| 771 |
+
ElementB, LayoutB,
|
| 772 |
+
ElementC, LayoutC,
|
| 773 |
+
ElementCompute,
|
| 774 |
+
ElementAccumulator,
|
| 775 |
+
ConvertOp, InnerProductOp
|
| 776 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
|
| 777 |
+
break;
|
| 778 |
+
|
| 779 |
+
case conv::Operator::kWgrad:
|
| 780 |
+
Conv3dWgrad<
|
| 781 |
+
ElementA, LayoutA,
|
| 782 |
+
ElementB, LayoutB,
|
| 783 |
+
ElementC, LayoutC,
|
| 784 |
+
ElementCompute,
|
| 785 |
+
ElementAccumulator,
|
| 786 |
+
ConvertOp, InnerProductOp
|
| 787 |
+
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
|
| 788 |
+
break;
|
| 789 |
+
|
| 790 |
+
default:
|
| 791 |
+
break;
|
| 792 |
+
}
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 796 |
+
|
| 797 |
+
} // namespace host
|
| 798 |
+
} // namespace reference
|
| 799 |
+
} // namespace cutlass
|
| 800 |
+
|
| 801 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 802 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
/***************************************************************************************************
|
| 3 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
*
|
| 6 |
+
* Redistribution and use in source and binary forms, with or without
|
| 7 |
+
* modification, are permitted provided that the following conditions are met:
|
| 8 |
+
*
|
| 9 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
* list of conditions and the following disclaimer.
|
| 11 |
+
*
|
| 12 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
* and/or other materials provided with the distribution.
|
| 15 |
+
*
|
| 16 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
* contributors may be used to endorse or promote products derived from
|
| 18 |
+
* this software without specific prior written permission.
|
| 19 |
+
*
|
| 20 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
*
|
| 31 |
+
**************************************************************************************************/
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include <cmath>
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "cutlass/complex.h"
|
| 38 |
+
#include "cutlass/util/reference/host/tensor_reduce.h"
|
| 39 |
+
#include "cutlass/core_io.h"
|
| 40 |
+
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace reference {
|
| 43 |
+
namespace host {
|
| 44 |
+
|
| 45 |
+
/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference
|
| 46 |
+
template <
|
| 47 |
+
typename Element,
|
| 48 |
+
typename Layout,
|
| 49 |
+
typename ComputeType = double
|
| 50 |
+
>
|
| 51 |
+
ComputeType TensorRelativeErrorMetric(
|
| 52 |
+
TensorView<Element, Layout> view_A_computed,
|
| 53 |
+
TensorView<Element, Layout> view_B_reference,
|
| 54 |
+
ComputeType identity = ComputeType()
|
| 55 |
+
) {
|
| 56 |
+
|
| 57 |
+
return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) /
|
| 58 |
+
cutlass::reference::host::TensorNorm(view_B_reference, identity);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
|
| 64 |
+
} // namespace host
|
| 65 |
+
} // namespace reference
|
| 66 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for GEMM in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/numeric_types.h"
|
| 39 |
+
#include "cutlass/functional.h"
|
| 40 |
+
#include "cutlass/numeric_conversion.h"
|
| 41 |
+
|
| 42 |
+
#include "cutlass/tensor_view.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
#include "cutlass/arch/mma.h"
|
| 45 |
+
#include "cutlass/util/host_tensor.h"
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace reference {
|
| 49 |
+
namespace host {
|
| 50 |
+
|
| 51 |
+
template<typename Out, typename In>
|
| 52 |
+
struct CastIfScalar {
|
| 53 |
+
static Out cast(In in) {
|
| 54 |
+
return Out(in);
|
| 55 |
+
}
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
template<typename OutScalar, typename In>
|
| 59 |
+
struct CastIfScalar<cutlass::complex<OutScalar>, In> {
|
| 60 |
+
typedef cutlass::complex<OutScalar> Out;
|
| 61 |
+
static Out cast(In in) {
|
| 62 |
+
return Out(static_cast<OutScalar>(in));
|
| 63 |
+
}
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
template<typename OutScalar, typename InScalar>
|
| 67 |
+
struct CastIfScalar<cutlass::complex<OutScalar>, cutlass::complex<InScalar>> {
|
| 68 |
+
typedef cutlass::complex<OutScalar> Out;
|
| 69 |
+
typedef cutlass::complex<InScalar> In;
|
| 70 |
+
static Out cast(In in) {
|
| 71 |
+
return Out(in);
|
| 72 |
+
}
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
template<typename Out, typename In>
|
| 76 |
+
Out cast_if_scalar(In in) {
|
| 77 |
+
return CastIfScalar<Out, In>::cast(in);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 81 |
+
|
| 82 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 83 |
+
/// objects.
|
| 84 |
+
template <
|
| 85 |
+
typename ElementA,
|
| 86 |
+
typename LayoutA,
|
| 87 |
+
typename ElementB,
|
| 88 |
+
typename LayoutB,
|
| 89 |
+
typename ElementC,
|
| 90 |
+
typename LayoutC,
|
| 91 |
+
typename ScalarType,
|
| 92 |
+
typename ComputeType,
|
| 93 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 94 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 95 |
+
>
|
| 96 |
+
void compute_gemm(
|
| 97 |
+
gemm::GemmCoord problem_size,
|
| 98 |
+
ScalarType alpha,
|
| 99 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 100 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 101 |
+
ScalarType beta,
|
| 102 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 103 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 104 |
+
ComputeType initial_accum) {
|
| 105 |
+
|
| 106 |
+
static_assert(
|
| 107 |
+
LayoutA::kRank == 2 &&
|
| 108 |
+
LayoutB::kRank == 2 &&
|
| 109 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
// Note: batch is ignored.
|
| 113 |
+
int const M = problem_size.m();
|
| 114 |
+
int const N = problem_size.n();
|
| 115 |
+
int const K = problem_size.k();
|
| 116 |
+
|
| 117 |
+
// Blocking necessary to speedup reference implementation
|
| 118 |
+
int const Mblock = 16;
|
| 119 |
+
int const Nblock = 16;
|
| 120 |
+
|
| 121 |
+
ConvertOp convert_op;
|
| 122 |
+
InnerProductOp inner_product_op;
|
| 123 |
+
|
| 124 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 125 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 126 |
+
|
| 127 |
+
ComputeType accum[Mblock][Nblock];
|
| 128 |
+
|
| 129 |
+
for (int j = 0; j < Nblock; j++) {
|
| 130 |
+
for (int i = 0; i < Mblock; i++) {
|
| 131 |
+
accum[i][j] = initial_accum;
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 136 |
+
for (int j = 0; j < Nblock; j++) {
|
| 137 |
+
for (int i = 0; i < Mblock; i++) {
|
| 138 |
+
int row = row_block + i;
|
| 139 |
+
int col = col_block + j;
|
| 140 |
+
|
| 141 |
+
if (row < M && col < N) {
|
| 142 |
+
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 143 |
+
ElementB b = tensor_b.at(MatrixCoord(k_block, col));
|
| 144 |
+
|
| 145 |
+
ComputeType compute_a(cast_if_scalar<ComputeType>(a));
|
| 146 |
+
ComputeType compute_b(cast_if_scalar<ComputeType>(b));
|
| 147 |
+
|
| 148 |
+
accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]);
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
for (int j = 0; j < Nblock; j++) {
|
| 155 |
+
for (int i = 0; i < Mblock; i++) {
|
| 156 |
+
int row = row_block + i;
|
| 157 |
+
int col = col_block + j;
|
| 158 |
+
|
| 159 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 160 |
+
|
| 161 |
+
if (row < M && col < N) {
|
| 162 |
+
tensor_d.at(coord) = convert_op(
|
| 163 |
+
alpha * ScalarType(accum[i][j]) +
|
| 164 |
+
beta * ScalarType(tensor_c.at(coord)));
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 173 |
+
|
| 174 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 175 |
+
/// objects.
|
| 176 |
+
template <
|
| 177 |
+
typename ElementA,
|
| 178 |
+
typename LayoutA,
|
| 179 |
+
typename ElementB,
|
| 180 |
+
typename LayoutB,
|
| 181 |
+
typename ElementC,
|
| 182 |
+
typename LayoutC,
|
| 183 |
+
typename ScalarType,
|
| 184 |
+
typename ComputeType,
|
| 185 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 186 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 187 |
+
>
|
| 188 |
+
void compute_gemm(
|
| 189 |
+
gemm::GemmCoord problem_size,
|
| 190 |
+
ScalarType alpha,
|
| 191 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 192 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 193 |
+
ScalarType beta,
|
| 194 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 195 |
+
ComputeType initial_accum) {
|
| 196 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 197 |
+
ScalarType, ComputeType, InnerProductOp, ConvertOp>(
|
| 198 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
| 199 |
+
initial_accum);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 203 |
+
|
| 204 |
+
template <
|
| 205 |
+
typename ElementA,
|
| 206 |
+
typename LayoutA,
|
| 207 |
+
typename ElementB,
|
| 208 |
+
typename LayoutB,
|
| 209 |
+
typename ElementC,
|
| 210 |
+
typename LayoutC,
|
| 211 |
+
typename ScalarType,
|
| 212 |
+
typename ComputeType,
|
| 213 |
+
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
| 214 |
+
>
|
| 215 |
+
struct Gemm;
|
| 216 |
+
|
| 217 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 218 |
+
|
| 219 |
+
/// Partial specialization for multiply-add
|
| 220 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 221 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 222 |
+
typename ScalarType, typename ComputeType>
|
| 223 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 224 |
+
ComputeType, arch::OpMultiplyAdd> {
|
| 225 |
+
|
| 226 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 227 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 228 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 229 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 230 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 231 |
+
static_assert(
|
| 232 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 233 |
+
"Tensors must be of rank 2");
|
| 234 |
+
|
| 235 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 236 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 237 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 241 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 242 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 243 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 244 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 245 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 246 |
+
static_assert(
|
| 247 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 248 |
+
"Tensors must be of rank 2");
|
| 249 |
+
|
| 250 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 251 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 252 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 253 |
+
}
|
| 254 |
+
};
|
| 255 |
+
|
| 256 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 257 |
+
|
| 258 |
+
/// Partial specialization for multiply-add
|
| 259 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 260 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 261 |
+
typename ScalarType, typename ComputeType>
|
| 262 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 263 |
+
ComputeType, arch::OpMultiplyAddFastBF16> {
|
| 264 |
+
|
| 265 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 266 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 267 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 268 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 269 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 270 |
+
static_assert(
|
| 271 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 272 |
+
"Tensors must be of rank 2");
|
| 273 |
+
|
| 274 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 275 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 276 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 280 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 281 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 282 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 283 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 284 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 285 |
+
static_assert(
|
| 286 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 287 |
+
"Tensors must be of rank 2");
|
| 288 |
+
|
| 289 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 290 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 291 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 292 |
+
}
|
| 293 |
+
};
|
| 294 |
+
|
| 295 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 296 |
+
|
| 297 |
+
/// Partial specialization for multiply-add-saturate
|
| 298 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 299 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 300 |
+
typename ScalarType, typename ComputeType>
|
| 301 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 302 |
+
ComputeType, arch::OpMultiplyAddSaturate> {
|
| 303 |
+
|
| 304 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 305 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 306 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 307 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 308 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 309 |
+
static_assert(
|
| 310 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 311 |
+
"Tensors must be of rank 2");
|
| 312 |
+
|
| 313 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 314 |
+
ScalarType, ComputeType, multiply_add<ComputeType>,
|
| 315 |
+
NumericConverterClamp<ElementC, ScalarType>>(
|
| 316 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 320 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 321 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 322 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 323 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 324 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 325 |
+
static_assert(
|
| 326 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 327 |
+
"Tensors must be of rank 2");
|
| 328 |
+
|
| 329 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 330 |
+
ScalarType, ComputeType, multiply_add<ComputeType>,
|
| 331 |
+
NumericConverterClamp<ElementC, ScalarType>>(
|
| 332 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 333 |
+
}
|
| 334 |
+
};
|
| 335 |
+
|
| 336 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 337 |
+
|
| 338 |
+
/// Partial specialization for XOR-popc
|
| 339 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 340 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 341 |
+
typename ScalarType, typename ComputeType>
|
| 342 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 343 |
+
ComputeType, arch::OpXorPopc> {
|
| 344 |
+
|
| 345 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 346 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 347 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 348 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 349 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 350 |
+
static_assert(
|
| 351 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 352 |
+
"Tensors must be of rank 2");
|
| 353 |
+
|
| 354 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 355 |
+
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
|
| 356 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 360 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 361 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 362 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 363 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 364 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 365 |
+
static_assert(
|
| 366 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 367 |
+
"Tensors must be of rank 2");
|
| 368 |
+
|
| 369 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 370 |
+
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
|
| 371 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 372 |
+
}
|
| 373 |
+
};
|
| 374 |
+
|
| 375 |
+
/// Partial specialization for AND-popc
|
| 376 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 377 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 378 |
+
typename ScalarType, typename ComputeType>
|
| 379 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 380 |
+
ComputeType, arch::OpAndPopc> {
|
| 381 |
+
|
| 382 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 383 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 384 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 385 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 386 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 387 |
+
static_assert(
|
| 388 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 389 |
+
"Tensors must be of rank 2");
|
| 390 |
+
|
| 391 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 392 |
+
ScalarType, ComputeType, and_popc_add<ComputeType>>(
|
| 393 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 397 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 398 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 399 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 400 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 401 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 402 |
+
static_assert(
|
| 403 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 404 |
+
"Tensors must be of rank 2");
|
| 405 |
+
|
| 406 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 407 |
+
ScalarType, ComputeType, and_popc_add<ComputeType>>(
|
| 408 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 409 |
+
}
|
| 410 |
+
};
|
| 411 |
+
|
| 412 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 413 |
+
|
| 414 |
+
/// Partial specialization for multiply-add
|
| 415 |
+
template <typename ElementA, typename LayoutA, typename ElementB,
|
| 416 |
+
typename LayoutB, typename ElementC, typename LayoutC,
|
| 417 |
+
typename ScalarType, typename ComputeType>
|
| 418 |
+
struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 419 |
+
ComputeType, arch::OpMultiplyAddFastF32> {
|
| 420 |
+
|
| 421 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 422 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 423 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 424 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 425 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 426 |
+
static_assert(
|
| 427 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 428 |
+
"Tensors must be of rank 2");
|
| 429 |
+
|
| 430 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 431 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 432 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 436 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 437 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 438 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 439 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 440 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 441 |
+
static_assert(
|
| 442 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 443 |
+
"Tensors must be of rank 2");
|
| 444 |
+
|
| 445 |
+
compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
| 446 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 447 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 448 |
+
}
|
| 449 |
+
};
|
| 450 |
+
|
| 451 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 452 |
+
|
| 453 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 454 |
+
//
|
| 455 |
+
// Batched GEMM
|
| 456 |
+
//
|
| 457 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 458 |
+
|
| 459 |
+
/// Computes a batch of GEMMs over a set of matrices of common dimension.
|
| 460 |
+
//
|
| 461 |
+
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 462 |
+
//
|
| 463 |
+
template <
|
| 464 |
+
typename TensorRefCollectionA,
|
| 465 |
+
typename TensorRefCollectionB,
|
| 466 |
+
typename TensorRefCollectionC,
|
| 467 |
+
typename ScalarType,
|
| 468 |
+
typename AccumulatorType
|
| 469 |
+
>
|
| 470 |
+
void BatchedGemm(
|
| 471 |
+
gemm::GemmCoord problem_size,
|
| 472 |
+
int batch_count,
|
| 473 |
+
ScalarType alpha,
|
| 474 |
+
TensorRefCollectionA const& tensor_a,
|
| 475 |
+
TensorRefCollectionB const& tensor_b,
|
| 476 |
+
ScalarType beta,
|
| 477 |
+
TensorRefCollectionC &tensor_c,
|
| 478 |
+
AccumulatorType initial_accum) {
|
| 479 |
+
|
| 480 |
+
typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
|
| 481 |
+
typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
|
| 482 |
+
typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin();
|
| 483 |
+
|
| 484 |
+
for (int batch = 0;
|
| 485 |
+
batch < batch_count;
|
| 486 |
+
++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) {
|
| 487 |
+
|
| 488 |
+
Gemm<typename TensorRefCollectionA::Element,
|
| 489 |
+
typename TensorRefCollectionA::Layout,
|
| 490 |
+
typename TensorRefCollectionB::Element,
|
| 491 |
+
typename TensorRefCollectionB::Layout,
|
| 492 |
+
typename TensorRefCollectionC::Element,
|
| 493 |
+
typename TensorRefCollectionC::Layout,
|
| 494 |
+
typename TensorRefCollectionC::Element,
|
| 495 |
+
typename TensorRefCollectionC::Element>
|
| 496 |
+
gemm;
|
| 497 |
+
|
| 498 |
+
gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it,
|
| 499 |
+
initial_accum);
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 504 |
+
/// objects.
|
| 505 |
+
//
|
| 506 |
+
// TensorRefCollection* is a type satisfying the TensorRefCollection concept.
|
| 507 |
+
//
|
| 508 |
+
template <
|
| 509 |
+
typename TensorRefCollectionA,
|
| 510 |
+
typename TensorRefCollectionB,
|
| 511 |
+
typename TensorRefCollectionC,
|
| 512 |
+
typename ScalarType,
|
| 513 |
+
typename AccumulatorType
|
| 514 |
+
>
|
| 515 |
+
void BatchedGemm(
|
| 516 |
+
gemm::GemmCoord problem_size,
|
| 517 |
+
int batch_count,
|
| 518 |
+
ScalarType alpha,
|
| 519 |
+
TensorRefCollectionA const& tensor_a,
|
| 520 |
+
TensorRefCollectionB const& tensor_b,
|
| 521 |
+
ScalarType beta,
|
| 522 |
+
TensorRefCollectionC &tensor_c) {
|
| 523 |
+
|
| 524 |
+
BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 528 |
+
|
| 529 |
+
} // namespace host
|
| 530 |
+
} // namespace reference
|
| 531 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued GEMM in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/matrix_coord.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/tensor_view.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reference {
|
| 50 |
+
namespace host {
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 55 |
+
/// objects.
|
| 56 |
+
///
|
| 57 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 58 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 59 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 60 |
+
/// arguments explicitly.
|
| 61 |
+
template <
|
| 62 |
+
typename ElementA,
|
| 63 |
+
typename LayoutA,
|
| 64 |
+
typename ElementB,
|
| 65 |
+
typename LayoutB,
|
| 66 |
+
typename ElementC,
|
| 67 |
+
typename LayoutC,
|
| 68 |
+
typename ScalarType,
|
| 69 |
+
typename ComputeType,
|
| 70 |
+
typename ElementD = ElementC,
|
| 71 |
+
typename ConvertOp = NumericConverter<ElementD, ScalarType>,
|
| 72 |
+
typename InnerProductOp = multiply_add<ComputeType>
|
| 73 |
+
>
|
| 74 |
+
void GemmComplex(
|
| 75 |
+
gemm::GemmCoord problem_size,
|
| 76 |
+
ScalarType alpha,
|
| 77 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 78 |
+
ComplexTransform transform_a,
|
| 79 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 80 |
+
ComplexTransform transform_b,
|
| 81 |
+
ScalarType beta,
|
| 82 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 83 |
+
TensorRef<ElementD, LayoutC> tensor_d,
|
| 84 |
+
ComputeType initial_accum,
|
| 85 |
+
int batch_count = 1,
|
| 86 |
+
int64_t batch_stride_A = 0,
|
| 87 |
+
int64_t batch_stride_B = 0,
|
| 88 |
+
int64_t batch_stride_C = 0,
|
| 89 |
+
int64_t batch_stride_D = 0) {
|
| 90 |
+
|
| 91 |
+
static_assert(
|
| 92 |
+
LayoutA::kRank == 2 &&
|
| 93 |
+
LayoutB::kRank == 2 &&
|
| 94 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 95 |
+
|
| 96 |
+
// Note: batch is ignored.
|
| 97 |
+
int const M = problem_size.m();
|
| 98 |
+
int const N = problem_size.n();
|
| 99 |
+
int const K = problem_size.k();
|
| 100 |
+
|
| 101 |
+
// Blocking necessary to speedup reference implementation
|
| 102 |
+
int const Mblock = 16;
|
| 103 |
+
int const Nblock = 16;
|
| 104 |
+
|
| 105 |
+
ConvertOp convert_op;
|
| 106 |
+
InnerProductOp inner_product_op;
|
| 107 |
+
|
| 108 |
+
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
| 109 |
+
|
| 110 |
+
// Compute matrix product using blocks
|
| 111 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 112 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 113 |
+
|
| 114 |
+
ComputeType accum[Mblock][Nblock];
|
| 115 |
+
|
| 116 |
+
for (int j = 0; j < Nblock; j++) {
|
| 117 |
+
for (int i = 0; i < Mblock; i++) {
|
| 118 |
+
accum[i][j] = initial_accum;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 123 |
+
for (int j = 0; j < Nblock; j++) {
|
| 124 |
+
for (int i = 0; i < Mblock; i++) {
|
| 125 |
+
int row = row_block + i;
|
| 126 |
+
int col = col_block + j;
|
| 127 |
+
|
| 128 |
+
if (row < M && col < N) {
|
| 129 |
+
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 130 |
+
ElementB b = tensor_b.at(MatrixCoord(k_block, col));
|
| 131 |
+
|
| 132 |
+
ComputeType a_ik = ComputeType(a);
|
| 133 |
+
ComputeType b_kj = ComputeType(b);
|
| 134 |
+
|
| 135 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 136 |
+
a_ik = conj(a_ik);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 140 |
+
b_kj = conj(b_kj);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
for (int j = 0; j < Nblock; j++) {
|
| 150 |
+
for (int i = 0; i < Mblock; i++) {
|
| 151 |
+
int row = row_block + i;
|
| 152 |
+
int col = col_block + j;
|
| 153 |
+
|
| 154 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 155 |
+
|
| 156 |
+
if (row < M && col < N) {
|
| 157 |
+
|
| 158 |
+
tensor_d.at(coord) = convert_op(
|
| 159 |
+
alpha * ScalarType(accum[i][j]) +
|
| 160 |
+
beta * ScalarType(tensor_c.at(coord)));
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
} // for (col_block)
|
| 166 |
+
} // for (row_block)
|
| 167 |
+
|
| 168 |
+
tensor_a.add_pointer_offset(batch_stride_A);
|
| 169 |
+
tensor_b.add_pointer_offset(batch_stride_B);
|
| 170 |
+
tensor_c.add_pointer_offset(batch_stride_C);
|
| 171 |
+
tensor_d.add_pointer_offset(batch_stride_D);
|
| 172 |
+
|
| 173 |
+
} // for (batch_idx)
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 177 |
+
|
| 178 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 179 |
+
/// objects.
|
| 180 |
+
///
|
| 181 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 182 |
+
template <
|
| 183 |
+
typename ElementA,
|
| 184 |
+
typename LayoutA,
|
| 185 |
+
typename ElementB,
|
| 186 |
+
typename LayoutB,
|
| 187 |
+
typename ElementC,
|
| 188 |
+
typename LayoutC,
|
| 189 |
+
typename ScalarType,
|
| 190 |
+
typename ElementD = ElementC
|
| 191 |
+
>
|
| 192 |
+
void GemmComplex(
|
| 193 |
+
gemm::GemmCoord problem_size,
|
| 194 |
+
ScalarType alpha,
|
| 195 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 196 |
+
ComplexTransform transform_a,
|
| 197 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 198 |
+
ComplexTransform transform_b,
|
| 199 |
+
ScalarType beta,
|
| 200 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 201 |
+
TensorRef<ElementD, LayoutC> tensor_d) {
|
| 202 |
+
|
| 203 |
+
GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 207 |
+
|
| 208 |
+
} // namespace host
|
| 209 |
+
} // namespace reference
|
| 210 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued GEMM in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/tensor_ref_planar_complex.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/tensor_view.h"
|
| 45 |
+
#include "cutlass/gemm/gemm.h"
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace reference {
|
| 49 |
+
namespace host {
|
| 50 |
+
|
| 51 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 54 |
+
/// objects.
|
| 55 |
+
///
|
| 56 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 57 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 58 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 59 |
+
/// arguments explicitly.
|
| 60 |
+
template <
|
| 61 |
+
typename ElementA,
|
| 62 |
+
typename LayoutA,
|
| 63 |
+
typename ElementB,
|
| 64 |
+
typename LayoutB,
|
| 65 |
+
typename ElementC,
|
| 66 |
+
typename LayoutC,
|
| 67 |
+
typename ScalarType,
|
| 68 |
+
typename ComputeType,
|
| 69 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 70 |
+
typename InnerProductOp = multiply_add<complex<ComputeType>>
|
| 71 |
+
>
|
| 72 |
+
void GemmPlanarComplex(
|
| 73 |
+
gemm::GemmCoord problem_size,
|
| 74 |
+
complex<ScalarType> alpha,
|
| 75 |
+
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 76 |
+
ComplexTransform transform_a,
|
| 77 |
+
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 78 |
+
ComplexTransform transform_b,
|
| 79 |
+
complex<ScalarType> beta,
|
| 80 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 81 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
|
| 82 |
+
complex<ComputeType> initial_accum) {
|
| 83 |
+
|
| 84 |
+
static_assert(
|
| 85 |
+
LayoutA::kRank == 2 &&
|
| 86 |
+
LayoutB::kRank == 2 &&
|
| 87 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 88 |
+
|
| 89 |
+
using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
|
| 90 |
+
using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
|
| 91 |
+
using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
|
| 92 |
+
|
| 93 |
+
// Note: batch is ignored.
|
| 94 |
+
int const M = problem_size.m();
|
| 95 |
+
int const N = problem_size.n();
|
| 96 |
+
int const K = problem_size.k();
|
| 97 |
+
|
| 98 |
+
// Blocking necessary to speedup reference implementation
|
| 99 |
+
int const Mblock = 16;
|
| 100 |
+
int const Nblock = 16;
|
| 101 |
+
|
| 102 |
+
ConvertOp convert_op;
|
| 103 |
+
InnerProductOp inner_product_op;
|
| 104 |
+
|
| 105 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 106 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 107 |
+
|
| 108 |
+
complex<ComputeType> accum[Mblock][Nblock];
|
| 109 |
+
|
| 110 |
+
for (int j = 0; j < Nblock; j++) {
|
| 111 |
+
for (int i = 0; i < Mblock; i++) {
|
| 112 |
+
accum[i][j] = initial_accum;
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 117 |
+
for (int j = 0; j < Nblock; j++) {
|
| 118 |
+
for (int i = 0; i < Mblock; i++) {
|
| 119 |
+
int row = row_block + i;
|
| 120 |
+
int col = col_block + j;
|
| 121 |
+
|
| 122 |
+
if (row < M && col < N) {
|
| 123 |
+
|
| 124 |
+
ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
|
| 125 |
+
ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
|
| 126 |
+
|
| 127 |
+
complex<ComputeType> a = complex<ComputeType>{
|
| 128 |
+
ComputeType(a_ik.real()),
|
| 129 |
+
ComputeType(a_ik.imag())
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
complex<ComputeType> b = complex<ComputeType>{
|
| 133 |
+
ComputeType(b_kj.real()),
|
| 134 |
+
ComputeType(b_kj.imag())
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 138 |
+
a = conj(a);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 142 |
+
b = conj(b);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
accum[i][j] = inner_product_op(a, b, accum[i][j]);
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
for (int j = 0; j < Nblock; j++) {
|
| 152 |
+
for (int i = 0; i < Mblock; i++) {
|
| 153 |
+
int row = row_block + i;
|
| 154 |
+
int col = col_block + j;
|
| 155 |
+
|
| 156 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 157 |
+
|
| 158 |
+
if (row < M && col < N) {
|
| 159 |
+
|
| 160 |
+
complex<ScalarType> acc{
|
| 161 |
+
ScalarType(accum[i][j].real()),
|
| 162 |
+
ScalarType(accum[i][j].imag())
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
ComplexC d_ij = tensor_c.at(coord);
|
| 166 |
+
|
| 167 |
+
complex<ScalarType> src{
|
| 168 |
+
ScalarType(d_ij.real()),
|
| 169 |
+
ScalarType(d_ij.imag())
|
| 170 |
+
};
|
| 171 |
+
|
| 172 |
+
complex<ScalarType> result = alpha * acc + beta * src;
|
| 173 |
+
|
| 174 |
+
d_ij.real() = convert_op(result.real());
|
| 175 |
+
d_ij.imag() = convert_op(result.imag());
|
| 176 |
+
|
| 177 |
+
tensor_d.at(coord) = d_ij;
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 186 |
+
|
| 187 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 188 |
+
/// objects.
|
| 189 |
+
///
|
| 190 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 191 |
+
template <
|
| 192 |
+
typename ElementA,
|
| 193 |
+
typename LayoutA,
|
| 194 |
+
typename ElementB,
|
| 195 |
+
typename LayoutB,
|
| 196 |
+
typename ElementC,
|
| 197 |
+
typename LayoutC,
|
| 198 |
+
typename ScalarType
|
| 199 |
+
>
|
| 200 |
+
void GemmPlanarComplex(
|
| 201 |
+
gemm::GemmCoord problem_size,
|
| 202 |
+
complex<ScalarType> alpha,
|
| 203 |
+
TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
|
| 204 |
+
ComplexTransform transform_a,
|
| 205 |
+
TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
|
| 206 |
+
ComplexTransform transform_b,
|
| 207 |
+
complex<ScalarType> beta,
|
| 208 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
|
| 209 |
+
TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
|
| 210 |
+
|
| 211 |
+
GemmPlanarComplex(
|
| 212 |
+
problem_size,
|
| 213 |
+
alpha,
|
| 214 |
+
tensor_a, transform_a,
|
| 215 |
+
tensor_b, transform_b,
|
| 216 |
+
beta,
|
| 217 |
+
tensor_c,
|
| 218 |
+
tensor_d,
|
| 219 |
+
complex<ScalarType>());
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 223 |
+
|
| 224 |
+
} // namespace host
|
| 225 |
+
} // namespace reference
|
| 226 |
+
} // namespace cutlass
|
| 227 |
+
|
| 228 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp
ADDED
|
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for GETT in host-side code.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 38 |
+
#include "cutlass/gemm/gemm.h"
|
| 39 |
+
#include "cutlass/complex.h"
|
| 40 |
+
#include "cutlass/numeric_conversion.h"
|
| 41 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 42 |
+
#include "cutlass/relatively_equal.h"
|
| 43 |
+
|
| 44 |
+
#include "cute/tensor.hpp"
|
| 45 |
+
#include "cute/pointer.hpp"
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace cutlass::reference::host {
|
| 50 |
+
|
| 51 |
+
template<class T, class = void>
|
| 52 |
+
struct ElementTraits {
|
| 53 |
+
using type = T;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
template<class T>
|
| 57 |
+
struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
|
| 58 |
+
using type = decltype(std::declval<T>().get());
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
///////////////////////////////////////////////////////////
|
| 64 |
+
//
|
| 65 |
+
// Gett Mainloop Parameters
|
| 66 |
+
//
|
| 67 |
+
///////////////////////////////////////////////////////////
|
| 68 |
+
|
| 69 |
+
template<
|
| 70 |
+
class ElementAccumulator_,
|
| 71 |
+
class TensorA_, // (M, K, L)
|
| 72 |
+
class TensorB_ // (N, K, L)
|
| 73 |
+
|
| 74 |
+
, class TensorSfA_ = TensorA_,
|
| 75 |
+
class TensorSfB_ = TensorB_
|
| 76 |
+
|
| 77 |
+
>
|
| 78 |
+
struct GettMainloopParams {
|
| 79 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 80 |
+
using TensorA = TensorA_;
|
| 81 |
+
using TensorB = TensorB_;
|
| 82 |
+
using EngineA = typename TensorA::engine_type;
|
| 83 |
+
using LayoutA = typename TensorA::layout_type;
|
| 84 |
+
using EngineB = typename TensorB::engine_type;
|
| 85 |
+
using LayoutB = typename TensorB::layout_type;
|
| 86 |
+
|
| 87 |
+
TensorA A{};
|
| 88 |
+
TensorB B{};
|
| 89 |
+
|
| 90 |
+
ComplexTransform transform_A = ComplexTransform::kNone;
|
| 91 |
+
ComplexTransform transform_B = ComplexTransform::kNone;
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
using TensorSfA = TensorSfA_;
|
| 95 |
+
using TensorSfB = TensorSfB_;
|
| 96 |
+
using EngineSfA = typename TensorSfA::engine_type;
|
| 97 |
+
using LayoutSfA = typename TensorSfA::layout_type;
|
| 98 |
+
using EngineSfB = typename TensorSfB::engine_type;
|
| 99 |
+
using LayoutSfB = typename TensorSfB::layout_type;
|
| 100 |
+
TensorSfA_ SfA{};
|
| 101 |
+
TensorSfB_ SfB{};
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
GettMainloopParams() {}
|
| 105 |
+
|
| 106 |
+
GettMainloopParams(TensorA tensor_A, TensorB tensor_B)
|
| 107 |
+
: A(tensor_A), B(tensor_B) {}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
|
| 111 |
+
: A(tensor_A), SfA(tensor_SfA),
|
| 112 |
+
B(tensor_B), SfB(tensor_SfB) {}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
////////////////////////////////////////////////////////////////////////
|
| 120 |
+
//
|
| 121 |
+
// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels
|
| 122 |
+
//
|
| 123 |
+
////////////////////////////////////////////////////////////////////////
|
| 124 |
+
|
| 125 |
+
template<
|
| 126 |
+
class ElementAccumulator_,
|
| 127 |
+
class TensorA_, // (M, K, L)
|
| 128 |
+
class TensorSfA_, // (M, K, L)
|
| 129 |
+
class TensorB_, // (N, K, L)
|
| 130 |
+
class TensorSfB_ // (N, K, L)
|
| 131 |
+
>
|
| 132 |
+
struct GettBlockScalingMainloopParams : public GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_> {
|
| 133 |
+
using Base = GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_>;
|
| 134 |
+
using ElementAccumulator = typename Base::ElementAccumulator;
|
| 135 |
+
using TensorA = typename Base::TensorA;
|
| 136 |
+
using TensorB = typename Base::TensorB;
|
| 137 |
+
using EngineA = typename Base::EngineA;
|
| 138 |
+
using LayoutA = typename Base::LayoutA;
|
| 139 |
+
using EngineB = typename Base::EngineB;
|
| 140 |
+
using LayoutB = typename Base::LayoutB;
|
| 141 |
+
ComplexTransform transform_A = Base::transform_A;
|
| 142 |
+
ComplexTransform transform_B = Base::transform_B;
|
| 143 |
+
|
| 144 |
+
using TensorSfA = typename Base::TensorSfA;
|
| 145 |
+
using TensorSfB = typename Base::TensorSfB;
|
| 146 |
+
using EngineSfA = typename Base::EngineSfA;
|
| 147 |
+
using LayoutSfA = typename Base::LayoutSfA;
|
| 148 |
+
using EngineSfB = typename Base::EngineSfB;
|
| 149 |
+
using LayoutSfB = typename Base::LayoutSfB;
|
| 150 |
+
|
| 151 |
+
GettBlockScalingMainloopParams() {}
|
| 152 |
+
|
| 153 |
+
GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
|
| 154 |
+
: Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 161 |
+
|
| 162 |
+
enum class SfStrategy {
|
| 163 |
+
None = 0,
|
| 164 |
+
SfDGen = 1
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
///////////////////////////////////////////////////////////
|
| 169 |
+
//
|
| 170 |
+
// Gett Epilogue Parameters
|
| 171 |
+
//
|
| 172 |
+
///////////////////////////////////////////////////////////
|
| 173 |
+
|
| 174 |
+
template<
|
| 175 |
+
class ElementScalar_,
|
| 176 |
+
class ElementScalingFactor_,
|
| 177 |
+
class ElementAccumulator_,
|
| 178 |
+
class ElementCompute_,
|
| 179 |
+
class TensorC_, // (M, N, L)
|
| 180 |
+
class TensorD_, // (M, N, L)
|
| 181 |
+
class VectorBias_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
|
| 182 |
+
class TensorAux_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, N, L)
|
| 183 |
+
class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
|
| 184 |
+
class VectorBeta_ = VectorAlpha_, // (M, 1)
|
| 185 |
+
class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
|
| 186 |
+
class TensorSFD_ = TensorD_,
|
| 187 |
+
class SFD_VectorSize_ = cute::Int<0>,
|
| 188 |
+
class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
|
| 189 |
+
bool PerColumnBias_ = false
|
| 190 |
+
,
|
| 191 |
+
SfStrategy SfGenStrategy_ = SfStrategy::None
|
| 192 |
+
>
|
| 193 |
+
struct GettEpilogueParams {
|
| 194 |
+
using ElementScalar = ElementScalar_;
|
| 195 |
+
using ElementScalingFactor = ElementScalingFactor_;
|
| 196 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 197 |
+
using ElementCompute = ElementCompute_;
|
| 198 |
+
using TensorC = TensorC_;
|
| 199 |
+
using TensorD = TensorD_;
|
| 200 |
+
using TensorAux = TensorAux_;
|
| 201 |
+
using VectorBias = VectorBias_;
|
| 202 |
+
using VectorAlpha = VectorAlpha_;
|
| 203 |
+
using VectorBeta = VectorBeta_;
|
| 204 |
+
using TensorSFD = TensorSFD_;
|
| 205 |
+
using SFD_VectorSize = SFD_VectorSize_;
|
| 206 |
+
using ActivationFunctor = ActivationFunctor_;
|
| 207 |
+
using BiasBinaryOp = BiasBinaryOp_;
|
| 208 |
+
|
| 209 |
+
using EngineC = typename TensorC::engine_type;
|
| 210 |
+
using LayoutC = typename TensorC::layout_type;
|
| 211 |
+
using EngineD = typename TensorD::engine_type;
|
| 212 |
+
using LayoutD = typename TensorD::layout_type;
|
| 213 |
+
using EngineSfD = typename TensorSFD::engine_type;
|
| 214 |
+
using LayoutSfD = typename TensorSFD::layout_type;
|
| 215 |
+
static constexpr bool PerColumnBias = PerColumnBias_;
|
| 216 |
+
static constexpr SfStrategy SfGenStrategy = SfGenStrategy_;
|
| 217 |
+
|
| 218 |
+
ElementScalar alpha = ElementScalar(1);
|
| 219 |
+
ElementScalar beta = ElementScalar(0);
|
| 220 |
+
|
| 221 |
+
TensorC C{};
|
| 222 |
+
TensorD D{};
|
| 223 |
+
VectorBias Bias{};
|
| 224 |
+
TensorAux Aux{};
|
| 225 |
+
VectorAlpha Valpha{};
|
| 226 |
+
VectorBeta Vbeta{};
|
| 227 |
+
TensorSFD SfD{};
|
| 228 |
+
ElementCompute st = ElementCompute(1);
|
| 229 |
+
|
| 230 |
+
ElementAccumulator* abs_max_D = nullptr;
|
| 231 |
+
ElementAccumulator* abs_max_Aux = nullptr;
|
| 232 |
+
|
| 233 |
+
ElementScalingFactor scale_a = ElementScalingFactor(1);
|
| 234 |
+
ElementScalingFactor scale_b = ElementScalingFactor(1);
|
| 235 |
+
ElementScalingFactor scale_c = ElementScalingFactor(1);
|
| 236 |
+
ElementScalingFactor scale_d = ElementScalingFactor(1);
|
| 237 |
+
ElementScalingFactor scale_aux = ElementScalingFactor(1);
|
| 238 |
+
|
| 239 |
+
bool beta_per_channel_scaling = false;
|
| 240 |
+
GettEpilogueParams() {}
|
| 241 |
+
|
| 242 |
+
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
|
| 243 |
+
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
|
| 247 |
+
: alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
GettEpilogueParams(
|
| 251 |
+
ElementScalar alpha, ElementScalar beta,
|
| 252 |
+
TensorC tensor_C, TensorD tensor_D,
|
| 253 |
+
VectorBias bias, TensorAux tensor_aux,
|
| 254 |
+
VectorAlpha vector_alpha, VectorBeta vector_beta)
|
| 255 |
+
: alpha(alpha), beta(beta),
|
| 256 |
+
C(tensor_C), D(tensor_D),
|
| 257 |
+
Bias(bias), Aux(tensor_aux),
|
| 258 |
+
Valpha(vector_alpha), Vbeta(vector_beta) {}
|
| 259 |
+
};
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
////////////////////////////////////////////////////////////////////////
|
| 264 |
+
//
|
| 265 |
+
// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels
|
| 266 |
+
//
|
| 267 |
+
////////////////////////////////////////////////////////////////////////
|
| 268 |
+
|
| 269 |
+
template<
|
| 270 |
+
class ElementScalar_,
|
| 271 |
+
class ElementAccumulator_,
|
| 272 |
+
class ElementCompute_,
|
| 273 |
+
class TensorC_,
|
| 274 |
+
class TensorD_,
|
| 275 |
+
class TensorSfD_ = TensorD_,
|
| 276 |
+
class SFD_VectorSize_ = cute::Int<0>,
|
| 277 |
+
SfStrategy SfGenStrategy_ = SfStrategy::None
|
| 278 |
+
>
|
| 279 |
+
struct GettBlockScalingEpilogueParams : public GettEpilogueParams<
|
| 280 |
+
ElementScalar_, // ElementScalar
|
| 281 |
+
ElementScalar_, // ElementScalingFactor
|
| 282 |
+
ElementAccumulator_, // ElementAccumulator
|
| 283 |
+
ElementCompute_, // ElementCompute
|
| 284 |
+
TensorC_, // TensorC (M, N, L)
|
| 285 |
+
TensorD_, // TensorD (M, N, L)
|
| 286 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
|
| 287 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
|
| 288 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
|
| 289 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
|
| 290 |
+
cutlass::epilogue::thread::Identity<ElementCompute_>, //
|
| 291 |
+
TensorSfD_, // TensorSfD
|
| 292 |
+
SFD_VectorSize_, // SFD_VectorSize
|
| 293 |
+
cutlass::plus<ElementCompute_>, // class BiasBinaryOp_ =
|
| 294 |
+
false, //PerColumnBias_
|
| 295 |
+
SfGenStrategy_ // SfGenStrategy
|
| 296 |
+
> {
|
| 297 |
+
using Base = GettEpilogueParams<
|
| 298 |
+
ElementScalar_, // ElementScalar
|
| 299 |
+
ElementScalar_, // ElementScalingFactor
|
| 300 |
+
ElementAccumulator_, // ElementAccumulator
|
| 301 |
+
ElementCompute_, // ElementCompute
|
| 302 |
+
TensorC_, // TensorC (M, N, L)
|
| 303 |
+
TensorD_, // TensorD (M, N, L)
|
| 304 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
|
| 305 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
|
| 306 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
|
| 307 |
+
decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
|
| 308 |
+
cutlass::epilogue::thread::Identity<ElementCompute_>, //
|
| 309 |
+
TensorSfD_, // TensorSfD
|
| 310 |
+
SFD_VectorSize_, // SFD_VectorSize
|
| 311 |
+
cutlass::plus<ElementCompute_>, // BiasBinaryOp
|
| 312 |
+
false, // PerColumnBias
|
| 313 |
+
SfGenStrategy_ // SfGenStrategy
|
| 314 |
+
>;
|
| 315 |
+
using ElementScalar = typename Base::ElementScalar;
|
| 316 |
+
using ElementScalingFactor = typename Base::ElementScalingFactor;
|
| 317 |
+
using ElementAccumulator = typename Base::ElementAccumulator;
|
| 318 |
+
using ElementCompute = typename Base::ElementCompute;
|
| 319 |
+
using TensorC = typename Base::TensorC;
|
| 320 |
+
using TensorD = typename Base::TensorD;
|
| 321 |
+
using TensorAux = typename Base::TensorAux;
|
| 322 |
+
using VectorBias = typename Base::VectorBias;
|
| 323 |
+
using VectorAlpha = typename Base::VectorAlpha;
|
| 324 |
+
using VectorBeta = typename Base::VectorBeta;
|
| 325 |
+
using TensorSFD = typename Base::TensorSFD;
|
| 326 |
+
using SFD_VectorSize = typename Base::SFD_VectorSize;
|
| 327 |
+
using ActivationFunctor = typename Base::ActivationFunctor;
|
| 328 |
+
using BiasBinaryOp = typename Base::BiasBinaryOp;
|
| 329 |
+
|
| 330 |
+
using EngineC = typename Base::EngineC;
|
| 331 |
+
using LayoutC = typename Base::LayoutC;
|
| 332 |
+
using EngineD = typename Base::EngineD;
|
| 333 |
+
using LayoutD = typename Base::LayoutD;
|
| 334 |
+
using EngineSfD = typename Base::EngineSfD;
|
| 335 |
+
using LayoutSfD = typename Base::LayoutSfD;
|
| 336 |
+
static constexpr bool PerColumnBias = Base::PerColumnBias;
|
| 337 |
+
static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy;
|
| 338 |
+
|
| 339 |
+
GettBlockScalingEpilogueParams() {}
|
| 340 |
+
|
| 341 |
+
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
|
| 342 |
+
: Base(alpha, beta, tensor_C, tensor_D) {}
|
| 343 |
+
|
| 344 |
+
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD)
|
| 345 |
+
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {}
|
| 346 |
+
|
| 347 |
+
GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
|
| 348 |
+
: Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {}
|
| 349 |
+
};
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
///////////////////////////////////////////////////////////
|
| 356 |
+
//
|
| 357 |
+
// Generic Gett 3x Implementation
|
| 358 |
+
//
|
| 359 |
+
///////////////////////////////////////////////////////////
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 363 |
+
template <int kVectorSize, class EpilogueParams, class TensorD, class TensorSFD, class ElementCompute, int kBlockM, int kBlockN>
|
| 364 |
+
void compute_1d_scaling_factor_and_quantized_output(
|
| 365 |
+
EpilogueParams const& epilogue_params,
|
| 366 |
+
TensorD &tensor_D,
|
| 367 |
+
TensorSFD &tensor_SfD,
|
| 368 |
+
int64_t m,
|
| 369 |
+
int64_t n,
|
| 370 |
+
int64_t l,
|
| 371 |
+
ElementCompute (&acc)[kBlockM][kBlockN])
|
| 372 |
+
{
|
| 373 |
+
using ElementD = typename ElementTraits<typename EpilogueParams::EngineD::value_type>::type;
|
| 374 |
+
using ElementSfD = typename ElementTraits<typename EpilogueParams::EngineSfD::value_type>::type;
|
| 375 |
+
|
| 376 |
+
int const M = cute::size<0>(tensor_D.layout());
|
| 377 |
+
int const N = cute::size<1>(tensor_D.layout());
|
| 378 |
+
int const L = cute::size<2>(tensor_D.layout());
|
| 379 |
+
|
| 380 |
+
auto mul = cutlass::multiplies<ElementCompute>{};
|
| 381 |
+
auto div = divides<ElementCompute>{};
|
| 382 |
+
// Get FP max
|
| 383 |
+
ElementCompute fp_max = ElementCompute(std::numeric_limits<ElementD>::max());
|
| 384 |
+
float scale_down_factor = div(1.0f, fp_max);
|
| 385 |
+
// Get st' = st / FP max
|
| 386 |
+
ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor);
|
| 387 |
+
|
| 388 |
+
absolute_value_op<ElementCompute> abs_op;
|
| 389 |
+
maximum_with_nan_propogation<ElementCompute> max_op;
|
| 390 |
+
|
| 391 |
+
if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) {
|
| 392 |
+
// MN major output
|
| 393 |
+
int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize);
|
| 394 |
+
// Col major output
|
| 395 |
+
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
| 396 |
+
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
|
| 397 |
+
int64_t col = n + n_b;
|
| 398 |
+
|
| 399 |
+
/// Step1: get max across a vector
|
| 400 |
+
ElementCompute accum_max = ElementCompute(0);
|
| 401 |
+
for (int v = 0; v < kVectorSize; v++) {
|
| 402 |
+
int accum_row = v_b * kVectorSize + v;
|
| 403 |
+
int64_t output_row = accum_row + m;
|
| 404 |
+
if (output_row < M && col < N) {
|
| 405 |
+
accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b]));
|
| 406 |
+
}
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
/// Step2: Compute Scale
|
| 410 |
+
ElementCompute pvscale = mul(accum_max, st_scaled_down);
|
| 411 |
+
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
|
| 412 |
+
// Store the Scaling Factors
|
| 413 |
+
int64_t sf_row = m + kVectorSize * v_b;
|
| 414 |
+
if (sf_row < M && col < N) {
|
| 415 |
+
tensor_SfD(sf_row, col, l) = qpvscale;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
/// Step3: Compute quantized output values
|
| 419 |
+
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
|
| 420 |
+
// Get float reciprocal
|
| 421 |
+
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
|
| 422 |
+
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
|
| 423 |
+
// Map INF to fp32::max
|
| 424 |
+
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
|
| 425 |
+
// Store the intermediate_accum
|
| 426 |
+
for (int v = 0; v < kVectorSize; v++) {
|
| 427 |
+
int accum_row = v_b * kVectorSize + v;
|
| 428 |
+
int64_t output_row = accum_row + m;
|
| 429 |
+
if (output_row < M && col < N) {
|
| 430 |
+
acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale);
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
else {
|
| 437 |
+
int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize);
|
| 438 |
+
// row major output
|
| 439 |
+
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
| 440 |
+
for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
|
| 441 |
+
int64_t row = m + m_b;
|
| 442 |
+
|
| 443 |
+
/// Step1: get max across a vector
|
| 444 |
+
ElementCompute accum_max = ElementCompute(0);
|
| 445 |
+
for (int v = 0; v < kVectorSize; v++) {
|
| 446 |
+
int accum_col = v_b * kVectorSize + v;
|
| 447 |
+
int64_t output_col = accum_col + n;
|
| 448 |
+
if (row < M && output_col < N) {
|
| 449 |
+
accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col]));
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
/// Step2: Compute Scale
|
| 454 |
+
ElementCompute pvscale = mul(accum_max, st_scaled_down);
|
| 455 |
+
ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
|
| 456 |
+
// Store the Scaling Factors
|
| 457 |
+
int64_t sf_col = n + kVectorSize * v_b;
|
| 458 |
+
|
| 459 |
+
if (row < M && sf_col < N) {
|
| 460 |
+
tensor_SfD(row, sf_col, l) = qpvscale;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
/// Step3: Compute quantized output values
|
| 464 |
+
ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
|
| 465 |
+
// Get float reciprocal
|
| 466 |
+
ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
|
| 467 |
+
ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
|
| 468 |
+
// Map INF to fp32::max
|
| 469 |
+
acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
|
| 470 |
+
// Store the intermediate_accum
|
| 471 |
+
for (int v = 0; v < kVectorSize; v++) {
|
| 472 |
+
int accum_col = v_b * kVectorSize + v;
|
| 473 |
+
int64_t output_col = accum_col + n;
|
| 474 |
+
if (row < M && output_col < N) {
|
| 475 |
+
acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale);
|
| 476 |
+
}
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 485 |
+
|
| 486 |
+
/// GETT - General Tensor-Tensor contraction reference kernel
|
| 487 |
+
template <
|
| 488 |
+
class MainloopParams,
|
| 489 |
+
class EpilogueParams
|
| 490 |
+
>
|
| 491 |
+
void Gett(
|
| 492 |
+
MainloopParams const& mainloop_params,
|
| 493 |
+
EpilogueParams const& epilogue_params)
|
| 494 |
+
{
|
| 495 |
+
|
| 496 |
+
static int constexpr kBlockM = 64;
|
| 497 |
+
static int constexpr kBlockN = 64;
|
| 498 |
+
|
| 499 |
+
#if defined(_OPENMP)
|
| 500 |
+
#pragma omp parallel for collapse(3)
|
| 501 |
+
#endif
|
| 502 |
+
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
|
| 503 |
+
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
|
| 504 |
+
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
|
| 505 |
+
typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
|
| 506 |
+
gett_mainloop(mainloop_params, m, n, l, acc);
|
| 507 |
+
gett_epilogue(epilogue_params, m, n, l, acc);
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 514 |
+
|
| 515 |
+
/// GETT - Mainloop
|
| 516 |
+
template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
| 517 |
+
void gett_mainloop(
|
| 518 |
+
MainloopParams const& mainloop_params,
|
| 519 |
+
int64_t m,
|
| 520 |
+
int64_t n,
|
| 521 |
+
int64_t l,
|
| 522 |
+
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
| 523 |
+
{
|
| 524 |
+
|
| 525 |
+
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
|
| 526 |
+
static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
|
| 527 |
+
|
| 528 |
+
using cute::raw_pointer_cast;
|
| 529 |
+
|
| 530 |
+
using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
|
| 531 |
+
using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
using ElementSFA = typename ElementTraits<typename MainloopParams::EngineSfA::value_type>::type;
|
| 535 |
+
using ElementSFB = typename ElementTraits<typename MainloopParams::EngineSfB::value_type>::type;
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
|
| 539 |
+
RingOp fma_op;
|
| 540 |
+
|
| 541 |
+
// Zero out accumulators
|
| 542 |
+
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
| 543 |
+
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
| 544 |
+
acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
| 545 |
+
}
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
// Compute on this k-block
|
| 549 |
+
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
|
| 550 |
+
// Load A
|
| 551 |
+
ElementAccumulator a_frag[kBlockM];
|
| 552 |
+
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
| 553 |
+
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
|
| 554 |
+
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
| 555 |
+
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
if constexpr (not cute::is_same_v<ElementSFA, ElementA>){
|
| 559 |
+
// Load SFA
|
| 560 |
+
auto sfa = static_cast<ElementAccumulator>(mainloop_params.SfA(m + m_b, k, l));
|
| 561 |
+
a_frag[m_b] *= sfa;
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
|
| 566 |
+
a_frag[m_b] = conj(a_frag[m_b]);
|
| 567 |
+
}
|
| 568 |
+
} else {
|
| 569 |
+
a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
| 570 |
+
}
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
// Load B
|
| 574 |
+
ElementAccumulator b_frag[kBlockN];
|
| 575 |
+
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
| 576 |
+
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
|
| 577 |
+
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
| 578 |
+
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
if constexpr (not cute::is_same_v<ElementSFB, ElementB>){
|
| 582 |
+
// Load SFB
|
| 583 |
+
auto sfb = static_cast<ElementAccumulator>(mainloop_params.SfB(n + n_b, k, l));
|
| 584 |
+
b_frag[n_b] *= sfb;
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
|
| 589 |
+
b_frag[n_b] = conj(b_frag[n_b]);
|
| 590 |
+
}
|
| 591 |
+
} else {
|
| 592 |
+
b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
|
| 593 |
+
}
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
// do compute
|
| 597 |
+
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
| 598 |
+
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
| 599 |
+
acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]);
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
}
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 607 |
+
|
| 608 |
+
/// GETT - Epilogue
|
| 609 |
+
template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
|
| 610 |
+
void gett_epilogue(
|
| 611 |
+
EpilogueParams const& epilogue_params,
|
| 612 |
+
int64_t m,
|
| 613 |
+
int64_t n,
|
| 614 |
+
int64_t l,
|
| 615 |
+
ElementAccumulator (&acc)[kBlockM][kBlockN])
|
| 616 |
+
{
|
| 617 |
+
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
|
| 618 |
+
static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
|
| 619 |
+
|
| 620 |
+
using cute::raw_pointer_cast;
|
| 621 |
+
|
| 622 |
+
using ElementCompute = typename EpilogueParams::ElementCompute;
|
| 623 |
+
using ElementC = typename EpilogueParams::TensorC::value_type;
|
| 624 |
+
using ElementD = typename EpilogueParams::TensorD::value_type;
|
| 625 |
+
using ElementSfD = typename EpilogueParams::TensorSFD::value_type;
|
| 626 |
+
using ElementAux = typename EpilogueParams::TensorAux::value_type;
|
| 627 |
+
using ElementBias = typename EpilogueParams::VectorBias::value_type;
|
| 628 |
+
using ElementScalar = typename EpilogueParams::ElementScalar;
|
| 629 |
+
using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
|
| 630 |
+
using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
|
| 631 |
+
using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
|
| 632 |
+
|
| 633 |
+
constexpr bool PerColBias = EpilogueParams::PerColumnBias;
|
| 634 |
+
constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy;
|
| 635 |
+
|
| 636 |
+
constexpr bool IsScalingAndAmaxOutputNeeded =
|
| 637 |
+
cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
|
| 638 |
+
cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
|
| 639 |
+
|
| 640 |
+
constexpr bool IsScalingAndAmaxAuxOutputNeeded =
|
| 641 |
+
cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
|
| 642 |
+
cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
|
| 643 |
+
|
| 644 |
+
constexpr bool IsReLUAuxNeeded =
|
| 645 |
+
(cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
|
| 646 |
+
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
|
| 647 |
+
cute::is_same_v<ElementAux, cutlass::uint1b_t>;
|
| 648 |
+
constexpr bool UseReLU =
|
| 649 |
+
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>; // Treat Clamp as ReLU
|
| 650 |
+
|
| 651 |
+
constexpr bool IsBackpropFusion =
|
| 652 |
+
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
|
| 653 |
+
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
|
| 654 |
+
|
| 655 |
+
// Input related converter
|
| 656 |
+
NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
|
| 657 |
+
NumericConverter<ElementCompute, ElementC> source_converter;
|
| 658 |
+
NumericConverter<ElementCompute, ElementBias> bias_converter;
|
| 659 |
+
[[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
|
| 660 |
+
|
| 661 |
+
// Scale related converter
|
| 662 |
+
NumericConverter<ElementCompute, ElementScalar> scale_converter;
|
| 663 |
+
NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
|
| 664 |
+
|
| 665 |
+
// Abs max converter
|
| 666 |
+
[[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
|
| 667 |
+
|
| 668 |
+
// Output related converter
|
| 669 |
+
NumericConverter<ElementD, ElementCompute> destination_converter;
|
| 670 |
+
[[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
|
| 671 |
+
NumericConverter<ElementBias, ElementCompute> dBias_converter;
|
| 672 |
+
|
| 673 |
+
// Epilogue operations
|
| 674 |
+
multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
|
| 675 |
+
multiplies<ElementCompute> mul;
|
| 676 |
+
plus<ElementCompute> add;
|
| 677 |
+
|
| 678 |
+
// Activation operation
|
| 679 |
+
ActivationFunctor activation;
|
| 680 |
+
|
| 681 |
+
// Bias binary operation
|
| 682 |
+
BiasBinaryOp bias_op;
|
| 683 |
+
|
| 684 |
+
// Do conversion
|
| 685 |
+
ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
|
| 686 |
+
ElementCompute converted_beta = scale_converter(epilogue_params.beta);
|
| 687 |
+
ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
|
| 688 |
+
ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
|
| 689 |
+
ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
|
| 690 |
+
ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
|
| 691 |
+
ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
|
| 692 |
+
|
| 693 |
+
// Init local var
|
| 694 |
+
[[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
|
| 695 |
+
[[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
|
| 696 |
+
|
| 697 |
+
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
| 698 |
+
converted_beta = mul(converted_beta, converted_scale_c);
|
| 699 |
+
|
| 700 |
+
ElementCompute inter_accum[kBlockM][kBlockN];
|
| 701 |
+
|
| 702 |
+
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
| 703 |
+
ElementCompute local_dBias = ElementCompute(0);
|
| 704 |
+
|
| 705 |
+
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
| 706 |
+
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
| 707 |
+
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
|
| 708 |
+
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
|
| 709 |
+
// vector alpha
|
| 710 |
+
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
|
| 711 |
+
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l));
|
| 712 |
+
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
| 713 |
+
}
|
| 714 |
+
ElementCompute output = mul(converted_alpha, converted_acc);
|
| 715 |
+
|
| 716 |
+
if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
|
| 717 |
+
ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
|
| 718 |
+
output = bias_op(output, converted_bias);
|
| 719 |
+
}
|
| 720 |
+
|
| 721 |
+
if (raw_pointer_cast(epilogue_params.C.data())) {
|
| 722 |
+
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
|
| 723 |
+
// vector beta
|
| 724 |
+
if (epilogue_params.Vbeta.data()) {
|
| 725 |
+
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l));
|
| 726 |
+
converted_beta = mul(converted_beta, converted_scale_c);
|
| 727 |
+
}
|
| 728 |
+
output = epilogue_fma(converted_beta, converted_src, output);
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
if constexpr (IsBackpropFusion) {
|
| 732 |
+
ElementAux aux_input = ElementAux(0);
|
| 733 |
+
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
| 734 |
+
aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
output = activation(output, aux_source_converter(aux_input));
|
| 738 |
+
local_dBias = add(local_dBias, output);
|
| 739 |
+
}
|
| 740 |
+
else {
|
| 741 |
+
if (raw_pointer_cast(epilogue_params.Aux.data())) {
|
| 742 |
+
auto aux_output = output;
|
| 743 |
+
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
| 744 |
+
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
| 745 |
+
local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
|
| 746 |
+
aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
if constexpr (IsReLUAuxNeeded) {
|
| 750 |
+
epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
|
| 751 |
+
} else {
|
| 752 |
+
epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
|
| 753 |
+
}
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
if constexpr (UseReLU) {
|
| 757 |
+
cutlass::epilogue::thread::ReLU<ElementCompute> relu;
|
| 758 |
+
output = relu(output);
|
| 759 |
+
}
|
| 760 |
+
else {
|
| 761 |
+
output = activation(output);
|
| 762 |
+
}
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
| 766 |
+
maximum_absolute_value_reduction<ElementCompute, true> amax_op;
|
| 767 |
+
local_abs_max_output = amax_op(local_abs_max_output, output);
|
| 768 |
+
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
inter_accum[m_b][n_b] = ElementCompute(output);
|
| 772 |
+
}
|
| 773 |
+
} // n_b
|
| 774 |
+
|
| 775 |
+
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
|
| 776 |
+
if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
|
| 777 |
+
ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
|
| 778 |
+
local_dBias = add(local_dBias, converted_dBias);
|
| 779 |
+
epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
|
| 780 |
+
}
|
| 781 |
+
}
|
| 782 |
+
} // m_b
|
| 783 |
+
|
| 784 |
+
if constexpr (
|
| 785 |
+
SfGenStrategy == SfStrategy::SfDGen
|
| 786 |
+
) {
|
| 787 |
+
// 1d scale factor generation
|
| 788 |
+
constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{};
|
| 789 |
+
if (epilogue_params.SfD.data() != nullptr) {
|
| 790 |
+
compute_1d_scaling_factor_and_quantized_output<kVectorSize>(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum);
|
| 791 |
+
}
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
| 795 |
+
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
| 796 |
+
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
| 797 |
+
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
|
| 798 |
+
}
|
| 799 |
+
}
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
#if defined(_OPENMP)
|
| 803 |
+
#pragma omp critical(Abs_Max_Data_Update)
|
| 804 |
+
#endif
|
| 805 |
+
{
|
| 806 |
+
if constexpr (IsScalingAndAmaxOutputNeeded) {
|
| 807 |
+
if (epilogue_params.abs_max_D) {
|
| 808 |
+
*epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
|
| 809 |
+
*epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
|
| 810 |
+
}
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
|
| 814 |
+
if (epilogue_params.abs_max_Aux) {
|
| 815 |
+
*epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
|
| 816 |
+
*epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
|
| 817 |
+
}
|
| 818 |
+
}
|
| 819 |
+
}
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 823 |
+
|
| 824 |
+
template <class TensorType>
|
| 825 |
+
auto make_layout_rank3(const TensorType& tensor) {
|
| 826 |
+
// append a batch mode of size 1 if we do not have tensors that are rank 3
|
| 827 |
+
return make_layout(
|
| 828 |
+
make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}),
|
| 829 |
+
make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout()))));
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
/// GEMM - General Matrix-Matrix contraction without conjugation options
|
| 833 |
+
template <
|
| 834 |
+
class MainloopParams,
|
| 835 |
+
class EpilogueParams
|
| 836 |
+
>
|
| 837 |
+
void Gemm3x(
|
| 838 |
+
MainloopParams const& mainloop_params,
|
| 839 |
+
EpilogueParams const& epilogue_params)
|
| 840 |
+
{
|
| 841 |
+
using namespace cute;
|
| 842 |
+
|
| 843 |
+
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
|
| 844 |
+
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
|
| 845 |
+
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
| 846 |
+
|
| 847 |
+
if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) {
|
| 848 |
+
cute::Layout layout_A = make_layout_rank3(mainloop_params.A);
|
| 849 |
+
cute::Layout layout_B = make_layout_rank3(mainloop_params.B);
|
| 850 |
+
cute::Layout layout_C = make_layout_rank3(epilogue_params.C);
|
| 851 |
+
cute::Layout layout_D = make_layout_rank3(epilogue_params.D);
|
| 852 |
+
cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux);
|
| 853 |
+
cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias);
|
| 854 |
+
cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha);
|
| 855 |
+
cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta);
|
| 856 |
+
|
| 857 |
+
auto TensorA = make_tensor(mainloop_params.A.data(), layout_A);
|
| 858 |
+
auto TensorB = make_tensor(mainloop_params.B.data(), layout_B);
|
| 859 |
+
auto TensorC = make_tensor(epilogue_params.C.data(), layout_C);
|
| 860 |
+
auto TensorD = make_tensor(epilogue_params.D.data(), layout_D);
|
| 861 |
+
auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux);
|
| 862 |
+
auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias);
|
| 863 |
+
auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha);
|
| 864 |
+
auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta);
|
| 865 |
+
|
| 866 |
+
// Reconstruct mainloop params
|
| 867 |
+
GettMainloopParams<typename MainloopParams::ElementAccumulator,
|
| 868 |
+
decltype(TensorA),
|
| 869 |
+
decltype(TensorB)>
|
| 870 |
+
mainloop_params_converted{TensorA,
|
| 871 |
+
TensorB,
|
| 872 |
+
mainloop_params.transform_A,
|
| 873 |
+
mainloop_params.transform_B};
|
| 874 |
+
|
| 875 |
+
// Reconstruct epilogue params
|
| 876 |
+
GettEpilogueParams<typename EpilogueParams::ElementScalar,
|
| 877 |
+
typename EpilogueParams::ElementScalingFactor,
|
| 878 |
+
typename EpilogueParams::ElementAccumulator,
|
| 879 |
+
typename EpilogueParams::ElementCompute,
|
| 880 |
+
decltype(TensorC),
|
| 881 |
+
decltype(TensorD),
|
| 882 |
+
decltype(VectorBias),
|
| 883 |
+
decltype(TensorAux),
|
| 884 |
+
decltype(VectorAlpha),
|
| 885 |
+
decltype(VectorBeta)
|
| 886 |
+
>
|
| 887 |
+
epilogue_params_converted{epilogue_params.alpha,
|
| 888 |
+
epilogue_params.beta,
|
| 889 |
+
TensorC,
|
| 890 |
+
TensorD,
|
| 891 |
+
VectorBias,
|
| 892 |
+
TensorAux,
|
| 893 |
+
VectorAlpha,
|
| 894 |
+
VectorBeta,
|
| 895 |
+
epilogue_params.abs_amax_D,
|
| 896 |
+
epilogue_params.abs_amax_Aux,
|
| 897 |
+
epilogue_params.scale_a,
|
| 898 |
+
epilogue_params.scale_b,
|
| 899 |
+
epilogue_params.scale_c,
|
| 900 |
+
epilogue_params.scale_d,
|
| 901 |
+
epilogue_params.scale_aux
|
| 902 |
+
};
|
| 903 |
+
|
| 904 |
+
Gett(mainloop_params_converted, epilogue_params_converted);
|
| 905 |
+
}
|
| 906 |
+
else {
|
| 907 |
+
// if we already have a batch mode, just pass it through
|
| 908 |
+
Gett(mainloop_params, epilogue_params);
|
| 909 |
+
}
|
| 910 |
+
}
|
| 911 |
+
|
| 912 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 913 |
+
|
| 914 |
+
} // cutlass::reference::host
|
| 915 |
+
|
| 916 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for Rank 2k update in host-side code.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/blas3.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/tensor_view.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
#include "cutlass/arch/mma.h"
|
| 45 |
+
#include "cutlass/util/host_tensor.h"
|
| 46 |
+
#include "cutlass/util/reference/host/gemm.h"
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reference {
|
| 50 |
+
namespace host {
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 55 |
+
/// objects.
|
| 56 |
+
template <
|
| 57 |
+
typename ElementA,
|
| 58 |
+
typename LayoutA,
|
| 59 |
+
typename ElementB,
|
| 60 |
+
typename LayoutB,
|
| 61 |
+
typename ElementC,
|
| 62 |
+
typename LayoutC,
|
| 63 |
+
FillMode FillModeC,
|
| 64 |
+
typename ScalarType,
|
| 65 |
+
typename ComputeType,
|
| 66 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 67 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 68 |
+
>
|
| 69 |
+
void compute_rank2k(
|
| 70 |
+
gemm::GemmCoord problem_size,
|
| 71 |
+
ScalarType alpha,
|
| 72 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 73 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 74 |
+
ScalarType beta,
|
| 75 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 76 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 77 |
+
ComputeType initial_accum) {
|
| 78 |
+
|
| 79 |
+
static_assert(
|
| 80 |
+
LayoutA::kRank == 2 &&
|
| 81 |
+
LayoutB::kRank == 2 &&
|
| 82 |
+
LayoutC::kRank == 2,
|
| 83 |
+
"Tensors must be of rank 2");
|
| 84 |
+
|
| 85 |
+
static_assert(
|
| 86 |
+
FillModeC == FillMode::kLower ||
|
| 87 |
+
FillModeC == FillMode::kUpper,
|
| 88 |
+
"Fill Mode can either be Lower or Upper.");
|
| 89 |
+
|
| 90 |
+
using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower),
|
| 91 |
+
std::greater_equal<int>,
|
| 92 |
+
std::less_equal<int>>::type;
|
| 93 |
+
|
| 94 |
+
// Note: batch is ignored.
|
| 95 |
+
// Note: M is same as N for Rank 2k update
|
| 96 |
+
int const N = problem_size.n();
|
| 97 |
+
int const K = problem_size.k();
|
| 98 |
+
|
| 99 |
+
// Blocking necessary to speedup reference implementation
|
| 100 |
+
int const Nblock = 16;
|
| 101 |
+
|
| 102 |
+
ConvertOp convert_op;
|
| 103 |
+
InnerProductOp inner_product_op;
|
| 104 |
+
CompareOp compare_op;
|
| 105 |
+
|
| 106 |
+
for (int row_block = 0; row_block < N; row_block += Nblock) {
|
| 107 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 108 |
+
|
| 109 |
+
ComputeType accum[Nblock][Nblock];
|
| 110 |
+
|
| 111 |
+
for (int j = 0; j < Nblock; j++) {
|
| 112 |
+
for (int i = 0; i < Nblock; i++) {
|
| 113 |
+
accum[i][j] = initial_accum;
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 118 |
+
for (int j = 0; j < Nblock; j++) {
|
| 119 |
+
for (int i = 0; i < Nblock; i++) {
|
| 120 |
+
int row = row_block + i;
|
| 121 |
+
int col = col_block + j;
|
| 122 |
+
|
| 123 |
+
if (row < N && col < N && compare_op(row, col))
|
| 124 |
+
{
|
| 125 |
+
|
| 126 |
+
// A x B^T
|
| 127 |
+
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 128 |
+
ElementB b_t = tensor_b.at(MatrixCoord(col, k_block));
|
| 129 |
+
|
| 130 |
+
ComputeType compute_a(cast_if_scalar<ComputeType>(a));
|
| 131 |
+
ComputeType compute_b_t(cast_if_scalar<ComputeType>(b_t));
|
| 132 |
+
|
| 133 |
+
accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]);
|
| 134 |
+
|
| 135 |
+
// B x A^T
|
| 136 |
+
ElementB b = tensor_b.at(MatrixCoord(row, k_block));
|
| 137 |
+
ElementA a_t = tensor_a.at(MatrixCoord(col, k_block));
|
| 138 |
+
|
| 139 |
+
ComputeType compute_b(cast_if_scalar<ComputeType>(b));
|
| 140 |
+
ComputeType compute_a_t(cast_if_scalar<ComputeType>(a_t));
|
| 141 |
+
|
| 142 |
+
accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]);
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
for (int j = 0; j < Nblock; j++) {
|
| 149 |
+
for (int i = 0; i < Nblock; i++) {
|
| 150 |
+
int row = row_block + i;
|
| 151 |
+
int col = col_block + j;
|
| 152 |
+
|
| 153 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 154 |
+
|
| 155 |
+
if (row < N && col < N &&
|
| 156 |
+
( (FillModeC == FillMode::kLower && row >= col) ||
|
| 157 |
+
(FillModeC == FillMode::kUpper && row <= col) )
|
| 158 |
+
) {
|
| 159 |
+
tensor_d.at(coord) = convert_op(
|
| 160 |
+
alpha * ScalarType(accum[i][j]) +
|
| 161 |
+
beta * ScalarType(tensor_c.at(coord)));
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 170 |
+
|
| 171 |
+
/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef
|
| 172 |
+
/// objects.
|
| 173 |
+
template <
|
| 174 |
+
typename ElementA,
|
| 175 |
+
typename LayoutA,
|
| 176 |
+
typename ElementB,
|
| 177 |
+
typename LayoutB,
|
| 178 |
+
typename ElementC,
|
| 179 |
+
typename LayoutC,
|
| 180 |
+
FillMode FillModeC,
|
| 181 |
+
typename ScalarType,
|
| 182 |
+
typename ComputeType,
|
| 183 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 184 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 185 |
+
>
|
| 186 |
+
void compute_rank2k(
|
| 187 |
+
gemm::GemmCoord problem_size,
|
| 188 |
+
ScalarType alpha,
|
| 189 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 190 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 191 |
+
ScalarType beta,
|
| 192 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 193 |
+
ComputeType initial_accum) {
|
| 194 |
+
compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
|
| 195 |
+
ScalarType, ComputeType, InnerProductOp, ConvertOp>(
|
| 196 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
| 197 |
+
initial_accum);
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 201 |
+
|
| 202 |
+
template <
|
| 203 |
+
typename ElementA,
|
| 204 |
+
typename LayoutA,
|
| 205 |
+
typename ElementB,
|
| 206 |
+
typename LayoutB,
|
| 207 |
+
typename ElementC,
|
| 208 |
+
typename LayoutC,
|
| 209 |
+
FillMode FillModeC,
|
| 210 |
+
typename ScalarType,
|
| 211 |
+
typename ComputeType,
|
| 212 |
+
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
| 213 |
+
>
|
| 214 |
+
struct Rank2K;
|
| 215 |
+
|
| 216 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 217 |
+
|
| 218 |
+
/// Partial specialization for multiply-add
|
| 219 |
+
template <typename ElementA, typename LayoutA,
|
| 220 |
+
typename ElementB, typename LayoutB,
|
| 221 |
+
typename ElementC, typename LayoutC, FillMode FillModeC,
|
| 222 |
+
typename ScalarType, typename ComputeType>
|
| 223 |
+
struct Rank2K<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC, ScalarType,
|
| 224 |
+
ComputeType, arch::OpMultiplyAdd> {
|
| 225 |
+
|
| 226 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 227 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 228 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 229 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 230 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 231 |
+
static_assert(
|
| 232 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 233 |
+
"Tensors must be of rank 2");
|
| 234 |
+
|
| 235 |
+
compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
|
| 236 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 237 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 241 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 242 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 243 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 244 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 245 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 246 |
+
static_assert(
|
| 247 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 248 |
+
"Tensors must be of rank 2");
|
| 249 |
+
|
| 250 |
+
compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
|
| 251 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 252 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 253 |
+
}
|
| 254 |
+
};
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 258 |
+
|
| 259 |
+
} // namespace host
|
| 260 |
+
} // namespace reference
|
| 261 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued Rank 2K update in host-side code.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/blas3.h"
|
| 40 |
+
#include "cutlass/complex.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/tensor_view.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
#include <cassert>
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace reference {
|
| 48 |
+
namespace host {
|
| 49 |
+
|
| 50 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 53 |
+
/// objects.
|
| 54 |
+
///
|
| 55 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 56 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 57 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 58 |
+
/// arguments explicitly.
|
| 59 |
+
template <
|
| 60 |
+
typename ElementA,
|
| 61 |
+
typename LayoutA,
|
| 62 |
+
typename ElementB,
|
| 63 |
+
typename LayoutB,
|
| 64 |
+
typename ElementC,
|
| 65 |
+
typename LayoutC,
|
| 66 |
+
typename ScalarType,
|
| 67 |
+
typename ComputeType,
|
| 68 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 69 |
+
typename InnerProductOp = multiply_add<ComputeType>
|
| 70 |
+
>
|
| 71 |
+
void Rank2KComplex(
|
| 72 |
+
gemm::GemmCoord problem_size,
|
| 73 |
+
ScalarType alpha,
|
| 74 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 75 |
+
ComplexTransform transform_a,
|
| 76 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 77 |
+
ComplexTransform transform_b,
|
| 78 |
+
ScalarType beta,
|
| 79 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 80 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 81 |
+
ComputeType initial_accum,
|
| 82 |
+
FillMode fill_mode_c,
|
| 83 |
+
BlasMode blas_mode,
|
| 84 |
+
int batch_count = 1,
|
| 85 |
+
int64_t batch_stride_A = 0,
|
| 86 |
+
int64_t batch_stride_B = 0,
|
| 87 |
+
int64_t batch_stride_C = 0,
|
| 88 |
+
int64_t batch_stride_D = 0) {
|
| 89 |
+
|
| 90 |
+
static_assert(
|
| 91 |
+
LayoutA::kRank == 2 &&
|
| 92 |
+
LayoutB::kRank == 2 &&
|
| 93 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 94 |
+
|
| 95 |
+
// Note: batch is ignored.
|
| 96 |
+
int const M = problem_size.m();
|
| 97 |
+
int const N = problem_size.n();
|
| 98 |
+
int const K = problem_size.k();
|
| 99 |
+
|
| 100 |
+
// Rank2K update operates on A=NxK, B=NxK, and C=NxN
|
| 101 |
+
assert(M==N);
|
| 102 |
+
|
| 103 |
+
// Blocking necessary to speedup reference implementation
|
| 104 |
+
int const Mblock = 16;
|
| 105 |
+
int const Nblock = 16;
|
| 106 |
+
|
| 107 |
+
ConvertOp convert_op;
|
| 108 |
+
InnerProductOp inner_product_op;
|
| 109 |
+
|
| 110 |
+
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
| 111 |
+
|
| 112 |
+
// Compute matrix product using blocks
|
| 113 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 114 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 115 |
+
|
| 116 |
+
ComputeType accum[Mblock][Nblock];
|
| 117 |
+
|
| 118 |
+
for (int j = 0; j < Nblock; j++) {
|
| 119 |
+
for (int i = 0; i < Mblock; i++) {
|
| 120 |
+
accum[i][j] = initial_accum;
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 125 |
+
for (int j = 0; j < Nblock; j++) {
|
| 126 |
+
for (int i = 0; i < Mblock; i++) {
|
| 127 |
+
int row = row_block + i;
|
| 128 |
+
int col = col_block + j;
|
| 129 |
+
|
| 130 |
+
if (row < M && col < N &&
|
| 131 |
+
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
| 132 |
+
(fill_mode_c == FillMode::kUpper && row <= col) )
|
| 133 |
+
) {
|
| 134 |
+
|
| 135 |
+
// A x B^T (Symmetric) or A x B^H (Hermitian)
|
| 136 |
+
// complex conjugation on operandB (b_t) is function of blas3 computation
|
| 137 |
+
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 138 |
+
ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
|
| 139 |
+
conj(tensor_b.at(MatrixCoord(col, k_block))) :
|
| 140 |
+
tensor_b.at(MatrixCoord(col, k_block));
|
| 141 |
+
|
| 142 |
+
ComputeType a_ik = ComputeType(a);
|
| 143 |
+
ComputeType b_jk = ComputeType(b_t);
|
| 144 |
+
|
| 145 |
+
// complex conjugation is a function of operand layouts
|
| 146 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 147 |
+
a_ik = conj(a_ik);
|
| 148 |
+
}
|
| 149 |
+
// complex conjugation is a function of operand layouts
|
| 150 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 151 |
+
b_jk = conj(b_jk);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/* HER2K need two epilogues to handle complex alpha value */
|
| 161 |
+
if ( blas_mode == BlasMode::kHermitian ) {
|
| 162 |
+
for (int j = 0; j < Nblock; j++) {
|
| 163 |
+
for (int i = 0; i < Mblock; i++) {
|
| 164 |
+
int row = row_block + i;
|
| 165 |
+
int col = col_block + j;
|
| 166 |
+
|
| 167 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 168 |
+
|
| 169 |
+
if (row < M && col < N &&
|
| 170 |
+
((fill_mode_c == FillMode::kLower && row >= col) ||
|
| 171 |
+
(fill_mode_c == FillMode::kUpper && row <= col))
|
| 172 |
+
) {
|
| 173 |
+
|
| 174 |
+
ScalarType c = tensor_c.at(coord);
|
| 175 |
+
// The imaginary parts of the diagonal elements of
|
| 176 |
+
// a complex data type are assumed and set to zero
|
| 177 |
+
if (blas_mode == BlasMode::kHermitian) {
|
| 178 |
+
c = (row == col) ? real(c) : c;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
tensor_d.at(coord) = convert_op(alpha *
|
| 182 |
+
ScalarType(accum[i][j]) +
|
| 183 |
+
beta * c);
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/* Zeoring out accum for second HERK */
|
| 189 |
+
for (int j = 0; j < Nblock; j++) {
|
| 190 |
+
for (int i = 0; i < Mblock; i++) {
|
| 191 |
+
accum[i][j] = initial_accum;
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 197 |
+
for (int j = 0; j < Nblock; j++) {
|
| 198 |
+
for (int i = 0; i < Mblock; i++) {
|
| 199 |
+
int row = row_block + i;
|
| 200 |
+
int col = col_block + j;
|
| 201 |
+
|
| 202 |
+
if (row < M && col < N &&
|
| 203 |
+
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
| 204 |
+
(fill_mode_c == FillMode::kUpper && row <= col) )
|
| 205 |
+
) {
|
| 206 |
+
|
| 207 |
+
// B x A^T (Symmetric) or B x A^H (Hermitian)
|
| 208 |
+
// complex conjugation on operandB (a_t) is function of blas3 computation
|
| 209 |
+
ElementB b = tensor_b.at(MatrixCoord(row, k_block));
|
| 210 |
+
ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
|
| 211 |
+
conj(tensor_a.at(MatrixCoord(col, k_block))):
|
| 212 |
+
tensor_a.at(MatrixCoord(col, k_block));
|
| 213 |
+
|
| 214 |
+
ComputeType b_ik = ComputeType(b);
|
| 215 |
+
ComputeType a_jk = ComputeType(a_t);
|
| 216 |
+
|
| 217 |
+
// complex conjugation here is a function of operand layouts
|
| 218 |
+
if (transform_b == ComplexTransform::kConjugate) {
|
| 219 |
+
b_ik = conj(b_ik);
|
| 220 |
+
}
|
| 221 |
+
// complex conjugation here is a function of operand layouts
|
| 222 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 223 |
+
a_jk = conj(a_jk);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]);
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ?
|
| 233 |
+
conj(alpha) : alpha;
|
| 234 |
+
ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ?
|
| 235 |
+
1 : beta;
|
| 236 |
+
|
| 237 |
+
for (int j = 0; j < Nblock; j++) {
|
| 238 |
+
for (int i = 0; i < Mblock; i++) {
|
| 239 |
+
int row = row_block + i;
|
| 240 |
+
int col = col_block + j;
|
| 241 |
+
|
| 242 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 243 |
+
|
| 244 |
+
if (row < M && col < N &&
|
| 245 |
+
((fill_mode_c == FillMode::kLower && row >= col) ||
|
| 246 |
+
(fill_mode_c == FillMode::kUpper && row <= col))
|
| 247 |
+
) {
|
| 248 |
+
|
| 249 |
+
ScalarType d = (blas_mode == BlasMode::kHermitian) ?
|
| 250 |
+
tensor_d.at(coord) : tensor_c.at(coord);
|
| 251 |
+
|
| 252 |
+
ScalarType tmp_d = convert_op(
|
| 253 |
+
alpha_hermitian * ScalarType(accum[i][j]) +
|
| 254 |
+
beta_hermitian * d);
|
| 255 |
+
|
| 256 |
+
if (blas_mode == BlasMode::kHermitian && row == col ) {
|
| 257 |
+
tensor_d.at(coord) = real(tmp_d);
|
| 258 |
+
} else {
|
| 259 |
+
tensor_d.at(coord) = tmp_d;
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
} // for (col_block)
|
| 266 |
+
} // for (row_block)
|
| 267 |
+
|
| 268 |
+
tensor_a.add_pointer_offset(batch_stride_A);
|
| 269 |
+
tensor_b.add_pointer_offset(batch_stride_B);
|
| 270 |
+
tensor_c.add_pointer_offset(batch_stride_C);
|
| 271 |
+
tensor_d.add_pointer_offset(batch_stride_D);
|
| 272 |
+
|
| 273 |
+
} // for (batch_idx)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 277 |
+
|
| 278 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 279 |
+
/// objects.
|
| 280 |
+
///
|
| 281 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 282 |
+
template <
|
| 283 |
+
typename ElementA,
|
| 284 |
+
typename LayoutA,
|
| 285 |
+
typename ElementB,
|
| 286 |
+
typename LayoutB,
|
| 287 |
+
typename ElementC,
|
| 288 |
+
typename LayoutC,
|
| 289 |
+
typename ScalarType
|
| 290 |
+
>
|
| 291 |
+
void Rank2KComplex(
|
| 292 |
+
gemm::GemmCoord problem_size,
|
| 293 |
+
ScalarType alpha,
|
| 294 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 295 |
+
ComplexTransform transform_a,
|
| 296 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 297 |
+
ComplexTransform transform_b,
|
| 298 |
+
ScalarType beta,
|
| 299 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 300 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 301 |
+
FillMode fill_mode_c,
|
| 302 |
+
BlasMode blas_mode) {
|
| 303 |
+
|
| 304 |
+
Rank2KComplex(
|
| 305 |
+
problem_size, alpha,
|
| 306 |
+
tensor_a, transform_a,
|
| 307 |
+
tensor_b, transform_b,
|
| 308 |
+
beta, tensor_c, tensor_d,
|
| 309 |
+
ScalarType(0),
|
| 310 |
+
fill_mode_c,
|
| 311 |
+
blas_mode);
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 315 |
+
|
| 316 |
+
} // namespace host
|
| 317 |
+
} // namespace reference
|
| 318 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued Rank 2K update in host-side code.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/blas3.h"
|
| 40 |
+
#include "cutlass/complex.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/tensor_view.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
#include <cassert>
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace reference {
|
| 48 |
+
namespace host {
|
| 49 |
+
|
| 50 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 53 |
+
/// objects.
|
| 54 |
+
///
|
| 55 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 56 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 57 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 58 |
+
/// arguments explicitly.
|
| 59 |
+
template <
|
| 60 |
+
typename ElementA,
|
| 61 |
+
typename LayoutA,
|
| 62 |
+
typename ElementC,
|
| 63 |
+
typename LayoutC,
|
| 64 |
+
typename ScalarType,
|
| 65 |
+
typename ComputeType,
|
| 66 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>,
|
| 67 |
+
typename InnerProductOp = multiply_add<ComputeType>
|
| 68 |
+
>
|
| 69 |
+
void Rank2KComplex(
|
| 70 |
+
gemm::GemmCoord problem_size,
|
| 71 |
+
ScalarType alpha,
|
| 72 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 73 |
+
ComplexTransform transform_a,
|
| 74 |
+
ScalarType beta,
|
| 75 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 76 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 77 |
+
ComputeType initial_accum,
|
| 78 |
+
FillMode fill_mode_c,
|
| 79 |
+
BlasMode blas_mode,
|
| 80 |
+
int batch_count = 1,
|
| 81 |
+
int64_t batch_stride_A = 0,
|
| 82 |
+
int64_t batch_stride_C = 0,
|
| 83 |
+
int64_t batch_stride_D = 0) {
|
| 84 |
+
|
| 85 |
+
static_assert(
|
| 86 |
+
LayoutA::kRank == 2 &&
|
| 87 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 88 |
+
|
| 89 |
+
// Note: batch is ignored.
|
| 90 |
+
int const M = problem_size.m();
|
| 91 |
+
int const N = problem_size.n();
|
| 92 |
+
int const K = problem_size.k();
|
| 93 |
+
|
| 94 |
+
// Rank2K update operates on A=NxK, B=NxK, and C=NxN
|
| 95 |
+
assert(M==N);
|
| 96 |
+
|
| 97 |
+
// Blocking necessary to speedup reference implementation
|
| 98 |
+
int const Mblock = 16;
|
| 99 |
+
int const Nblock = 16;
|
| 100 |
+
|
| 101 |
+
ConvertOp convert_op;
|
| 102 |
+
InnerProductOp inner_product_op;
|
| 103 |
+
|
| 104 |
+
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
| 105 |
+
|
| 106 |
+
// Compute matrix product using blocks
|
| 107 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 108 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 109 |
+
|
| 110 |
+
ComputeType accum[Mblock][Nblock];
|
| 111 |
+
|
| 112 |
+
for (int j = 0; j < Nblock; j++) {
|
| 113 |
+
for (int i = 0; i < Mblock; i++) {
|
| 114 |
+
accum[i][j] = initial_accum;
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 119 |
+
for (int j = 0; j < Nblock; j++) {
|
| 120 |
+
for (int i = 0; i < Mblock; i++) {
|
| 121 |
+
int row = row_block + i;
|
| 122 |
+
int col = col_block + j;
|
| 123 |
+
|
| 124 |
+
if (row < M && col < N &&
|
| 125 |
+
( (fill_mode_c == FillMode::kLower && row >= col) ||
|
| 126 |
+
(fill_mode_c == FillMode::kUpper && row <= col) )
|
| 127 |
+
) {
|
| 128 |
+
|
| 129 |
+
// A x A^T (Symmetric) or A x A^H (Hermitian)
|
| 130 |
+
// complex conjugation on operandB (a_t) (function of blas3 computation)
|
| 131 |
+
ElementA a = tensor_a.at(MatrixCoord(row, k_block));
|
| 132 |
+
ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
|
| 133 |
+
conj(tensor_a.at(MatrixCoord(col, k_block))) :
|
| 134 |
+
tensor_a.at(MatrixCoord(col, k_block));
|
| 135 |
+
|
| 136 |
+
ComputeType a_ik = ComputeType(a);
|
| 137 |
+
ComputeType b_jk = ComputeType(a_t);
|
| 138 |
+
|
| 139 |
+
// complex conjugation (function of input layouts)
|
| 140 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 141 |
+
a_ik = conj(a_ik);
|
| 142 |
+
}
|
| 143 |
+
// complex conjugation (function of input layouts)
|
| 144 |
+
if (transform_a == ComplexTransform::kConjugate) {
|
| 145 |
+
b_jk = conj(b_jk);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
|
| 149 |
+
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
for (int j = 0; j < Nblock; j++) {
|
| 156 |
+
for (int i = 0; i < Mblock; i++) {
|
| 157 |
+
int row = row_block + i;
|
| 158 |
+
int col = col_block + j;
|
| 159 |
+
|
| 160 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 161 |
+
|
| 162 |
+
if (row < M && col < N &&
|
| 163 |
+
((fill_mode_c == FillMode::kLower && row >= col) ||
|
| 164 |
+
(fill_mode_c == FillMode::kUpper && row <= col))
|
| 165 |
+
) {
|
| 166 |
+
|
| 167 |
+
ScalarType c = tensor_c.at(coord);
|
| 168 |
+
// The imaginary parts of the diagonal elements of
|
| 169 |
+
// a complex data type are assumed and set to zero
|
| 170 |
+
if (blas_mode == BlasMode::kHermitian) {
|
| 171 |
+
c = (row == col) ? real(c) : c;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
ScalarType tmp_d = convert_op(
|
| 175 |
+
alpha * ScalarType(accum[i][j]) +
|
| 176 |
+
beta * c);
|
| 177 |
+
|
| 178 |
+
if (blas_mode == BlasMode::kHermitian && row == col ) {
|
| 179 |
+
tensor_d.at(coord) = real(tmp_d);
|
| 180 |
+
} else {
|
| 181 |
+
tensor_d.at(coord) = tmp_d;
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
} // for (col_block)
|
| 188 |
+
} // for (row_block)
|
| 189 |
+
|
| 190 |
+
tensor_a.add_pointer_offset(batch_stride_A);
|
| 191 |
+
tensor_c.add_pointer_offset(batch_stride_C);
|
| 192 |
+
tensor_d.add_pointer_offset(batch_stride_D);
|
| 193 |
+
|
| 194 |
+
} // for (batch_idx)
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 198 |
+
|
| 199 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 200 |
+
/// objects.
|
| 201 |
+
///
|
| 202 |
+
/// This assumes the accumulator type is the same type as the scalars.
|
| 203 |
+
template <
|
| 204 |
+
typename ElementA,
|
| 205 |
+
typename LayoutA,
|
| 206 |
+
typename ElementC,
|
| 207 |
+
typename LayoutC,
|
| 208 |
+
typename ScalarType
|
| 209 |
+
>
|
| 210 |
+
void RankKComplex(
|
| 211 |
+
gemm::GemmCoord problem_size,
|
| 212 |
+
ScalarType alpha,
|
| 213 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 214 |
+
ComplexTransform transform_a,
|
| 215 |
+
ScalarType beta,
|
| 216 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 217 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 218 |
+
FillMode fill_mode_c,
|
| 219 |
+
BlasMode blas_mode) {
|
| 220 |
+
|
| 221 |
+
Rank2KComplex(
|
| 222 |
+
problem_size, alpha,
|
| 223 |
+
tensor_a, transform_a,
|
| 224 |
+
beta, tensor_c, tensor_d,
|
| 225 |
+
ScalarType(0),
|
| 226 |
+
fill_mode_c,
|
| 227 |
+
blas_mode);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 231 |
+
|
| 232 |
+
} // namespace host
|
| 233 |
+
} // namespace reference
|
| 234 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for SYMM update in host-side code.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/blas3.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/tensor_view.h"
|
| 44 |
+
#include "cutlass/gemm/gemm.h"
|
| 45 |
+
#include "cutlass/arch/mma.h"
|
| 46 |
+
#include "cutlass/util/host_tensor.h"
|
| 47 |
+
#include "cutlass/util/reference/host/gemm.h"
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace reference {
|
| 51 |
+
namespace host {
|
| 52 |
+
|
| 53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 56 |
+
/// objects.
|
| 57 |
+
template <
|
| 58 |
+
typename ElementA,
|
| 59 |
+
typename LayoutA,
|
| 60 |
+
SideMode SideModeA,
|
| 61 |
+
FillMode FillModeA,
|
| 62 |
+
typename ElementB,
|
| 63 |
+
typename LayoutB,
|
| 64 |
+
typename ElementC,
|
| 65 |
+
typename LayoutC,
|
| 66 |
+
typename ScalarType,
|
| 67 |
+
typename ComputeType,
|
| 68 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 69 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 70 |
+
>
|
| 71 |
+
void compute_symm(
|
| 72 |
+
gemm::GemmCoord problem_size,
|
| 73 |
+
ScalarType alpha,
|
| 74 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 75 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 76 |
+
ScalarType beta,
|
| 77 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 78 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 79 |
+
ComputeType initial_accum) {
|
| 80 |
+
|
| 81 |
+
static_assert(
|
| 82 |
+
LayoutA::kRank == 2 &&
|
| 83 |
+
LayoutB::kRank == 2 &&
|
| 84 |
+
LayoutC::kRank == 2,
|
| 85 |
+
"Tensors must be of rank 2");
|
| 86 |
+
|
| 87 |
+
static_assert(SideModeA != SideMode::kInvalid
|
| 88 |
+
, "Side Mode can either be Left or Right.");
|
| 89 |
+
|
| 90 |
+
static_assert(
|
| 91 |
+
FillModeA == FillMode::kLower ||
|
| 92 |
+
FillModeA == FillMode::kUpper,
|
| 93 |
+
"Fill Mode can either be Lower or Upper.");
|
| 94 |
+
|
| 95 |
+
using CompareOp_w_diag = typename TrMatrixCompareOp<FillModeA, DiagType::kNonUnit>::Type;
|
| 96 |
+
using CompareOp_wo_diag = typename TrMatrixCompareOp<FillModeA, DiagType::kZero>::Type;
|
| 97 |
+
|
| 98 |
+
// Note: batch is ignored.
|
| 99 |
+
int const M = problem_size.m();
|
| 100 |
+
int const N = problem_size.n();
|
| 101 |
+
// Assuming correct k-dimension value is passed
|
| 102 |
+
int const K = problem_size.k();
|
| 103 |
+
|
| 104 |
+
// Blocking necessary to speedup reference implementation
|
| 105 |
+
int const Mblock = 16;
|
| 106 |
+
int const Nblock = 16;
|
| 107 |
+
|
| 108 |
+
ConvertOp convert_op;
|
| 109 |
+
InnerProductOp inner_product_op;
|
| 110 |
+
CompareOp_w_diag compare_op_1;
|
| 111 |
+
CompareOp_wo_diag compare_op_2;
|
| 112 |
+
|
| 113 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 114 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 115 |
+
|
| 116 |
+
ComputeType accum[Mblock][Nblock];
|
| 117 |
+
|
| 118 |
+
for (int j = 0; j < Nblock; j++) {
|
| 119 |
+
for (int i = 0; i < Mblock; i++) {
|
| 120 |
+
accum[i][j] = initial_accum;
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 125 |
+
for (int j = 0; j < Nblock; j++) {
|
| 126 |
+
for (int i = 0; i < Mblock; i++) {
|
| 127 |
+
int row = row_block + i;
|
| 128 |
+
int col = col_block + j;
|
| 129 |
+
|
| 130 |
+
if (row < M && col < N) {
|
| 131 |
+
ElementA a_1 = ElementA();
|
| 132 |
+
ElementB b_1 = ElementB();
|
| 133 |
+
ElementA a_2 = ElementA();
|
| 134 |
+
ElementB b_2 = ElementB();
|
| 135 |
+
|
| 136 |
+
// A x B or B x A (with diagonal)
|
| 137 |
+
if (SideModeA == SideMode::kLeft) {
|
| 138 |
+
a_1 = (compare_op_1(row, k_block)) ?
|
| 139 |
+
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA();
|
| 140 |
+
b_1 = tensor_b.at(MatrixCoord(k_block, col));
|
| 141 |
+
} else if (SideModeA == SideMode::kRight) {
|
| 142 |
+
a_1 = tensor_b.at(MatrixCoord(row, k_block));
|
| 143 |
+
b_1 = (compare_op_1(k_block, col)) ?
|
| 144 |
+
tensor_a.at(MatrixCoord(k_block, col)) : ElementA();
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
ComputeType compute_a_1(cast_if_scalar<ComputeType>(a_1));
|
| 148 |
+
ComputeType compute_b_1(cast_if_scalar<ComputeType>(b_1));
|
| 149 |
+
|
| 150 |
+
accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]);
|
| 151 |
+
|
| 152 |
+
// A^T x B or B x A^T (without diagonal)
|
| 153 |
+
if (SideModeA == SideMode::kLeft) {
|
| 154 |
+
a_2 = (compare_op_2(k_block, row)) ?
|
| 155 |
+
(tensor_a.at(MatrixCoord(k_block, row))) : ElementA();
|
| 156 |
+
b_2 = tensor_b.at(MatrixCoord(k_block, col));
|
| 157 |
+
} else if (SideModeA == SideMode::kRight) {
|
| 158 |
+
a_2 = tensor_b.at(MatrixCoord(row, k_block));
|
| 159 |
+
b_2 = (compare_op_2(col, k_block)) ?
|
| 160 |
+
tensor_a.at(MatrixCoord(col, k_block)) : ElementA();
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
ComputeType compute_a_2(cast_if_scalar<ComputeType>(a_2));
|
| 164 |
+
ComputeType compute_b_2(cast_if_scalar<ComputeType>(b_2));
|
| 165 |
+
|
| 166 |
+
accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]);
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
for (int j = 0; j < Nblock; j++) {
|
| 173 |
+
for (int i = 0; i < Mblock; i++) {
|
| 174 |
+
int row = row_block + i;
|
| 175 |
+
int col = col_block + j;
|
| 176 |
+
|
| 177 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 178 |
+
|
| 179 |
+
if (row < M && col < N) {
|
| 180 |
+
tensor_d.at(coord) = convert_op(
|
| 181 |
+
alpha * ScalarType(accum[i][j]) +
|
| 182 |
+
beta * ScalarType(tensor_c.at(coord)));
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 191 |
+
|
| 192 |
+
/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef
|
| 193 |
+
/// objects.
|
| 194 |
+
template <
|
| 195 |
+
typename ElementA,
|
| 196 |
+
typename LayoutA,
|
| 197 |
+
SideMode SideModeA,
|
| 198 |
+
FillMode FillModeA,
|
| 199 |
+
typename ElementB,
|
| 200 |
+
typename LayoutB,
|
| 201 |
+
typename ElementC,
|
| 202 |
+
typename LayoutC,
|
| 203 |
+
typename ScalarType,
|
| 204 |
+
typename ComputeType,
|
| 205 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 206 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 207 |
+
>
|
| 208 |
+
void compute_symm(
|
| 209 |
+
gemm::GemmCoord problem_size,
|
| 210 |
+
ScalarType alpha,
|
| 211 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 212 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 213 |
+
ScalarType beta,
|
| 214 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 215 |
+
ComputeType initial_accum) {
|
| 216 |
+
compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
|
| 217 |
+
ScalarType, ComputeType, InnerProductOp, ConvertOp>(
|
| 218 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
|
| 219 |
+
initial_accum);
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 223 |
+
|
| 224 |
+
template <
|
| 225 |
+
typename ElementA,
|
| 226 |
+
typename LayoutA,
|
| 227 |
+
SideMode SideModeA,
|
| 228 |
+
FillMode FillModeA,
|
| 229 |
+
typename ElementB,
|
| 230 |
+
typename LayoutB,
|
| 231 |
+
typename ElementC,
|
| 232 |
+
typename LayoutC,
|
| 233 |
+
typename ScalarType,
|
| 234 |
+
typename ComputeType,
|
| 235 |
+
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
| 236 |
+
>
|
| 237 |
+
struct Symm;
|
| 238 |
+
|
| 239 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 240 |
+
|
| 241 |
+
/// Partial specialization for multiply-add
|
| 242 |
+
template <typename ElementA, typename LayoutA,
|
| 243 |
+
SideMode SideModeA, FillMode FillModeA,
|
| 244 |
+
typename ElementB, typename LayoutB,
|
| 245 |
+
typename ElementC, typename LayoutC,
|
| 246 |
+
typename ScalarType, typename ComputeType>
|
| 247 |
+
struct Symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
|
| 248 |
+
ComputeType, arch::OpMultiplyAdd> {
|
| 249 |
+
|
| 250 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 251 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 252 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 253 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 254 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 255 |
+
static_assert(
|
| 256 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 257 |
+
"Tensors must be of rank 2");
|
| 258 |
+
|
| 259 |
+
compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
|
| 260 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 261 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 265 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 266 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 267 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 268 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 269 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 270 |
+
static_assert(
|
| 271 |
+
LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
|
| 272 |
+
"Tensors must be of rank 2");
|
| 273 |
+
|
| 274 |
+
compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
|
| 275 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 276 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 277 |
+
}
|
| 278 |
+
};
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 282 |
+
|
| 283 |
+
} // namespace host
|
| 284 |
+
} // namespace reference
|
| 285 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued SYMM update in host-side code.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/blas3.h"
|
| 40 |
+
#include "cutlass/complex.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/tensor_view.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
#include <cassert>
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace reference {
|
| 48 |
+
namespace host {
|
| 49 |
+
|
| 50 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
|
| 53 |
+
/// objects.
|
| 54 |
+
///
|
| 55 |
+
/// Explicitly naming types needed by this template can be cumbersome, particularly for the
|
| 56 |
+
/// accumulator type, so a function argument 'initial_accum' is exposed. Passing
|
| 57 |
+
/// AccumulatorType(0) as the last function argument can be easier than naming all template
|
| 58 |
+
/// arguments explicitly.
|
| 59 |
+
template <
|
| 60 |
+
typename ElementA,
|
| 61 |
+
typename LayoutA,
|
| 62 |
+
SideMode SideModeA,
|
| 63 |
+
FillMode FillModeA,
|
| 64 |
+
typename ElementB,
|
| 65 |
+
typename LayoutB,
|
| 66 |
+
typename ElementC,
|
| 67 |
+
typename LayoutC,
|
| 68 |
+
typename ScalarType,
|
| 69 |
+
typename ComputeType,
|
| 70 |
+
BlasMode BlasMode_ = BlasMode::kSymmetric,
|
| 71 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 72 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 73 |
+
>
|
| 74 |
+
void compute_symm_complex(
|
| 75 |
+
gemm::GemmCoord problem_size,
|
| 76 |
+
ScalarType alpha,
|
| 77 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 78 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 79 |
+
ScalarType beta,
|
| 80 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 81 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 82 |
+
ComputeType initial_accum,
|
| 83 |
+
int batch_count = 1,
|
| 84 |
+
int64_t batch_stride_A = 0,
|
| 85 |
+
int64_t batch_stride_B = 0,
|
| 86 |
+
int64_t batch_stride_C = 0,
|
| 87 |
+
int64_t batch_stride_D = 0) {
|
| 88 |
+
|
| 89 |
+
static SideMode const kSideModeA = SideModeA;
|
| 90 |
+
static FillMode const kFillModeA = FillModeA;
|
| 91 |
+
static BlasMode const kBlasMode = BlasMode_;
|
| 92 |
+
|
| 93 |
+
static_assert(
|
| 94 |
+
LayoutA::kRank == 2 &&
|
| 95 |
+
LayoutB::kRank == 2 &&
|
| 96 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 97 |
+
|
| 98 |
+
static_assert(kSideModeA != SideMode::kInvalid
|
| 99 |
+
, "Side Mode can either be Left or Right.");
|
| 100 |
+
|
| 101 |
+
static_assert(
|
| 102 |
+
kFillModeA == FillMode::kLower ||
|
| 103 |
+
kFillModeA == FillMode::kUpper,
|
| 104 |
+
"Fill Mode can either be Lower or Upper.");
|
| 105 |
+
|
| 106 |
+
using CompareOp_w_diag = typename TrMatrixCompareOp<kFillModeA, DiagType::kNonUnit>::Type;
|
| 107 |
+
using CompareOp_wo_diag = typename TrMatrixCompareOp<kFillModeA, DiagType::kZero>::Type;
|
| 108 |
+
|
| 109 |
+
// Note: batch is ignored.
|
| 110 |
+
int const M = problem_size.m();
|
| 111 |
+
int const N = problem_size.n();
|
| 112 |
+
// Assuming correct k-dimension value is passed
|
| 113 |
+
int const K = problem_size.k();
|
| 114 |
+
|
| 115 |
+
// Blocking necessary to speedup reference implementation
|
| 116 |
+
int const Mblock = 16;
|
| 117 |
+
int const Nblock = 16;
|
| 118 |
+
|
| 119 |
+
ConvertOp convert_op;
|
| 120 |
+
InnerProductOp inner_product_op;
|
| 121 |
+
CompareOp_w_diag compare_op_1;
|
| 122 |
+
CompareOp_wo_diag compare_op_2;
|
| 123 |
+
|
| 124 |
+
for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
|
| 125 |
+
|
| 126 |
+
// Compute matrix product using blocks
|
| 127 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 128 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 129 |
+
|
| 130 |
+
ComputeType accum[Mblock][Nblock];
|
| 131 |
+
|
| 132 |
+
for (int j = 0; j < Nblock; j++) {
|
| 133 |
+
for (int i = 0; i < Mblock; i++) {
|
| 134 |
+
accum[i][j] = initial_accum;
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 139 |
+
for (int j = 0; j < Nblock; j++) {
|
| 140 |
+
for (int i = 0; i < Mblock; i++) {
|
| 141 |
+
int row = row_block + i;
|
| 142 |
+
int col = col_block + j;
|
| 143 |
+
|
| 144 |
+
if (row < M && col < N)
|
| 145 |
+
{
|
| 146 |
+
ElementA a_1 = ElementA();
|
| 147 |
+
ElementB b_1 = ElementB();
|
| 148 |
+
ElementA a_2 = ElementA();
|
| 149 |
+
ElementB b_2 = ElementB();
|
| 150 |
+
|
| 151 |
+
// A x B or B x A (with diagonal)
|
| 152 |
+
if (kSideModeA == SideMode::kLeft) {
|
| 153 |
+
a_1 = (compare_op_1(row, k_block)) ?
|
| 154 |
+
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA();
|
| 155 |
+
b_1 = tensor_b.at(MatrixCoord(k_block, col));
|
| 156 |
+
} else if (kSideModeA == SideMode::kRight) {
|
| 157 |
+
a_1 = tensor_b.at(MatrixCoord(row, k_block));
|
| 158 |
+
b_1 = (compare_op_1(k_block, col)) ?
|
| 159 |
+
tensor_a.at(MatrixCoord(k_block, col)) : ElementA();
|
| 160 |
+
}
|
| 161 |
+
ComputeType compute_a_1 = ComputeType(a_1);
|
| 162 |
+
ComputeType compute_b_1 = ComputeType(b_1);
|
| 163 |
+
|
| 164 |
+
// The imaginary parts of the diagonal elements of
|
| 165 |
+
// a complex data type are assumed and set to zero
|
| 166 |
+
if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) {
|
| 167 |
+
compute_a_1 = real(compute_a_1);
|
| 168 |
+
} else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) {
|
| 169 |
+
compute_b_1 = real(compute_b_1);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]);
|
| 173 |
+
|
| 174 |
+
// A^T x B or B x A^T (without diagonal)
|
| 175 |
+
if (kSideModeA == SideMode::kLeft) {
|
| 176 |
+
a_2 = (compare_op_2(k_block, row)) ?
|
| 177 |
+
(tensor_a.at(MatrixCoord(k_block, row))) : ElementA();
|
| 178 |
+
b_2 = tensor_b.at(MatrixCoord(k_block, col));
|
| 179 |
+
if (kBlasMode == BlasMode::kHermitian)
|
| 180 |
+
a_2 = conj(a_2);
|
| 181 |
+
} else if (kSideModeA == SideMode::kRight) {
|
| 182 |
+
a_2 = tensor_b.at(MatrixCoord(row, k_block));
|
| 183 |
+
b_2 = (compare_op_2(col, k_block)) ?
|
| 184 |
+
tensor_a.at(MatrixCoord(col, k_block)) : ElementA();
|
| 185 |
+
if (kBlasMode == BlasMode::kHermitian)
|
| 186 |
+
b_2 = conj(b_2);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
ComputeType compute_a_2 = ComputeType(a_2);
|
| 190 |
+
ComputeType compute_b_2 = ComputeType(b_2);
|
| 191 |
+
|
| 192 |
+
accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]);
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
for (int j = 0; j < Nblock; j++) {
|
| 199 |
+
for (int i = 0; i < Mblock; i++) {
|
| 200 |
+
int row = row_block + i;
|
| 201 |
+
int col = col_block + j;
|
| 202 |
+
|
| 203 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 204 |
+
|
| 205 |
+
if (row < M && col < N) {
|
| 206 |
+
|
| 207 |
+
ScalarType c = tensor_c.at(coord);
|
| 208 |
+
|
| 209 |
+
tensor_d.at(coord) = convert_op(
|
| 210 |
+
alpha * ScalarType(accum[i][j]) +
|
| 211 |
+
beta * c);
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
} // for (col_block)
|
| 217 |
+
} // for (row_block)
|
| 218 |
+
|
| 219 |
+
tensor_a.add_pointer_offset(batch_stride_A);
|
| 220 |
+
tensor_b.add_pointer_offset(batch_stride_B);
|
| 221 |
+
tensor_c.add_pointer_offset(batch_stride_C);
|
| 222 |
+
tensor_d.add_pointer_offset(batch_stride_D);
|
| 223 |
+
|
| 224 |
+
} // for (batch_idx)
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 228 |
+
|
| 229 |
+
template <
|
| 230 |
+
typename ElementA,
|
| 231 |
+
typename LayoutA,
|
| 232 |
+
SideMode SideModeA,
|
| 233 |
+
FillMode FillModeA,
|
| 234 |
+
typename ElementB,
|
| 235 |
+
typename LayoutB,
|
| 236 |
+
typename ElementC,
|
| 237 |
+
typename LayoutC,
|
| 238 |
+
typename ScalarType,
|
| 239 |
+
typename ComputeType,
|
| 240 |
+
BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric,
|
| 241 |
+
typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex
|
| 242 |
+
>
|
| 243 |
+
struct SymmComplex;
|
| 244 |
+
|
| 245 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 246 |
+
|
| 247 |
+
/// Partial specialization for multiply-add
|
| 248 |
+
template <typename ElementA, typename LayoutA,
|
| 249 |
+
SideMode SideModeA, FillMode FillModeA,
|
| 250 |
+
typename ElementB, typename LayoutB,
|
| 251 |
+
typename ElementC, typename LayoutC,
|
| 252 |
+
typename ScalarType, typename ComputeType,
|
| 253 |
+
BlasMode BlasMode_>
|
| 254 |
+
struct SymmComplex<ElementA, LayoutA,
|
| 255 |
+
SideModeA, FillModeA,
|
| 256 |
+
ElementB, LayoutB,
|
| 257 |
+
ElementC, LayoutC, ScalarType,
|
| 258 |
+
ComputeType, BlasMode_,
|
| 259 |
+
arch::OpMultiplyAddComplex> {
|
| 260 |
+
|
| 261 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 262 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 263 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 264 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 265 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 266 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 267 |
+
static_assert(
|
| 268 |
+
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
| 269 |
+
"Tensors must be of rank 2");
|
| 270 |
+
|
| 271 |
+
compute_symm_complex<ElementA, LayoutA,
|
| 272 |
+
SideModeA, FillModeA,
|
| 273 |
+
ElementB, LayoutB,
|
| 274 |
+
ElementC, LayoutC,
|
| 275 |
+
ScalarType, ComputeType, BlasMode_, multiply_add<ComputeType>>(
|
| 276 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 277 |
+
}
|
| 278 |
+
};
|
| 279 |
+
|
| 280 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 281 |
+
|
| 282 |
+
/// Partial specialization for gaussian multiply-add
|
| 283 |
+
template <typename ElementA, typename LayoutA,
|
| 284 |
+
SideMode SideModeA, FillMode FillModeA,
|
| 285 |
+
typename ElementB, typename LayoutB,
|
| 286 |
+
typename ElementC, typename LayoutC,
|
| 287 |
+
typename ScalarType, typename ComputeType,
|
| 288 |
+
BlasMode BlasMode_>
|
| 289 |
+
struct SymmComplex<ElementA, LayoutA,
|
| 290 |
+
SideModeA, FillModeA,
|
| 291 |
+
ElementB, LayoutB,
|
| 292 |
+
ElementC, LayoutC, ScalarType,
|
| 293 |
+
ComputeType, BlasMode_,
|
| 294 |
+
arch::OpMultiplyAddGaussianComplex> {
|
| 295 |
+
|
| 296 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 297 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 298 |
+
TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
|
| 299 |
+
TensorRef<ElementC, LayoutC> tensor_c,
|
| 300 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 301 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 302 |
+
static_assert(
|
| 303 |
+
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
| 304 |
+
"Tensors must be of rank 2");
|
| 305 |
+
|
| 306 |
+
compute_symm_complex<ElementA, LayoutA,
|
| 307 |
+
SideModeA, FillModeA,
|
| 308 |
+
ElementB, LayoutB,
|
| 309 |
+
ElementC, LayoutC,
|
| 310 |
+
ScalarType, ComputeType, BlasMode_, multiply_add<ComputeType>>(
|
| 311 |
+
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
|
| 312 |
+
}
|
| 313 |
+
};
|
| 314 |
+
|
| 315 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 316 |
+
|
| 317 |
+
} // namespace host
|
| 318 |
+
} // namespace reference
|
| 319 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Defines host-side elementwise operations on TensorView.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// Standard Library includes
|
| 38 |
+
#include <utility>
|
| 39 |
+
|
| 40 |
+
// Cutlass includes
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "cutlass/relatively_equal.h"
|
| 43 |
+
#include "cutlass/tensor_view.h"
|
| 44 |
+
#include "cutlass/tensor_view_planar_complex.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/util/distribution.h"
|
| 47 |
+
#include "tensor_foreach.h"
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace reference {
|
| 51 |
+
namespace host {
|
| 52 |
+
|
| 53 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
namespace detail {
|
| 57 |
+
|
| 58 |
+
template <
|
| 59 |
+
typename Element, ///< Element type
|
| 60 |
+
typename Layout> ///< Layout function
|
| 61 |
+
struct TensorGreatestErrorFunc {
|
| 62 |
+
|
| 63 |
+
//
|
| 64 |
+
// Data members
|
| 65 |
+
//
|
| 66 |
+
|
| 67 |
+
TensorView<Element, Layout> lhs;
|
| 68 |
+
TensorView<Element, Layout> rhs;
|
| 69 |
+
double result;
|
| 70 |
+
|
| 71 |
+
/// Ctor
|
| 72 |
+
TensorGreatestErrorFunc(
|
| 73 |
+
TensorView<Element, Layout> const &lhs_,
|
| 74 |
+
TensorView<Element, Layout> const &rhs_
|
| 75 |
+
) :
|
| 76 |
+
lhs(lhs_),
|
| 77 |
+
rhs(rhs_),
|
| 78 |
+
result(0.0) { }
|
| 79 |
+
|
| 80 |
+
/// Visits a coordinate
|
| 81 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 82 |
+
|
| 83 |
+
Element lhs_ = lhs.at(coord);
|
| 84 |
+
Element rhs_ = rhs.at(coord);
|
| 85 |
+
|
| 86 |
+
result = std::max(result, std::abs(double(lhs_) - double(rhs_)));
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
/// Returns true if equal
|
| 90 |
+
operator double() const {
|
| 91 |
+
return result;
|
| 92 |
+
}
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
template <
|
| 96 |
+
typename Element, ///< Element type
|
| 97 |
+
typename Layout> ///< Layout function
|
| 98 |
+
struct TensorMREFunc {
|
| 99 |
+
|
| 100 |
+
//
|
| 101 |
+
// Data members
|
| 102 |
+
//
|
| 103 |
+
|
| 104 |
+
TensorView<Element, Layout> lhs;
|
| 105 |
+
TensorView<Element, Layout> rhs;
|
| 106 |
+
double sum;
|
| 107 |
+
uint64_t count;
|
| 108 |
+
static constexpr double epsilon = 1e-6;
|
| 109 |
+
|
| 110 |
+
/// Ctor
|
| 111 |
+
TensorMREFunc(
|
| 112 |
+
TensorView<Element, Layout> const &lhs_,
|
| 113 |
+
TensorView<Element, Layout> const &rhs_
|
| 114 |
+
) :
|
| 115 |
+
lhs(lhs_),
|
| 116 |
+
rhs(rhs_),
|
| 117 |
+
sum(0.0),
|
| 118 |
+
count(0) { }
|
| 119 |
+
|
| 120 |
+
/// Visits a coordinate
|
| 121 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 122 |
+
|
| 123 |
+
Element lhs_ = lhs.at(coord);
|
| 124 |
+
Element rhs_ = rhs.at(coord);
|
| 125 |
+
|
| 126 |
+
sum += std::abs(double(lhs_) - double(rhs_) / (double(rhs_) + epsilon));
|
| 127 |
+
++count;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Returns true if equal
|
| 131 |
+
operator double() const {
|
| 132 |
+
return sum / double(count);
|
| 133 |
+
}
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
template <
|
| 137 |
+
typename Element, ///< Element type
|
| 138 |
+
typename Layout> ///< Layout function
|
| 139 |
+
struct TensorMSEFunc {
|
| 140 |
+
|
| 141 |
+
//
|
| 142 |
+
// Data members
|
| 143 |
+
//
|
| 144 |
+
|
| 145 |
+
TensorView<Element, Layout> lhs;
|
| 146 |
+
TensorView<Element, Layout> rhs;
|
| 147 |
+
double sum;
|
| 148 |
+
uint64_t count;
|
| 149 |
+
|
| 150 |
+
/// Ctor
|
| 151 |
+
TensorMSEFunc(
|
| 152 |
+
TensorView<Element, Layout> const &lhs_,
|
| 153 |
+
TensorView<Element, Layout> const &rhs_
|
| 154 |
+
) :
|
| 155 |
+
lhs(lhs_),
|
| 156 |
+
rhs(rhs_),
|
| 157 |
+
sum(0.0),
|
| 158 |
+
count(0) { }
|
| 159 |
+
|
| 160 |
+
/// Visits a coordinate
|
| 161 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 162 |
+
|
| 163 |
+
Element lhs_ = lhs.at(coord);
|
| 164 |
+
Element rhs_ = rhs.at(coord);
|
| 165 |
+
|
| 166 |
+
sum += std::pow((double(lhs_) - double(rhs_)), 2);
|
| 167 |
+
++count;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/// Returns true if equal
|
| 171 |
+
operator double() const {
|
| 172 |
+
return sum / double(count);
|
| 173 |
+
}
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
template <
|
| 177 |
+
typename Element, ///< Element type
|
| 178 |
+
typename Layout> ///< Layout function
|
| 179 |
+
struct TensorEqualsFunc {
|
| 180 |
+
|
| 181 |
+
//
|
| 182 |
+
// Data members
|
| 183 |
+
//
|
| 184 |
+
|
| 185 |
+
TensorView<Element, Layout> lhs;
|
| 186 |
+
TensorView<Element, Layout> rhs;
|
| 187 |
+
bool result;
|
| 188 |
+
|
| 189 |
+
/// Ctor
|
| 190 |
+
TensorEqualsFunc(): result(true) { }
|
| 191 |
+
|
| 192 |
+
/// Ctor
|
| 193 |
+
TensorEqualsFunc(
|
| 194 |
+
TensorView<Element, Layout> const &lhs_,
|
| 195 |
+
TensorView<Element, Layout> const &rhs_
|
| 196 |
+
) :
|
| 197 |
+
lhs(lhs_), rhs(rhs_), result(true) { }
|
| 198 |
+
|
| 199 |
+
/// Visits a coordinate
|
| 200 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 201 |
+
|
| 202 |
+
Element lhs_ = lhs.at(coord);
|
| 203 |
+
Element rhs_ = rhs.at(coord);
|
| 204 |
+
|
| 205 |
+
if (lhs_ != rhs_) {
|
| 206 |
+
result = false;
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/// Returns true if equal
|
| 211 |
+
operator bool() const {
|
| 212 |
+
return result;
|
| 213 |
+
}
|
| 214 |
+
};
|
| 215 |
+
|
| 216 |
+
template <
|
| 217 |
+
typename Element, ///< Element type
|
| 218 |
+
typename Layout> ///< Layout function
|
| 219 |
+
struct TensorRelativelyEqualsFunc {
|
| 220 |
+
|
| 221 |
+
//
|
| 222 |
+
// Data members
|
| 223 |
+
//
|
| 224 |
+
|
| 225 |
+
TensorView<Element, Layout> lhs;
|
| 226 |
+
TensorView<Element, Layout> rhs;
|
| 227 |
+
Element epsilon;
|
| 228 |
+
Element nonzero_floor;
|
| 229 |
+
bool result;
|
| 230 |
+
|
| 231 |
+
/// Ctor
|
| 232 |
+
TensorRelativelyEqualsFunc(
|
| 233 |
+
TensorView<Element, Layout> const &lhs_,
|
| 234 |
+
TensorView<Element, Layout> const &rhs_,
|
| 235 |
+
Element epsilon_,
|
| 236 |
+
Element nonzero_floor_
|
| 237 |
+
) :
|
| 238 |
+
lhs(lhs_),
|
| 239 |
+
rhs(rhs_),
|
| 240 |
+
epsilon(epsilon_),
|
| 241 |
+
nonzero_floor(nonzero_floor_),
|
| 242 |
+
result(true) { }
|
| 243 |
+
|
| 244 |
+
/// Visits a coordinate
|
| 245 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 246 |
+
|
| 247 |
+
Element lhs_ = lhs.at(coord);
|
| 248 |
+
Element rhs_ = rhs.at(coord);
|
| 249 |
+
|
| 250 |
+
if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) {
|
| 251 |
+
result = false;
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/// Returns true if equal
|
| 256 |
+
operator bool() const {
|
| 257 |
+
return result;
|
| 258 |
+
}
|
| 259 |
+
};
|
| 260 |
+
|
| 261 |
+
} // namespace detail
|
| 262 |
+
|
| 263 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 264 |
+
|
| 265 |
+
/// Returns the Mean Squared Error between two tensors.
|
| 266 |
+
template <
|
| 267 |
+
typename Element, ///< Element type
|
| 268 |
+
typename Layout> ///< Layout function
|
| 269 |
+
double TensorMSE(
|
| 270 |
+
TensorView<Element, Layout> const &lhs,
|
| 271 |
+
TensorView<Element, Layout> const &rhs) {
|
| 272 |
+
|
| 273 |
+
// Extents must be identical
|
| 274 |
+
if (lhs.extent() != rhs.extent()) {
|
| 275 |
+
return -1;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
detail::TensorMSEFunc<Element, Layout> func(lhs, rhs);
|
| 279 |
+
TensorForEach(
|
| 280 |
+
lhs.extent(),
|
| 281 |
+
func
|
| 282 |
+
);
|
| 283 |
+
|
| 284 |
+
return double(func);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 288 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 289 |
+
|
| 290 |
+
/// Returns the Mean Relative Error between two tensors.
|
| 291 |
+
template <
|
| 292 |
+
typename Element, ///< Element type
|
| 293 |
+
typename Layout> ///< Layout function
|
| 294 |
+
double TensorMRE(
|
| 295 |
+
TensorView<Element, Layout> const &lhs,
|
| 296 |
+
TensorView<Element, Layout> const &rhs) {
|
| 297 |
+
|
| 298 |
+
// Extents must be identical
|
| 299 |
+
if (lhs.extent() != rhs.extent()) {
|
| 300 |
+
return -1;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
detail::TensorMREFunc<Element, Layout> func(lhs, rhs);
|
| 304 |
+
TensorForEach(
|
| 305 |
+
lhs.extent(),
|
| 306 |
+
func
|
| 307 |
+
);
|
| 308 |
+
|
| 309 |
+
return double(func);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 313 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 314 |
+
|
| 315 |
+
/// Returns the greatest error between two tensors.
|
| 316 |
+
template <
|
| 317 |
+
typename Element, ///< Element type
|
| 318 |
+
typename Layout> ///< Layout function
|
| 319 |
+
double TensorGreatestError(
|
| 320 |
+
TensorView<Element, Layout> const &lhs,
|
| 321 |
+
TensorView<Element, Layout> const &rhs) {
|
| 322 |
+
|
| 323 |
+
// Extents must be identical
|
| 324 |
+
if (lhs.extent() != rhs.extent()) {
|
| 325 |
+
return -1;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
detail::TensorGreatestErrorFunc<Element, Layout> func(lhs, rhs);
|
| 329 |
+
TensorForEach(
|
| 330 |
+
lhs.extent(),
|
| 331 |
+
func
|
| 332 |
+
);
|
| 333 |
+
|
| 334 |
+
return double(func);
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 338 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 339 |
+
|
| 340 |
+
/// Returns true if two tensor views are equal.
|
| 341 |
+
template <
|
| 342 |
+
typename Element, ///< Element type
|
| 343 |
+
typename Layout> ///< Layout function
|
| 344 |
+
bool TensorEquals(
|
| 345 |
+
TensorView<Element, Layout> const &lhs,
|
| 346 |
+
TensorView<Element, Layout> const &rhs) {
|
| 347 |
+
|
| 348 |
+
// Extents must be identical
|
| 349 |
+
if (lhs.extent() != rhs.extent()) {
|
| 350 |
+
return false;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);
|
| 354 |
+
TensorForEach(
|
| 355 |
+
lhs.extent(),
|
| 356 |
+
func
|
| 357 |
+
);
|
| 358 |
+
|
| 359 |
+
return bool(func);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
/// Returns true if two tensor views are equal.
|
| 363 |
+
template <
|
| 364 |
+
typename Element, ///< Element type
|
| 365 |
+
typename Layout> ///< Layout function
|
| 366 |
+
bool TensorEquals(
|
| 367 |
+
TensorViewPlanarComplex<Element, Layout> const &lhs,
|
| 368 |
+
TensorViewPlanarComplex<Element, Layout> const &rhs) {
|
| 369 |
+
|
| 370 |
+
// Extents must be identical
|
| 371 |
+
if (lhs.extent() != rhs.extent()) {
|
| 372 |
+
return false;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
detail::TensorEqualsFunc<Element, Layout> real_func(
|
| 376 |
+
{lhs.data(), lhs.layout(), lhs.extent()},
|
| 377 |
+
{rhs.data(), rhs.layout(), rhs.extent()}
|
| 378 |
+
);
|
| 379 |
+
|
| 380 |
+
TensorForEach(
|
| 381 |
+
lhs.extent(),
|
| 382 |
+
real_func
|
| 383 |
+
);
|
| 384 |
+
|
| 385 |
+
if (!bool(real_func)) {
|
| 386 |
+
return false;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
detail::TensorEqualsFunc<Element, Layout> imag_func(
|
| 390 |
+
{lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()},
|
| 391 |
+
{rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}
|
| 392 |
+
);
|
| 393 |
+
|
| 394 |
+
TensorForEach(
|
| 395 |
+
lhs.extent(),
|
| 396 |
+
imag_func
|
| 397 |
+
);
|
| 398 |
+
|
| 399 |
+
return bool(imag_func);
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 403 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 404 |
+
|
| 405 |
+
/// Returns true if two tensor views are relatively equal.
|
| 406 |
+
template <
|
| 407 |
+
typename Element, ///< Element type
|
| 408 |
+
typename Layout> ///< Layout function
|
| 409 |
+
bool TensorRelativelyEquals(
|
| 410 |
+
TensorView<Element, Layout> const &lhs,
|
| 411 |
+
TensorView<Element, Layout> const &rhs,
|
| 412 |
+
Element epsilon,
|
| 413 |
+
Element nonzero_floor) {
|
| 414 |
+
|
| 415 |
+
// Extents must be identical
|
| 416 |
+
if (lhs.extent() != rhs.extent()) {
|
| 417 |
+
return false;
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
detail::TensorRelativelyEqualsFunc<Element, Layout> func(lhs, rhs, epsilon, nonzero_floor);
|
| 421 |
+
TensorForEach(
|
| 422 |
+
lhs.extent(),
|
| 423 |
+
func
|
| 424 |
+
);
|
| 425 |
+
|
| 426 |
+
return bool(func);
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
/// Returns true if two tensor views are relatively equal.
|
| 430 |
+
template <
|
| 431 |
+
typename Element, ///< Element type
|
| 432 |
+
typename Layout> ///< Layout function
|
| 433 |
+
bool TensorRelativelyEquals(
|
| 434 |
+
TensorViewPlanarComplex<Element, Layout> const &lhs,
|
| 435 |
+
TensorViewPlanarComplex<Element, Layout> const &rhs,
|
| 436 |
+
Element epsilon,
|
| 437 |
+
Element nonzero_floor) {
|
| 438 |
+
|
| 439 |
+
// Extents must be identical
|
| 440 |
+
if (lhs.extent() != rhs.extent()) {
|
| 441 |
+
return false;
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
detail::TensorRelativelyEqualsFunc<Element, Layout> real_func(
|
| 445 |
+
{lhs.data(), lhs.layout(), lhs.extent()},
|
| 446 |
+
{rhs.data(), rhs.layout(), rhs.extent()},
|
| 447 |
+
epsilon,
|
| 448 |
+
nonzero_floor
|
| 449 |
+
);
|
| 450 |
+
|
| 451 |
+
TensorForEach(
|
| 452 |
+
lhs.extent(),
|
| 453 |
+
real_func
|
| 454 |
+
);
|
| 455 |
+
|
| 456 |
+
if (!bool(real_func)) {
|
| 457 |
+
return false;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
detail::TensorEqualsFunc<Element, Layout> imag_func(
|
| 461 |
+
{lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()},
|
| 462 |
+
{rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()},
|
| 463 |
+
epsilon,
|
| 464 |
+
nonzero_floor
|
| 465 |
+
);
|
| 466 |
+
|
| 467 |
+
TensorForEach(
|
| 468 |
+
lhs.extent(),
|
| 469 |
+
imag_func
|
| 470 |
+
);
|
| 471 |
+
|
| 472 |
+
return bool(imag_func);
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 476 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 477 |
+
|
| 478 |
+
/// Returns true if two tensor views are NOT equal.
|
| 479 |
+
template <
|
| 480 |
+
typename Element, ///< Element type
|
| 481 |
+
typename Layout> ///< Layout function
|
| 482 |
+
bool TensorNotEquals(
|
| 483 |
+
TensorView<Element, Layout> const &lhs,
|
| 484 |
+
TensorView<Element, Layout> const &rhs) {
|
| 485 |
+
|
| 486 |
+
// Extents must be identical
|
| 487 |
+
if (lhs.extent() != rhs.extent()) {
|
| 488 |
+
return true;
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);
|
| 492 |
+
TensorForEach(
|
| 493 |
+
lhs.extent(),
|
| 494 |
+
func
|
| 495 |
+
);
|
| 496 |
+
|
| 497 |
+
return !bool(func);
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
/// Returns true if two tensor views are equal.
|
| 501 |
+
template <
|
| 502 |
+
typename Element, ///< Element type
|
| 503 |
+
typename Layout> ///< Layout function
|
| 504 |
+
bool TensorNotEquals(
|
| 505 |
+
TensorViewPlanarComplex<Element, Layout> const &lhs,
|
| 506 |
+
TensorViewPlanarComplex<Element, Layout> const &rhs) {
|
| 507 |
+
|
| 508 |
+
return !TensorEquals(lhs, rhs);
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 512 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 513 |
+
|
| 514 |
+
namespace detail {
|
| 515 |
+
|
| 516 |
+
template <
|
| 517 |
+
typename Element, ///< Element type
|
| 518 |
+
typename Layout> ///< Layout function
|
| 519 |
+
struct TensorContainsFunc {
|
| 520 |
+
|
| 521 |
+
//
|
| 522 |
+
// Data members
|
| 523 |
+
//
|
| 524 |
+
|
| 525 |
+
TensorView<Element, Layout> view;
|
| 526 |
+
Element value;
|
| 527 |
+
bool contains;
|
| 528 |
+
Coord<Layout::kRank> location;
|
| 529 |
+
|
| 530 |
+
//
|
| 531 |
+
// Methods
|
| 532 |
+
//
|
| 533 |
+
|
| 534 |
+
/// Ctor
|
| 535 |
+
TensorContainsFunc(): contains(false) { }
|
| 536 |
+
|
| 537 |
+
/// Ctor
|
| 538 |
+
TensorContainsFunc(
|
| 539 |
+
TensorView<Element, Layout> const &view_,
|
| 540 |
+
Element value_
|
| 541 |
+
) :
|
| 542 |
+
view(view_), value(value_), contains(false) { }
|
| 543 |
+
|
| 544 |
+
/// Visits a coordinate
|
| 545 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 546 |
+
|
| 547 |
+
if (view.at(coord) == value) {
|
| 548 |
+
if (!contains) {
|
| 549 |
+
location = coord;
|
| 550 |
+
}
|
| 551 |
+
contains = true;
|
| 552 |
+
}
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
/// Returns true if equal
|
| 556 |
+
operator bool() const {
|
| 557 |
+
return contains;
|
| 558 |
+
}
|
| 559 |
+
};
|
| 560 |
+
|
| 561 |
+
} // namespace detail
|
| 562 |
+
|
| 563 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 564 |
+
|
| 565 |
+
/// Returns true if a value is present in a tensor
|
| 566 |
+
template <
|
| 567 |
+
typename Element, ///< Element type
|
| 568 |
+
typename Layout> ///< Layout function
|
| 569 |
+
bool TensorContains(
|
| 570 |
+
TensorView<Element, Layout> const & view,
|
| 571 |
+
Element value) {
|
| 572 |
+
|
| 573 |
+
detail::TensorContainsFunc<Element, Layout> func(
|
| 574 |
+
view,
|
| 575 |
+
value
|
| 576 |
+
);
|
| 577 |
+
|
| 578 |
+
TensorForEach(
|
| 579 |
+
view.extent(),
|
| 580 |
+
func
|
| 581 |
+
);
|
| 582 |
+
|
| 583 |
+
return bool(func);
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 587 |
+
|
| 588 |
+
/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of
|
| 589 |
+
/// of the first occurrence. If the value is not contained in the tensor, the second element of the
|
| 590 |
+
/// pair is undefined.
|
| 591 |
+
template <
|
| 592 |
+
typename Element, ///< Element type
|
| 593 |
+
typename Layout> ///< Layout function
|
| 594 |
+
std::pair<bool, Coord<Layout::kRank> > TensorFind(
|
| 595 |
+
TensorView<Element, Layout> const & view,
|
| 596 |
+
Element value) {
|
| 597 |
+
|
| 598 |
+
detail::TensorContainsFunc<Element, Layout> func(
|
| 599 |
+
view,
|
| 600 |
+
value
|
| 601 |
+
);
|
| 602 |
+
|
| 603 |
+
TensorForEach(
|
| 604 |
+
view.extent(),
|
| 605 |
+
func
|
| 606 |
+
);
|
| 607 |
+
|
| 608 |
+
return std::make_pair(bool(func), func.location);
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 612 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 613 |
+
|
| 614 |
+
} // namespace host
|
| 615 |
+
} // namespace reference
|
| 616 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Provides several functions for filling tensors with data.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// Standard Library includes
|
| 38 |
+
#include <utility>
|
| 39 |
+
#include <cstdlib>
|
| 40 |
+
#include <cmath>
|
| 41 |
+
|
| 42 |
+
// Cute includes
|
| 43 |
+
#include "cute/tensor.hpp"
|
| 44 |
+
|
| 45 |
+
// Cutlass includes
|
| 46 |
+
#include "cutlass/cutlass.h"
|
| 47 |
+
#include "cutlass/complex.h"
|
| 48 |
+
#include "cutlass/quaternion.h"
|
| 49 |
+
#include "cutlass/array.h"
|
| 50 |
+
#include "cutlass/numeric_types.h"
|
| 51 |
+
|
| 52 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace cutlass {
|
| 55 |
+
namespace reference {
|
| 56 |
+
namespace host {
|
| 57 |
+
|
| 58 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
/// Returns true if two tensor views are equal.
|
| 61 |
+
template <
|
| 62 |
+
typename TensorL,
|
| 63 |
+
typename TensorR
|
| 64 |
+
>
|
| 65 |
+
bool TensorEquals(
|
| 66 |
+
TensorL lhs,
|
| 67 |
+
TensorR rhs) {
|
| 68 |
+
|
| 69 |
+
// Extents must be identical
|
| 70 |
+
if (cute::size(lhs) != cute::size(rhs)) {
|
| 71 |
+
return false;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
for (int64_t idx = 0; idx < cute::size(lhs); ++idx) {
|
| 75 |
+
if (lhs(idx) != rhs(idx)) {
|
| 76 |
+
return false;
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
return true;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// Returns true if two tensor views are NOT equal.
|
| 84 |
+
template <
|
| 85 |
+
typename TensorL,
|
| 86 |
+
typename TensorR
|
| 87 |
+
>
|
| 88 |
+
bool TensorNotEquals(
|
| 89 |
+
TensorL lhs,
|
| 90 |
+
TensorR rhs) {
|
| 91 |
+
|
| 92 |
+
return TensorEquals(lhs, rhs);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 96 |
+
|
| 97 |
+
} // namespace host
|
| 98 |
+
} // namespace reference
|
| 99 |
+
} // namespace cutlass
|
| 100 |
+
|
| 101 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Defines host-side elementwise operations on TensorView.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// Standard Library includes
|
| 38 |
+
#include <utility>
|
| 39 |
+
|
| 40 |
+
// Cutlass includes
|
| 41 |
+
#include "cutlass/cutlass.h"
|
| 42 |
+
#include "tensor_foreach.h"
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
namespace reference {
|
| 46 |
+
namespace host {
|
| 47 |
+
|
| 48 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace detail {
|
| 51 |
+
|
| 52 |
+
/// Helper to convert between types
|
| 53 |
+
template <
|
| 54 |
+
typename DstElement,
|
| 55 |
+
typename SrcElement
|
| 56 |
+
>
|
| 57 |
+
struct TrivialConvert {
|
| 58 |
+
|
| 59 |
+
TrivialConvert() { }
|
| 60 |
+
|
| 61 |
+
DstElement operator()(SrcElement src) const {
|
| 62 |
+
return DstElement(src);
|
| 63 |
+
}
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
/// Helper to conditionally copy between tensor views.
|
| 67 |
+
template <
|
| 68 |
+
typename DstElement,
|
| 69 |
+
typename DstLayout,
|
| 70 |
+
typename SrcElement,
|
| 71 |
+
typename SrcLayout,
|
| 72 |
+
typename F
|
| 73 |
+
>
|
| 74 |
+
struct TensorCopyIf {
|
| 75 |
+
|
| 76 |
+
using DstTensorView = TensorView<DstElement, DstLayout>;
|
| 77 |
+
using SrcTensorView = TensorView<SrcElement, SrcLayout>;
|
| 78 |
+
|
| 79 |
+
//
|
| 80 |
+
// Data members
|
| 81 |
+
//
|
| 82 |
+
|
| 83 |
+
DstTensorView dst;
|
| 84 |
+
SrcTensorView src;
|
| 85 |
+
F convert;
|
| 86 |
+
|
| 87 |
+
//
|
| 88 |
+
// Methods
|
| 89 |
+
//
|
| 90 |
+
|
| 91 |
+
TensorCopyIf() { }
|
| 92 |
+
|
| 93 |
+
TensorCopyIf(
|
| 94 |
+
DstTensorView const &dst_,
|
| 95 |
+
SrcTensorView const &src_,
|
| 96 |
+
F const &convert_): dst(dst_), src(src_), convert(convert_) {}
|
| 97 |
+
|
| 98 |
+
/// Copies based on destination and source bounds
|
| 99 |
+
void operator()(Coord<DstLayout::kRank> const &coord) {
|
| 100 |
+
if (dst.contains(coord) && src.contains(coord)) {
|
| 101 |
+
dst.at(coord) = convert(src.at(coord));
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
} // namespace detail
|
| 107 |
+
|
| 108 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 109 |
+
|
| 110 |
+
/// Copies elements from one tensor view into another, satisfying bounds of each tensor.
|
| 111 |
+
template <
|
| 112 |
+
typename DstElement, /// Destination tensor's element type
|
| 113 |
+
typename DstLayout, /// Destination tensor's layout
|
| 114 |
+
typename SrcElement, /// Source tensor's element type
|
| 115 |
+
typename SrcLayout, /// Source tensor's layout
|
| 116 |
+
typename F /// Transformation functor
|
| 117 |
+
>
|
| 118 |
+
void TensorCopy(
|
| 119 |
+
TensorView<DstElement, DstLayout> dst,
|
| 120 |
+
TensorView<SrcElement, SrcLayout> src,
|
| 121 |
+
F const &transform) {
|
| 122 |
+
|
| 123 |
+
using CopyIf = detail::TensorCopyIf<
|
| 124 |
+
DstElement,
|
| 125 |
+
DstLayout,
|
| 126 |
+
SrcElement,
|
| 127 |
+
SrcLayout,
|
| 128 |
+
F>;
|
| 129 |
+
|
| 130 |
+
CopyIf copy_if(dst, src, transform);
|
| 131 |
+
|
| 132 |
+
TensorForEach(dst.extent(), copy_if);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 137 |
+
|
| 138 |
+
/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent
|
| 139 |
+
/// to avoid out of bounds accesses.
|
| 140 |
+
template <
|
| 141 |
+
typename DstElement, /// Destination tensor's element type
|
| 142 |
+
typename DstLayout, /// Destination tensor's layout
|
| 143 |
+
typename SrcElement, /// Source tensor's element type
|
| 144 |
+
typename SrcLayout, /// Source tensor's layout
|
| 145 |
+
typename F /// Transformation functor
|
| 146 |
+
>
|
| 147 |
+
void TensorCopy(
|
| 148 |
+
TensorView<DstElement, DstLayout> dst,
|
| 149 |
+
TensorRef<SrcElement, SrcLayout> src,
|
| 150 |
+
F const &transform) {
|
| 151 |
+
|
| 152 |
+
using CopyIf = detail::TensorCopyIf<
|
| 153 |
+
DstElement,
|
| 154 |
+
DstLayout,
|
| 155 |
+
SrcElement,
|
| 156 |
+
SrcLayout,
|
| 157 |
+
F>;
|
| 158 |
+
|
| 159 |
+
TensorView<SrcElement, SrcLayout> src_view(src, dst.extent());
|
| 160 |
+
|
| 161 |
+
CopyIf copy_if(dst, src_view, transform);
|
| 162 |
+
|
| 163 |
+
TensorForEach(dst.extent(), copy_if);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent
|
| 167 |
+
/// to avoid out of bounds accesses.
|
| 168 |
+
template <
|
| 169 |
+
typename DstElement, /// Destination tensor's element type
|
| 170 |
+
typename DstLayout, /// Destination tensor's layout
|
| 171 |
+
typename SrcElement, /// Source tensor's element type
|
| 172 |
+
typename SrcLayout, /// Source tensor's layout
|
| 173 |
+
typename F /// Transformation functor
|
| 174 |
+
>
|
| 175 |
+
void TensorCopy(
|
| 176 |
+
TensorRef<DstElement, DstLayout> dst,
|
| 177 |
+
TensorView<SrcElement, SrcLayout> src,
|
| 178 |
+
F const &transform) {
|
| 179 |
+
|
| 180 |
+
using CopyIf = detail::TensorCopyIf<
|
| 181 |
+
DstElement,
|
| 182 |
+
DstLayout,
|
| 183 |
+
SrcElement,
|
| 184 |
+
SrcLayout,
|
| 185 |
+
F>;
|
| 186 |
+
|
| 187 |
+
TensorView<DstElement, DstLayout> dst_view(dst, src.extent());
|
| 188 |
+
|
| 189 |
+
CopyIf copy_if(dst_view, src, transform);
|
| 190 |
+
|
| 191 |
+
TensorForEach(src.extent(), copy_if);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 195 |
+
|
| 196 |
+
/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds
|
| 197 |
+
/// if SrcElement can be converted to DstElement.
|
| 198 |
+
template <
|
| 199 |
+
typename DstElement, /// Destination tensor's element type
|
| 200 |
+
typename DstLayout, /// Destination tensor's layout
|
| 201 |
+
typename SrcElement, /// Source tensor's element type
|
| 202 |
+
typename SrcLayout /// Source tensor's layout
|
| 203 |
+
>
|
| 204 |
+
void TensorCopy(
|
| 205 |
+
TensorView<DstElement, DstLayout> dst,
|
| 206 |
+
TensorView<SrcElement, SrcLayout> src) {
|
| 207 |
+
|
| 208 |
+
detail::TrivialConvert<DstElement, SrcElement> convert;
|
| 209 |
+
|
| 210 |
+
TensorCopy(dst, src, convert);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 214 |
+
|
| 215 |
+
/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds
|
| 216 |
+
/// if SrcElement can be converted to DstElement.
|
| 217 |
+
template <
|
| 218 |
+
typename DstElement, /// Destination tensor's element type
|
| 219 |
+
typename DstLayout, /// Destination tensor's layout
|
| 220 |
+
typename SrcElement, /// Source tensor's element type
|
| 221 |
+
typename SrcLayout, /// Source tensor's layout
|
| 222 |
+
typename F /// Transformation functor
|
| 223 |
+
>
|
| 224 |
+
void TensorCopy(
|
| 225 |
+
TensorView<DstElement, DstLayout> dst,
|
| 226 |
+
TensorRef<SrcElement, SrcLayout> src) {
|
| 227 |
+
|
| 228 |
+
detail::TrivialConvert<DstElement, SrcElement> convert;
|
| 229 |
+
|
| 230 |
+
TensorCopy(dst, src, convert);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 234 |
+
|
| 235 |
+
/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds
|
| 236 |
+
/// if SrcElement can be converted to DstElement.
|
| 237 |
+
template <
|
| 238 |
+
typename DstElement, /// Destination tensor's element type
|
| 239 |
+
typename DstLayout, /// Destination tensor's layout
|
| 240 |
+
typename SrcElement, /// Source tensor's element type
|
| 241 |
+
typename SrcLayout /// Source tensor's layout
|
| 242 |
+
>
|
| 243 |
+
void TensorCopy(
|
| 244 |
+
TensorRef<DstElement, DstLayout> dst,
|
| 245 |
+
TensorView<SrcElement, SrcLayout> src) {
|
| 246 |
+
|
| 247 |
+
detail::TrivialConvert<DstElement, SrcElement> convert;
|
| 248 |
+
|
| 249 |
+
TensorCopy(dst, src, convert);
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 253 |
+
|
| 254 |
+
} // namespace host
|
| 255 |
+
} // namespace reference
|
| 256 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Defines host-side elementwise operations on TensorView.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// Cutlass includes
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/functional.h"
|
| 40 |
+
|
| 41 |
+
#include "tensor_foreach.h"
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace reference {
|
| 45 |
+
namespace host {
|
| 46 |
+
|
| 47 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace detail {
|
| 51 |
+
|
| 52 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Helper to apply a binary operator in place
|
| 55 |
+
template <
|
| 56 |
+
typename ElementA,
|
| 57 |
+
typename LayoutA,
|
| 58 |
+
typename ElementB,
|
| 59 |
+
typename LayoutB,
|
| 60 |
+
typename ElementD,
|
| 61 |
+
typename LayoutD,
|
| 62 |
+
typename BinaryFunc>
|
| 63 |
+
struct TensorFuncBinaryOp {
|
| 64 |
+
|
| 65 |
+
//
|
| 66 |
+
// Data members
|
| 67 |
+
//
|
| 68 |
+
|
| 69 |
+
/// View of left-hand-side tensor
|
| 70 |
+
TensorView<ElementD, LayoutD> view_d;
|
| 71 |
+
TensorRef<ElementA, LayoutA> view_a;
|
| 72 |
+
TensorRef<ElementB, LayoutB> view_b;
|
| 73 |
+
BinaryFunc func;
|
| 74 |
+
|
| 75 |
+
//
|
| 76 |
+
// Methods
|
| 77 |
+
//
|
| 78 |
+
|
| 79 |
+
/// Constructor
|
| 80 |
+
TensorFuncBinaryOp() { }
|
| 81 |
+
|
| 82 |
+
/// Constructor
|
| 83 |
+
TensorFuncBinaryOp(
|
| 84 |
+
TensorView<ElementD, LayoutD> const & view_d_,
|
| 85 |
+
TensorRef<ElementA, LayoutA> const & view_a_,
|
| 86 |
+
TensorRef<ElementB, LayoutB> const & view_b_,
|
| 87 |
+
BinaryFunc func = BinaryFunc()
|
| 88 |
+
):
|
| 89 |
+
view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { }
|
| 90 |
+
|
| 91 |
+
/// Equality check
|
| 92 |
+
void operator()(Coord<LayoutD::kRank> const &coord) const {
|
| 93 |
+
view_d.at(coord) = func(
|
| 94 |
+
ElementD(view_a.at(coord)),
|
| 95 |
+
ElementD(view_b.at(coord))
|
| 96 |
+
);
|
| 97 |
+
}
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
} // namespace detail
|
| 101 |
+
|
| 102 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 103 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 104 |
+
|
| 105 |
+
/// Adds two tensors and stores in the destination tensor: d = a + b
|
| 106 |
+
template <
|
| 107 |
+
typename ElementD,
|
| 108 |
+
typename LayoutD,
|
| 109 |
+
typename ElementA,
|
| 110 |
+
typename LayoutA,
|
| 111 |
+
typename ElementB,
|
| 112 |
+
typename LayoutB
|
| 113 |
+
>
|
| 114 |
+
void TensorAdd(
|
| 115 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 116 |
+
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
| 117 |
+
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
| 118 |
+
) {
|
| 119 |
+
|
| 120 |
+
detail::TensorFuncBinaryOp<
|
| 121 |
+
ElementD,
|
| 122 |
+
LayoutD,
|
| 123 |
+
ElementA,
|
| 124 |
+
LayoutA,
|
| 125 |
+
ElementB,
|
| 126 |
+
LayoutB,
|
| 127 |
+
cutlass::plus<ElementD>
|
| 128 |
+
> func(d, a, b);
|
| 129 |
+
|
| 130 |
+
TensorForEach(
|
| 131 |
+
d.extent(),
|
| 132 |
+
func);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
/// Adds a tensor in place: d = d .+ a
|
| 136 |
+
template <
|
| 137 |
+
typename ElementD,
|
| 138 |
+
typename LayoutD,
|
| 139 |
+
typename ElementA,
|
| 140 |
+
typename LayoutA
|
| 141 |
+
>
|
| 142 |
+
void TensorAdd(
|
| 143 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 144 |
+
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
| 145 |
+
) {
|
| 146 |
+
TensorAdd(d, d, a);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 150 |
+
|
| 151 |
+
/// Subtracts two tensors and stores in the destination tensor: d = a - b
|
| 152 |
+
template <
|
| 153 |
+
typename ElementD,
|
| 154 |
+
typename LayoutD,
|
| 155 |
+
typename ElementA,
|
| 156 |
+
typename LayoutA,
|
| 157 |
+
typename ElementB,
|
| 158 |
+
typename LayoutB
|
| 159 |
+
>
|
| 160 |
+
void TensorSub(
|
| 161 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 162 |
+
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
| 163 |
+
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
| 164 |
+
) {
|
| 165 |
+
|
| 166 |
+
detail::TensorFuncBinaryOp<
|
| 167 |
+
ElementD,
|
| 168 |
+
LayoutD,
|
| 169 |
+
ElementA,
|
| 170 |
+
LayoutA,
|
| 171 |
+
ElementB,
|
| 172 |
+
LayoutB,
|
| 173 |
+
cutlass::minus<ElementD>
|
| 174 |
+
> func(d, a, b);
|
| 175 |
+
|
| 176 |
+
TensorForEach(
|
| 177 |
+
d.extent(),
|
| 178 |
+
func);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
/// Subtracts two tensors in place: d = d .- a
|
| 182 |
+
template <
|
| 183 |
+
typename ElementD,
|
| 184 |
+
typename LayoutD,
|
| 185 |
+
typename ElementA,
|
| 186 |
+
typename LayoutA,
|
| 187 |
+
typename ElementB,
|
| 188 |
+
typename LayoutB
|
| 189 |
+
>
|
| 190 |
+
void TensorSub(
|
| 191 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 192 |
+
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
| 193 |
+
) {
|
| 194 |
+
|
| 195 |
+
TensorSub(d, d, a);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 199 |
+
|
| 200 |
+
/// Multiplies two tensors and stores in the destination tensor: d = a .* b
|
| 201 |
+
template <
|
| 202 |
+
typename ElementD,
|
| 203 |
+
typename LayoutD,
|
| 204 |
+
typename ElementA,
|
| 205 |
+
typename LayoutA,
|
| 206 |
+
typename ElementB,
|
| 207 |
+
typename LayoutB
|
| 208 |
+
>
|
| 209 |
+
void TensorMul(
|
| 210 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 211 |
+
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
| 212 |
+
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
| 213 |
+
) {
|
| 214 |
+
|
| 215 |
+
detail::TensorFuncBinaryOp<
|
| 216 |
+
ElementD,
|
| 217 |
+
LayoutD,
|
| 218 |
+
ElementA,
|
| 219 |
+
LayoutA,
|
| 220 |
+
ElementB,
|
| 221 |
+
LayoutB,
|
| 222 |
+
cutlass::multiplies<ElementD>
|
| 223 |
+
> func(d, a, b);
|
| 224 |
+
|
| 225 |
+
TensorForEach(
|
| 226 |
+
d.extent(),
|
| 227 |
+
func);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
/// Multiplies tensors in place: d = d .* a
|
| 231 |
+
template <
|
| 232 |
+
typename ElementD,
|
| 233 |
+
typename LayoutD,
|
| 234 |
+
typename ElementA,
|
| 235 |
+
typename LayoutA
|
| 236 |
+
>
|
| 237 |
+
void TensorMul(
|
| 238 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 239 |
+
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
| 240 |
+
) {
|
| 241 |
+
TensorMul(d, d, a);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 245 |
+
|
| 246 |
+
/// Divides two tensors and stores in the destination tensor: d = a ./ b
|
| 247 |
+
template <
|
| 248 |
+
typename ElementD,
|
| 249 |
+
typename LayoutD,
|
| 250 |
+
typename ElementA,
|
| 251 |
+
typename LayoutA,
|
| 252 |
+
typename ElementB,
|
| 253 |
+
typename LayoutB
|
| 254 |
+
>
|
| 255 |
+
void TensorDiv(
|
| 256 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 257 |
+
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
| 258 |
+
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
| 259 |
+
) {
|
| 260 |
+
|
| 261 |
+
detail::TensorFuncBinaryOp<
|
| 262 |
+
ElementD,
|
| 263 |
+
LayoutD,
|
| 264 |
+
ElementA,
|
| 265 |
+
LayoutA,
|
| 266 |
+
ElementB,
|
| 267 |
+
LayoutB,
|
| 268 |
+
cutlass::divides<ElementD>
|
| 269 |
+
> func(d, a, b);
|
| 270 |
+
|
| 271 |
+
TensorForEach(
|
| 272 |
+
d.extent(),
|
| 273 |
+
func);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
/// Divides tensors in place: d = d ./ a
|
| 277 |
+
template <
|
| 278 |
+
typename ElementD,
|
| 279 |
+
typename LayoutD,
|
| 280 |
+
typename ElementA,
|
| 281 |
+
typename LayoutA
|
| 282 |
+
>
|
| 283 |
+
void TensorDiv(
|
| 284 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 285 |
+
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
| 286 |
+
) {
|
| 287 |
+
TensorDiv(d, d, a);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 292 |
+
|
| 293 |
+
/// Divides two tensors and stores in the destination tensor: d = a ./ b
|
| 294 |
+
template <
|
| 295 |
+
typename ElementD,
|
| 296 |
+
typename LayoutD,
|
| 297 |
+
typename ElementA,
|
| 298 |
+
typename LayoutA,
|
| 299 |
+
typename ElementB,
|
| 300 |
+
typename LayoutB
|
| 301 |
+
>
|
| 302 |
+
void TensorModulus(
|
| 303 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 304 |
+
TensorRef<ElementA, LayoutA> a, ///< A tensor reference
|
| 305 |
+
TensorRef<ElementB, LayoutB> b ///< B tensor reference
|
| 306 |
+
) {
|
| 307 |
+
|
| 308 |
+
detail::TensorFuncBinaryOp<
|
| 309 |
+
ElementD,
|
| 310 |
+
LayoutD,
|
| 311 |
+
ElementA,
|
| 312 |
+
LayoutA,
|
| 313 |
+
ElementB,
|
| 314 |
+
LayoutB,
|
| 315 |
+
cutlass::divides<ElementD>
|
| 316 |
+
> func(d, a, b);
|
| 317 |
+
|
| 318 |
+
TensorForEach(
|
| 319 |
+
d.extent(),
|
| 320 |
+
func);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
/// Divides tensors in place: d = d ./ a
|
| 324 |
+
template <
|
| 325 |
+
typename ElementD,
|
| 326 |
+
typename LayoutD,
|
| 327 |
+
typename ElementA,
|
| 328 |
+
typename LayoutA
|
| 329 |
+
>
|
| 330 |
+
void TensorModulus(
|
| 331 |
+
TensorView<ElementD, LayoutD> d, ///< destination tensor view
|
| 332 |
+
TensorRef<ElementA, LayoutA> a ///< A tensor reference
|
| 333 |
+
) {
|
| 334 |
+
TensorDiv(d, d, a);
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 338 |
+
|
| 339 |
+
} // namespace host
|
| 340 |
+
} // namespace reference
|
| 341 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h
ADDED
|
@@ -0,0 +1,1718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Provides several functions for filling tensors with data.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// Standard Library includes
|
| 38 |
+
#include <utility>
|
| 39 |
+
#include <cstdlib>
|
| 40 |
+
#include <cmath>
|
| 41 |
+
#include <random>
|
| 42 |
+
#include <stdexcept>
|
| 43 |
+
|
| 44 |
+
// Cutlass includes
|
| 45 |
+
#include "cutlass/cutlass.h"
|
| 46 |
+
#include "cutlass/complex.h"
|
| 47 |
+
#include "cutlass/quaternion.h"
|
| 48 |
+
#include "cutlass/array.h"
|
| 49 |
+
#include "cutlass/numeric_types.h"
|
| 50 |
+
#include "cutlass/subbyte_reference.h"
|
| 51 |
+
#include "cutlass/tensor_view.h"
|
| 52 |
+
#include "cutlass/tensor_view_planar_complex.h"
|
| 53 |
+
#include "cutlass/blas3.h"
|
| 54 |
+
|
| 55 |
+
#include "cutlass/util/distribution.h"
|
| 56 |
+
#include "tensor_foreach.h"
|
| 57 |
+
|
| 58 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace reference {
|
| 62 |
+
namespace host {
|
| 63 |
+
|
| 64 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
namespace detail {
|
| 68 |
+
|
| 69 |
+
template <
|
| 70 |
+
typename Element, ///< Element type
|
| 71 |
+
typename Layout> ///< Layout function
|
| 72 |
+
struct TensorFillFunc {
|
| 73 |
+
|
| 74 |
+
using TensorView = TensorView<Element, Layout>;
|
| 75 |
+
|
| 76 |
+
//
|
| 77 |
+
// Data members
|
| 78 |
+
//
|
| 79 |
+
|
| 80 |
+
TensorView view;
|
| 81 |
+
Element value;
|
| 82 |
+
|
| 83 |
+
//
|
| 84 |
+
// Methods
|
| 85 |
+
//
|
| 86 |
+
|
| 87 |
+
TensorFillFunc(
|
| 88 |
+
TensorView const &view_ = TensorView(),
|
| 89 |
+
Element value_ = Element(0)
|
| 90 |
+
): view(view_), value(value_) { }
|
| 91 |
+
|
| 92 |
+
void operator()(Coord<Layout::kRank> const & coord) const {
|
| 93 |
+
view.at(coord) = value;
|
| 94 |
+
}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method
|
| 98 |
+
struct BoxMullerFunc {
|
| 99 |
+
|
| 100 |
+
BoxMullerFunc() {}
|
| 101 |
+
|
| 102 |
+
void operator()(
|
| 103 |
+
double* rnd, ///< Size-2 vector to be filled with random values
|
| 104 |
+
double mean = 0, ///< Mean of the Gaussian distribution
|
| 105 |
+
double stddev = 1, ///< Standard deviation of the Gaussian distribution
|
| 106 |
+
double pi = std::acos(-1)) const {
|
| 107 |
+
|
| 108 |
+
double u1 = double(std::rand()) / double(RAND_MAX);
|
| 109 |
+
double u2 = double(std::rand()) / double(RAND_MAX);
|
| 110 |
+
rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
|
| 111 |
+
rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2);
|
| 112 |
+
rnd[0] = mean + stddev * rnd[0];
|
| 113 |
+
rnd[1] = mean + stddev * rnd[1];
|
| 114 |
+
}
|
| 115 |
+
};
|
| 116 |
+
} // namespace detail
|
| 117 |
+
|
| 118 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 119 |
+
|
| 120 |
+
/// Fills a tensor with a uniform value
|
| 121 |
+
template <
|
| 122 |
+
typename Element, ///< Element type
|
| 123 |
+
typename Layout> ///< Layout function
|
| 124 |
+
void TensorFill(
|
| 125 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 126 |
+
Element val = Element(0)) { ///< value to uniformly fill it with
|
| 127 |
+
|
| 128 |
+
detail::TensorFillFunc<Element, Layout> func(dst, val);
|
| 129 |
+
|
| 130 |
+
TensorForEach(
|
| 131 |
+
dst.extent(),
|
| 132 |
+
func
|
| 133 |
+
);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/// Fills a tensor with a uniform value
|
| 137 |
+
template <
|
| 138 |
+
typename Element, ///< Element type
|
| 139 |
+
typename Layout> ///< Layout function
|
| 140 |
+
void TensorFill(
|
| 141 |
+
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
|
| 142 |
+
cutlass::complex<Element> val = cutlass::complex<Element>(0)) { ///< value to uniformly fill it with
|
| 143 |
+
|
| 144 |
+
TensorFill(dst.view_real(), val.real());
|
| 145 |
+
TensorFill(dst.view_imag(), val.imag());
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 149 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 150 |
+
|
| 151 |
+
namespace detail {
|
| 152 |
+
|
| 153 |
+
template <typename Element>
|
| 154 |
+
struct RandomGaussianFunc {
|
| 155 |
+
|
| 156 |
+
uint64_t seed;
|
| 157 |
+
double mean;
|
| 158 |
+
double stddev;
|
| 159 |
+
int int_scale;
|
| 160 |
+
double pi;
|
| 161 |
+
double pnz;
|
| 162 |
+
bool exclude_zero;
|
| 163 |
+
|
| 164 |
+
//
|
| 165 |
+
// Methods
|
| 166 |
+
//
|
| 167 |
+
RandomGaussianFunc(
|
| 168 |
+
uint64_t seed_ = 0,
|
| 169 |
+
double mean_ = 0,
|
| 170 |
+
double stddev_ = 1,
|
| 171 |
+
int int_scale_ = -1,
|
| 172 |
+
double pnz_ = 1.0,
|
| 173 |
+
bool exclude_zero_ = false
|
| 174 |
+
):
|
| 175 |
+
seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) {
|
| 176 |
+
std::srand((unsigned)seed);
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
/// Compute random value and update RNG state
|
| 180 |
+
Element operator()() const {
|
| 181 |
+
|
| 182 |
+
// Box-Muller transform to generate random numbers with Normal distribution
|
| 183 |
+
double u1 = double(std::rand()) / double(RAND_MAX);
|
| 184 |
+
double u2 = double(std::rand()) / double(RAND_MAX);
|
| 185 |
+
|
| 186 |
+
// Compute Gaussian random value
|
| 187 |
+
double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
|
| 188 |
+
rnd = mean + stddev * rnd;
|
| 189 |
+
|
| 190 |
+
// Scale and convert final result
|
| 191 |
+
Element result;
|
| 192 |
+
|
| 193 |
+
// Sample from the Bernoulli distribution, and use the result to sample from the Gaussian
|
| 194 |
+
std::random_device rnd_device;
|
| 195 |
+
std::mt19937 bernoulli_rnd(rnd_device());
|
| 196 |
+
std::bernoulli_distribution bernoulli_dist(pnz);
|
| 197 |
+
bool bernoulli_result = bernoulli_dist(bernoulli_rnd);
|
| 198 |
+
|
| 199 |
+
// Sample from the Gaussian distribution for a nonzero element
|
| 200 |
+
if (bernoulli_result) {
|
| 201 |
+
if (int_scale >= 0) {
|
| 202 |
+
rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
| 203 |
+
result = static_cast<Element>(rnd);
|
| 204 |
+
}
|
| 205 |
+
else {
|
| 206 |
+
result = static_cast<Element>(rnd);
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
else {
|
| 210 |
+
result = static_cast<Element>(0);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
// Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros
|
| 214 |
+
if (exclude_zero && result == Element(0)) {
|
| 215 |
+
if (rnd > 0) {
|
| 216 |
+
rnd += 1;
|
| 217 |
+
} else {
|
| 218 |
+
rnd -= 1;
|
| 219 |
+
}
|
| 220 |
+
result = Element(rnd);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return result;
|
| 224 |
+
}
|
| 225 |
+
};
|
| 226 |
+
|
| 227 |
+
/// Partial specialization for initializing a complex value.
|
| 228 |
+
template <typename Element>
|
| 229 |
+
struct RandomGaussianFunc<complex<Element> > {
|
| 230 |
+
|
| 231 |
+
uint64_t seed;
|
| 232 |
+
double mean;
|
| 233 |
+
double stddev;
|
| 234 |
+
int int_scale;
|
| 235 |
+
double pi;
|
| 236 |
+
double pnz;
|
| 237 |
+
bool exclude_zero;
|
| 238 |
+
|
| 239 |
+
//
|
| 240 |
+
// Methods
|
| 241 |
+
//
|
| 242 |
+
RandomGaussianFunc(
|
| 243 |
+
uint64_t seed_ = 0,
|
| 244 |
+
double mean_ = 0,
|
| 245 |
+
double stddev_ = 1,
|
| 246 |
+
int int_scale_ = -1,
|
| 247 |
+
double pnz_ = 1.0,
|
| 248 |
+
bool exclude_zero_ = false
|
| 249 |
+
):
|
| 250 |
+
seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) {
|
| 251 |
+
std::srand((unsigned)seed);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
/// Compute random value and update RNG state
|
| 255 |
+
complex<Element> operator()() const {
|
| 256 |
+
|
| 257 |
+
Element reals[2];
|
| 258 |
+
|
| 259 |
+
double rnd[2];
|
| 260 |
+
detail::BoxMullerFunc func;
|
| 261 |
+
func(rnd, mean, stddev, pi);
|
| 262 |
+
|
| 263 |
+
// Sample from the Bernoulli distribution, and use the result to sample from the Gaussian
|
| 264 |
+
std::random_device rnd_device;
|
| 265 |
+
std::mt19937 bernoulli_rnd(rnd_device());
|
| 266 |
+
std::bernoulli_distribution bernoulli_dist(pnz);
|
| 267 |
+
bool bernoulli_result = bernoulli_dist(bernoulli_rnd);
|
| 268 |
+
|
| 269 |
+
// Sample from the Gaussian distribution for a nonzero element
|
| 270 |
+
if (bernoulli_result) {
|
| 271 |
+
if (int_scale >= 0) {
|
| 272 |
+
rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale)));
|
| 273 |
+
rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale)));
|
| 274 |
+
reals[0] = from_real<Element>(rnd[0] / double(1 << int_scale));
|
| 275 |
+
reals[1] = from_real<Element>(rnd[1] / double(1 << int_scale));
|
| 276 |
+
}
|
| 277 |
+
else {
|
| 278 |
+
reals[0] = from_real<Element>(rnd[0]);
|
| 279 |
+
reals[1] = from_real<Element>(rnd[1]);
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
else {
|
| 283 |
+
reals[0] = from_real<Element>(0);
|
| 284 |
+
reals[1] = from_real<Element>(0);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// Note that this will invalidate the above else statement because it unsets zero elements
|
| 288 |
+
if (exclude_zero &&
|
| 289 |
+
reals[0] == from_real<Element>(0.0) &&
|
| 290 |
+
reals[1] == from_real<Element>(0.0)) {
|
| 291 |
+
|
| 292 |
+
if (rnd[0] > 0.0) {
|
| 293 |
+
rnd[0] += 1.0;
|
| 294 |
+
} else {
|
| 295 |
+
rnd[0] -= 1.0;
|
| 296 |
+
}
|
| 297 |
+
reals[0] = from_real<Element>(rnd[0]);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
return complex<Element>(reals[0], reals[1]);
|
| 301 |
+
}
|
| 302 |
+
};
|
| 303 |
+
|
| 304 |
+
/// Partial specialization for initializing a complex value.
|
| 305 |
+
template <typename Element>
|
| 306 |
+
struct RandomGaussianFunc<Quaternion<Element> > {
|
| 307 |
+
|
| 308 |
+
uint64_t seed;
|
| 309 |
+
double mean;
|
| 310 |
+
double stddev;
|
| 311 |
+
int int_scale;
|
| 312 |
+
double pi;
|
| 313 |
+
double pnz;
|
| 314 |
+
bool exclude_zero;
|
| 315 |
+
|
| 316 |
+
//
|
| 317 |
+
// Methods
|
| 318 |
+
//
|
| 319 |
+
RandomGaussianFunc(
|
| 320 |
+
uint64_t seed_ = 0,
|
| 321 |
+
double mean_ = 0,
|
| 322 |
+
double stddev_ = 1,
|
| 323 |
+
int int_scale_ = -1,
|
| 324 |
+
double pnz_ = 1.0,
|
| 325 |
+
bool exclude_zero_ = false
|
| 326 |
+
):
|
| 327 |
+
seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) {
|
| 328 |
+
std::srand((unsigned)seed);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/// Compute random value and update RNG state
|
| 332 |
+
Quaternion<Element> operator()() const {
|
| 333 |
+
|
| 334 |
+
Element reals[4];
|
| 335 |
+
|
| 336 |
+
double rnd1[2];
|
| 337 |
+
double rnd2[2];
|
| 338 |
+
detail::BoxMullerFunc func;
|
| 339 |
+
func(rnd1, mean, stddev, pi);
|
| 340 |
+
func(rnd2, mean, stddev, pi);
|
| 341 |
+
|
| 342 |
+
// Sample from the Bernoulli distribution, and use the result to sample from the Gaussian
|
| 343 |
+
std::random_device rnd_device;
|
| 344 |
+
std::mt19937 bernoulli_rnd(rnd_device());
|
| 345 |
+
std::bernoulli_distribution bernoulli_dist(pnz);
|
| 346 |
+
bool bernoulli_result = bernoulli_dist(bernoulli_rnd);
|
| 347 |
+
|
| 348 |
+
// Sample from the Gaussian distribution for a nonzero element
|
| 349 |
+
if (bernoulli_result) {
|
| 350 |
+
if (int_scale >= 0) {
|
| 351 |
+
rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale)));
|
| 352 |
+
rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale)));
|
| 353 |
+
rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale)));
|
| 354 |
+
rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale)));
|
| 355 |
+
|
| 356 |
+
reals[0] = from_real<Element>(rnd1[0] / double(1 << int_scale));
|
| 357 |
+
reals[1] = from_real<Element>(rnd1[1] / double(1 << int_scale));
|
| 358 |
+
reals[2] = from_real<Element>(rnd2[0] / double(1 << int_scale));
|
| 359 |
+
reals[3] = from_real<Element>(rnd2[1] / double(1 << int_scale));
|
| 360 |
+
}
|
| 361 |
+
else {
|
| 362 |
+
reals[0] = from_real<Element>(rnd1[0]);
|
| 363 |
+
reals[1] = from_real<Element>(rnd1[1]);
|
| 364 |
+
reals[2] = from_real<Element>(rnd2[0]);
|
| 365 |
+
reals[3] = from_real<Element>(rnd2[1]);
|
| 366 |
+
}
|
| 367 |
+
}
|
| 368 |
+
else {
|
| 369 |
+
reals[0] = from_real<Element>(0);
|
| 370 |
+
reals[1] = from_real<Element>(0);
|
| 371 |
+
reals[2] = from_real<Element>(0);
|
| 372 |
+
reals[3] = from_real<Element>(0);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
// Note that this will invalidate the above else statement because it unsets zero elements
|
| 376 |
+
if (exclude_zero &&
|
| 377 |
+
reals[0] == from_real<Element>(0) &&
|
| 378 |
+
reals[1] == from_real<Element>(0) &&
|
| 379 |
+
reals[2] == from_real<Element>(0) &&
|
| 380 |
+
reals[3] == from_real<Element>(0)) {
|
| 381 |
+
|
| 382 |
+
if (rnd1[0] > 0.0) {
|
| 383 |
+
rnd1[0] += 1.0;
|
| 384 |
+
} else {
|
| 385 |
+
rnd1[0] -= 1.0;
|
| 386 |
+
}
|
| 387 |
+
reals[0] = from_real<Element>(rnd1[0]);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
return Quaternion<Element>(reals[0], reals[1], reals[2], reals[3]);
|
| 391 |
+
}
|
| 392 |
+
};
|
| 393 |
+
|
| 394 |
+
/// Computes a random Gaussian distribution
|
| 395 |
+
template <
|
| 396 |
+
typename Element, ///< Element type
|
| 397 |
+
typename Layout> ///< Layout function
|
| 398 |
+
struct TensorFillGaussianFunc {
|
| 399 |
+
|
| 400 |
+
using TensorView = TensorView<Element, Layout>;
|
| 401 |
+
|
| 402 |
+
//
|
| 403 |
+
// Data members
|
| 404 |
+
//
|
| 405 |
+
|
| 406 |
+
TensorView view;
|
| 407 |
+
RandomGaussianFunc<Element> func;
|
| 408 |
+
|
| 409 |
+
//
|
| 410 |
+
// Methods
|
| 411 |
+
//
|
| 412 |
+
|
| 413 |
+
/// Construction of Gaussian RNG functor.
|
| 414 |
+
TensorFillGaussianFunc(
|
| 415 |
+
TensorView view_ = TensorView(),
|
| 416 |
+
RandomGaussianFunc<Element> func_ = RandomGaussianFunc<Element>()
|
| 417 |
+
):
|
| 418 |
+
view(view_), func(func_) {
|
| 419 |
+
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
/// Compute random value and update RNG state
|
| 423 |
+
void operator()(Coord<Layout::kRank> const &coord) const {
|
| 424 |
+
view.at(coord) = func();
|
| 425 |
+
}
|
| 426 |
+
};
|
| 427 |
+
|
| 428 |
+
/// Computes a random Gaussian distribution for a rank-2 tensor
|
| 429 |
+
template <
|
| 430 |
+
typename Element, ///< Element type
|
| 431 |
+
typename Layout> ///< Layout function
|
| 432 |
+
struct TensorFillSymmetricGaussianFunc {
|
| 433 |
+
|
| 434 |
+
using TensorView = TensorView<Element, Layout>;
|
| 435 |
+
|
| 436 |
+
//
|
| 437 |
+
// Data members
|
| 438 |
+
//
|
| 439 |
+
|
| 440 |
+
TensorView view;
|
| 441 |
+
RandomGaussianFunc<Element> func;
|
| 442 |
+
cutlass::FillMode fill_mode;
|
| 443 |
+
|
| 444 |
+
//
|
| 445 |
+
// Methods
|
| 446 |
+
//
|
| 447 |
+
|
| 448 |
+
/// Construction of Gaussian RNG functor.
|
| 449 |
+
TensorFillSymmetricGaussianFunc(
|
| 450 |
+
TensorView view_ = TensorView(),
|
| 451 |
+
RandomGaussianFunc<Element> func_ = RandomGaussianFunc<Element>(),
|
| 452 |
+
cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid
|
| 453 |
+
):
|
| 454 |
+
view(view_), func(func_), fill_mode(fill_mode_) {
|
| 455 |
+
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
/// Compute random value and update RNG state
|
| 459 |
+
void operator()(Coord<Layout::kRank> const &coord) const {
|
| 460 |
+
// Fill half of matrix based on FillMode
|
| 461 |
+
if (Layout::kRank == 2 &&
|
| 462 |
+
fill_mode == cutlass::FillMode::kLower &&
|
| 463 |
+
coord[0] >= coord[1]) {
|
| 464 |
+
view.at(coord) = func();
|
| 465 |
+
} else if (Layout::kRank == 2 &&
|
| 466 |
+
fill_mode == cutlass::FillMode::kUpper &&
|
| 467 |
+
coord[0] <= coord[1]) {
|
| 468 |
+
view.at(coord) = func();
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
};
|
| 472 |
+
|
| 473 |
+
} // namespace detail
|
| 474 |
+
|
| 475 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 476 |
+
|
| 477 |
+
/// Fills a tensor with random values with a Gaussian distribution.
|
| 478 |
+
template <
|
| 479 |
+
typename Element, ///< Element type
|
| 480 |
+
typename Layout> ///< Layout function
|
| 481 |
+
void TensorFillRandomGaussian(
|
| 482 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 483 |
+
uint64_t seed, ///< seed for RNG
|
| 484 |
+
double mean = 0, ///< Gaussian distribution's mean
|
| 485 |
+
double stddev = 1, ///< Gaussian distribution's standard deviation
|
| 486 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 487 |
+
double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of
|
| 488 |
+
/// data.
|
| 489 |
+
bool exclude_zero = false) { ///< Exclude zeros from tensor init.
|
| 490 |
+
|
| 491 |
+
detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits, pnz, exclude_zero);
|
| 492 |
+
|
| 493 |
+
detail::TensorFillGaussianFunc<Element, Layout> func(
|
| 494 |
+
dst,
|
| 495 |
+
random_func
|
| 496 |
+
);
|
| 497 |
+
|
| 498 |
+
TensorForEach(
|
| 499 |
+
dst.extent(),
|
| 500 |
+
func
|
| 501 |
+
);
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
/// Fills a tensor with random values with a Gaussian distribution.
|
| 505 |
+
template <
|
| 506 |
+
typename Element, ///< Element type
|
| 507 |
+
typename Layout> ///< Layout function
|
| 508 |
+
void TensorFillRandomGaussian(
|
| 509 |
+
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
|
| 510 |
+
uint64_t seed, ///< seed for RNG
|
| 511 |
+
double mean = 0, ///< Gaussian distribution's mean
|
| 512 |
+
double stddev = 1, ///< Gaussian distribution's standard deviation
|
| 513 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 514 |
+
double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of
|
| 515 |
+
/// data.
|
| 516 |
+
bool exclude_zero = false) { ///< Exclude zeros from tensor init.
|
| 517 |
+
|
| 518 |
+
TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz);
|
| 519 |
+
TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz);
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 523 |
+
/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution.
|
| 524 |
+
template <
|
| 525 |
+
typename Element, ///< Element type
|
| 526 |
+
typename Layout> ///< Layout function
|
| 527 |
+
void TensorFillSymmetricRandomGaussian(
|
| 528 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 529 |
+
uint64_t seed, ///< seed for RNG
|
| 530 |
+
cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
|
| 531 |
+
double mean = 0, ///< Gaussian distribution's mean
|
| 532 |
+
double stddev = 1, ///< Gaussian distribution's standard deviation
|
| 533 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 534 |
+
double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of
|
| 535 |
+
/// data.
|
| 536 |
+
|
| 537 |
+
detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits, pnz);
|
| 538 |
+
|
| 539 |
+
detail::TensorFillSymmetricGaussianFunc<Element, Layout> func(
|
| 540 |
+
dst,
|
| 541 |
+
random_func,
|
| 542 |
+
fill_mode
|
| 543 |
+
);
|
| 544 |
+
|
| 545 |
+
TensorForEach(
|
| 546 |
+
dst.extent(),
|
| 547 |
+
func
|
| 548 |
+
);
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 552 |
+
|
| 553 |
+
/// Fills a tensor with random values of a Gaussian distribution.
|
| 554 |
+
template <
|
| 555 |
+
typename Element ///< Element type
|
| 556 |
+
>
|
| 557 |
+
void BlockFillRandomGaussian(
|
| 558 |
+
Element *ptr, ///< destination buffer
|
| 559 |
+
size_t capacity, ///< number of elements
|
| 560 |
+
uint64_t seed, ///< seed for RNG
|
| 561 |
+
double mean = 0, ///< Gaussian distribution's mean
|
| 562 |
+
double stddev = 1, ///< Gaussian distribution's standard deviation
|
| 563 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 564 |
+
double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of
|
| 565 |
+
/// data.
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits, pnz);
|
| 569 |
+
|
| 570 |
+
for (size_t i = 0; i < capacity; ++i) {
|
| 571 |
+
ReferenceFactory<Element>::get(ptr, i) = random_func();
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 576 |
+
|
| 577 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 578 |
+
|
| 579 |
+
namespace detail {
|
| 580 |
+
|
| 581 |
+
template <typename Element>
|
| 582 |
+
struct RandomUniformFunc {
|
| 583 |
+
|
| 584 |
+
using Real = typename RealType<Element>::Type;
|
| 585 |
+
|
| 586 |
+
uint64_t seed;
|
| 587 |
+
double range;
|
| 588 |
+
double min;
|
| 589 |
+
int int_scale;
|
| 590 |
+
|
| 591 |
+
double pnan;
|
| 592 |
+
private:
|
| 593 |
+
using engine_type = std::mt19937;
|
| 594 |
+
public:
|
| 595 |
+
engine_type bernoulli_rnd;
|
| 596 |
+
std::bernoulli_distribution bernoulli_dist;
|
| 597 |
+
|
| 598 |
+
bool exclude_zero;
|
| 599 |
+
|
| 600 |
+
RandomUniformFunc(
|
| 601 |
+
uint64_t seed_ = 0,
|
| 602 |
+
double max = 1,
|
| 603 |
+
double min_ = 0,
|
| 604 |
+
int int_scale_ = -1,
|
| 605 |
+
double pnan_ = 0,
|
| 606 |
+
bool exclude_zero_ = false
|
| 607 |
+
):
|
| 608 |
+
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_)
|
| 609 |
+
, bernoulli_rnd{static_cast<engine_type::result_type>(seed_)}
|
| 610 |
+
, bernoulli_dist(pnan_)
|
| 611 |
+
, exclude_zero(exclude_zero_)
|
| 612 |
+
{
|
| 613 |
+
std::srand((unsigned)seed);
|
| 614 |
+
|
| 615 |
+
// Handle cases where min = 0 or max = 0 for excluding zeros
|
| 616 |
+
if (exclude_zero) {
|
| 617 |
+
min = (min == 0.0) ? min + 1: min;
|
| 618 |
+
range = (max == 0.0) ? range - 1: range;
|
| 619 |
+
}
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
/// Compute random value and update RNG state
|
| 624 |
+
Element operator()() {
|
| 625 |
+
|
| 626 |
+
// Sample from NaN distribution.
|
| 627 |
+
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
| 628 |
+
if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
|
| 629 |
+
return Element(NAN);
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
double rnd = double(std::rand()) / double(RAND_MAX);
|
| 634 |
+
|
| 635 |
+
rnd = min + range * rnd;
|
| 636 |
+
|
| 637 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 638 |
+
// testing
|
| 639 |
+
Element result;
|
| 640 |
+
if (int_scale >= 0) {
|
| 641 |
+
rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
| 642 |
+
result = static_cast<Element>(Real(rnd));
|
| 643 |
+
}
|
| 644 |
+
else {
|
| 645 |
+
result = static_cast<Element>(Real(rnd));
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
if (exclude_zero && result == Element(0)) {
|
| 649 |
+
if (rnd > 0.0) {
|
| 650 |
+
rnd = std::min(min + range, rnd + 1.0);
|
| 651 |
+
} else {
|
| 652 |
+
rnd = std::max(min, rnd - 1.0);
|
| 653 |
+
}
|
| 654 |
+
result = static_cast<Element>(Real(rnd));
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
return result;
|
| 658 |
+
}
|
| 659 |
+
};
|
| 660 |
+
|
| 661 |
+
/// Partial specialization for initializing a complex value.
|
| 662 |
+
template <typename Element>
|
| 663 |
+
struct RandomUniformFunc<complex<Element> > {
|
| 664 |
+
|
| 665 |
+
using Real = typename RealType<Element>::Type;
|
| 666 |
+
|
| 667 |
+
uint64_t seed;
|
| 668 |
+
double range;
|
| 669 |
+
double min;
|
| 670 |
+
int int_scale;
|
| 671 |
+
|
| 672 |
+
double pnan;
|
| 673 |
+
private:
|
| 674 |
+
using engine_type = std::mt19937;
|
| 675 |
+
public:
|
| 676 |
+
engine_type bernoulli_rnd;
|
| 677 |
+
std::bernoulli_distribution bernoulli_dist;
|
| 678 |
+
|
| 679 |
+
bool exclude_zero;
|
| 680 |
+
|
| 681 |
+
//
|
| 682 |
+
// Methods
|
| 683 |
+
//
|
| 684 |
+
|
| 685 |
+
RandomUniformFunc(
|
| 686 |
+
uint64_t seed_ = 0,
|
| 687 |
+
double max = 1,
|
| 688 |
+
double min_ = 0,
|
| 689 |
+
int int_scale_ = -1,
|
| 690 |
+
double pnan_ = 0,
|
| 691 |
+
bool exclude_zero_ = false
|
| 692 |
+
):
|
| 693 |
+
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_)
|
| 694 |
+
, bernoulli_rnd{static_cast<engine_type::result_type>(seed_)}
|
| 695 |
+
, bernoulli_dist(pnan_)
|
| 696 |
+
, exclude_zero(exclude_zero_) {
|
| 697 |
+
std::srand((unsigned)seed);
|
| 698 |
+
|
| 699 |
+
// Handle cases where min = 0 or max = 0 for excluding zeros
|
| 700 |
+
if (exclude_zero) {
|
| 701 |
+
min = (min == 0.0) ? min + 1: min;
|
| 702 |
+
range = (max == 0.0) ? range - 1: range;
|
| 703 |
+
}
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
/// Compute random value and update RNG state
|
| 708 |
+
complex<Element> operator()() {
|
| 709 |
+
|
| 710 |
+
// Sample from NaN distribution.
|
| 711 |
+
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
| 712 |
+
if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
|
| 713 |
+
return Element(NAN);
|
| 714 |
+
}
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
Element reals[2];
|
| 718 |
+
|
| 719 |
+
for (int i = 0; i < 2; ++i) {
|
| 720 |
+
double rnd = double(std::rand()) / double(RAND_MAX);
|
| 721 |
+
|
| 722 |
+
rnd = min + range * rnd;
|
| 723 |
+
|
| 724 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 725 |
+
// testing
|
| 726 |
+
|
| 727 |
+
if (int_scale >= 0) {
|
| 728 |
+
rnd = double(std::llround(rnd * double(1 << int_scale)));
|
| 729 |
+
reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
|
| 730 |
+
}
|
| 731 |
+
else {
|
| 732 |
+
reals[i] = from_real<Element>(Real(rnd));
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
if (exclude_zero &&
|
| 736 |
+
i == 0 &&
|
| 737 |
+
reals[0] == from_real<Element>(0.0)) {
|
| 738 |
+
|
| 739 |
+
if (rnd > 0.0) {
|
| 740 |
+
rnd = std::min(min + range, rnd + 1.0);
|
| 741 |
+
} else {
|
| 742 |
+
rnd = std::max(min, rnd - 1.0);
|
| 743 |
+
}
|
| 744 |
+
reals[0] = from_real<Element>(Real(rnd));
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
return complex<Element>(reals[0], reals[1]);
|
| 750 |
+
}
|
| 751 |
+
};
|
| 752 |
+
|
| 753 |
+
/// Partial specialization for initializing a Quaternion value.
|
| 754 |
+
template <typename Element>
|
| 755 |
+
struct RandomUniformFunc<Quaternion<Element> > {
|
| 756 |
+
|
| 757 |
+
using Real = typename RealType<Element>::Type;
|
| 758 |
+
|
| 759 |
+
uint64_t seed;
|
| 760 |
+
double range;
|
| 761 |
+
double min;
|
| 762 |
+
int int_scale;
|
| 763 |
+
|
| 764 |
+
double pnan;
|
| 765 |
+
private:
|
| 766 |
+
using engine_type = std::mt19937;
|
| 767 |
+
public:
|
| 768 |
+
engine_type bernoulli_rnd;
|
| 769 |
+
std::bernoulli_distribution bernoulli_dist;
|
| 770 |
+
|
| 771 |
+
//
|
| 772 |
+
// Methods
|
| 773 |
+
//
|
| 774 |
+
|
| 775 |
+
RandomUniformFunc(
|
| 776 |
+
uint64_t seed_ = 0,
|
| 777 |
+
double max = 1,
|
| 778 |
+
double min_ = 0,
|
| 779 |
+
int int_scale_ = -1,
|
| 780 |
+
double pnan_ = 0
|
| 781 |
+
):
|
| 782 |
+
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_),
|
| 783 |
+
bernoulli_rnd{static_cast<engine_type::result_type>(seed_)},
|
| 784 |
+
bernoulli_dist(pnan_)
|
| 785 |
+
{
|
| 786 |
+
std::srand((unsigned)seed);
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
/// Compute random value and update RNG state
|
| 791 |
+
Quaternion<Element> operator()() {
|
| 792 |
+
|
| 793 |
+
// Sample from NaN distribution.
|
| 794 |
+
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
| 795 |
+
if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
|
| 796 |
+
return Element(NAN);
|
| 797 |
+
}
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
Element reals[4];
|
| 801 |
+
|
| 802 |
+
for (int i = 0; i < 4; ++i) {
|
| 803 |
+
double rnd = double(std::rand()) / double(RAND_MAX);
|
| 804 |
+
|
| 805 |
+
rnd = min + range * rnd;
|
| 806 |
+
|
| 807 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 808 |
+
// testing
|
| 809 |
+
|
| 810 |
+
if (int_scale >= 0) {
|
| 811 |
+
rnd = double(std::llround(rnd * double(1 << int_scale)));
|
| 812 |
+
reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
|
| 813 |
+
}
|
| 814 |
+
else {
|
| 815 |
+
reals[i] = from_real<Element>(Real(rnd));
|
| 816 |
+
}
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
return make_Quaternion(reals[0], reals[1], reals[2], reals[3]);
|
| 820 |
+
}
|
| 821 |
+
};
|
| 822 |
+
|
| 823 |
+
/// Computes a random uniform distribution
|
| 824 |
+
template <
|
| 825 |
+
typename Element, ///< Element type
|
| 826 |
+
typename Layout> ///< Layout function
|
| 827 |
+
struct TensorFillRandomUniformFunc {
|
| 828 |
+
|
| 829 |
+
using TensorView = TensorView<Element, Layout>;
|
| 830 |
+
|
| 831 |
+
//
|
| 832 |
+
// Data members
|
| 833 |
+
//
|
| 834 |
+
|
| 835 |
+
TensorView view;
|
| 836 |
+
RandomUniformFunc<Element> func;
|
| 837 |
+
|
| 838 |
+
//
|
| 839 |
+
// Methods
|
| 840 |
+
//
|
| 841 |
+
|
| 842 |
+
/// Construction of uniform RNG functor.
|
| 843 |
+
TensorFillRandomUniformFunc(
|
| 844 |
+
TensorView view_ = TensorView(),
|
| 845 |
+
RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>()
|
| 846 |
+
):
|
| 847 |
+
view(view_), func(func_) {
|
| 848 |
+
|
| 849 |
+
}
|
| 850 |
+
|
| 851 |
+
/// Compute random value and update RNG state
|
| 852 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 853 |
+
|
| 854 |
+
view.at(coord) = func();
|
| 855 |
+
}
|
| 856 |
+
};
|
| 857 |
+
|
| 858 |
+
/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution.
|
| 859 |
+
template <
|
| 860 |
+
typename Element, ///< Element type
|
| 861 |
+
typename Layout> ///< Layout function
|
| 862 |
+
struct TensorFillSymmetricRandomUniformFunc {
|
| 863 |
+
|
| 864 |
+
using TensorView = TensorView<Element, Layout>;
|
| 865 |
+
|
| 866 |
+
//
|
| 867 |
+
// Data members
|
| 868 |
+
//
|
| 869 |
+
|
| 870 |
+
TensorView view;
|
| 871 |
+
RandomUniformFunc<Element> func;
|
| 872 |
+
cutlass::FillMode fill_mode;
|
| 873 |
+
|
| 874 |
+
//
|
| 875 |
+
// Methods
|
| 876 |
+
//
|
| 877 |
+
|
| 878 |
+
/// Construction of uniform RNG functor.
|
| 879 |
+
TensorFillSymmetricRandomUniformFunc(
|
| 880 |
+
TensorView view_ = TensorView(),
|
| 881 |
+
RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>(),
|
| 882 |
+
cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid
|
| 883 |
+
):
|
| 884 |
+
view(view_), func(func_), fill_mode(fill_mode_) {
|
| 885 |
+
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
/// Compute random value and update RNG state
|
| 889 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 890 |
+
// Fill half of matrix based on FillMode
|
| 891 |
+
if (Layout::kRank == 2 &&
|
| 892 |
+
fill_mode == cutlass::FillMode::kLower &&
|
| 893 |
+
coord[0] >= coord[1]) {
|
| 894 |
+
view.at(coord) = func();
|
| 895 |
+
} else if (Layout::kRank == 2 &&
|
| 896 |
+
fill_mode == cutlass::FillMode::kUpper &&
|
| 897 |
+
coord[0] <= coord[1]) {
|
| 898 |
+
view.at(coord) = func();
|
| 899 |
+
}
|
| 900 |
+
}
|
| 901 |
+
};
|
| 902 |
+
|
| 903 |
+
/// Computes a random Uniform distribution and pads diagonal with zeros
|
| 904 |
+
template <
|
| 905 |
+
typename Element, ///< Element type
|
| 906 |
+
typename Layout> ///< Layout function
|
| 907 |
+
struct TensorFillPadDiagonalRandomUniformFunc {
|
| 908 |
+
|
| 909 |
+
using TensorView = TensorView<Element, Layout>;
|
| 910 |
+
|
| 911 |
+
//
|
| 912 |
+
// Data members
|
| 913 |
+
//
|
| 914 |
+
|
| 915 |
+
TensorView view;
|
| 916 |
+
RandomUniformFunc<Element> func;
|
| 917 |
+
cutlass::FillMode fill_mode;
|
| 918 |
+
int alignment;
|
| 919 |
+
|
| 920 |
+
//
|
| 921 |
+
// Methods
|
| 922 |
+
//
|
| 923 |
+
|
| 924 |
+
/// Construction of uniform RNG functor.
|
| 925 |
+
TensorFillPadDiagonalRandomUniformFunc(
|
| 926 |
+
TensorView view_ = TensorView(),
|
| 927 |
+
RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>(),
|
| 928 |
+
cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid,
|
| 929 |
+
int alignment_ = 1
|
| 930 |
+
):
|
| 931 |
+
view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) {
|
| 932 |
+
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
/// Compute random value and update RNG state
|
| 936 |
+
void operator()(Coord<Layout::kRank> const &coord) {
|
| 937 |
+
// Fill half of matrix based on FillMode
|
| 938 |
+
if (Layout::kRank == 2 &&
|
| 939 |
+
(fill_mode == cutlass::FillMode::kLower) &&
|
| 940 |
+
(coord[0] >= coord[1]) ||
|
| 941 |
+
((coord[1] - coord[0]) >= alignment)) {
|
| 942 |
+
view.at(coord) = func();
|
| 943 |
+
} else if (Layout::kRank == 2 &&
|
| 944 |
+
fill_mode == cutlass::FillMode::kUpper &&
|
| 945 |
+
(coord[0] <= coord[1]) ||
|
| 946 |
+
((coord[0] - coord[1]) >= alignment)) {
|
| 947 |
+
view.at(coord) = func();
|
| 948 |
+
}
|
| 949 |
+
}
|
| 950 |
+
};
|
| 951 |
+
|
| 952 |
+
} // namespace detail
|
| 953 |
+
|
| 954 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 955 |
+
|
| 956 |
+
/// Fills a tensor with random values of a uniform random distribution.
|
| 957 |
+
template <
|
| 958 |
+
typename Element, ///< Element type
|
| 959 |
+
typename Layout> ///< Layout function
|
| 960 |
+
void TensorFillRandomUniform(
|
| 961 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 962 |
+
uint64_t seed, ///< seed for RNG
|
| 963 |
+
double max = 1, ///< upper bound of distribution
|
| 964 |
+
double min = 0, ///< lower bound for distribution
|
| 965 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 966 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 967 |
+
/// data.
|
| 968 |
+
double pnan = 0, ///< Percentage of NaN elements.
|
| 969 |
+
bool exclude_zero = false) { ///< Exclude zero from tensor init
|
| 970 |
+
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits, pnan, exclude_zero);
|
| 971 |
+
|
| 972 |
+
detail::TensorFillRandomUniformFunc<Element, Layout> func(
|
| 973 |
+
dst,
|
| 974 |
+
random_func
|
| 975 |
+
);
|
| 976 |
+
|
| 977 |
+
TensorForEach(
|
| 978 |
+
dst.extent(),
|
| 979 |
+
func
|
| 980 |
+
);
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
/// Fills a tensor with random values of a uniform random distribution.
|
| 984 |
+
template <
|
| 985 |
+
typename Element, ///< Element type
|
| 986 |
+
typename Layout> ///< Layout function
|
| 987 |
+
void TensorFillRandomUniform(
|
| 988 |
+
TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
|
| 989 |
+
uint64_t seed, ///< seed for RNG
|
| 990 |
+
double max = 1, ///< upper bound of distribution
|
| 991 |
+
double min = 0, ///< lower bound for distribution
|
| 992 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 993 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 994 |
+
/// data.
|
| 995 |
+
double pnan = 0, ///< Percentage of NaN elements.
|
| 996 |
+
bool exclude_zero = false) { ///< Exclude zero from tensor init
|
| 997 |
+
|
| 998 |
+
TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero);
|
| 999 |
+
TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero);
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 1004 |
+
template <
|
| 1005 |
+
typename Element, ///< Element type
|
| 1006 |
+
typename Layout> ///< Layout function
|
| 1007 |
+
void TensorFillRandomUniform(
|
| 1008 |
+
TensorView<Quaternion<Element>, Layout> dst, ///< destination tensor
|
| 1009 |
+
uint64_t seed, ///< seed for RNG
|
| 1010 |
+
double max = 1, ///< upper bound of distribution
|
| 1011 |
+
double min = 0, ///< lower bound for distribution
|
| 1012 |
+
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
| 1013 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 1014 |
+
/// data.
|
| 1015 |
+
detail::RandomUniformFunc<Quaternion<Element>> random_func(seed, max, min, bits);
|
| 1016 |
+
|
| 1017 |
+
detail::TensorFillRandomUniformFunc<Quaternion<Element>, Layout> func(
|
| 1018 |
+
dst,
|
| 1019 |
+
random_func
|
| 1020 |
+
);
|
| 1021 |
+
|
| 1022 |
+
TensorForEach(
|
| 1023 |
+
dst.extent(),
|
| 1024 |
+
func
|
| 1025 |
+
);
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1029 |
+
|
| 1030 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 1031 |
+
template <
|
| 1032 |
+
typename Element, ///< Element type
|
| 1033 |
+
typename Layout> ///< Layout function
|
| 1034 |
+
void TensorFillSymmetricRandomUniform(
|
| 1035 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1036 |
+
uint64_t seed, ///< seed for RNG
|
| 1037 |
+
cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
|
| 1038 |
+
double max = 1, ///< upper bound of distribution
|
| 1039 |
+
double min = 0, ///< lower bound for distribution
|
| 1040 |
+
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
| 1041 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 1042 |
+
/// data.
|
| 1043 |
+
|
| 1044 |
+
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
|
| 1045 |
+
|
| 1046 |
+
detail::TensorFillSymmetricRandomUniformFunc<Element, Layout> func(
|
| 1047 |
+
dst,
|
| 1048 |
+
random_func,
|
| 1049 |
+
fill_mode
|
| 1050 |
+
);
|
| 1051 |
+
|
| 1052 |
+
TensorForEach(
|
| 1053 |
+
dst.extent(),
|
| 1054 |
+
func
|
| 1055 |
+
);
|
| 1056 |
+
}
|
| 1057 |
+
|
| 1058 |
+
/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal
|
| 1059 |
+
template <
|
| 1060 |
+
typename Element, ///< Element type
|
| 1061 |
+
typename Layout> ///< Layout function
|
| 1062 |
+
void TensorFillPadDiagonalRandomUniform(
|
| 1063 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1064 |
+
uint64_t seed, ///< seed for RNG
|
| 1065 |
+
cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
|
| 1066 |
+
double max = 1, ///< upper bound of distribution
|
| 1067 |
+
double min = 0, ///< lower bound for distribution
|
| 1068 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 1069 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 1070 |
+
/// data.
|
| 1071 |
+
int alignment = 1
|
| 1072 |
+
) {
|
| 1073 |
+
|
| 1074 |
+
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
|
| 1075 |
+
|
| 1076 |
+
detail::TensorFillPadDiagonalRandomUniformFunc<Element, Layout> func(
|
| 1077 |
+
dst,
|
| 1078 |
+
random_func,
|
| 1079 |
+
fill_mode,
|
| 1080 |
+
alignment
|
| 1081 |
+
);
|
| 1082 |
+
|
| 1083 |
+
TensorForEach(
|
| 1084 |
+
dst.extent(),
|
| 1085 |
+
func
|
| 1086 |
+
);
|
| 1087 |
+
}
|
| 1088 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1089 |
+
|
| 1090 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1091 |
+
|
| 1092 |
+
/// Fills a tensor with a uniform value
|
| 1093 |
+
template <
|
| 1094 |
+
typename Element ///< Element type
|
| 1095 |
+
>
|
| 1096 |
+
void BlockFill(
|
| 1097 |
+
Element *ptr,
|
| 1098 |
+
size_t capacity,
|
| 1099 |
+
Element val
|
| 1100 |
+
) {
|
| 1101 |
+
for (size_t i = 0; i < capacity; ++i) {
|
| 1102 |
+
ReferenceFactory<Element>::get(ptr, i) = val;
|
| 1103 |
+
}
|
| 1104 |
+
}
|
| 1105 |
+
|
| 1106 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 1107 |
+
template <
|
| 1108 |
+
typename Element ///< Element type
|
| 1109 |
+
>
|
| 1110 |
+
void BlockFillRandomUniform(
|
| 1111 |
+
Element *ptr,
|
| 1112 |
+
size_t capacity,
|
| 1113 |
+
uint64_t seed, ///< seed for RNG
|
| 1114 |
+
double max = 1, ///< upper bound of distribution
|
| 1115 |
+
double min = 0, ///< lower bound for distribution
|
| 1116 |
+
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
| 1117 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 1118 |
+
/// data.
|
| 1119 |
+
double pnan = 0) { ///< Percentage of NaN elements.
|
| 1120 |
+
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits, pnan);
|
| 1121 |
+
|
| 1122 |
+
for (size_t i = 0; i < capacity; ++i) {
|
| 1123 |
+
ReferenceFactory<Element>::get(ptr, i) = random_func();
|
| 1124 |
+
}
|
| 1125 |
+
}
|
| 1126 |
+
|
| 1127 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1128 |
+
|
| 1129 |
+
namespace detail {
|
| 1130 |
+
|
| 1131 |
+
template <
|
| 1132 |
+
typename Element, ///< Element type
|
| 1133 |
+
typename Layout> ///< Layout function
|
| 1134 |
+
struct TensorFillDiagonalFunc {
|
| 1135 |
+
|
| 1136 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1137 |
+
|
| 1138 |
+
//
|
| 1139 |
+
// Data members
|
| 1140 |
+
//
|
| 1141 |
+
|
| 1142 |
+
TensorView view;
|
| 1143 |
+
Element diag;
|
| 1144 |
+
Element other;
|
| 1145 |
+
|
| 1146 |
+
//
|
| 1147 |
+
// Methods
|
| 1148 |
+
//
|
| 1149 |
+
|
| 1150 |
+
TensorFillDiagonalFunc(
|
| 1151 |
+
TensorView const &view_ = TensorView(),
|
| 1152 |
+
Element diag_ = Element(1),
|
| 1153 |
+
Element other_ = Element(0)
|
| 1154 |
+
):
|
| 1155 |
+
view(view_), diag(diag_), other(other_) { }
|
| 1156 |
+
|
| 1157 |
+
void operator()(Coord<Layout::kRank> const & coord) const {
|
| 1158 |
+
bool is_diag = true;
|
| 1159 |
+
|
| 1160 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1161 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1162 |
+
if (coord[i] != coord[i - 1]) {
|
| 1163 |
+
is_diag = false;
|
| 1164 |
+
break;
|
| 1165 |
+
}
|
| 1166 |
+
}
|
| 1167 |
+
|
| 1168 |
+
view.at(coord) = (is_diag ? diag : other);
|
| 1169 |
+
}
|
| 1170 |
+
};
|
| 1171 |
+
|
| 1172 |
+
} // namespace detail
|
| 1173 |
+
|
| 1174 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1175 |
+
|
| 1176 |
+
/// Fills a tensor everywhere with a unique value for its diagonal.
|
| 1177 |
+
template <
|
| 1178 |
+
typename Element, ///< Element type
|
| 1179 |
+
typename Layout> ///< Layout function
|
| 1180 |
+
void TensorFillDiagonal(
|
| 1181 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1182 |
+
Element diag = Element(1), ///< value to write in the diagonal
|
| 1183 |
+
Element other = Element(0)) { ///< value to write off the diagonal
|
| 1184 |
+
|
| 1185 |
+
detail::TensorFillDiagonalFunc<Element, Layout> func(
|
| 1186 |
+
dst,
|
| 1187 |
+
diag,
|
| 1188 |
+
other
|
| 1189 |
+
);
|
| 1190 |
+
|
| 1191 |
+
TensorForEach(
|
| 1192 |
+
dst.extent(),
|
| 1193 |
+
func
|
| 1194 |
+
);
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1198 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1199 |
+
|
| 1200 |
+
/// Helper to fill a tensor's diagonal with 1 and 0 everywhere else.
|
| 1201 |
+
template <
|
| 1202 |
+
typename Element, ///< Element type
|
| 1203 |
+
typename Layout> ///< Layout function
|
| 1204 |
+
void TensorFillIdentity(
|
| 1205 |
+
TensorView<Element, Layout> dst) { ///< destination tensor
|
| 1206 |
+
|
| 1207 |
+
TensorFillDiagonal(dst, Element(1), Element(0));
|
| 1208 |
+
}
|
| 1209 |
+
|
| 1210 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1211 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1212 |
+
|
| 1213 |
+
/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements.
|
| 1214 |
+
template <
|
| 1215 |
+
typename Element, ///< Element type
|
| 1216 |
+
typename Layout> ///< Layout function
|
| 1217 |
+
void TensorUpdateDiagonal(
|
| 1218 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1219 |
+
Element val = Element(1)) {
|
| 1220 |
+
|
| 1221 |
+
typename Layout::Index extent = dst.extent().min();
|
| 1222 |
+
|
| 1223 |
+
for (typename Layout::Index i = 0; i < extent; ++i) {
|
| 1224 |
+
Coord<Layout::kRank> coord(i);
|
| 1225 |
+
dst.at(coord) = val;
|
| 1226 |
+
}
|
| 1227 |
+
}
|
| 1228 |
+
|
| 1229 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1230 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1231 |
+
|
| 1232 |
+
namespace detail {
|
| 1233 |
+
|
| 1234 |
+
template <
|
| 1235 |
+
typename Element, ///< Element type
|
| 1236 |
+
typename Layout> ///< Layout function
|
| 1237 |
+
struct TensorUpdateOffDiagonalFunc {
|
| 1238 |
+
|
| 1239 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1240 |
+
|
| 1241 |
+
//
|
| 1242 |
+
// Data members
|
| 1243 |
+
//
|
| 1244 |
+
|
| 1245 |
+
TensorView view;
|
| 1246 |
+
Element other;
|
| 1247 |
+
|
| 1248 |
+
//
|
| 1249 |
+
// Methods
|
| 1250 |
+
//
|
| 1251 |
+
|
| 1252 |
+
TensorUpdateOffDiagonalFunc(
|
| 1253 |
+
TensorView const &view_ = TensorView(),
|
| 1254 |
+
Element other_ = Element(0)
|
| 1255 |
+
):
|
| 1256 |
+
view(view_), other(other_) { }
|
| 1257 |
+
|
| 1258 |
+
void operator()(Coord<Layout::kRank> const & coord) const {
|
| 1259 |
+
bool is_diag = true;
|
| 1260 |
+
|
| 1261 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1262 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1263 |
+
if (coord[i] != coord[i - 1]) {
|
| 1264 |
+
is_diag = false;
|
| 1265 |
+
break;
|
| 1266 |
+
}
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
if (!is_diag) {
|
| 1270 |
+
view.at(coord) = other;
|
| 1271 |
+
}
|
| 1272 |
+
}
|
| 1273 |
+
};
|
| 1274 |
+
|
| 1275 |
+
} // namespace detail
|
| 1276 |
+
|
| 1277 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1278 |
+
|
| 1279 |
+
/// Writes a uniform value to all elements in the tensor without modifying diagonal elements.
|
| 1280 |
+
template <
|
| 1281 |
+
typename Element, ///< Element type
|
| 1282 |
+
typename Layout> ///< Layout function
|
| 1283 |
+
void TensorUpdateOffDiagonal(
|
| 1284 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1285 |
+
Element other = Element(1)) {
|
| 1286 |
+
|
| 1287 |
+
detail::TensorUpdateOffDiagonalFunc<Element, Layout> func(
|
| 1288 |
+
dst,
|
| 1289 |
+
other
|
| 1290 |
+
);
|
| 1291 |
+
|
| 1292 |
+
TensorForEach(
|
| 1293 |
+
dst.extent(),
|
| 1294 |
+
func
|
| 1295 |
+
);
|
| 1296 |
+
}
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1300 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1301 |
+
|
| 1302 |
+
namespace detail {
|
| 1303 |
+
|
| 1304 |
+
template <
|
| 1305 |
+
typename Element, ///< Element type
|
| 1306 |
+
typename Layout> ///< Layout function
|
| 1307 |
+
struct TensorFillLinearFunc {
|
| 1308 |
+
|
| 1309 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1310 |
+
|
| 1311 |
+
//
|
| 1312 |
+
// Data members
|
| 1313 |
+
//
|
| 1314 |
+
|
| 1315 |
+
TensorView view;
|
| 1316 |
+
Array<Element, Layout::kRank> v;
|
| 1317 |
+
Element s;
|
| 1318 |
+
|
| 1319 |
+
//
|
| 1320 |
+
// Methods
|
| 1321 |
+
//
|
| 1322 |
+
|
| 1323 |
+
TensorFillLinearFunc() { }
|
| 1324 |
+
|
| 1325 |
+
/// Constructs functor
|
| 1326 |
+
TensorFillLinearFunc(
|
| 1327 |
+
TensorView const &view_,
|
| 1328 |
+
Array<Element, Layout::kRank> const & v_,
|
| 1329 |
+
Element s_ = Element(0)
|
| 1330 |
+
):
|
| 1331 |
+
view(view_), v(v_), s(s_) { }
|
| 1332 |
+
|
| 1333 |
+
/// Updates the tensor
|
| 1334 |
+
void operator()(Coord<Layout::kRank> const & coord) const {
|
| 1335 |
+
|
| 1336 |
+
Element sum(s);
|
| 1337 |
+
|
| 1338 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1339 |
+
for (int i = 0; i < Layout::kRank; ++i) {
|
| 1340 |
+
sum += Element(coord[i]) * v[i];
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
view.at(coord) = sum;
|
| 1344 |
+
}
|
| 1345 |
+
};
|
| 1346 |
+
|
| 1347 |
+
} // namespace detail
|
| 1348 |
+
|
| 1349 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1350 |
+
|
| 1351 |
+
/// Fills tensor with a linear combination of its coordinate and another vector
|
| 1352 |
+
template <
|
| 1353 |
+
typename Element, ///< Element type
|
| 1354 |
+
typename Layout> ///< Layout function
|
| 1355 |
+
void TensorFillLinear(
|
| 1356 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1357 |
+
Array<Element, Layout::kRank> const & v,
|
| 1358 |
+
Element s = Element(0)) {
|
| 1359 |
+
|
| 1360 |
+
detail::TensorFillLinearFunc<Element, Layout> func(
|
| 1361 |
+
dst,
|
| 1362 |
+
v,
|
| 1363 |
+
s
|
| 1364 |
+
);
|
| 1365 |
+
|
| 1366 |
+
TensorForEach(
|
| 1367 |
+
dst.extent(),
|
| 1368 |
+
func
|
| 1369 |
+
);
|
| 1370 |
+
}
|
| 1371 |
+
|
| 1372 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1373 |
+
|
| 1374 |
+
/// Fills tensor with a linear combination of its coordinate and another vector
|
| 1375 |
+
template <
|
| 1376 |
+
typename Element, ///< Element type
|
| 1377 |
+
typename Layout> ///< Layout function
|
| 1378 |
+
void TensorFillSequential(
|
| 1379 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1380 |
+
Element s = Element(0)) {
|
| 1381 |
+
|
| 1382 |
+
Array<Element, Layout::kRank> stride;
|
| 1383 |
+
|
| 1384 |
+
stride[0] = Element(1);
|
| 1385 |
+
|
| 1386 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1387 |
+
for (int i = 1; i < Layout::kRank; ++i) {
|
| 1388 |
+
stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]);
|
| 1389 |
+
}
|
| 1390 |
+
|
| 1391 |
+
TensorFillLinear(dst, stride, s);
|
| 1392 |
+
}
|
| 1393 |
+
|
| 1394 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1395 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1396 |
+
|
| 1397 |
+
/// Fills a tensor with random values from a distribution.
|
| 1398 |
+
template <
|
| 1399 |
+
typename Element, ///< Element type
|
| 1400 |
+
typename Layout> ///< Layout function
|
| 1401 |
+
void TensorFillRandom(
|
| 1402 |
+
TensorView<Element, Layout> view, ///< destination tensor
|
| 1403 |
+
uint64_t seed,
|
| 1404 |
+
Distribution dist,
|
| 1405 |
+
bool exclude_zero = false ///< If true, excludes 0.
|
| 1406 |
+
/// Note that setting this flag will result in more 1's,
|
| 1407 |
+
/// as we use a simple mechanism to replace 0's by adding/subtracting 1's.
|
| 1408 |
+
) {
|
| 1409 |
+
|
| 1410 |
+
using Real = typename RealType<Element>::Type;
|
| 1411 |
+
|
| 1412 |
+
if (dist.kind == Distribution::Gaussian) {
|
| 1413 |
+
TensorFillRandomGaussian(
|
| 1414 |
+
view,
|
| 1415 |
+
seed,
|
| 1416 |
+
dist.gaussian.mean,
|
| 1417 |
+
dist.gaussian.stddev,
|
| 1418 |
+
dist.int_scale,
|
| 1419 |
+
dist.gaussian.pnz,
|
| 1420 |
+
exclude_zero);
|
| 1421 |
+
} else if (dist.kind == Distribution::Uniform) {
|
| 1422 |
+
TensorFillRandomUniform(
|
| 1423 |
+
view,
|
| 1424 |
+
seed,
|
| 1425 |
+
dist.uniform.max,
|
| 1426 |
+
dist.uniform.min,
|
| 1427 |
+
dist.int_scale,
|
| 1428 |
+
dist.uniform.pnan,
|
| 1429 |
+
exclude_zero);
|
| 1430 |
+
}
|
| 1431 |
+
}
|
| 1432 |
+
|
| 1433 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1434 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1435 |
+
|
| 1436 |
+
/// Fills a block of data with sequential elements
|
| 1437 |
+
template <
|
| 1438 |
+
typename Element
|
| 1439 |
+
>
|
| 1440 |
+
void BlockFillSequential(
|
| 1441 |
+
Element *ptr,
|
| 1442 |
+
int64_t capacity,
|
| 1443 |
+
Element v = Element(1),
|
| 1444 |
+
Element s = Element(0)) {
|
| 1445 |
+
int i = 0;
|
| 1446 |
+
|
| 1447 |
+
while (i < capacity) {
|
| 1448 |
+
cutlass::ReferenceFactory<Element, (cutlass::sizeof_bits<Element>::value <
|
| 1449 |
+
8)>::get(ptr, i) = s;
|
| 1450 |
+
|
| 1451 |
+
s = Element(s + v);
|
| 1452 |
+
++i;
|
| 1453 |
+
}
|
| 1454 |
+
}
|
| 1455 |
+
|
| 1456 |
+
/// Fills a block of data with sequential elements
|
| 1457 |
+
template <
|
| 1458 |
+
typename Element
|
| 1459 |
+
>
|
| 1460 |
+
void BlockFillSequentialModN(
|
| 1461 |
+
Element *ptr,
|
| 1462 |
+
int64_t capacity,
|
| 1463 |
+
int64_t mod,
|
| 1464 |
+
int64_t v = int64_t(1),
|
| 1465 |
+
int64_t s = int64_t(0)) {
|
| 1466 |
+
int i = 0;
|
| 1467 |
+
|
| 1468 |
+
while (i < capacity) {
|
| 1469 |
+
cutlass::ReferenceFactory<Element, (cutlass::sizeof_bits<Element>::value <
|
| 1470 |
+
8)>::get(ptr, i) = Element(s);
|
| 1471 |
+
|
| 1472 |
+
s = int64_t(s + v) % mod;
|
| 1473 |
+
++i;
|
| 1474 |
+
}
|
| 1475 |
+
}
|
| 1476 |
+
|
| 1477 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1478 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1479 |
+
|
| 1480 |
+
/// Fills a block of data with sequential elements
|
| 1481 |
+
template <
|
| 1482 |
+
typename Element
|
| 1483 |
+
>
|
| 1484 |
+
void BlockFillRandom(
|
| 1485 |
+
Element *ptr,
|
| 1486 |
+
size_t capacity,
|
| 1487 |
+
uint64_t seed,
|
| 1488 |
+
Distribution dist) {
|
| 1489 |
+
|
| 1490 |
+
if (dist.kind == Distribution::Gaussian) {
|
| 1491 |
+
BlockFillRandomGaussian<Element>(
|
| 1492 |
+
ptr,
|
| 1493 |
+
capacity,
|
| 1494 |
+
seed,
|
| 1495 |
+
dist.gaussian.mean,
|
| 1496 |
+
dist.gaussian.stddev,
|
| 1497 |
+
dist.int_scale,
|
| 1498 |
+
dist.gaussian.pnz);
|
| 1499 |
+
}
|
| 1500 |
+
else if (dist.kind == Distribution::Uniform) {
|
| 1501 |
+
BlockFillRandomUniform<Element>(
|
| 1502 |
+
ptr,
|
| 1503 |
+
capacity,
|
| 1504 |
+
seed,
|
| 1505 |
+
dist.uniform.max,
|
| 1506 |
+
dist.uniform.min,
|
| 1507 |
+
dist.int_scale,
|
| 1508 |
+
dist.uniform.pnan);
|
| 1509 |
+
}
|
| 1510 |
+
}
|
| 1511 |
+
|
| 1512 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1513 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1514 |
+
|
| 1515 |
+
namespace detail {
|
| 1516 |
+
|
| 1517 |
+
template <typename Element>
|
| 1518 |
+
struct RandomSparseMetaFunc {
|
| 1519 |
+
|
| 1520 |
+
uint64_t seed;
|
| 1521 |
+
int range;
|
| 1522 |
+
int MetaSizeInBits;
|
| 1523 |
+
|
| 1524 |
+
//
|
| 1525 |
+
// Methods
|
| 1526 |
+
//
|
| 1527 |
+
|
| 1528 |
+
RandomSparseMetaFunc(
|
| 1529 |
+
uint64_t seed_ = 0,
|
| 1530 |
+
int MetaSizeInBits_ = 2
|
| 1531 |
+
):
|
| 1532 |
+
seed(seed_), MetaSizeInBits(MetaSizeInBits_) {
|
| 1533 |
+
std::srand((unsigned)seed);
|
| 1534 |
+
if (MetaSizeInBits_ == 2) {
|
| 1535 |
+
range = 6;
|
| 1536 |
+
}
|
| 1537 |
+
else if (MetaSizeInBits_ == 4) {
|
| 1538 |
+
range = 2;
|
| 1539 |
+
}
|
| 1540 |
+
else {
|
| 1541 |
+
throw std::invalid_argument("Invalid MetaSizeInBits");
|
| 1542 |
+
}
|
| 1543 |
+
}
|
| 1544 |
+
|
| 1545 |
+
/// Compute random value and update RNG state
|
| 1546 |
+
Element operator()() const {
|
| 1547 |
+
Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe};
|
| 1548 |
+
Element TwoToOneMeta[2] = {0x4, 0xe};
|
| 1549 |
+
|
| 1550 |
+
Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta;
|
| 1551 |
+
|
| 1552 |
+
Element result = 0x0;
|
| 1553 |
+
|
| 1554 |
+
for (int i = 0; i < cutlass::sizeof_bits<Element>::value / 4; ++i) {
|
| 1555 |
+
int rnd = std::rand() % range;
|
| 1556 |
+
Element meta = MetaArray[rnd];
|
| 1557 |
+
|
| 1558 |
+
result = (Element)(result | ((Element)(meta << (i * 4))));
|
| 1559 |
+
}
|
| 1560 |
+
|
| 1561 |
+
return result;
|
| 1562 |
+
}
|
| 1563 |
+
};
|
| 1564 |
+
|
| 1565 |
+
/// Computes a random sparse meta
|
| 1566 |
+
template <
|
| 1567 |
+
typename Element, ///< Element type
|
| 1568 |
+
typename Layout> ///< Layout function
|
| 1569 |
+
struct TensorFillRandomSparseMetaFunc {
|
| 1570 |
+
|
| 1571 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1572 |
+
|
| 1573 |
+
//
|
| 1574 |
+
// Data members
|
| 1575 |
+
//
|
| 1576 |
+
|
| 1577 |
+
TensorView view;
|
| 1578 |
+
RandomSparseMetaFunc<Element> func;
|
| 1579 |
+
|
| 1580 |
+
//
|
| 1581 |
+
// Methods
|
| 1582 |
+
//
|
| 1583 |
+
|
| 1584 |
+
/// Construction of Gaussian RNG functor.
|
| 1585 |
+
TensorFillRandomSparseMetaFunc(
|
| 1586 |
+
TensorView view_ = TensorView(),
|
| 1587 |
+
RandomSparseMetaFunc<Element> func_ = RandomSparseMetaFunc<Element>()
|
| 1588 |
+
):
|
| 1589 |
+
view(view_), func(func_) {
|
| 1590 |
+
|
| 1591 |
+
}
|
| 1592 |
+
|
| 1593 |
+
/// Compute random value and update RNG state
|
| 1594 |
+
void operator()(Coord<Layout::kRank> const &coord) const {
|
| 1595 |
+
|
| 1596 |
+
view.at(coord) = func();
|
| 1597 |
+
}
|
| 1598 |
+
};
|
| 1599 |
+
|
| 1600 |
+
} // namespace detail
|
| 1601 |
+
|
| 1602 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1603 |
+
|
| 1604 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 1605 |
+
template <
|
| 1606 |
+
typename Element, ///< Element type
|
| 1607 |
+
typename Layout> ///< Layout function
|
| 1608 |
+
void TensorFillRandomSparseMeta(
|
| 1609 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1610 |
+
uint64_t seed, ///< seed for RNG
|
| 1611 |
+
int MetaSizeInBits) { ///< 2 bit or 4 bit
|
| 1612 |
+
|
| 1613 |
+
detail::RandomSparseMetaFunc<Element> random_func(seed, MetaSizeInBits);
|
| 1614 |
+
|
| 1615 |
+
detail::TensorFillRandomSparseMetaFunc<Element, Layout> func(
|
| 1616 |
+
dst,
|
| 1617 |
+
random_func
|
| 1618 |
+
);
|
| 1619 |
+
|
| 1620 |
+
TensorForEach(
|
| 1621 |
+
dst.extent(),
|
| 1622 |
+
func
|
| 1623 |
+
);
|
| 1624 |
+
}
|
| 1625 |
+
|
| 1626 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1627 |
+
|
| 1628 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 1629 |
+
template <
|
| 1630 |
+
typename Element ///< Element type
|
| 1631 |
+
>
|
| 1632 |
+
void BlockFillRandomSparseMeta(
|
| 1633 |
+
Element *ptr,
|
| 1634 |
+
size_t capacity,
|
| 1635 |
+
uint64_t seed, ///< seed for RNG
|
| 1636 |
+
int MetaSizeInBits) { ///< 2 bit or 4bit
|
| 1637 |
+
|
| 1638 |
+
detail::RandomSparseMetaFunc<Element> random_func(seed, MetaSizeInBits);
|
| 1639 |
+
|
| 1640 |
+
for (size_t i = 0; i < capacity; ++i) {
|
| 1641 |
+
ptr[i] = random_func();
|
| 1642 |
+
}
|
| 1643 |
+
}
|
| 1644 |
+
|
| 1645 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1646 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1647 |
+
|
| 1648 |
+
/// Fills a ell block index matrix with random values with a uniform random distribution.
|
| 1649 |
+
template <
|
| 1650 |
+
typename Element, ///< Element type
|
| 1651 |
+
typename Layout> ///< Layout function
|
| 1652 |
+
void TensorFillRandomEllIdx(
|
| 1653 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1654 |
+
uint64_t seed, ///< seed for RNG
|
| 1655 |
+
int rows, int ell_cols, int cols) { ///< dimension of the matrix
|
| 1656 |
+
|
| 1657 |
+
std::srand((unsigned)seed);
|
| 1658 |
+
|
| 1659 |
+
for (int i = 0; i < rows; ++i) {
|
| 1660 |
+
int col_idx = std::rand() % cols;
|
| 1661 |
+
|
| 1662 |
+
for (int j = 0; j < ell_cols; ++j) {
|
| 1663 |
+
dst.at({i, j}) = col_idx;
|
| 1664 |
+
|
| 1665 |
+
if (col_idx != -1) {
|
| 1666 |
+
if (col_idx == (cols - 1)) {
|
| 1667 |
+
col_idx = -1;
|
| 1668 |
+
} else {
|
| 1669 |
+
col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1;
|
| 1670 |
+
}
|
| 1671 |
+
}
|
| 1672 |
+
}
|
| 1673 |
+
}
|
| 1674 |
+
}
|
| 1675 |
+
|
| 1676 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1677 |
+
|
| 1678 |
+
/// Copies a diagonal in from host memory without modifying off-diagonal elements.
|
| 1679 |
+
template <
|
| 1680 |
+
typename Element, ///< Element type
|
| 1681 |
+
typename Layout> ///< Layout function
|
| 1682 |
+
void TensorCopyDiagonalIn(
|
| 1683 |
+
TensorView<Element, Layout> dst, ///< destination tensor
|
| 1684 |
+
Element const *ptr) { ///< dense buffer of elements
|
| 1685 |
+
|
| 1686 |
+
typename Layout::Index extent = dst.extent().min();
|
| 1687 |
+
|
| 1688 |
+
for (typename Layout::Index i = 0; i < extent; ++i) {
|
| 1689 |
+
Coord<Layout::kRank> coord(i);
|
| 1690 |
+
dst.at(coord) = ReferenceFactory<Element>::get(ptr, i);
|
| 1691 |
+
}
|
| 1692 |
+
}
|
| 1693 |
+
|
| 1694 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1695 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1696 |
+
|
| 1697 |
+
/// Copies the diagonal of a tensor into a dense buffer in host memory.
|
| 1698 |
+
template <
|
| 1699 |
+
typename Element, ///< Element type
|
| 1700 |
+
typename Layout> ///< Layout function
|
| 1701 |
+
void TensorCopyDiagonalOut(
|
| 1702 |
+
Element *ptr, ///< dense buffer of elements
|
| 1703 |
+
TensorView<Element, Layout> src) { ///< source tensor
|
| 1704 |
+
|
| 1705 |
+
typename Layout::Index extent = src.extent().min();
|
| 1706 |
+
|
| 1707 |
+
for (typename Layout::Index i = 0; i < extent; ++i) {
|
| 1708 |
+
Coord<Layout::kRank> coord(i);
|
| 1709 |
+
ReferenceFactory<Element>::get(ptr, i) = src.at(coord);
|
| 1710 |
+
}
|
| 1711 |
+
}
|
| 1712 |
+
|
| 1713 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1714 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1715 |
+
|
| 1716 |
+
} // namespace host
|
| 1717 |
+
} // namespace reference
|
| 1718 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Provides several functions for filling tensors with data.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// Standard Library includes
|
| 38 |
+
#include <utility>
|
| 39 |
+
#include <cstdlib>
|
| 40 |
+
#include <cmath>
|
| 41 |
+
|
| 42 |
+
// Cute includes
|
| 43 |
+
#include "cute/tensor.hpp"
|
| 44 |
+
|
| 45 |
+
// Cutlass includes
|
| 46 |
+
#include "cutlass/cutlass.h"
|
| 47 |
+
#include "cutlass/complex.h"
|
| 48 |
+
#include "cutlass/quaternion.h"
|
| 49 |
+
#include "cutlass/array.h"
|
| 50 |
+
#include "cutlass/numeric_types.h"
|
| 51 |
+
|
| 52 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace cutlass {
|
| 55 |
+
namespace reference {
|
| 56 |
+
namespace host {
|
| 57 |
+
|
| 58 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
//
|
| 60 |
+
// Uniform and procedural tensor fills
|
| 61 |
+
//
|
| 62 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
|
| 64 |
+
/// Fills a tensor with a scalar element
|
| 65 |
+
template <typename Tensor>
|
| 66 |
+
void TensorFill(Tensor dst, typename Tensor::value_type element) {
|
| 67 |
+
|
| 68 |
+
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
|
| 69 |
+
dst(idx) = element;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/// Fills a tensor with the contents of its layout
|
| 74 |
+
template <typename Tensor>
|
| 75 |
+
void TensorFillSequential(Tensor dst) {
|
| 76 |
+
|
| 77 |
+
auto layout = dst.layout();
|
| 78 |
+
|
| 79 |
+
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
|
| 80 |
+
dst(idx) = layout(idx);
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 85 |
+
//
|
| 86 |
+
// Random uniform values
|
| 87 |
+
//
|
| 88 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 89 |
+
|
| 90 |
+
namespace detail {
|
| 91 |
+
|
| 92 |
+
template <typename Element>
|
| 93 |
+
struct RandomUniformFunc {
|
| 94 |
+
|
| 95 |
+
using Real = typename RealType<Element>::Type;
|
| 96 |
+
|
| 97 |
+
uint64_t seed;
|
| 98 |
+
double range;
|
| 99 |
+
double min;
|
| 100 |
+
int int_scale;
|
| 101 |
+
|
| 102 |
+
//
|
| 103 |
+
// Methods
|
| 104 |
+
//
|
| 105 |
+
|
| 106 |
+
RandomUniformFunc(
|
| 107 |
+
uint64_t seed_ = 0,
|
| 108 |
+
double max = 1,
|
| 109 |
+
double min_ = 0,
|
| 110 |
+
int int_scale_ = -1
|
| 111 |
+
):
|
| 112 |
+
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
|
| 113 |
+
std::srand((unsigned)seed);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
/// Compute random value and update RNG state
|
| 118 |
+
Element operator()() const {
|
| 119 |
+
|
| 120 |
+
double rnd = double(std::rand()) / double(RAND_MAX);
|
| 121 |
+
|
| 122 |
+
rnd = min + range * rnd;
|
| 123 |
+
|
| 124 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 125 |
+
// testing
|
| 126 |
+
Element result;
|
| 127 |
+
|
| 128 |
+
if (int_scale >= 0) {
|
| 129 |
+
rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
| 130 |
+
result = static_cast<Element>(Real(rnd));
|
| 131 |
+
}
|
| 132 |
+
else {
|
| 133 |
+
result = static_cast<Element>(Real(rnd));
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
return result;
|
| 137 |
+
}
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
/// Partial specialization for initializing a complex value.
|
| 141 |
+
template <typename Element>
|
| 142 |
+
struct RandomUniformFunc<complex<Element> > {
|
| 143 |
+
|
| 144 |
+
using Real = typename RealType<Element>::Type;
|
| 145 |
+
|
| 146 |
+
uint64_t seed;
|
| 147 |
+
double range;
|
| 148 |
+
double min;
|
| 149 |
+
int int_scale;
|
| 150 |
+
|
| 151 |
+
//
|
| 152 |
+
// Methods
|
| 153 |
+
//
|
| 154 |
+
|
| 155 |
+
RandomUniformFunc(
|
| 156 |
+
uint64_t seed_ = 0,
|
| 157 |
+
double max = 1,
|
| 158 |
+
double min_ = 0,
|
| 159 |
+
int int_scale_ = -1
|
| 160 |
+
):
|
| 161 |
+
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
|
| 162 |
+
std::srand((unsigned)seed);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
/// Compute random value and update RNG state
|
| 167 |
+
complex<Element> operator()() const {
|
| 168 |
+
|
| 169 |
+
Element reals[2];
|
| 170 |
+
|
| 171 |
+
for (int i = 0; i < 2; ++i) {
|
| 172 |
+
double rnd = double(std::rand()) / double(RAND_MAX);
|
| 173 |
+
|
| 174 |
+
rnd = min + range * rnd;
|
| 175 |
+
|
| 176 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 177 |
+
// testing
|
| 178 |
+
|
| 179 |
+
if (int_scale >= 0) {
|
| 180 |
+
rnd = double(int(rnd * double(1 << int_scale)));
|
| 181 |
+
reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
|
| 182 |
+
}
|
| 183 |
+
else {
|
| 184 |
+
reals[i] = from_real<Element>(Real(rnd));
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
return complex<Element>(reals[0], reals[1]);
|
| 189 |
+
}
|
| 190 |
+
};
|
| 191 |
+
|
| 192 |
+
/// Partial specialization for initializing a Quaternion value.
|
| 193 |
+
template <typename Element>
|
| 194 |
+
struct RandomUniformFunc<Quaternion<Element> > {
|
| 195 |
+
|
| 196 |
+
using Real = typename RealType<Element>::Type;
|
| 197 |
+
|
| 198 |
+
uint64_t seed;
|
| 199 |
+
double range;
|
| 200 |
+
double min;
|
| 201 |
+
int int_scale;
|
| 202 |
+
|
| 203 |
+
//
|
| 204 |
+
// Methods
|
| 205 |
+
//
|
| 206 |
+
|
| 207 |
+
RandomUniformFunc(
|
| 208 |
+
uint64_t seed_ = 0,
|
| 209 |
+
double max = 1,
|
| 210 |
+
double min_ = 0,
|
| 211 |
+
int int_scale_ = -1
|
| 212 |
+
):
|
| 213 |
+
seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
|
| 214 |
+
std::srand((unsigned)seed);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
/// Compute random value and update RNG state
|
| 219 |
+
Quaternion<Element> operator()() const {
|
| 220 |
+
|
| 221 |
+
Element reals[4];
|
| 222 |
+
|
| 223 |
+
for (int i = 0; i < 4; ++i) {
|
| 224 |
+
double rnd = double(std::rand()) / double(RAND_MAX);
|
| 225 |
+
|
| 226 |
+
rnd = min + range * rnd;
|
| 227 |
+
|
| 228 |
+
// Random values are cast to integer after scaling by a power of two to facilitate error
|
| 229 |
+
// testing
|
| 230 |
+
|
| 231 |
+
if (int_scale >= 0) {
|
| 232 |
+
rnd = double(int(rnd * double(1 << int_scale)));
|
| 233 |
+
reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
|
| 234 |
+
}
|
| 235 |
+
else {
|
| 236 |
+
reals[i] = from_real<Element>(Real(rnd));
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
return make_Quaternion(reals[0], reals[1], reals[2], reals[3]);
|
| 241 |
+
}
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
} // namespace detail
|
| 245 |
+
|
| 246 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 247 |
+
|
| 248 |
+
/// Fills a tensor with random values with a uniform random distribution.
|
| 249 |
+
template <typename Tensor> ///< Tensor object
|
| 250 |
+
void TensorFillRandomUniform(
|
| 251 |
+
Tensor dst, ///< destination tensor
|
| 252 |
+
uint64_t seed, ///< seed for RNG
|
| 253 |
+
double max = 1, ///< upper bound of distribution
|
| 254 |
+
double min = 0, ///< lower bound for distribution
|
| 255 |
+
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
| 256 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 257 |
+
/// data.
|
| 258 |
+
|
| 259 |
+
detail::RandomUniformFunc<typename Tensor::value_type> random_func(seed, max, min, bits);
|
| 260 |
+
|
| 261 |
+
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
|
| 262 |
+
dst(idx) = random_func();
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
/// Fills a block with random values with a uniform random distribution.
|
| 267 |
+
template <
|
| 268 |
+
typename Element ///< Element type
|
| 269 |
+
>
|
| 270 |
+
void BlockFillRandomUniform(
|
| 271 |
+
Element *ptr,
|
| 272 |
+
size_t capacity,
|
| 273 |
+
uint64_t seed, ///< seed for RNG
|
| 274 |
+
double max = 1, ///< upper bound of distribution
|
| 275 |
+
double min = 0, ///< lower bound for distribution
|
| 276 |
+
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
| 277 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 278 |
+
/// data.
|
| 279 |
+
detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
|
| 280 |
+
|
| 281 |
+
for (size_t i = 0; i < capacity; ++i) {
|
| 282 |
+
ptr[i] = random_func();
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 287 |
+
//
|
| 288 |
+
// Random Gaussian
|
| 289 |
+
//
|
| 290 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 291 |
+
|
| 292 |
+
namespace detail {
|
| 293 |
+
|
| 294 |
+
template <typename Element>
|
| 295 |
+
struct RandomGaussianFunc {
|
| 296 |
+
|
| 297 |
+
uint64_t seed;
|
| 298 |
+
double mean;
|
| 299 |
+
double stddev;
|
| 300 |
+
int int_scale;
|
| 301 |
+
double pi;
|
| 302 |
+
|
| 303 |
+
//
|
| 304 |
+
// Methods
|
| 305 |
+
//
|
| 306 |
+
RandomGaussianFunc(
|
| 307 |
+
uint64_t seed_ = 0,
|
| 308 |
+
double mean_ = 0,
|
| 309 |
+
double stddev_ = 1,
|
| 310 |
+
int int_scale_ = -1
|
| 311 |
+
):
|
| 312 |
+
seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) {
|
| 313 |
+
std::srand((unsigned)seed);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
/// Compute random value and update RNG state
|
| 317 |
+
Element operator()() const {
|
| 318 |
+
|
| 319 |
+
// Box-Muller transform to generate random numbers with Normal distribution
|
| 320 |
+
double u1 = double(std::rand()) / double(RAND_MAX);
|
| 321 |
+
double u2 = double(std::rand()) / double(RAND_MAX);
|
| 322 |
+
|
| 323 |
+
// Compute Gaussian random value
|
| 324 |
+
double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
|
| 325 |
+
rnd = mean + stddev * rnd;
|
| 326 |
+
|
| 327 |
+
// Scale and convert final result
|
| 328 |
+
Element result;
|
| 329 |
+
|
| 330 |
+
if (int_scale >= 0) {
|
| 331 |
+
rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
|
| 332 |
+
result = static_cast<Element>(rnd);
|
| 333 |
+
}
|
| 334 |
+
else {
|
| 335 |
+
result = static_cast<Element>(rnd);
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
return result;
|
| 339 |
+
}
|
| 340 |
+
};
|
| 341 |
+
|
| 342 |
+
} // namespace detail
|
| 343 |
+
|
| 344 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 345 |
+
|
| 346 |
+
/// Fills a tensor with random values with a Gaussian distribution.
|
| 347 |
+
template <
|
| 348 |
+
typename Tensor
|
| 349 |
+
>
|
| 350 |
+
void TensorFillRandomGaussian(
|
| 351 |
+
Tensor dst, ///< destination tensor
|
| 352 |
+
uint64_t seed, ///< seed for RNG
|
| 353 |
+
double mean = 0, ///< Gaussian distribution's mean
|
| 354 |
+
double stddev = 1, ///< Gaussian distribution's standard deviation
|
| 355 |
+
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
| 356 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 357 |
+
/// data.
|
| 358 |
+
|
| 359 |
+
detail::RandomGaussianFunc<typename Tensor::value_type> random_func(seed, mean, stddev, bits);
|
| 360 |
+
|
| 361 |
+
for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
|
| 362 |
+
dst(idx) = random_func();
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
/// Fills a block with random values with a Gaussian distribution.
|
| 367 |
+
template <
|
| 368 |
+
typename Element ///< Element type
|
| 369 |
+
>
|
| 370 |
+
void BlockFillRandomGaussian(
|
| 371 |
+
Element *ptr, ///< destination buffer
|
| 372 |
+
size_t capacity, ///< number of elements
|
| 373 |
+
uint64_t seed, ///< seed for RNG
|
| 374 |
+
double mean = 0, ///< Gaussian distribution's mean
|
| 375 |
+
double stddev = 1, ///< Gaussian distribution's standard deviation
|
| 376 |
+
int bits = -1) { ///< If non-negative, specifies number of fractional bits that
|
| 377 |
+
/// are not truncated to zero. Permits reducing precision of
|
| 378 |
+
/// data.
|
| 379 |
+
|
| 380 |
+
detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits);
|
| 381 |
+
|
| 382 |
+
for (size_t i = 0; i < capacity; ++i) {
|
| 383 |
+
ptr[i] = random_func();
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 388 |
+
|
| 389 |
+
/// Fills a block of data with sequential elements
|
| 390 |
+
template <
|
| 391 |
+
typename Element
|
| 392 |
+
>
|
| 393 |
+
void BlockFillSequential(
|
| 394 |
+
Element *ptr,
|
| 395 |
+
int64_t capacity,
|
| 396 |
+
Element v = Element(1),
|
| 397 |
+
Element s = Element(0)) {
|
| 398 |
+
int i = 0;
|
| 399 |
+
|
| 400 |
+
while (i < capacity) {
|
| 401 |
+
|
| 402 |
+
ptr[i] = Element(s + v);
|
| 403 |
+
++i;
|
| 404 |
+
}
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
/// Fills a block of data with sequential elements
|
| 408 |
+
template <
|
| 409 |
+
typename Element
|
| 410 |
+
>
|
| 411 |
+
void BlockFillSequentialModN(
|
| 412 |
+
Element *ptr,
|
| 413 |
+
int64_t capacity,
|
| 414 |
+
int64_t mod,
|
| 415 |
+
int64_t v = int64_t(1),
|
| 416 |
+
int64_t s = int64_t(0)) {
|
| 417 |
+
int i = 0;
|
| 418 |
+
|
| 419 |
+
while (i < capacity) {
|
| 420 |
+
|
| 421 |
+
ptr[i] = static_cast<Element>(int32_t(int64_t(s + v) % mod));
|
| 422 |
+
++i;
|
| 423 |
+
}
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 427 |
+
|
| 428 |
+
} // namespace host
|
| 429 |
+
} // namespace reference
|
| 430 |
+
} // namespace cutlass
|
| 431 |
+
|
| 432 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include <stdexcept>
|
| 34 |
+
#include "cutlass/cutlass.h"
|
| 35 |
+
|
| 36 |
+
namespace cutlass {
|
| 37 |
+
namespace reference {
|
| 38 |
+
namespace host {
|
| 39 |
+
|
| 40 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
/// Defines several helpers
|
| 43 |
+
namespace detail {
|
| 44 |
+
|
| 45 |
+
/// Helper to perform for-each operation
|
| 46 |
+
template <typename Func, int Rank, int RankRemaining>
|
| 47 |
+
struct TensorForEachHelper {
|
| 48 |
+
|
| 49 |
+
/// Index of the active rank
|
| 50 |
+
static int const kActiveRank = Rank - RankRemaining - 1;
|
| 51 |
+
|
| 52 |
+
/// Constructor for general rank
|
| 53 |
+
TensorForEachHelper(
|
| 54 |
+
Func &func,
|
| 55 |
+
Coord<Rank> const &extent,
|
| 56 |
+
Coord<Rank> &coord) {
|
| 57 |
+
|
| 58 |
+
for (int i = 0; i < extent.at(kActiveRank); ++i) {
|
| 59 |
+
coord[kActiveRank] = i;
|
| 60 |
+
TensorForEachHelper<Func, Rank, RankRemaining - 1>(func, extent, coord);
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
/// Helper to perform for-each operation
|
| 66 |
+
template <typename Func, int Rank>
|
| 67 |
+
struct TensorForEachHelper<Func, Rank, 0> {
|
| 68 |
+
|
| 69 |
+
/// Index of the active rank
|
| 70 |
+
static int const kActiveRank = Rank - 1;
|
| 71 |
+
|
| 72 |
+
/// Constructor for fastest changing rank
|
| 73 |
+
TensorForEachHelper(
|
| 74 |
+
Func &func,
|
| 75 |
+
Coord<Rank> const &extent,
|
| 76 |
+
Coord<Rank> &coord) {
|
| 77 |
+
|
| 78 |
+
for (int i = 0; i < extent.at(kActiveRank); ++i) {
|
| 79 |
+
coord[kActiveRank] = i;
|
| 80 |
+
func(coord);
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
} // namespace detail
|
| 86 |
+
|
| 87 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
|
| 89 |
+
/// Iterates over the index space of a tensor
|
| 90 |
+
template <
|
| 91 |
+
typename Func, ///< function applied to each point in a tensor's index space
|
| 92 |
+
int Rank> ///< rank of index space
|
| 93 |
+
void TensorForEach(Coord<Rank> extent, Func & func) {
|
| 94 |
+
Coord<Rank> coord;
|
| 95 |
+
detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, extent, coord);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 99 |
+
|
| 100 |
+
/// Iterates over the index space of a tensor and calls a C++ lambda
|
| 101 |
+
template <
|
| 102 |
+
typename Func, ///< function applied to each point in a tensor's index space
|
| 103 |
+
int Rank> ///< rank of index space
|
| 104 |
+
void TensorForEachLambda(Coord<Rank> extent, Func func) {
|
| 105 |
+
Coord<Rank> coord;
|
| 106 |
+
detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, extent, coord);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 110 |
+
|
| 111 |
+
template <typename Element, typename Func>
|
| 112 |
+
struct BlockForEach {
|
| 113 |
+
|
| 114 |
+
/// Constructor performs the operation.
|
| 115 |
+
BlockForEach(
|
| 116 |
+
Element *ptr,
|
| 117 |
+
size_t capacity,
|
| 118 |
+
typename Func::Params params = typename Func::Params()) {
|
| 119 |
+
|
| 120 |
+
Func func(params);
|
| 121 |
+
|
| 122 |
+
for (size_t index = 0; index < capacity; ++index) {
|
| 123 |
+
ptr[index] = func();
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
};
|
| 127 |
+
|
| 128 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 129 |
+
|
| 130 |
+
} // namespace host
|
| 131 |
+
} // namespace reference
|
| 132 |
+
} // namespace cutlass
|
| 133 |
+
|
| 134 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
#include "cutlass/cutlass.h"
|
| 35 |
+
|
| 36 |
+
// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions.
|
| 37 |
+
|
| 38 |
+
#include "cutlass/util/reference/host/tensor_reduce.h"
|
| 39 |
+
|
| 40 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include <cmath>
|
| 34 |
+
|
| 35 |
+
#include "cutlass/cutlass.h"
|
| 36 |
+
#include "cutlass/complex.h"
|
| 37 |
+
#include "cutlass/tensor_ref.h"
|
| 38 |
+
|
| 39 |
+
#include "cutlass/util/reference/detail/linear_to_coordinate.h"
|
| 40 |
+
#include "cutlass/core_io.h"
|
| 41 |
+
|
| 42 |
+
namespace cutlass {
|
| 43 |
+
namespace reference {
|
| 44 |
+
namespace host {
|
| 45 |
+
|
| 46 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 49 |
+
/// workspace
|
| 50 |
+
template <
|
| 51 |
+
typename Element,
|
| 52 |
+
typename Layout,
|
| 53 |
+
typename ComputeType,
|
| 54 |
+
typename ReduceOp,
|
| 55 |
+
typename TransformOp
|
| 56 |
+
>
|
| 57 |
+
ComputeType TensorTransformReduce(
|
| 58 |
+
TensorView<Element, Layout> view,
|
| 59 |
+
ComputeType identity,
|
| 60 |
+
ReduceOp reduce,
|
| 61 |
+
TransformOp transform
|
| 62 |
+
) {
|
| 63 |
+
|
| 64 |
+
for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) {
|
| 65 |
+
typename Layout::TensorCoord coord;
|
| 66 |
+
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view.extent());
|
| 67 |
+
|
| 68 |
+
if (view.contains(coord)) {
|
| 69 |
+
Element x = view.at(coord);
|
| 70 |
+
identity = reduce(identity, transform(x));
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
return identity;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 78 |
+
/// workspace
|
| 79 |
+
template <
|
| 80 |
+
typename Element,
|
| 81 |
+
typename Layout,
|
| 82 |
+
typename ComputeType,
|
| 83 |
+
typename ReduceOp,
|
| 84 |
+
typename TransformOp
|
| 85 |
+
>
|
| 86 |
+
ComputeType TensorTransformReduce(
|
| 87 |
+
TensorView<Element, Layout> view_A,
|
| 88 |
+
TensorView<Element, Layout> view_B,
|
| 89 |
+
ComputeType identity,
|
| 90 |
+
ReduceOp reduce,
|
| 91 |
+
TransformOp transform) {
|
| 92 |
+
|
| 93 |
+
if (view_A.extent() != view_B.extent()) {
|
| 94 |
+
throw std::runtime_error("Tensor extents must match.");
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) {
|
| 98 |
+
|
| 99 |
+
typename Layout::TensorCoord coord;
|
| 100 |
+
cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view_A.extent());
|
| 101 |
+
|
| 102 |
+
if (view_A.contains(coord)) {
|
| 103 |
+
Element a = view_A.at(coord);
|
| 104 |
+
Element b = view_B.at(coord);
|
| 105 |
+
identity = reduce(identity, transform(a, b));
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
return identity;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
/// Helper to compute the sum of the elements of a tensor
|
| 113 |
+
template <
|
| 114 |
+
typename Element,
|
| 115 |
+
typename Layout,
|
| 116 |
+
typename ComputeType = Element
|
| 117 |
+
>
|
| 118 |
+
ComputeType TensorSum(
|
| 119 |
+
TensorView<Element, Layout> view,
|
| 120 |
+
ComputeType identity = ComputeType()
|
| 121 |
+
) {
|
| 122 |
+
|
| 123 |
+
plus<ComputeType> reduce;
|
| 124 |
+
NumericConverter<ComputeType, Element> transform;
|
| 125 |
+
|
| 126 |
+
return TensorTransformReduce(
|
| 127 |
+
view, identity, reduce, transform);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Helper to compute the sum of the squares of the elements of a tensor
|
| 131 |
+
template <
|
| 132 |
+
typename Element,
|
| 133 |
+
typename Layout,
|
| 134 |
+
typename ComputeType = Element
|
| 135 |
+
>
|
| 136 |
+
ComputeType TensorSumSq(
|
| 137 |
+
TensorView<Element, Layout> view,
|
| 138 |
+
ComputeType identity = ComputeType()
|
| 139 |
+
) {
|
| 140 |
+
|
| 141 |
+
plus<ComputeType> reduce;
|
| 142 |
+
magnitude_squared<Element, ComputeType> transform;
|
| 143 |
+
|
| 144 |
+
return TensorTransformReduce(
|
| 145 |
+
view, identity, reduce, transform);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/// Helper to compute the norm of the elements of a tensor.
|
| 149 |
+
template <
|
| 150 |
+
typename Element,
|
| 151 |
+
typename Layout,
|
| 152 |
+
typename ComputeType = double
|
| 153 |
+
>
|
| 154 |
+
ComputeType TensorNorm(
|
| 155 |
+
TensorView<Element, Layout> view,
|
| 156 |
+
ComputeType identity = ComputeType()
|
| 157 |
+
) {
|
| 158 |
+
|
| 159 |
+
return std::sqrt(TensorSumSq(view, identity));
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
/// Helper to compute the sum of the squares of the differences of two tensors
|
| 163 |
+
template <
|
| 164 |
+
typename Element,
|
| 165 |
+
typename Layout,
|
| 166 |
+
typename ComputeType = double
|
| 167 |
+
>
|
| 168 |
+
ComputeType TensorSumSqDiff(
|
| 169 |
+
TensorView<Element, Layout> view_A,
|
| 170 |
+
TensorView<Element, Layout> view_B,
|
| 171 |
+
ComputeType identity = ComputeType()
|
| 172 |
+
) {
|
| 173 |
+
|
| 174 |
+
plus<ComputeType> reduce;
|
| 175 |
+
magnitude_squared_difference<Element, ComputeType> transform;
|
| 176 |
+
|
| 177 |
+
return TensorTransformReduce(
|
| 178 |
+
view_A, view_B, identity, reduce, transform);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
|
| 183 |
+
template <
|
| 184 |
+
typename Element,
|
| 185 |
+
typename Layout,
|
| 186 |
+
typename ComputeType = double
|
| 187 |
+
>
|
| 188 |
+
ComputeType TensorNormDiff(
|
| 189 |
+
TensorView<Element, Layout> view_A,
|
| 190 |
+
TensorView<Element, Layout> view_B,
|
| 191 |
+
ComputeType identity = ComputeType()
|
| 192 |
+
) {
|
| 193 |
+
|
| 194 |
+
return std::sqrt(TensorSumSqDiff(view_A, view_B, identity));
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 198 |
+
|
| 199 |
+
} // namespace host
|
| 200 |
+
} // namespace reference
|
| 201 |
+
} // namespace cutlass
|
| 202 |
+
|
| 203 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Provides several functions for filling tensors with data.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// Standard Library includes
|
| 38 |
+
#include <utility>
|
| 39 |
+
#include <cstdlib>
|
| 40 |
+
#include <cmath>
|
| 41 |
+
|
| 42 |
+
// Cute includes
|
| 43 |
+
#include "cute/tensor.hpp"
|
| 44 |
+
|
| 45 |
+
// Cutlass includes
|
| 46 |
+
#include "cutlass/cutlass.h"
|
| 47 |
+
#include "cutlass/complex.h"
|
| 48 |
+
#include "cutlass/functional.h"
|
| 49 |
+
#include "cutlass/numeric_conversion.h"
|
| 50 |
+
#include "cutlass/quaternion.h"
|
| 51 |
+
#include "cutlass/array.h"
|
| 52 |
+
#include "cutlass/numeric_types.h"
|
| 53 |
+
|
| 54 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
namespace cutlass {
|
| 57 |
+
namespace reference {
|
| 58 |
+
namespace host {
|
| 59 |
+
|
| 60 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 61 |
+
//
|
| 62 |
+
// Tensor reductions
|
| 63 |
+
//
|
| 64 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 67 |
+
/// workspace
|
| 68 |
+
template <
|
| 69 |
+
typename Tensor,
|
| 70 |
+
typename ComputeType,
|
| 71 |
+
typename ReduceOp,
|
| 72 |
+
typename TransformOp
|
| 73 |
+
>
|
| 74 |
+
ComputeType TensorTransformReduce(
|
| 75 |
+
Tensor view,
|
| 76 |
+
ComputeType identity,
|
| 77 |
+
ReduceOp reduce,
|
| 78 |
+
TransformOp transform
|
| 79 |
+
) {
|
| 80 |
+
|
| 81 |
+
for (int64_t idx = 0; idx < cute::size(view); ++idx) {
|
| 82 |
+
identity = reduce(identity, transform(view(idx)));
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
return identity;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
|
| 89 |
+
/// workspace
|
| 90 |
+
template <
|
| 91 |
+
typename TensorA,
|
| 92 |
+
typename TensorB,
|
| 93 |
+
typename ComputeType,
|
| 94 |
+
typename ReduceOp,
|
| 95 |
+
typename TransformOp
|
| 96 |
+
>
|
| 97 |
+
ComputeType TensorTransformReduce(
|
| 98 |
+
TensorA view_A,
|
| 99 |
+
TensorB view_B,
|
| 100 |
+
ComputeType identity,
|
| 101 |
+
ReduceOp reduce,
|
| 102 |
+
TransformOp transform) {
|
| 103 |
+
|
| 104 |
+
if (cute::size(view_A) != cute::size(view_B)) {
|
| 105 |
+
throw std::runtime_error("Tensor sizes must match.");
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
for (int64_t idx = 0; idx < cute::size(view_A); ++idx) {
|
| 109 |
+
identity = reduce(identity, transform(view_A(idx), view_B(idx)));
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return identity;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/// Helper to compute the sum of the elements of a tensor
|
| 116 |
+
template <
|
| 117 |
+
typename Tensor,
|
| 118 |
+
typename ComputeType = typename Tensor::value_type
|
| 119 |
+
>
|
| 120 |
+
ComputeType TensorSum(
|
| 121 |
+
Tensor view,
|
| 122 |
+
ComputeType identity = ComputeType()
|
| 123 |
+
) {
|
| 124 |
+
|
| 125 |
+
plus<ComputeType> reduce;
|
| 126 |
+
NumericConverter<ComputeType, typename Tensor::value_type> transform;
|
| 127 |
+
|
| 128 |
+
return TensorTransformReduce(
|
| 129 |
+
view, identity, reduce, transform);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
/// Helper to compute the sum of the squares of the elements of a tensor
|
| 133 |
+
template <
|
| 134 |
+
typename Tensor,
|
| 135 |
+
typename ComputeType = typename Tensor::value_type
|
| 136 |
+
>
|
| 137 |
+
ComputeType TensorSumSq(
|
| 138 |
+
Tensor view,
|
| 139 |
+
ComputeType identity = ComputeType()
|
| 140 |
+
) {
|
| 141 |
+
|
| 142 |
+
plus<ComputeType> reduce;
|
| 143 |
+
magnitude_squared<typename Tensor::value_type, ComputeType> transform;
|
| 144 |
+
|
| 145 |
+
return TensorTransformReduce(
|
| 146 |
+
view, identity, reduce, transform);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
/// Helper to compute the norm of the elements of a tensor.
|
| 150 |
+
template <
|
| 151 |
+
typename Tensor,
|
| 152 |
+
typename ComputeType = double
|
| 153 |
+
>
|
| 154 |
+
ComputeType TensorNorm(
|
| 155 |
+
Tensor view,
|
| 156 |
+
ComputeType identity = ComputeType()
|
| 157 |
+
) {
|
| 158 |
+
|
| 159 |
+
return std::sqrt(TensorSumSq(view, identity));
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
/// Helper to compute the sum of the squares of the differences of two tensors
|
| 163 |
+
template <
|
| 164 |
+
typename TensorA,
|
| 165 |
+
typename TensorB,
|
| 166 |
+
typename ComputeType = double
|
| 167 |
+
>
|
| 168 |
+
ComputeType TensorSumSqDiff(
|
| 169 |
+
TensorA view_A,
|
| 170 |
+
TensorB view_B,
|
| 171 |
+
ComputeType identity = ComputeType()
|
| 172 |
+
) {
|
| 173 |
+
|
| 174 |
+
plus<ComputeType> reduce;
|
| 175 |
+
magnitude_squared_difference<typename TensorA::value_type, ComputeType> transform;
|
| 176 |
+
|
| 177 |
+
return TensorTransformReduce(
|
| 178 |
+
view_A, view_B, identity, reduce, transform);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
|
| 183 |
+
template <
|
| 184 |
+
typename TensorA,
|
| 185 |
+
typename TensorB,
|
| 186 |
+
typename ComputeType = double
|
| 187 |
+
>
|
| 188 |
+
ComputeType TensorNormDiff(
|
| 189 |
+
TensorA view_A,
|
| 190 |
+
TensorB view_B,
|
| 191 |
+
ComputeType identity = ComputeType()
|
| 192 |
+
) {
|
| 193 |
+
|
| 194 |
+
return std::sqrt(TensorSumSqDiff(view_A, view_B, identity));
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 198 |
+
|
| 199 |
+
} // namespace host
|
| 200 |
+
} // namespace reference
|
| 201 |
+
} // namespace cutlass
|
| 202 |
+
|
| 203 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for TRMM in host-side code.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/blas3.h"
|
| 40 |
+
#include "cutlass/numeric_conversion.h"
|
| 41 |
+
#include "cutlass/tensor_view.h"
|
| 42 |
+
#include "cutlass/gemm/gemm.h"
|
| 43 |
+
#include "cutlass/arch/mma.h"
|
| 44 |
+
#include "cutlass/util/host_tensor.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/util/reference/host/gemm.h"
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reference {
|
| 50 |
+
namespace host {
|
| 51 |
+
|
| 52 |
+
/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef
|
| 53 |
+
/// objects.
|
| 54 |
+
template <
|
| 55 |
+
typename ElementA,
|
| 56 |
+
typename LayoutA,
|
| 57 |
+
SideMode SideModeA,
|
| 58 |
+
FillMode FillModeA,
|
| 59 |
+
DiagType DiagTypeA,
|
| 60 |
+
typename ElementB,
|
| 61 |
+
typename LayoutB,
|
| 62 |
+
typename ElementC,
|
| 63 |
+
typename LayoutC,
|
| 64 |
+
typename ScalarType,
|
| 65 |
+
typename ComputeType,
|
| 66 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 67 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 68 |
+
>
|
| 69 |
+
void compute_trmm(
|
| 70 |
+
gemm::GemmCoord problem_size,
|
| 71 |
+
ScalarType alpha,
|
| 72 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 73 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 74 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 75 |
+
ComputeType initial_accum) {
|
| 76 |
+
|
| 77 |
+
static_assert(
|
| 78 |
+
LayoutA::kRank == 2 &&
|
| 79 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 80 |
+
|
| 81 |
+
static_assert(SideModeA != SideMode::kInvalid
|
| 82 |
+
, "Side Mode can either be Left or Right.");
|
| 83 |
+
|
| 84 |
+
static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper
|
| 85 |
+
, "Fill Mode can either be Lower or Upper.");
|
| 86 |
+
|
| 87 |
+
using CompareOp = typename TrMatrixCompareOp<FillModeA, DiagTypeA>::Type;
|
| 88 |
+
|
| 89 |
+
// Note: batch is ignored.
|
| 90 |
+
int const M = problem_size.m();
|
| 91 |
+
int const N = problem_size.n();
|
| 92 |
+
// Assuming correct k-dimension value is passed
|
| 93 |
+
int const K = problem_size.k();
|
| 94 |
+
|
| 95 |
+
// Blocking necessary to speedup reference implementation
|
| 96 |
+
int const Mblock = 16;
|
| 97 |
+
int const Nblock = 16;
|
| 98 |
+
|
| 99 |
+
ConvertOp convert_op;
|
| 100 |
+
InnerProductOp inner_product_op;
|
| 101 |
+
CompareOp compare_op;
|
| 102 |
+
|
| 103 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 104 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 105 |
+
|
| 106 |
+
ComputeType accum[Mblock][Nblock];
|
| 107 |
+
|
| 108 |
+
for (int j = 0; j < Nblock; j++) {
|
| 109 |
+
for (int i = 0; i < Mblock; i++) {
|
| 110 |
+
accum[i][j] = initial_accum;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 115 |
+
for (int j = 0; j < Nblock; j++) {
|
| 116 |
+
for (int i = 0; i < Mblock; i++) {
|
| 117 |
+
int row = row_block + i;
|
| 118 |
+
int col = col_block + j;
|
| 119 |
+
|
| 120 |
+
if (row < M && col < N) {
|
| 121 |
+
ElementA a = ElementA();
|
| 122 |
+
ElementB b = ElementB();
|
| 123 |
+
|
| 124 |
+
if (SideModeA == SideMode::kLeft) {
|
| 125 |
+
a = (compare_op(row, k_block)) ?
|
| 126 |
+
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0);
|
| 127 |
+
if (row == k_block && DiagTypeA == DiagType::kUnit) {
|
| 128 |
+
a = ElementA(1);
|
| 129 |
+
}
|
| 130 |
+
b = tensor_b.at(MatrixCoord(k_block, col));
|
| 131 |
+
} else if (SideModeA == SideMode::kRight) {
|
| 132 |
+
a = tensor_b.at(MatrixCoord(row, k_block));
|
| 133 |
+
b = (compare_op(k_block, col)) ?
|
| 134 |
+
tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0);
|
| 135 |
+
if (k_block == col && DiagTypeA == DiagType::kUnit) {
|
| 136 |
+
b = ElementA(1);
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
ComputeType compute_a(cast_if_scalar<ComputeType>(a));
|
| 141 |
+
ComputeType compute_b(cast_if_scalar<ComputeType>(b));
|
| 142 |
+
|
| 143 |
+
accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
for (int j = 0; j < Nblock; j++) {
|
| 150 |
+
for (int i = 0; i < Mblock; i++) {
|
| 151 |
+
int row = row_block + i;
|
| 152 |
+
int col = col_block + j;
|
| 153 |
+
|
| 154 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 155 |
+
|
| 156 |
+
if (row < M && col < N) {
|
| 157 |
+
tensor_d.at(coord) = convert_op(
|
| 158 |
+
alpha * ScalarType(accum[i][j]));
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 167 |
+
|
| 168 |
+
template <
|
| 169 |
+
typename ElementA,
|
| 170 |
+
typename LayoutA,
|
| 171 |
+
SideMode SideModeA,
|
| 172 |
+
FillMode FillModeA,
|
| 173 |
+
DiagType DiagTypeA,
|
| 174 |
+
typename ElementB,
|
| 175 |
+
typename LayoutB,
|
| 176 |
+
typename ElementC,
|
| 177 |
+
typename LayoutC,
|
| 178 |
+
typename ScalarType,
|
| 179 |
+
typename ComputeType,
|
| 180 |
+
typename InnerProductOp = cutlass::arch::OpMultiplyAdd
|
| 181 |
+
>
|
| 182 |
+
struct Trmm;
|
| 183 |
+
|
| 184 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 185 |
+
|
| 186 |
+
/// Partial specialization for multiply-add
|
| 187 |
+
template <typename ElementA, typename LayoutA, SideMode SideModeA,
|
| 188 |
+
FillMode FillModeA, DiagType DiagTypeA,
|
| 189 |
+
typename ElementB, typename LayoutB,
|
| 190 |
+
typename ElementC, typename LayoutC,
|
| 191 |
+
typename ScalarType, typename ComputeType>
|
| 192 |
+
struct Trmm<ElementA, LayoutA, SideModeA, FillModeA, DiagTypeA, ElementB, LayoutB,
|
| 193 |
+
ElementC, LayoutC, ScalarType,
|
| 194 |
+
ComputeType, arch::OpMultiplyAdd> {
|
| 195 |
+
|
| 196 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 197 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 198 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 199 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 200 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 201 |
+
static_assert(
|
| 202 |
+
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
| 203 |
+
"Tensors must be of rank 2");
|
| 204 |
+
|
| 205 |
+
compute_trmm<ElementA, LayoutA, SideModeA, FillModeA, DiagTypeA, ElementB, LayoutB,
|
| 206 |
+
ElementC, LayoutC, ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 207 |
+
problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
|
| 208 |
+
}
|
| 209 |
+
};
|
| 210 |
+
|
| 211 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 212 |
+
|
| 213 |
+
} // namespace host
|
| 214 |
+
} // namespace reference
|
| 215 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Reference implementation for complex-valued TRMM in host-side code.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/blas3.h"
|
| 40 |
+
#include "cutlass/complex.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/tensor_view.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/util/reference/host/gemm.h"
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace reference {
|
| 49 |
+
namespace host {
|
| 50 |
+
|
| 51 |
+
/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef
|
| 52 |
+
/// objects.
|
| 53 |
+
template <
|
| 54 |
+
typename ElementA,
|
| 55 |
+
typename LayoutA,
|
| 56 |
+
ComplexTransform TransformA,
|
| 57 |
+
SideMode SideModeA,
|
| 58 |
+
FillMode FillModeA,
|
| 59 |
+
DiagType DiagTypeA,
|
| 60 |
+
typename ElementB,
|
| 61 |
+
typename LayoutB,
|
| 62 |
+
ComplexTransform TransformB,
|
| 63 |
+
typename ElementC,
|
| 64 |
+
typename LayoutC,
|
| 65 |
+
typename ScalarType,
|
| 66 |
+
typename ComputeType,
|
| 67 |
+
typename InnerProductOp = multiply_add<ComputeType>,
|
| 68 |
+
typename ConvertOp = NumericConverter<ElementC, ScalarType>
|
| 69 |
+
>
|
| 70 |
+
void compute_trmm_complex(
|
| 71 |
+
gemm::GemmCoord problem_size,
|
| 72 |
+
ScalarType alpha,
|
| 73 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 74 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 75 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 76 |
+
ComputeType initial_accum) {
|
| 77 |
+
|
| 78 |
+
static_assert(
|
| 79 |
+
LayoutA::kRank == 2 &&
|
| 80 |
+
LayoutC::kRank == 2, "Tensors must be of rank 2");
|
| 81 |
+
|
| 82 |
+
static_assert(SideModeA != SideMode::kInvalid
|
| 83 |
+
, "Side Mode can either be Left or Right.");
|
| 84 |
+
|
| 85 |
+
static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper
|
| 86 |
+
, "Fill Mode can either be Lower or Upper.");
|
| 87 |
+
|
| 88 |
+
using CompareOp = typename TrMatrixCompareOp<FillModeA, DiagTypeA>::Type;
|
| 89 |
+
|
| 90 |
+
// Note: batch is ignored.
|
| 91 |
+
int const M = problem_size.m();
|
| 92 |
+
int const N = problem_size.n();
|
| 93 |
+
// Assuming correct k-dimension value is passed
|
| 94 |
+
int const K = problem_size.k();
|
| 95 |
+
|
| 96 |
+
// Blocking necessary to speedup reference implementation
|
| 97 |
+
int const Mblock = 16;
|
| 98 |
+
int const Nblock = 16;
|
| 99 |
+
|
| 100 |
+
ConvertOp convert_op;
|
| 101 |
+
InnerProductOp inner_product_op;
|
| 102 |
+
CompareOp compare_op;
|
| 103 |
+
|
| 104 |
+
for (int row_block = 0; row_block < M; row_block += Mblock) {
|
| 105 |
+
for (int col_block = 0; col_block < N; col_block += Nblock) {
|
| 106 |
+
|
| 107 |
+
ComputeType accum[Mblock][Nblock];
|
| 108 |
+
|
| 109 |
+
for (int j = 0; j < Nblock; j++) {
|
| 110 |
+
for (int i = 0; i < Mblock; i++) {
|
| 111 |
+
accum[i][j] = initial_accum;
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
for (int k_block = 0; k_block < K; ++k_block) {
|
| 116 |
+
for (int j = 0; j < Nblock; j++) {
|
| 117 |
+
for (int i = 0; i < Mblock; i++) {
|
| 118 |
+
int row = row_block + i;
|
| 119 |
+
int col = col_block + j;
|
| 120 |
+
|
| 121 |
+
if (row < M && col < N) {
|
| 122 |
+
ElementA a = ElementA();
|
| 123 |
+
ElementB b = ElementB();
|
| 124 |
+
|
| 125 |
+
if (SideModeA == SideMode::kLeft) {
|
| 126 |
+
a = (compare_op(row, k_block)) ?
|
| 127 |
+
(tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0);
|
| 128 |
+
if (row == k_block && DiagTypeA == DiagType::kUnit) {
|
| 129 |
+
a = ElementA(1);
|
| 130 |
+
}
|
| 131 |
+
b = tensor_b.at(MatrixCoord(k_block, col));
|
| 132 |
+
} else if (SideModeA == SideMode::kRight) {
|
| 133 |
+
a = tensor_b.at(MatrixCoord(row, k_block));
|
| 134 |
+
b = (compare_op(k_block, col)) ?
|
| 135 |
+
tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0);
|
| 136 |
+
if (k_block == col && DiagTypeA == DiagType::kUnit) {
|
| 137 |
+
b = ElementA(1);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
ComputeType a_ik = ComputeType(a);
|
| 142 |
+
ComputeType b_kj = ComputeType(b);
|
| 143 |
+
|
| 144 |
+
// Conjugate, and hence hermitian, is only allowed for the triangular matrix
|
| 145 |
+
if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) {
|
| 146 |
+
a_ik = conj(a_ik);
|
| 147 |
+
} else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) {
|
| 148 |
+
b_kj = conj(b_kj);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
for (int j = 0; j < Nblock; j++) {
|
| 158 |
+
for (int i = 0; i < Mblock; i++) {
|
| 159 |
+
int row = row_block + i;
|
| 160 |
+
int col = col_block + j;
|
| 161 |
+
|
| 162 |
+
MatrixCoord coord = MatrixCoord(row, col);
|
| 163 |
+
|
| 164 |
+
if (row < M && col < N) {
|
| 165 |
+
tensor_d.at(coord) = convert_op(
|
| 166 |
+
alpha * ScalarType(accum[i][j]));
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 175 |
+
|
| 176 |
+
template <
|
| 177 |
+
typename ElementA,
|
| 178 |
+
typename LayoutA,
|
| 179 |
+
ComplexTransform TransformA,
|
| 180 |
+
SideMode SideModeA,
|
| 181 |
+
FillMode FillModeA,
|
| 182 |
+
DiagType DiagTypeA,
|
| 183 |
+
typename ElementB,
|
| 184 |
+
typename LayoutB,
|
| 185 |
+
ComplexTransform TransformB,
|
| 186 |
+
typename ElementC,
|
| 187 |
+
typename LayoutC,
|
| 188 |
+
typename ScalarType,
|
| 189 |
+
typename ComputeType,
|
| 190 |
+
typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex
|
| 191 |
+
>
|
| 192 |
+
struct TrmmComplex;
|
| 193 |
+
|
| 194 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 195 |
+
|
| 196 |
+
/// Partial specialization for multiply-add
|
| 197 |
+
template <typename ElementA, typename LayoutA, ComplexTransform TransformA,
|
| 198 |
+
SideMode SideModeA, FillMode FillModeA, DiagType DiagTypeA,
|
| 199 |
+
typename ElementB, typename LayoutB, ComplexTransform TransformB,
|
| 200 |
+
typename ElementC, typename LayoutC,
|
| 201 |
+
typename ScalarType, typename ComputeType>
|
| 202 |
+
struct TrmmComplex<ElementA, LayoutA, TransformA,
|
| 203 |
+
SideModeA, FillModeA, DiagTypeA,
|
| 204 |
+
ElementB, LayoutB, TransformB,
|
| 205 |
+
ElementC, LayoutC, ScalarType,
|
| 206 |
+
ComputeType, arch::OpMultiplyAddComplex> {
|
| 207 |
+
|
| 208 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 209 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 210 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 211 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 212 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 213 |
+
static_assert(
|
| 214 |
+
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
| 215 |
+
"Tensors must be of rank 2");
|
| 216 |
+
|
| 217 |
+
compute_trmm_complex<ElementA, LayoutA, TransformA,
|
| 218 |
+
SideModeA, FillModeA, DiagTypeA,
|
| 219 |
+
ElementB, LayoutB, TransformB,
|
| 220 |
+
ElementC, LayoutC,
|
| 221 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 222 |
+
problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
|
| 223 |
+
}
|
| 224 |
+
};
|
| 225 |
+
|
| 226 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 227 |
+
|
| 228 |
+
/// Partial specialization for gaussian multiply-add
|
| 229 |
+
template <typename ElementA, typename LayoutA, ComplexTransform TransformA,
|
| 230 |
+
SideMode SideModeA, FillMode FillModeA, DiagType DiagTypeA,
|
| 231 |
+
typename ElementB, typename LayoutB, ComplexTransform TransformB,
|
| 232 |
+
typename ElementC, typename LayoutC,
|
| 233 |
+
typename ScalarType, typename ComputeType>
|
| 234 |
+
struct TrmmComplex<ElementA, LayoutA, TransformA,
|
| 235 |
+
SideModeA, FillModeA, DiagTypeA,
|
| 236 |
+
ElementB, LayoutB, TransformB,
|
| 237 |
+
ElementC, LayoutC, ScalarType,
|
| 238 |
+
ComputeType, arch::OpMultiplyAddGaussianComplex> {
|
| 239 |
+
|
| 240 |
+
void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
|
| 241 |
+
TensorRef<ElementA, LayoutA> tensor_a,
|
| 242 |
+
TensorRef<ElementB, LayoutB> tensor_b,
|
| 243 |
+
TensorRef<ElementC, LayoutC> tensor_d,
|
| 244 |
+
ComputeType initial_accum = ComputeType(0)) {
|
| 245 |
+
static_assert(
|
| 246 |
+
LayoutA::kRank == 2 && LayoutC::kRank == 2,
|
| 247 |
+
"Tensors must be of rank 2");
|
| 248 |
+
|
| 249 |
+
compute_trmm_complex<ElementA, LayoutA, TransformA,
|
| 250 |
+
SideModeA, FillModeA, DiagTypeA,
|
| 251 |
+
ElementB, LayoutB, TransformB,
|
| 252 |
+
ElementC, LayoutC,
|
| 253 |
+
ScalarType, ComputeType, multiply_add<ComputeType>>(
|
| 254 |
+
problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
|
| 255 |
+
}
|
| 256 |
+
};
|
| 257 |
+
|
| 258 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 259 |
+
|
| 260 |
+
} // namespace host
|
| 261 |
+
} // namespace reference
|
| 262 |
+
} // namespace cutlass
|