Kernels
danieldk HF Staff commited on
Commit
ff86389
·
verified ·
1 Parent(s): 38c7386

Build uploaded using `kernels` (batch 8/10).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h +111 -0
  3. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h +541 -0
  4. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h +591 -0
  5. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h +157 -0
  6. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h +38 -0
  7. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +472 -0
  8. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp +570 -0
  9. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp +341 -0
  10. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h +135 -0
  11. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h +94 -0
  12. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h +1549 -0
  13. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h +385 -0
  14. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h +350 -0
  15. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h +311 -0
  16. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp +146 -0
  17. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +162 -0
  18. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h +168 -0
  19. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h +159 -0
  20. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h +355 -0
  21. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h +250 -0
  22. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h +2075 -0
  23. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +142 -0
  24. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h +514 -0
  25. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h +141 -0
  26. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h +186 -0
  27. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp +782 -0
  28. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h +802 -0
  29. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h +66 -0
  30. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h +531 -0
  31. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h +210 -0
  32. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h +228 -0
  33. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp +916 -0
  34. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h +261 -0
  35. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h +318 -0
  36. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h +234 -0
  37. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h +285 -0
  38. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h +319 -0
  39. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h +616 -0
  40. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp +101 -0
  41. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h +256 -0
  42. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h +341 -0
  43. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h +1718 -0
  44. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp +432 -0
  45. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h +134 -0
  46. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h +42 -0
  47. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h +203 -0
  48. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp +203 -0
  49. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h +215 -0
  50. build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h +262 -0
.gitattributes CHANGED
@@ -16,3 +16,4 @@ build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=
16
  build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
17
  build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
18
  build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
16
  build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
17
  build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
18
  build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
19
+ build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief reorder data from the host side
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/coord.h"
39
+ #include "cutlass/util/host_tensor.h"
40
+ #include "cutlass/tensor_view.h"
41
+ #include "cutlass/util/tensor_view_io.h"
42
+ #include "cutlass/util/reference/host/gemm.h"
43
+
44
+ namespace cutlass {
45
+
46
+ /// This is needed for the interleaved integer tensor core kernels. The purpose
47
+ /// is to use skip the shared memory part in the epilogue.
48
+ template <int Interleaved, typename Element, typename Layout>
49
+ void reorder_column(TensorRef<Element, Layout> dest,
50
+ TensorRef<Element, Layout> src,
51
+ cutlass::gemm::GemmCoord problem_size) {
52
+ const int InstructionShapeCol = 8;
53
+ // 4 threads per Quad
54
+ const int ElementsPerThread = InstructionShapeCol / 4;
55
+ // 4 threads per Quad
56
+ const int ReorderedElementsPerThread =
57
+ Interleaved / 4;
58
+
59
+ for (int n = 0; n < problem_size.n(); n++) {
60
+ for (int k = 0; k < problem_size.k(); k++) {
61
+ dest.at({k, (n / Interleaved) * Interleaved +
62
+ ((n % ReorderedElementsPerThread) / ElementsPerThread) *
63
+ InstructionShapeCol +
64
+ ((n % Interleaved) / ReorderedElementsPerThread) *
65
+ ElementsPerThread +
66
+ (n % ElementsPerThread)}) = src.at({k, n});
67
+ }
68
+ }
69
+ }
70
+
71
+ template <int ColumnInterleaved, int LayoutInterleaved = ColumnInterleaved, typename Element, typename Layout>
72
+ void reorder_convK(TensorRef<Element, Layout> dest,
73
+ TensorRef<Element, Layout> src,
74
+ cutlass::gemm::GemmCoord problem_size) {
75
+
76
+ TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedDest(dest.data(), dest.stride(0));
77
+ TensorRef<Element, layout::RowMajorInterleaved<LayoutInterleaved>> mappedSrc(src.data(), src.stride(0));
78
+
79
+ reorder_column<ColumnInterleaved>(
80
+ mappedDest, mappedSrc, problem_size);
81
+ }
82
+
83
+ /// This is needed for the sparse tensor core kernels. The purpose
84
+ /// is to use ldmatrix to load from shared memory to the register file.
85
+ template <typename Element, typename LayoutDest, typename LayoutSrc>
86
+ void reorder_meta(TensorRef<Element, LayoutDest> dest,
87
+ TensorRef<Element, LayoutSrc> src,
88
+ cutlass::gemm::GemmCoord problem_size) {
89
+ for (int m = 0; m < problem_size.m(); m++) {
90
+ for (int k = 0; k < problem_size.k(); k++) {
91
+ // First reorder the rows.
92
+ int group = (sizeof(Element) == 2) ? 32 : 16;
93
+ int interweave = (sizeof(Element) == 2) ? 4 : 2;
94
+
95
+ int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
96
+ int dest_col = k;
97
+
98
+ // Next swizzle the 2x2 blocks from Z to N.
99
+ if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
100
+ ++dest_row;
101
+ --dest_col;
102
+ } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
103
+ --dest_row;
104
+ ++dest_col;
105
+ }
106
+
107
+ dest.at({dest_row, dest_col}) = src.at({m, k});
108
+ }
109
+ }
110
+ }
111
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ /*! \file
34
+ \brief HostTensor contributes management for both host and device memory.
35
+
36
+ HostTensor allocates host and device memory upon construction. Basic element-wise operations on
37
+ host memory synchronize device memory automatically. Explicit copy operations provide abstractions
38
+ for CUDA memcpy operations.
39
+
40
+ Call {host, device}_{data, ref, view}() for accessing host or device memory.
41
+
42
+ See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
43
+ */
44
+
45
+ #include <vector>
46
+
47
+ #include "cutlass/cutlass.h"
48
+ #include "cutlass/tensor_ref.h"
49
+ #include "cutlass/tensor_view.h"
50
+ #include "cutlass/fast_math.h"
51
+
52
+ #include "device_memory.h"
53
+
54
+ namespace cutlass {
55
+
56
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ /// Host tensor
59
+ template <
60
+ /// Data type of element stored within tensor (concept: NumericType)
61
+ typename Element_,
62
+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout)
63
+ typename Layout_
64
+ >
65
+ class HostTensor {
66
+ public:
67
+
68
+ /// Data type of individual access
69
+ using Element = Element_;
70
+
71
+ /// Mapping function from logical coordinate to linear memory
72
+ using Layout = Layout_;
73
+
74
+ /// Logical rank of tensor index space
75
+ static int const kRank = Layout::kRank;
76
+
77
+ /// Index type
78
+ using Index = typename Layout::Index;
79
+
80
+ /// Long index used for pointer offsets
81
+ using LongIndex = typename Layout::LongIndex;
82
+
83
+ /// Coordinate in logical tensor space
84
+ using TensorCoord = typename Layout::TensorCoord;
85
+
86
+ /// Layout's stride vector
87
+ using Stride = typename Layout::Stride;
88
+
89
+ /// Tensor reference to device memory
90
+ using TensorRef = TensorRef<Element, Layout>;
91
+
92
+ /// Tensor reference to constant device memory
93
+ using ConstTensorRef = typename TensorRef::ConstTensorRef;
94
+
95
+ /// Tensor reference to device memory
96
+ using TensorView = TensorView<Element, Layout>;
97
+
98
+ /// Tensor reference to constant device memory
99
+ using ConstTensorView = typename TensorView::ConstTensorView;
100
+
101
+ /// Reference to element in tensor
102
+ using Reference = typename TensorRef::Reference;
103
+
104
+ /// Constant reference to element in tensor
105
+ using ConstReference = typename ConstTensorRef::Reference;
106
+
107
+ private:
108
+ using StorageUnit = typename platform::conditional_t<std::is_same_v<Element, bool>, uint8_t, // Avoid the std::vector<bool> specialization
109
+ typename platform::conditional_t<sizeof_bits<Element>::value % 8 == 0, // Handle subbyte types
110
+ Element, uint8_t>>;
111
+ using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
112
+ static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits;
113
+ static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements;
114
+ static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes;
115
+ static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit;
116
+
117
+ //
118
+ // Data members
119
+ //
120
+
121
+ /// Extent of tensor in logical dimensions
122
+ TensorCoord extent_;
123
+
124
+ /// Layout object
125
+ Layout layout_;
126
+
127
+ /// Host-side memory allocation
128
+ std::vector<StorageUnit> host_;
129
+
130
+ /// Device-side memory
131
+ device_memory::allocation<StorageUnit> device_;
132
+
133
+ /// number of containers
134
+ size_t count_to_container_storage_unit_count(size_t count) {
135
+ return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit;
136
+ }
137
+
138
+ public:
139
+ //
140
+ // Device and Host Methods
141
+ //
142
+
143
+ /// Default constructor
144
+ HostTensor() {}
145
+
146
+ /// Constructs a tensor given an extent. Assumes a packed layout
147
+ HostTensor(
148
+ TensorCoord const &extent,
149
+ bool device_backed = true
150
+ ) {
151
+
152
+ this->reset(extent, Layout::packed(extent), device_backed);
153
+ }
154
+
155
+ /// Constructs a tensor given an extent and layout
156
+ HostTensor(
157
+ TensorCoord const &extent,
158
+ Layout const &layout,
159
+ bool device_backed = true
160
+ ) {
161
+
162
+ this->reset(extent, layout, device_backed);
163
+ }
164
+
165
+ ~HostTensor() { }
166
+
167
+ /// Clears the HostTensor allocation to size/capacity = 0
168
+ void reset() {
169
+ extent_ = TensorCoord();
170
+ layout_ = Layout::packed(extent_);
171
+
172
+ host_.clear();
173
+ device_.reset();
174
+ }
175
+
176
+ /// Resizes internal memory allocations without affecting layout or extent
177
+ void reserve(
178
+ size_t count, ///< size of tensor in elements
179
+ bool device_backed_ = true) { ///< if true, device memory is also allocated
180
+ #if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
181
+ CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")");
182
+ #endif
183
+
184
+ device_.reset();
185
+ host_.clear();
186
+
187
+ size_t count_container = count_to_container_storage_unit_count(count);
188
+ #if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
189
+ CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")");
190
+ #endif
191
+ host_.resize(count_container);
192
+
193
+ // Allocate memory
194
+ StorageUnit* device_memory = nullptr;
195
+ if (device_backed_) {
196
+ #if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
197
+ CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")");
198
+ #endif
199
+ device_memory = device_memory::allocate<StorageUnit>(count_container);
200
+ }
201
+ device_.reset(device_memory, device_backed_ ? count_container : 0);
202
+ }
203
+
204
+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
205
+ /// extent and layout.
206
+ void reset(
207
+ TensorCoord const &extent, ///< extent of logical tensor
208
+ Layout const &layout, ///< layout object of tensor
209
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
210
+
211
+ extent_ = extent;
212
+ layout_ = layout;
213
+
214
+ reserve(size_t(layout_.capacity(extent_)), device_backed_);
215
+ }
216
+
217
+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
218
+ /// extent and layout. Assumes a packed tensor configuration.
219
+ void reset(
220
+ TensorCoord const &extent, ///< extent of logical tensor
221
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
222
+
223
+ reset(extent, Layout::packed(extent), device_backed_);
224
+ }
225
+
226
+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
227
+ /// To force allocation, call reset().
228
+ void resize(
229
+ TensorCoord const &extent, ///< extent of logical tensor
230
+ Layout const &layout, ///< layout object of tensor
231
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
232
+
233
+ extent_ = extent;
234
+ layout_ = layout;
235
+
236
+ LongIndex new_size = size_t(layout_.capacity(extent_));
237
+ LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_)));
238
+
239
+ if (static_cast<decltype(host_.size())>(new_size_container) > host_.size()) {
240
+ reserve(new_size, device_backed_);
241
+ }
242
+ }
243
+
244
+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
245
+ /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
246
+ void resize(
247
+ TensorCoord const &extent, ///< extent of logical tensor
248
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
249
+
250
+ resize(extent, Layout::packed(extent), device_backed_);
251
+ }
252
+
253
+ /// Returns the logical number of elements stored in the host tensor
254
+ size_t size() const {
255
+ return layout_.capacity(extent_);
256
+ }
257
+
258
+ /// Returns the logical capacity in terms of number of elements. May be larger than the size().
259
+ LongIndex capacity() const {
260
+ return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements;
261
+ }
262
+
263
+ /// Gets pointer to host data
264
+ Element * host_data() { return reinterpret_cast<Element *>(host_.data()); }
265
+
266
+ /// Gets pointer to host data with a pointer offset
267
+ Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
268
+
269
+ /// Gets a reference to an element in host memory
270
+ Reference host_data(LongIndex idx) {
271
+ return ReferenceFactory<Element>::get(host_data(), idx);
272
+ }
273
+
274
+ /// Gets pointer to host data
275
+ Element const * host_data() const { return reinterpret_cast<Element const *>(host_.data()); }
276
+
277
+ /// Gets pointer to host data with a pointer offset
278
+ Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(host_data(), ptr_element_offset); }
279
+
280
+ /// Gets a constant reference to an element in host memory
281
+ ConstReference host_data(LongIndex idx) const {
282
+ return ReferenceFactory<Element const>::get(host_data(), idx);
283
+ }
284
+
285
+ /// Gets pointer to device data
286
+ Element * device_data() { return reinterpret_cast<Element *>(device_.get()); }
287
+
288
+ /// Gets pointer to device data
289
+ Element const * device_data() const { return reinterpret_cast<Element const *>(device_.get()); }
290
+
291
+ /// Gets pointer to device data with a pointer offset
292
+ Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
293
+
294
+ /// Gets pointer to device data with a pointer offset
295
+ Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory<Element>::get(device_data(), ptr_element_offset); }
296
+
297
+ /// Accesses the tensor reference pointing to data
298
+ TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
299
+
300
+ /// Accesses the tensor reference pointing to data
301
+ ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); }
302
+
303
+ /// Accesses the tensor reference pointing to data
304
+ TensorRef device_ref(LongIndex ptr_element_offset=0) {
305
+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
306
+ }
307
+
308
+ /// Accesses the tensor reference pointing to data
309
+ ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
310
+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_);
311
+ }
312
+
313
+ /// Accesses the tensor reference pointing to data
314
+ TensorView host_view(LongIndex ptr_element_offset=0) {
315
+ return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
316
+ }
317
+
318
+ /// Accesses the tensor reference pointing to data
319
+ ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
320
+ return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_);
321
+ }
322
+
323
+ /// Accesses the tensor reference pointing to data
324
+ TensorView device_view(LongIndex ptr_element_offset=0) {
325
+ return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
326
+ }
327
+
328
+ /// Accesses the tensor reference pointing to data
329
+ ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
330
+ return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_);
331
+ }
332
+
333
+ /// Returns true if device memory is allocated
334
+ bool device_backed() const {
335
+ return (device_.get() == nullptr) ? false : true;
336
+ }
337
+
338
+
339
+ /// Returns the layout object
340
+ Layout & layout() {
341
+ return layout_;
342
+ }
343
+
344
+ /// Returns the layout object
345
+ Layout layout() const {
346
+ return layout_;
347
+ }
348
+
349
+ /// Returns the layout object's stride vector
350
+ Stride stride() const {
351
+ return layout_.stride();
352
+ }
353
+
354
+ /// Returns the layout object's stride vector
355
+ Stride & stride() {
356
+ return layout_.stride();
357
+ }
358
+
359
+ /// Returns the layout object's stride in a given physical dimension
360
+ LongIndex stride(int dim) const {
361
+ return layout_.stride().at(dim);
362
+ }
363
+
364
+ /// Returns the layout object's stride in a given physical dimension
365
+ LongIndex & stride(int dim) {
366
+ return layout_.stride().at(dim);
367
+ }
368
+
369
+ /// Computes the offset of an index from the origin of the tensor
370
+ LongIndex offset(TensorCoord const& coord) const {
371
+ return layout_(coord);
372
+ }
373
+
374
+ /// Returns a reference to the element at the logical Coord in host memory
375
+ Reference at(TensorCoord const& coord) {
376
+ return host_data(offset(coord));
377
+ }
378
+
379
+ /// Returns a const reference to the element at the logical Coord in host memory
380
+ ConstReference at(TensorCoord const& coord) const {
381
+ return host_data(offset(coord));
382
+ }
383
+
384
+ /// Returns the extent of the tensor
385
+ TensorCoord extent() const {
386
+ return extent_;
387
+ }
388
+
389
+ /// Returns the extent of the tensor
390
+ TensorCoord & extent() {
391
+ return extent_;
392
+ }
393
+
394
+ /// Copies data from device to host
395
+ void sync_host() {
396
+ if (device_backed()) {
397
+ device_memory::copy_to_host(
398
+ host_.data(), device_.get(), device_.size());
399
+ }
400
+ }
401
+
402
+ /// Copies data from host to device
403
+ void sync_device() {
404
+ if (device_backed()) {
405
+ device_memory::copy_to_device(
406
+ device_.get(), host_.data(), host_.size());
407
+ }
408
+ }
409
+
410
+ /// Copy data from a caller-supplied device pointer into host memory.
411
+ void copy_in_device_to_host(
412
+ Element const* ptr_device, ///< source device memory
413
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
414
+
415
+ if (count < 0) {
416
+ count = capacity();
417
+ }
418
+ else {
419
+ count = __NV_STD_MIN(capacity(), count);
420
+ }
421
+ size_t container_count = count_to_container_storage_unit_count(count);
422
+ device_memory::copy_to_host(
423
+ host_.data(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
424
+ }
425
+
426
+ /// Copy data from a caller-supplied device pointer into host memory.
427
+ void copy_in_device_to_device(
428
+ Element const* ptr_device, ///< source device memory
429
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
430
+
431
+ if (count < 0) {
432
+ count = capacity();
433
+ }
434
+ else {
435
+ count = __NV_STD_MIN(capacity(), count);
436
+ }
437
+ size_t container_count = count_to_container_storage_unit_count(count);
438
+ device_memory::copy_device_to_device(
439
+ device_.get(), reinterpret_cast<StorageUnit const *>(ptr_device), container_count);
440
+ }
441
+
442
+ /// Copy data from a caller-supplied device pointer into host memory.
443
+ void copy_in_host_to_device(
444
+ Element const* ptr_host, ///< source host memory
445
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
446
+
447
+ if (count < 0) {
448
+ count = capacity();
449
+ }
450
+ else {
451
+ count = __NV_STD_MIN(capacity(), count);
452
+ }
453
+ size_t container_count = count_to_container_storage_unit_count(count);
454
+ device_memory::copy_to_device(
455
+ device_.get(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
456
+ }
457
+
458
+ /// Copy data from a caller-supplied device pointer into host memory.
459
+ void copy_in_host_to_host(
460
+ Element const* ptr_host, ///< source host memory
461
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
462
+
463
+ if (count < 0) {
464
+ count = capacity();
465
+ }
466
+ else {
467
+ count = __NV_STD_MIN(capacity(), count);
468
+ }
469
+ size_t container_count = count_to_container_storage_unit_count(count);
470
+ device_memory::copy_host_to_host(
471
+ host_.data(), reinterpret_cast<StorageUnit const *>(ptr_host), container_count);
472
+ }
473
+
474
+ /// Copy data from a caller-supplied device pointer into host memory.
475
+ void copy_out_device_to_host(
476
+ Element * ptr_host, ///< source device memory
477
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
478
+
479
+ if (count < 0) {
480
+ count = capacity();
481
+ }
482
+ else {
483
+ count = __NV_STD_MIN(capacity(), count);
484
+ }
485
+ size_t container_count = count_to_container_storage_unit_count(count);
486
+ device_memory::copy_to_host(
487
+ reinterpret_cast<StorageUnit *>(ptr_host), device_.get(), container_count);
488
+ }
489
+
490
+ /// Copy data from a caller-supplied device pointer into host memory.
491
+ void copy_out_device_to_device(
492
+ Element * ptr_device, ///< source device memory
493
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
494
+
495
+ if (count < 0) {
496
+ count = capacity();
497
+ }
498
+ else {
499
+ count = __NV_STD_MIN(capacity(), count);
500
+ }
501
+ size_t container_count = count_to_container_storage_unit_count(count);
502
+ device_memory::copy_device_to_device(
503
+ reinterpret_cast<StorageUnit *>(ptr_device), device_.get(), container_count);
504
+ }
505
+
506
+ /// Copy data from a caller-supplied device pointer into host memory.
507
+ void copy_out_host_to_device(
508
+ Element * ptr_device, ///< source host memory
509
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
510
+
511
+ if (count < 0) {
512
+ count = capacity();
513
+ }
514
+ else {
515
+ count = __NV_STD_MIN(capacity(), count);
516
+ }
517
+ size_t container_count = count_to_container_storage_unit_count(count);
518
+ device_memory::copy_to_device(
519
+ reinterpret_cast<StorageUnit *>(ptr_device), host_.data(), container_count);
520
+ }
521
+
522
+ /// Copy data from a caller-supplied device pointer into host memory.
523
+ void copy_out_host_to_host(
524
+ Element * ptr_host, ///< source host memory
525
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
526
+
527
+ if (count < 0) {
528
+ count = capacity();
529
+ }
530
+ else {
531
+ count = __NV_STD_MIN(capacity(), count);
532
+ }
533
+ size_t container_count = count_to_container_storage_unit_count(count);
534
+ device_memory::copy_host_to_host(
535
+ reinterpret_cast<StorageUnit *>(ptr_host), host_.data(), container_count);
536
+ }
537
+ };
538
+
539
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
540
+
541
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ /*! \file
34
+ \brief HostTensor contributes management for both host and device memory.
35
+
36
+ HostTensor allocates host and device memory upon construction. Basic element-wise operations on
37
+ host memory synchronize device memory automatically. Explicit copy operations provide abstractions
38
+ for CUDA memcpy operations.
39
+
40
+ Call {host, device}_{data, ref, view}() for accessing host or device memory.
41
+
42
+ See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
43
+ */
44
+
45
+ #include <vector>
46
+
47
+ #include "cutlass/cutlass.h"
48
+
49
+ #include "cutlass/tensor_ref_planar_complex.h"
50
+ #include "cutlass/tensor_view_planar_complex.h"
51
+
52
+ #include "device_memory.h"
53
+
54
+ namespace cutlass {
55
+
56
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ /// Host tensor
59
+ template <
60
+ /// Data type of element stored within tensor (concept: NumericType)
61
+ typename Element_,
62
+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout)
63
+ typename Layout_
64
+ >
65
+ class HostTensorPlanarComplex {
66
+ public:
67
+
68
+ /// Data type of individual access
69
+ using Element = Element_;
70
+
71
+ /// Mapping function from logical coordinate to linear memory
72
+ using Layout = Layout_;
73
+
74
+ /// Logical rank of tensor index space
75
+ static int const kRank = Layout::kRank;
76
+
77
+ /// Index type
78
+ using Index = typename Layout::Index;
79
+
80
+ /// Long index used for pointer offsets
81
+ using LongIndex = typename Layout::LongIndex;
82
+
83
+ /// Coordinate in logical tensor space
84
+ using TensorCoord = typename Layout::TensorCoord;
85
+
86
+ /// Layout's stride vector
87
+ using Stride = typename Layout::Stride;
88
+
89
+ /// Tensor reference to device memory
90
+ using TensorRef = TensorRefPlanarComplex<Element, Layout>;
91
+
92
+ /// Tensor reference to constant device memory
93
+ using ConstTensorRef = typename TensorRef::ConstTensorRef;
94
+
95
+ /// Tensor reference to device memory
96
+ using TensorView = TensorViewPlanarComplex<Element, Layout>;
97
+
98
+ /// Tensor reference to constant device memory
99
+ using ConstTensorView = typename TensorView::ConstTensorView;
100
+
101
+ /// Reference to element in tensor
102
+ using Reference = typename TensorRef::Reference;
103
+
104
+ /// Constant reference to element in tensor
105
+ using ConstReference = typename ConstTensorRef::Reference;
106
+
107
+ private:
108
+
109
+ //
110
+ // Data members
111
+ //
112
+
113
+ /// Extent of tensor in logical dimensions
114
+ TensorCoord extent_;
115
+
116
+ /// Layout object
117
+ Layout layout_;
118
+
119
+ /// Host-side memory allocation
120
+ std::vector<Element> host_;
121
+
122
+ /// Device-side memory
123
+ device_memory::allocation<Element> device_;
124
+
125
+ public:
126
+ //
127
+ // Device and Host Methods
128
+ //
129
+
130
+ /// Default constructor
131
+ HostTensorPlanarComplex() {}
132
+
133
+ /// Constructs a tensor given an extent. Assumes a packed layout
134
+ HostTensorPlanarComplex(
135
+ TensorCoord const &extent,
136
+ bool device_backed = true
137
+ ) {
138
+
139
+ this->reset(extent, Layout::packed(extent), device_backed);
140
+ }
141
+
142
+ /// Constructs a tensor given an extent and layout
143
+ HostTensorPlanarComplex(
144
+ TensorCoord const &extent,
145
+ Layout const &layout,
146
+ bool device_backed = true
147
+ ) {
148
+
149
+ this->reset(extent, layout, device_backed);
150
+ }
151
+
152
+ ~HostTensorPlanarComplex() { }
153
+
154
+ /// Clears the HostTensor allocation to size/capacity = 0
155
+ void reset() {
156
+ extent_ = TensorCoord();
157
+ layout_ = Layout::packed(extent_);
158
+
159
+ host_.clear();
160
+ device_.reset();
161
+ }
162
+
163
+ /// Resizes internal memory allocations without affecting layout or extent
164
+ void reserve(
165
+ size_t count, ///< size of tensor in elements
166
+ bool device_backed_ = true) { ///< if true, device memory is also allocated
167
+
168
+ device_.reset();
169
+ host_.clear();
170
+
171
+ host_.resize(count * 2);
172
+
173
+ // Allocate memory
174
+ Element* device_memory = nullptr;
175
+ if (device_backed_) {
176
+ device_memory = device_memory::allocate<Element>(count * 2);
177
+ }
178
+ device_.reset(device_memory, device_backed_ ? count * 2 : 0);
179
+ }
180
+
181
+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
182
+ /// extent and layout.
183
+ void reset(
184
+ TensorCoord const &extent, ///< extent of logical tensor
185
+ Layout const &layout, ///< layout object of tensor
186
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
187
+
188
+ extent_ = extent;
189
+ layout_ = layout;
190
+
191
+ reserve(size_t(layout_.capacity(extent_)), device_backed_);
192
+ }
193
+
194
+ /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
195
+ /// extent and layout. Assumes a packed tensor configuration.
196
+ void reset(
197
+ TensorCoord const &extent, ///< extent of logical tensor
198
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
199
+
200
+ reset(extent, Layout::packed(extent), device_backed_);
201
+ }
202
+
203
+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
204
+ /// To force allocation, call reset().
205
+ void resize(
206
+ TensorCoord const &extent, ///< extent of logical tensor
207
+ Layout const &layout, ///< layout object of tensor
208
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
209
+
210
+ extent_ = extent;
211
+ layout_ = layout;
212
+
213
+ LongIndex new_size = size_t(layout_.capacity(extent_));
214
+
215
+ if (static_cast<decltype(host_.size())>(new_size * 2) > host_.size()) {
216
+ reserve(new_size);
217
+ }
218
+ }
219
+
220
+ /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
221
+ /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration.
222
+ void resize(
223
+ TensorCoord const &extent, ///< extent of logical tensor
224
+ bool device_backed_ = true) { ///< if true, device memory is also allocated.
225
+
226
+ resize(extent, Layout::packed(extent), device_backed_);
227
+ }
228
+
229
+ /// Returns the number of elements stored in the host tensor
230
+ size_t size() const {
231
+ return host_.size() / 2;
232
+ }
233
+
234
+ /// Returns the logical capacity based on extent and layout. May differ from size().
235
+ LongIndex capacity() const {
236
+ return layout_.capacity(extent_);
237
+ }
238
+
239
+ /// Stride between real and imaginary parts
240
+ LongIndex imaginary_stride() const {
241
+ return host_.size() / 2;
242
+ }
243
+
244
+ /// Gets pointer to host data
245
+ Element * host_data() { return host_.data(); }
246
+
247
+ /// Gets pointer to host data imaginary part
248
+ Element * host_data_imag() { return host_.data() + imaginary_stride(); }
249
+
250
+ /// Gets pointer to host data with a pointer offset
251
+ Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; }
252
+
253
+ /// Gets pointer to host data with a pointer offset
254
+ Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; }
255
+
256
+ /// Gets a reference to an element in host memory
257
+ Reference host_data(LongIndex idx) {
258
+ return PlanarComplexReference<Element>(host_data() + idx, host_data_imag() + idx);
259
+ }
260
+
261
+ /// Gets pointer to host data
262
+ Element const * host_data() const { return host_.data(); }
263
+
264
+ /// Gets pointer to host data imaginary part
265
+ Element const * host_data_imag() const { return host_.data() + imaginary_stride(); }
266
+
267
+ /// Gets a constant reference to an element in host memory
268
+ ConstReference host_data(LongIndex idx) const {
269
+ return PlanarComplexReference<Element const>(host_data() + idx, host_data_imag() + idx);
270
+ }
271
+
272
+ /// Gets pointer to device data
273
+ Element * device_data() { return device_.get(); }
274
+
275
+ /// Gets pointer to device data with a pointer offset
276
+ Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; }
277
+
278
+ /// Gets pointer to device data
279
+ Element const * device_data() const { return device_.get(); }
280
+
281
+ /// Gets pointer to device data with a pointer offset
282
+ Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; }
283
+
284
+ /// Gets a pointer to the device data imaginary part
285
+ Element * device_data_imag() { return device_.get() + imaginary_stride(); }
286
+
287
+ /// Accesses the tensor reference pointing to data
288
+ TensorRef host_ref(LongIndex ptr_element_offset=0) {
289
+ return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
290
+ }
291
+
292
+ /// Returns a tensor reference to the real part of the tensor
293
+ cutlass::TensorRef<Element, Layout> host_ref_real() {
294
+ return cutlass::TensorRef<Element, Layout>(host_data(), layout_);
295
+ }
296
+
297
+ /// Returns a tensor reference to the real part of the tensor
298
+ cutlass::TensorRef<Element, Layout> host_ref_imag() {
299
+ return cutlass::TensorRef<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_);
300
+ }
301
+
302
+ /// Accesses the tensor reference pointing to data
303
+ ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const {
304
+ return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
305
+ }
306
+
307
+ /// Accesses the tensor reference pointing to data
308
+ TensorRef device_ref(LongIndex ptr_element_offset=0) {
309
+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
310
+ }
311
+
312
+ /// Accesses the tensor reference pointing to data
313
+ ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const {
314
+ return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
315
+ }
316
+
317
+ /// Returns a tensor reference to the real part of the tensor
318
+ cutlass::TensorRef<Element, Layout> device_ref_real() {
319
+ return cutlass::TensorRef<Element, Layout>(device_data(), layout_);
320
+ }
321
+
322
+ /// Returns a tensor reference to the real part of the tensor
323
+ cutlass::TensorRef<Element, Layout> device_ref_imag() {
324
+ return cutlass::TensorRef<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_);
325
+ }
326
+
327
+ /// Accesses the tensor reference pointing to data
328
+ TensorView host_view(LongIndex ptr_element_offset=0) {
329
+ return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
330
+ }
331
+
332
+ /// Accesses the tensor reference pointing to data
333
+ ConstTensorView host_view(LongIndex ptr_element_offset=0) const {
334
+ return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
335
+ }
336
+
337
+ /// Accesses the tensor reference pointing to data
338
+ cutlass::TensorView<Element, Layout> host_view_real() {
339
+ return cutlass::TensorView<Element, Layout>(host_data(), layout_, extent_);
340
+ }
341
+
342
+ /// Accesses the tensor reference pointing to data
343
+ cutlass::TensorView<Element, Layout> host_view_imag() {
344
+ return cutlass::TensorView<Element, Layout>(host_data_ptr_offset(imaginary_stride()), layout_, extent_);
345
+ }
346
+
347
+ /// Accesses the tensor reference pointing to data
348
+ TensorView device_view(LongIndex ptr_element_offset=0) {
349
+ return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
350
+ }
351
+
352
+ /// Accesses the tensor reference pointing to data
353
+ ConstTensorView device_view(LongIndex ptr_element_offset=0) const {
354
+ return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_);
355
+ }
356
+
357
+ /// Accesses the tensor reference pointing to data
358
+ cutlass::TensorView<Element, Layout> device_view_real() {
359
+ return cutlass::TensorView<Element, Layout>(device_data(), layout_, extent_);
360
+ }
361
+
362
+ /// Accesses the tensor reference pointing to data
363
+ cutlass::TensorView<Element, Layout> device_view_imag() {
364
+ return cutlass::TensorView<Element, Layout>(device_data_ptr_offset(imaginary_stride()), layout_, extent_);
365
+ }
366
+
367
+ /// Returns true if device memory is allocated
368
+ bool device_backed() const {
369
+ return (device_.get() == nullptr) ? false : true;
370
+ }
371
+
372
+ /// Returns the layout object
373
+ Layout layout() const {
374
+ return layout_;
375
+ }
376
+
377
+ /// Returns the layout object's stride vector
378
+ Stride stride() const {
379
+ return layout_.stride();
380
+ }
381
+
382
+ /// Returns the layout object's stride in a given physical dimension
383
+ Index stride(int dim) const {
384
+ return layout_.stride().at(dim);
385
+ }
386
+
387
+ /// Computes the offset of an index from the origin of the tensor
388
+ LongIndex offset(TensorCoord const& coord) const {
389
+ return layout_(coord);
390
+ }
391
+
392
+ /// Returns a reference to the element at the logical Coord in host memory
393
+ Reference at(TensorCoord const& coord) {
394
+ return host_data(offset(coord));
395
+ }
396
+
397
+ /// Returns a const reference to the element at the logical Coord in host memory
398
+ ConstReference at(TensorCoord const& coord) const {
399
+ return host_data(offset(coord));
400
+ }
401
+
402
+ /// Returns the extent of the tensor
403
+ TensorCoord extent() const {
404
+ return extent_;
405
+ }
406
+
407
+ /// Returns the extent of the tensor
408
+ TensorCoord & extent() {
409
+ return extent_;
410
+ }
411
+
412
+ /// Copies data from device to host
413
+ void sync_host() {
414
+ if (device_backed()) {
415
+ device_memory::copy_to_host(
416
+ host_data(), device_data(), imaginary_stride() * 2);
417
+ }
418
+ }
419
+
420
+ /// Copies data from host to device
421
+ void sync_device() {
422
+ if (device_backed()) {
423
+ device_memory::copy_to_device(
424
+ device_data(), host_data(), imaginary_stride() * 2);
425
+ }
426
+ }
427
+
428
+ /// Copy data from a caller-supplied device pointer into host memory.
429
+ void copy_in_device_to_host(
430
+ Element const* ptr_device_real, ///< source device memory
431
+ Element const* ptr_device_imag, ///< source device memory
432
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
433
+
434
+ if (count < 0) {
435
+ count = capacity();
436
+ }
437
+ else {
438
+ count = __NV_STD_MIN(capacity(), count);
439
+ }
440
+
441
+ device_memory::copy_to_host(
442
+ host_data(), ptr_device_real, count);
443
+
444
+ device_memory::copy_to_host(
445
+ host_data_imag(), ptr_device_imag, count);
446
+ }
447
+
448
+ /// Copy data from a caller-supplied device pointer into host memory.
449
+ void copy_in_device_to_device(
450
+ Element const* ptr_device_real, ///< source device memory
451
+ Element const* ptr_device_imag, ///< source device memory
452
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
453
+
454
+ if (count < 0) {
455
+ count = capacity();
456
+ }
457
+ else {
458
+ count = __NV_STD_MIN(capacity(), count);
459
+ }
460
+
461
+ device_memory::copy_device_to_device(
462
+ device_data(), ptr_device_real, count);
463
+
464
+ device_memory::copy_device_to_device(
465
+ device_data_imag(), ptr_device_imag, count);
466
+ }
467
+
468
+ /// Copy data from a caller-supplied device pointer into host memory.
469
+ void copy_in_host_to_device(
470
+ Element const* ptr_host_real, ///< source host memory
471
+ Element const* ptr_host_imag, ///< source host memory
472
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
473
+
474
+ if (count < 0) {
475
+ count = capacity();
476
+ }
477
+ else {
478
+ count = __NV_STD_MIN(capacity(), count);
479
+ }
480
+
481
+ device_memory::copy_to_device(
482
+ device_data(), ptr_host_real, count);
483
+
484
+ device_memory::copy_to_device(
485
+ device_data_imag(), ptr_host_imag, count);
486
+ }
487
+
488
+ /// Copy data from a caller-supplied device pointer into host memory.
489
+ void copy_in_host_to_host(
490
+ Element const* ptr_host_real, ///< source host memory
491
+ Element const* ptr_host_imag, ///< source host memory
492
+ LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
493
+
494
+ if (count < 0) {
495
+ count = capacity();
496
+ }
497
+ else {
498
+ count = __NV_STD_MIN(capacity(), count);
499
+ }
500
+
501
+ device_memory::copy_host_to_host(
502
+ host_data(), ptr_host_real, count);
503
+
504
+ device_memory::copy_host_to_host(
505
+ host_data_imag(), ptr_host_imag, count);
506
+ }
507
+
508
+ /// Copy data from a caller-supplied device pointer into host memory.
509
+ void copy_out_device_to_host(
510
+ Element * ptr_host_real, ///< source device memory
511
+ Element * ptr_host_imag, ///< source device memory
512
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
513
+
514
+ if (count < 0) {
515
+ count = capacity();
516
+ }
517
+ else {
518
+ count = __NV_STD_MIN(capacity(), count);
519
+ }
520
+
521
+ device_memory::copy_to_host(
522
+ ptr_host_real, device_data(), count);
523
+
524
+ device_memory::copy_to_host(
525
+ ptr_host_imag, device_data_imag(), count);
526
+ }
527
+
528
+ /// Copy data from a caller-supplied device pointer into host memory.
529
+ void copy_out_device_to_device(
530
+ Element * ptr_device_real, ///< source device memory
531
+ Element * ptr_device_imag, ///< source device memory
532
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
533
+
534
+ if (count < 0) {
535
+ count = capacity();
536
+ }
537
+ else {
538
+ count = __NV_STD_MIN(capacity(), count);
539
+ }
540
+
541
+ device_memory::copy_device_to_device(
542
+ ptr_device_real, device_data(), count);
543
+
544
+ device_memory::copy_device_to_device(
545
+ ptr_device_imag, device_data_imag(), count);
546
+ }
547
+
548
+ /// Copy data from a caller-supplied device pointer into host memory.
549
+ void copy_out_host_to_device(
550
+ Element * ptr_device_real, ///< source device memory
551
+ Element * ptr_device_imag, ///< source device memory
552
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
553
+
554
+ if (count < 0) {
555
+ count = capacity();
556
+ }
557
+ else {
558
+ count = __NV_STD_MIN(capacity(), count);
559
+ }
560
+
561
+ device_memory::copy_to_device(
562
+ ptr_device_real, host_data(), count);
563
+
564
+ device_memory::copy_to_device(
565
+ ptr_device_imag, host_data_imag(), count);
566
+ }
567
+
568
+ /// Copy data from a caller-supplied device pointer into host memory.
569
+ void copy_out_host_to_host(
570
+ Element * ptr_host_real, ///< source host memory
571
+ Element * ptr_host_imag, ///< source host memory
572
+ LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
573
+
574
+ if (count < 0) {
575
+ count = capacity();
576
+ }
577
+ else {
578
+ count = __NV_STD_MIN(capacity(), count);
579
+ }
580
+
581
+ device_memory::copy_host_to_host(
582
+ ptr_host_real, host_data(), count);
583
+
584
+ device_memory::copy_host_to_host(
585
+ ptr_host_imag, host_data_imag(), count);
586
+ }
587
+ };
588
+
589
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
590
+
591
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief uncompress sparse matrix from the host side
34
+ */
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/util/host_tensor.h"
39
+ #include "cutlass/tensor_view.h"
40
+ #include "cutlass/util/tensor_view_io.h"
41
+ #include "cutlass/util/reference/host/gemm.h"
42
+
43
+ namespace cutlass {
44
+
45
+ // uncompress sparse tensor core A matrix
46
+ template <typename ElementA, typename LayoutA, typename ElementE,
47
+ typename LayoutE>
48
+ void uncompress(TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
49
+ TensorRef<ElementA, LayoutA> tensor_a,
50
+ TensorRef<ElementE, LayoutE> tensor_e, int row, int col) {
51
+ // How many uncompressed data we can get with ElementE meta data
52
+ int DecompressedElementsPerElementE =
53
+ 256 / cutlass::sizeof_bits<ElementA>::value;
54
+
55
+ // Process 4bit meta data a time
56
+ int step;
57
+
58
+ // 1:2 or 2:4 or 4:8
59
+ int a, b;
60
+
61
+ if (cutlass::sizeof_bits<ElementA>::value == 4) {
62
+ step = 8;
63
+ a = 4;
64
+ b = 8;
65
+ } else if (cutlass::sizeof_bits<ElementA>::value == 8) {
66
+ step = 4;
67
+ a = 2;
68
+ b = 4;
69
+ } else if (cutlass::sizeof_bits<ElementA>::value == 16) {
70
+ step = 4;
71
+ a = 2;
72
+ b = 4;
73
+ } else if (cutlass::sizeof_bits<ElementA>::value == 32) {
74
+ step = 2;
75
+ a = 1;
76
+ b = 2;
77
+ }
78
+
79
+ int ElementsPerE = (cutlass::sizeof_bits<ElementA>::value == 4) ? 2 : 1;
80
+
81
+ for (int r = 0; r < row; ++r) {
82
+ for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) {
83
+
84
+ ElementE meta = tensor_e.at(MatrixCoord(r, c));
85
+
86
+ for (int i = 0; i < DecompressedElementsPerElementE; i += step) {
87
+ int e = (meta >> (i / step * 4)) & 0xf;
88
+ int idx0 = e & 0x3;
89
+ int idx1 = e >> 2;
90
+
91
+ if (a == 1) idx0 = idx0 / 2;
92
+
93
+ for (int ii = 0; ii < step; ii += ElementsPerE) {
94
+ int real_col =
95
+ c * DecompressedElementsPerElementE + i + ii;
96
+ int compressed_col = (real_col / b) * a;
97
+
98
+ if (ii == (idx0 * ElementsPerE)) {
99
+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
100
+ tensor_a.at(MatrixCoord(r, compressed_col));
101
+ if (ElementsPerE == 2)
102
+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
103
+ tensor_a.at(MatrixCoord(r, compressed_col + 1));
104
+ } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) {
105
+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
106
+ tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE));
107
+ if (ElementsPerE == 2)
108
+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
109
+ tensor_a.at(
110
+ MatrixCoord(r, compressed_col + ElementsPerE + 1));
111
+ } else {
112
+ uncompressed_tensor_a.at(MatrixCoord(r, real_col)) =
113
+ ElementA(0);
114
+ if (ElementsPerE == 2)
115
+ uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) =
116
+ ElementA(0);
117
+ }
118
+ }
119
+ }
120
+ }
121
+ }
122
+ }
123
+
124
+ // uncompress ELL block sparse matrix
125
+ template <typename ElementA, typename LayoutA,
126
+ typename ElementE, typename LayoutE>
127
+ void uncompress_ell_block_sparse(
128
+ TensorRef<ElementA, LayoutA> uncompressed_tensor_a,
129
+ TensorRef<ElementA, LayoutA> tensor_a,
130
+ TensorRef<ElementE, LayoutE> ell_idx,
131
+ int rows, int cols,
132
+ int ell_num_cols, int ell_blocksize) {
133
+
134
+ for (int r = 0; r < rows / ell_blocksize; ++r) {
135
+ for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) {
136
+
137
+ ElementE idx = ell_idx.at(MatrixCoord(r, c));
138
+
139
+ if (idx != -1) {
140
+ int row_begin = r * ell_blocksize;
141
+ int col_begin_real = idx * ell_blocksize;
142
+ int col_begin = c * ell_blocksize;
143
+
144
+ for (int i = 0; i < ell_blocksize; ++i) {
145
+ for (int j = 0; j < ell_blocksize; ++j) {
146
+ uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) =
147
+ tensor_a.at(
148
+ MatrixCoord(row_begin + i, col_begin +j));
149
+ }
150
+ }
151
+ }
152
+ }
153
+ }
154
+ }
155
+
156
+ } // namespace cutlass
157
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include "cutlass/cutlass.h"
35
+ #include "cutlass/numeric_types.h"
36
+
37
+ // integer_sequence moved to cutlass/numeric_types.h
38
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Utilities for mixed input data type kernels.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include <cuda.h>
38
+ #include "cute/layout.hpp"
39
+ #include "cute/tensor.hpp"
40
+ #include "cute/arch/mma_sm90.hpp"
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/util/device_memory.h"
43
+ #include "cutlass/util/reference/device/tensor_fill.h"
44
+ #include "cute/util/type_traits.hpp"
45
+
46
+ namespace cutlass {
47
+
48
+ #define CUDA_CHECK(status) \
49
+ { \
50
+ cudaError_t error = status; \
51
+ if (error != cudaSuccess) { \
52
+ std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
53
+ << " at line: " << __LINE__ << std::endl; \
54
+ exit(EXIT_FAILURE); \
55
+ } \
56
+ }
57
+
58
+ template <
59
+ class QuantizedElement,
60
+ class DequantizedElement,
61
+ class OperandLayout,
62
+ class ElementScale,
63
+ class ElementZero,
64
+ class ScaleBroadCastLayout,
65
+ class ThrLayout>
66
+ __global__ void dequantize_kernel(DequantizedElement* dq_buffer,
67
+ QuantizedElement const* q_buffer,
68
+ OperandLayout const operand_layout,
69
+ ElementScale const* scale_buffer,
70
+ ElementZero const* zero_buffer,
71
+ ScaleBroadCastLayout const broadcasted_scale_layout,
72
+ ThrLayout thr_layout) {
73
+ using namespace cute;
74
+
75
+ // Represent the full tensors to gmem elements.
76
+ // These are expected to have shape [MN, K, L]
77
+ cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
78
+ cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr<QuantizedElement const>(q_buffer), operand_layout);
79
+ // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
80
+ // It is expected that K % G == 0
81
+ cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
82
+ cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
83
+
84
+ // Assign 1 thread per element in the thread block
85
+ auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); //
86
+ auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
87
+
88
+ // Tile across the block
89
+ auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
90
+ auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
91
+ auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
92
+ auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
93
+
94
+ auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
95
+ auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
96
+ auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
97
+ auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
98
+
99
+ // Make a fragment of registers to hold gmem loads
100
+ cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
101
+ cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
102
+ cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
103
+ cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
104
+ cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
105
+ cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
106
+
107
+ cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
108
+ auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
109
+ auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
110
+
111
+ const auto num_iters = cute::size<3>(tOpDq_gOpDq);
112
+
113
+ for (int ii = 0; ii < num_iters; ++ii) {
114
+ const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
115
+ if (thread_offset < cute::size<0>(operand_layout)) {
116
+ cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
117
+ cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
118
+ cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
119
+ cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
120
+ cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
121
+ cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{});
122
+ cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{});
123
+ cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
124
+ cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
125
+ }
126
+ }
127
+ }
128
+
129
+ template <
130
+ class QuantizedElement,
131
+ class DequantizedElement,
132
+ class OperandLayout,
133
+ class ElementScale,
134
+ class ElementZero,
135
+ class ScaleLayout>
136
+ static void dequantize(DequantizedElement* dq_buffer,
137
+ QuantizedElement const* q_buffer,
138
+ OperandLayout const operand_layout,
139
+ ElementScale const* scale_buffer,
140
+ ElementZero const* zero_buffer,
141
+ ScaleLayout const scale_layout,
142
+ int const group_size,
143
+ cudaStream_t &stream) {
144
+ using namespace cute;
145
+
146
+ constexpr int tpb = 128;
147
+ auto thr_layout = make_layout(make_shape(Int<tpb>{}));
148
+
149
+ const auto num_rows = get<0>(shape(operand_layout));
150
+ const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
151
+ const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
152
+ const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
153
+
154
+ if (num_rows != size<0>(scale_layout)) {
155
+ std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
156
+ << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
157
+ << std::endl;
158
+ exit(-1);
159
+ }
160
+
161
+ const auto scale_stride0 = get<0>(stride(scale_layout));
162
+ const auto scale_stride1 = get<1>(stride(scale_layout));
163
+ const auto scale_stride2 = get<2>(stride(scale_layout));
164
+
165
+ auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
166
+ auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
167
+ auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
168
+
169
+ const auto blocks_x = gemm_k;
170
+ const auto blocks_y = batches;
171
+
172
+ dim3 blocks(blocks_x, blocks_y, 1);
173
+ dequantize_kernel<<<blocks, tpb, 0, stream>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
174
+ CUDA_CHECK(cudaStreamSynchronize(stream));
175
+ }
176
+
177
+ template <typename T>
178
+ class packed_scale_t {
179
+ public:
180
+ static_assert(cute::is_same_v<T, cutlass::int8_t> ||
181
+ cute::is_same_v<T, cutlass::uint8_t> ||
182
+ cute::is_same_v<T, cutlass::float_e4m3_t> ||
183
+ cute::is_same_v<T, cutlass::float_e5m2_t>,
184
+ "only 8 bit arithmetic types are supported.");
185
+ CUTLASS_HOST_DEVICE
186
+ explicit packed_scale_t(T val) {
187
+ if constexpr (!cute::is_unsigned_v<T>) {
188
+ // Only pack negative values. The positive values are generated in flight in the mainloop.
189
+ storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
190
+ storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
191
+ }
192
+ else {
193
+ storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
194
+ storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
195
+ }
196
+ }
197
+ CUTLASS_HOST_DEVICE
198
+ packed_scale_t() = default;
199
+ CUTLASS_HOST_DEVICE
200
+ explicit operator float() const {
201
+ return float(get());
202
+ }
203
+ CUTLASS_HOST_DEVICE
204
+ bool operator==(packed_scale_t const& rhs) const {
205
+ return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
206
+ }
207
+ CUTLASS_HOST_DEVICE
208
+ bool operator!=(packed_scale_t const& rhs) const {
209
+ return !(*this == rhs);
210
+ }
211
+ CUTLASS_HOST_DEVICE
212
+ friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
213
+ return packed_scale_t(lhs.get() + rhs.get());
214
+ }
215
+ CUTLASS_HOST_DEVICE
216
+ friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
217
+ return packed_scale_t(lhs.get() - rhs.get());
218
+ }
219
+ CUTLASS_HOST_DEVICE
220
+ friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
221
+ return packed_scale_t(lhs.get() * rhs.get());
222
+ }
223
+ CUTLASS_HOST_DEVICE
224
+ friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
225
+ return packed_scale_t(lhs.get() / rhs.get());
226
+ }
227
+
228
+ private:
229
+ using Storage = uint32_t;
230
+ using Stage = uint8_t;
231
+
232
+ Storage storage[2] {};
233
+
234
+ CUTLASS_HOST_DEVICE
235
+ static Storage pack4(T c1, T c2, T c3, T c4) {
236
+ Storage result = 0;
237
+ result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
238
+ result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
239
+ result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
240
+ result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
241
+ return result;
242
+ }
243
+ CUTLASS_HOST_DEVICE
244
+ T get() const {
245
+ auto stage = static_cast<Stage>(storage[0] >> 8);
246
+ #if defined(__CUDA_ARCH__)
247
+ return reinterpret_cast<T const&>(stage);
248
+ #else
249
+ T tmp;
250
+ std::memcpy(&tmp, &stage, sizeof(Stage));
251
+ return tmp;
252
+ #endif
253
+ }
254
+ CUTLASS_HOST_DEVICE
255
+ T get(int idx) const {
256
+ Stage stage;
257
+ if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
258
+ else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
259
+ #if defined(__CUDA_ARCH__)
260
+ return reinterpret_cast<T const&>(stage);
261
+ #else
262
+ T tmp;
263
+ std::memcpy(&tmp, &stage, sizeof(Stage));
264
+ return tmp;
265
+ #endif
266
+ }
267
+ };
268
+
269
+ // In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
270
+ // Here the encodings of positive values and negative values are unified (except for the sign bit).
271
+ // For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
272
+ static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) {
273
+
274
+ using StorageType = cutlass::int4b_t::Storage;
275
+ constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
276
+ const size_t host_buf_size = block_size / pack;
277
+ std::vector<StorageType> host_buf(host_buf_size);
278
+ cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size);
279
+
280
+ for (auto&& d : host_buf) {
281
+ StorageType out = 0;
282
+ StorageType mask = 0x0f;
283
+ for (int i = 0; i < pack; i++) {
284
+ cutlass::int4b_t curr;
285
+ curr.storage = (d >> (i * 4)) & 0x0f;
286
+ switch (curr) {
287
+ case 1: curr.storage = StorageType(0b0111); break; // 2's complement
288
+ case 2: curr.storage = StorageType(0b0110); break; // 2's complement
289
+ case 3: curr.storage = StorageType(0b0101); break; // 2's complement
290
+ case 4: curr.storage = StorageType(0b0100); break; // 2's complement
291
+ case 5: curr.storage = StorageType(0b0011); break; // 2's complement
292
+ case 6: curr.storage = StorageType(0b0010); break; // 2's complement
293
+ case 7: curr.storage = StorageType(0b0001); break; // 2's complement
294
+ default: break;
295
+ }
296
+ out |= (curr.storage << (4 * i)) & mask;
297
+ mask <<= 4;
298
+ }
299
+ d = out;
300
+ }
301
+
302
+ cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size);
303
+ return true;
304
+ }
305
+
306
+ template <class ElementScale>
307
+ static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array<ElementScale, 8> *block_out, const size_t block_size) {
308
+ std::vector<ElementScale> data_in(block_size);
309
+ std::vector<cutlass::Array<ElementScale, 8>> data_out(block_size);
310
+
311
+ try {
312
+ cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size);
313
+ }
314
+ catch (cutlass::cuda_exception const& e) {
315
+ std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
316
+ return false;
317
+ }
318
+
319
+ for (size_t i = 0; i < block_size; i++) {
320
+ cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
321
+ data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
322
+ }
323
+
324
+ try {
325
+ cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size);
326
+ }
327
+ catch (cutlass::cuda_exception const& e) {
328
+ std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
329
+ return false;
330
+ }
331
+ return true;
332
+ }
333
+
334
+ template <class T, class = void>
335
+ struct UnderlyingElement {
336
+ using type = T;
337
+ };
338
+
339
+ template <class T>
340
+ struct UnderlyingElement<T, cute::void_t<typename T::Element>> {
341
+ using type = typename T::Element;
342
+ };
343
+
344
+ // Given a type of MMA instruction, compute a memory reordering atom that places all values
345
+ // owned by each thread in contiguous memory locations. This improves smem load vectorization,
346
+ // particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
347
+ // of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
348
+ // In addition, we can reorder the values across several MMA instructions to get even wider
349
+ // vectorization (AtomLayout parameter) and permute the values within each instruction to get
350
+ // more optimal conversion instruction sequences (ValLayout parameter).
351
+ template <class ElementMma,
352
+ class AtomLayout = cute::Layout<cute::_1>,
353
+ class ValLayout = cute::Layout<cute::_1>>
354
+ constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
355
+ {
356
+ using namespace cute;
357
+
358
+ static_assert(is_static_v<ValLayout>, "ValLayout must be static");
359
+ static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
360
+
361
+ // 1. Choose an MMA atom to access TV layout and MN shape
362
+ // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
363
+ using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
364
+ using MmaTraits = MMA_Traits<MmaAtom>;
365
+ auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
366
+ auto tv_layout_mma = typename MmaTraits::ALayout{};
367
+ static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
368
+
369
+ // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
370
+ // Note: this assumes A is partitioned between warps along M mode
371
+ auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
372
+ auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
373
+ auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
374
+ auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
375
+
376
+ // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
377
+ auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
378
+
379
+ // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
380
+ auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
381
+ auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
382
+ auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
383
+ auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
384
+
385
+ return layout_atom;
386
+ }
387
+
388
+ template <class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
389
+ __global__ void reorder_tensor_kernel(
390
+ cute::Tensor<EngineSrc, LayoutSrc> S,
391
+ cute::Tensor<EngineDst, LayoutDst> D,
392
+ TiledCopy tiled_copy)
393
+ {
394
+ using namespace cute;
395
+
396
+ using T = typename EngineDst::value_type;
397
+
398
+ Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
399
+ Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
400
+
401
+ auto thread_copy = tiled_copy.get_slice(threadIdx.x);
402
+ Tensor tS = thread_copy.partition_S(gS);
403
+ Tensor tD = thread_copy.partition_D(gD);
404
+
405
+ copy(tiled_copy, tS, tD);
406
+ }
407
+
408
+ template <class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
409
+ void reorder_tensor(
410
+ cute::Tensor<EngineSrc, LayoutSrc> S,
411
+ cute::Tensor<EngineDst, LayoutDst> D)
412
+ {
413
+ using namespace cute;
414
+
415
+ using T = typename EngineDst::value_type;
416
+ static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
417
+
418
+ // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
419
+ // This avoids a race condition when writing out subbyte types (e.g. int4b_t).
420
+ auto has_major_mode = [](auto s) {
421
+ return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; });
422
+ };
423
+ static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
424
+ "Could not find stride-1 mode in destination layout");
425
+ constexpr int N = shape_div(Int<8>{}, Int<sizeof_bits_v<T>>{});
426
+ auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
427
+ make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
428
+ make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
429
+
430
+ // Make a tiled copy with a simple row-major thread order and above layout
431
+ int constexpr NumThreads = 128;
432
+ auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
433
+ auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
434
+
435
+ // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
436
+ using TileShape = Shape<_16>;
437
+ auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
438
+ dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
439
+
440
+ reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
441
+ CUDA_CHECK(cudaDeviceSynchronize());
442
+ }
443
+
444
+ // In-place version
445
+ template <class T, class LayoutSrc, class LayoutDst>
446
+ void reorder_tensor(
447
+ T const* src,
448
+ LayoutSrc const& layout_src,
449
+ T * dst,
450
+ LayoutDst const& layout_dst)
451
+ {
452
+ using namespace cute;
453
+ reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
454
+ make_tensor(make_gmem_ptr<T>(dst), layout_dst));
455
+ }
456
+
457
+ // In-place version
458
+ template <class T, class LayoutSrc, class LayoutDst>
459
+ void reorder_tensor(
460
+ T * data,
461
+ LayoutSrc const& layout_src,
462
+ LayoutDst const& layout_dst)
463
+ {
464
+ using namespace cute;
465
+ cutlass::DeviceAllocation<T> temp(size(layout_src));
466
+ reorder_tensor(data, layout_src, temp.get(), layout_dst);
467
+ cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
468
+ }
469
+
470
+ #undef CUDA_CHECK
471
+
472
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cute/layout.hpp"
38
+ #include "cute/container/array.hpp" // cute::array
39
+ #include "cutlass/conv/convolution.h" // cutlass::conv::Operator
40
+
41
+ /////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ // Strides without batch mode
48
+
49
+ template <class IntT>
50
+ CUTLASS_HOST_DEVICE
51
+ cute::Stride<IntT, cute::Int<1>>
52
+ make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int,int,int> shape_MKL) {
53
+ static_assert(std::is_integral_v<IntT>,
54
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
55
+ auto s_copy = s;
56
+ cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
57
+ return s_copy;
58
+ }
59
+
60
+ template <class IntT>
61
+ CUTLASS_HOST_DEVICE
62
+ cute::Stride<cute::Int<1>, IntT>
63
+ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int,int,int> shape_MKL) {
64
+ static_assert(std::is_integral_v<IntT>,
65
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
66
+ auto s_copy = s;
67
+ cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
68
+ return s_copy;
69
+ }
70
+
71
+ /////////////////////////////////////////////////////////////////////////////////////////////////
72
+
73
+ // Strides with batch mode
74
+
75
+ template <class IntT>
76
+ CUTLASS_HOST_DEVICE
77
+ cute::Stride<IntT, cute::Int<1>, int64_t>
78
+ make_cute_packed_stride(cute::Stride<IntT, cute::Int<1>, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
79
+ static_assert(std::is_integral_v<IntT>,
80
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
81
+ auto s_copy = s;
82
+ cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MKL));
83
+ int batch_count = cute::get<2>(shape_MKL);
84
+ if (batch_count > 1) {
85
+ cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
86
+ }
87
+ else {
88
+ cute::get<2>(s_copy) = static_cast<IntT>(0);
89
+ }
90
+ return s_copy;
91
+ }
92
+
93
+ template <class IntT>
94
+ CUTLASS_HOST_DEVICE
95
+ cute::Stride<cute::Int<1>, IntT, int64_t>
96
+ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape<int,int,int> shape_MKL) {
97
+ static_assert(std::is_integral_v<IntT>,
98
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
99
+ auto s_copy = s;
100
+ cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL));
101
+ int batch_count = cute::get<2>(shape_MKL);
102
+ if (batch_count > 1) {
103
+ cute::get<2>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL));
104
+ }
105
+ else {
106
+ cute::get<2>(s_copy) = static_cast<IntT>(0);
107
+ }
108
+ return s_copy;
109
+ }
110
+
111
+ /////////////////////////////////////////////////////////////////////////////////////////////////
112
+
113
+ // Strides with group mode
114
+
115
+ template <class StrideIntT>
116
+ CUTLASS_HOST_DEVICE
117
+ cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
118
+ make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
119
+ static_assert(std::is_integral_v<StrideIntT>,
120
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
121
+ auto s_copy = s;
122
+ cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
123
+ return s_copy;
124
+ }
125
+
126
+ template <class StrideIntT>
127
+ CUTLASS_HOST_DEVICE
128
+ cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
129
+ make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
130
+ static_assert(std::is_integral_v<StrideIntT>,
131
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
132
+ auto s_copy = s;
133
+ cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
134
+ return s_copy;
135
+ }
136
+
137
+ /////////////////////////////////////////////////////////////////////////////////////////////////
138
+
139
+ // Strides for convolutions
140
+
141
+ // Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
142
+ // Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order
143
+ // and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout
144
+ // right in KTRSC order and can be coalesced to just k.
145
+ // We enforce this condition here with asserts.
146
+ template <class IntT, size_t RankT_>
147
+ CUTLASS_HOST_DEVICE
148
+ cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
149
+ make_cute_packed_stride(
150
+ cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
151
+ cute::array<int32_t, RankT_> shape_output,
152
+ cute::array<IntT, RankT_> stride_output,
153
+ cutlass::conv::Operator conv_op) {
154
+ static_assert(std::is_integral_v<IntT>,
155
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
156
+ static_assert(RankT_ >= 3u);
157
+ constexpr static int RankT = static_cast<int>(RankT_);
158
+
159
+ assert(stride_output[RankT-1] == 1);
160
+ cute::for_each(cute::make_seq<RankT-2>{}, [&](auto i) {
161
+ assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]);
162
+ });
163
+
164
+ auto s_copy = s;
165
+ cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ?
166
+ stride_output[0] :
167
+ stride_output[RankT-2];
168
+ return s_copy;
169
+ }
170
+
171
+ //
172
+ // Activation tensor ((w, h, d, n), _1) for fprop kernel
173
+ //
174
+
175
+ // Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
176
+ template <class IntT>
177
+ CUTLASS_HOST_DEVICE
178
+ cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
179
+ make_cute_packed_stride(
180
+ cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
181
+ cute::array<IntT, 3> stride_nwc,
182
+ conv::Operator ConvOp) {
183
+ static_assert(std::is_integral_v<IntT>,
184
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
185
+ assert(stride_nwc[2] == 1);
186
+ auto s_copy = s;
187
+ cute::get<0,0>(s_copy) = stride_nwc[1];
188
+ cute::get<0,1>(s_copy) = stride_nwc[0];
189
+ return s_copy;
190
+ }
191
+
192
+ // Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
193
+ template <class IntT>
194
+ CUTLASS_HOST_DEVICE
195
+ cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
196
+ make_cute_packed_stride(
197
+ cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
198
+ cute::array<IntT, 4> stride_nhwc,
199
+ conv::Operator ConvOp) {
200
+ static_assert(std::is_integral_v<IntT>,
201
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
202
+ assert(stride_nhwc[3] == 1);
203
+ auto s_copy = s;
204
+ cute::for_each(cute::make_seq<3>{}, [&](auto i) {
205
+ cute::get<0,i>(s_copy) = stride_nhwc[2-i];
206
+ });
207
+ return s_copy;
208
+ }
209
+
210
+ // Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
211
+ template <class IntT>
212
+ CUTLASS_HOST_DEVICE
213
+ cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
214
+ make_cute_packed_stride(
215
+ cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
216
+ cute::array<IntT, 5> stride_ndhwc,
217
+ conv::Operator ConvOp) {
218
+ static_assert(std::is_integral_v<IntT>,
219
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
220
+
221
+ assert(stride_ndhwc[4] == 1);
222
+ auto s_copy = s;
223
+ cute::for_each(cute::make_seq<4>{}, [&](auto i) {
224
+ cute::get<0,i>(s_copy) = stride_ndhwc[3-i];
225
+ });
226
+ return s_copy;
227
+ }
228
+
229
+ //
230
+ // Filter tensor (k, (_1, s, r, t)) for fprop kernel
231
+ //
232
+
233
+ // Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
234
+ template <class IntT>
235
+ CUTLASS_HOST_DEVICE
236
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
237
+ make_cute_packed_stride(
238
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
239
+ cute::array<IntT, 3> stride_ksc,
240
+ conv::Operator ConvOp) {
241
+ static_assert(std::is_integral_v<IntT>,
242
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
243
+
244
+ assert(stride_ksc[2] == 1);
245
+ auto s_copy = s;
246
+ cute::get<0,0>(s_copy) = stride_ksc[0];
247
+ cute::get<1,1>(s_copy) = stride_ksc[1];
248
+ return s_copy;
249
+ }
250
+
251
+ // Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
252
+ template <class IntT>
253
+ CUTLASS_HOST_DEVICE
254
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
255
+ make_cute_packed_stride(
256
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
257
+ cute::array<IntT, 4> stride_krsc,
258
+ conv::Operator ConvOp) {
259
+ static_assert(std::is_integral_v<IntT>,
260
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
261
+
262
+ assert(stride_krsc[3] == 1);
263
+ auto s_copy = s;
264
+ cute::get<0,0>(s_copy) = stride_krsc[0];
265
+ cute::for_each(cute::make_seq<2>{}, [&](auto i) {
266
+ cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
267
+ });
268
+ return s_copy;
269
+ }
270
+
271
+ // Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
272
+ template <class IntT>
273
+ CUTLASS_HOST_DEVICE
274
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
275
+ make_cute_packed_stride(
276
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
277
+ cute::array<IntT, 5> stride_ktrsc,
278
+ conv::Operator ConvOp) {
279
+ static_assert(std::is_integral_v<IntT>,
280
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
281
+
282
+ assert(stride_ktrsc[4] == 1);
283
+ auto s_copy = s;
284
+ cute::get<0,0>(s_copy) = stride_ktrsc[0];
285
+ cute::for_each(cute::make_seq<3>{}, [&](auto i) {
286
+ cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
287
+ });
288
+ return s_copy;
289
+ }
290
+
291
+ //
292
+ // Activation tensor (_1, (w, h, d, n)) for wgrad kernel
293
+ //
294
+ // It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel
295
+ //
296
+
297
+ // Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
298
+ // Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
299
+ template <class IntT>
300
+ CUTLASS_HOST_DEVICE
301
+ cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
302
+ make_cute_packed_stride(
303
+ cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
304
+ cute::array<IntT, 3> stride_nwc,
305
+ conv::Operator ConvOp) {
306
+ static_assert(std::is_integral_v<IntT>,
307
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
308
+
309
+ assert(stride_nwc[2] == 1);
310
+ auto s_copy = s;
311
+ if (ConvOp == cutlass::conv::Operator::kWgrad) {
312
+ cute::get<1,0>(s_copy) = stride_nwc[1];
313
+ cute::get<1,1>(s_copy) = stride_nwc[0];
314
+ }
315
+ else if (ConvOp == cutlass::conv::Operator::kDgrad) {
316
+ // stride_nwc in dgrad is ksc.
317
+ cute::get<1,0>(s_copy) = stride_nwc[0];
318
+ cute::get<1,1>(s_copy) = stride_nwc[1];
319
+ }
320
+ return s_copy;
321
+ }
322
+
323
+ // Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
324
+ // Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
325
+ template <class IntT>
326
+ CUTLASS_HOST_DEVICE
327
+ cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
328
+ make_cute_packed_stride(
329
+ cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
330
+ cute::array<IntT, 4> stride_nhwc,
331
+ conv::Operator ConvOp) {
332
+ static_assert(std::is_integral_v<IntT>,
333
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
334
+
335
+ assert(stride_nhwc[3] == 1);
336
+ auto s_copy = s;
337
+ if (ConvOp == cutlass::conv::Operator::kWgrad) {
338
+ cute::for_each(cute::make_seq<3>{}, [&](auto i) {
339
+ cute::get<1,i>(s_copy) = stride_nhwc[2-i];
340
+ });
341
+ }
342
+ else if (ConvOp == cutlass::conv::Operator::kDgrad) {
343
+ // stride_nhwc in dgrad is krsc.
344
+ cute::get<1,0>(s_copy) = stride_nhwc[0];
345
+ cute::for_each(cute::make_seq<2>{}, [&](auto i) {
346
+ cute::get<1,2-i>(s_copy) = stride_nhwc[i+1];
347
+ });
348
+ }
349
+ return s_copy;
350
+ }
351
+
352
+ // Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
353
+ // Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
354
+ template <class IntT>
355
+ CUTLASS_HOST_DEVICE
356
+ cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
357
+ make_cute_packed_stride(
358
+ cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
359
+ cute::array<IntT, 5> stride_ndhwc,
360
+ conv::Operator ConvOp) {
361
+ static_assert(std::is_integral_v<IntT>,
362
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
363
+
364
+ assert(stride_ndhwc[4] == 1);
365
+ auto s_copy = s;
366
+ if (ConvOp == cutlass::conv::Operator::kWgrad) {
367
+ cute::for_each(cute::make_seq<4>{}, [&](auto i) {
368
+ cute::get<1,i>(s_copy) = stride_ndhwc[3-i];
369
+ });
370
+ }
371
+ else if (ConvOp == cutlass::conv::Operator::kDgrad) {
372
+ // stride_ndhwc in dgrad is ktrsc.
373
+ cute::get<1,0>(s_copy) = stride_ndhwc[0];
374
+ cute::for_each(cute::make_seq<3>{}, [&](auto i) {
375
+ cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1];
376
+ });
377
+ }
378
+ return s_copy;
379
+ }
380
+
381
+ //
382
+ // NZPQ tensor (_1, nzpq) for wgrad kernel
383
+ //
384
+
385
+ // cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
386
+ template <class IntT>
387
+ CUTLASS_HOST_DEVICE
388
+ cute::Stride<cute::Int<1>, IntT>
389
+ make_cute_packed_stride(
390
+ cute::Stride<cute::Int<1>, IntT> s,
391
+ cute::array<IntT, 3> stride_nqk,
392
+ conv::Operator ConvOp) {
393
+ static_assert(std::is_integral_v<IntT>,
394
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
395
+
396
+ assert(stride_nqk[2] == 1);
397
+ auto s_copy = s;
398
+ cute::get<1>(s_copy) = stride_nqk[1];
399
+ return s_copy;
400
+ }
401
+
402
+ // cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
403
+ template <class IntT>
404
+ CUTLASS_HOST_DEVICE
405
+ cute::Stride<cute::Int<1>, IntT>
406
+ make_cute_packed_stride(
407
+ cute::Stride<cute::Int<1>, IntT> s,
408
+ cute::array<IntT, 4> stride_npqk,
409
+ conv::Operator ConvOp) {
410
+ static_assert(std::is_integral_v<IntT>,
411
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
412
+
413
+ assert(stride_npqk[3] == 1);
414
+ auto s_copy = s;
415
+ cute::get<1>(s_copy) = stride_npqk[2];
416
+ return s_copy;
417
+ }
418
+
419
+ // cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
420
+ template <class IntT>
421
+ CUTLASS_HOST_DEVICE
422
+ cute::Stride<cute::Int<1>, IntT>
423
+ make_cute_packed_stride(
424
+ cute::Stride<cute::Int<1>, IntT> s,
425
+ cute::array<IntT, 5> stride_nzpqk,
426
+ conv::Operator ConvOp) {
427
+ static_assert(std::is_integral_v<IntT>,
428
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
429
+
430
+ assert(stride_nzpqk[4] == 1);
431
+ auto s_copy = s;
432
+ cute::get<1>(s_copy) = stride_nzpqk[3];
433
+ return s_copy;
434
+ }
435
+
436
+
437
+
438
+ //
439
+ // Wgrad output tensor (k, (_1, s, r, t), _0)
440
+ //
441
+
442
+ // Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
443
+ template <class IntT>
444
+ CUTLASS_HOST_DEVICE
445
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
446
+ make_cute_packed_stride(
447
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
448
+ [[maybe_unused]] cute::array<int32_t, 3> shape_output,
449
+ cute::array<IntT, 3> stride_ksc,
450
+ conv::Operator ConvOp) {
451
+ static_assert(std::is_integral_v<IntT>,
452
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
453
+
454
+ assert(stride_ksc[2] == 1);
455
+ auto s_copy = s;
456
+ cute::get<0,0>(s_copy) = stride_ksc[0];
457
+ cute::get<1,1>(s_copy) = stride_ksc[1];
458
+ return s_copy;
459
+ }
460
+
461
+ // Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0)
462
+ template <class IntT>
463
+ CUTLASS_HOST_DEVICE
464
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>>
465
+ make_cute_packed_stride(
466
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>, cute::Int<0>> s,
467
+ [[maybe_unused]] cute::array<int32_t, 4> shape_output,
468
+ cute::array<IntT, 4> stride_krsc,
469
+ conv::Operator ConvOp) {
470
+ static_assert(std::is_integral_v<IntT>,
471
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
472
+
473
+ assert(stride_krsc[3] == 1);
474
+ auto s_copy = s;
475
+ cute::get<0,0>(s_copy) = stride_krsc[0];
476
+ cute::for_each(cute::make_seq<2>{}, [&](auto i) {
477
+ cute::get<1,2-i>(s_copy) = stride_krsc[i+1];
478
+ });
479
+ return s_copy;
480
+ }
481
+
482
+ // Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
483
+ template <class IntT>
484
+ CUTLASS_HOST_DEVICE
485
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
486
+ make_cute_packed_stride(
487
+ cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,
488
+ [[maybe_unused]] cute::array<int32_t, 5> shape_output,
489
+ cute::array<IntT, 5> stride_ktrsc,
490
+ conv::Operator ConvOp) {
491
+ static_assert(std::is_integral_v<IntT>,
492
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
493
+
494
+ assert(stride_ktrsc[4] == 1);
495
+ auto s_copy = s;
496
+ cute::get<0,0>(s_copy) = stride_ktrsc[0];
497
+ cute::for_each(cute::make_seq<3>{}, [&](auto i) {
498
+ cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1];
499
+ });
500
+ return s_copy;
501
+ }
502
+
503
+
504
+ //
505
+ // Wgrad output tensor ((_1, s, r, t), k, _0)
506
+ //
507
+
508
+ // Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0)
509
+ template <class IntT>
510
+ CUTLASS_HOST_DEVICE
511
+ cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>>
512
+ make_cute_packed_stride(
513
+ cute::Stride<cute::Stride<cute::Int<1>, IntT>, IntT, cute::Int<0>> s,
514
+ [[maybe_unused]] cute::array<int32_t, 3> shape_output,
515
+ cute::array<IntT, 3> stride_ksc,
516
+ conv::Operator ConvOp) {
517
+ static_assert(std::is_integral_v<IntT>,
518
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
519
+
520
+ assert(stride_ksc[2] == 1);
521
+ auto s_copy = s;
522
+ cute::get<1,0>(s_copy) = stride_ksc[0];
523
+ cute::get<0,1>(s_copy) = stride_ksc[1];
524
+ return s_copy;
525
+ }
526
+
527
+ // Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0)
528
+ template <class IntT>
529
+ CUTLASS_HOST_DEVICE
530
+ cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>>
531
+ make_cute_packed_stride(
532
+ cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT>, IntT, cute::Int<0>> s,
533
+ [[maybe_unused]] cute::array<int32_t, 4> shape_output,
534
+ cute::array<IntT, 4> stride_krsc,
535
+ conv::Operator ConvOp) {
536
+ static_assert(std::is_integral_v<IntT>,
537
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
538
+
539
+ assert(stride_krsc[3] == 1);
540
+ auto s_copy = s;
541
+ cute::get<1,0>(s_copy) = stride_krsc[0];
542
+ cute::for_each(cute::make_seq<2>{}, [&](auto i) {
543
+ cute::get<0,2-i>(s_copy) = stride_krsc[i+1];
544
+ });
545
+ return s_copy;
546
+ }
547
+
548
+ // Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0)
549
+ template <class IntT>
550
+ CUTLASS_HOST_DEVICE
551
+ cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>>
552
+ make_cute_packed_stride(
553
+ cute::Stride<cute::Stride<cute::Int<1>, IntT, IntT, IntT>, IntT, cute::Int<0>> s,
554
+ [[maybe_unused]] cute::array<int32_t, 5> shape_output,
555
+ cute::array<IntT, 5> stride_ktrsc,
556
+ conv::Operator ConvOp) {
557
+ static_assert(std::is_integral_v<IntT>,
558
+ "Stride must have an integral type so it can be set dynamically. Static strides not supported.");
559
+
560
+ assert(stride_ktrsc[4] == 1);
561
+ auto s_copy = s;
562
+ cute::get<1,0>(s_copy) = stride_ktrsc[0];
563
+ cute::for_each(cute::make_seq<3>{}, [&](auto i) {
564
+ cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1];
565
+ });
566
+ return s_copy;
567
+ }
568
+ /////////////////////////////////////////////////////////////////////////////////////////////////
569
+
570
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include <array>
35
+ #include <cassert>
36
+ #include <cmath>
37
+ #include <iostream>
38
+ #include <type_traits>
39
+
40
+ #include <cute/util/type_traits.hpp>
41
+ #include <cute/tensor.hpp>
42
+
43
+ #include <cute/numeric/numeric_types.hpp>
44
+ #include <cute/numeric/complex.hpp>
45
+
46
+ #include <cutlass/layout/layout.h>
47
+
48
+ // The computed infinity norm does not include
49
+ // any NaN column absolute-value sums.
50
+ struct matrix_inf_norm_result {
51
+ // Accumulate errors in double, as this is generally
52
+ // the highest precision that the examples use.
53
+ double inf_norm = 0.0;
54
+ bool found_nan = false;
55
+ };
56
+
57
+ // In theory, cute::Tensor<ViewEngine<T*>, T> could be treated as a view type,
58
+ // and thus passed by value (as std::span or std::string_view would be).
59
+ // However, generic cute::Tensor are more like containers
60
+ // and thus are best passed by reference or const reference.
61
+ template <typename EngineType, typename LayoutType>
62
+ matrix_inf_norm_result
63
+ matrix_inf_norm(cute::Tensor<EngineType, LayoutType> const& host_matrix)
64
+ {
65
+ using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
66
+ using element_type = typename EngineType::value_type;
67
+
68
+ error_type inf_norm = 0.0;
69
+ bool found_nan = false;
70
+
71
+ // Computing the infinity norm requires that we be able
72
+ // to treat the input as a matrix, with rows and columns.
73
+ const int64_t num_rows = cute::size<0>(host_matrix);
74
+ const int64_t num_cols = cute::size<1>(host_matrix);
75
+
76
+ auto abs_fn = [] (element_type A_ij) {
77
+ if constexpr (not std::is_unsigned_v<element_type>) {
78
+ using std::abs;
79
+ return abs(A_ij);
80
+ }
81
+ else {
82
+ return A_ij;
83
+ }
84
+ };
85
+
86
+ for (int64_t i = 0; i < num_rows; ++i) {
87
+ error_type row_abs_sum = 0.0;
88
+ for(int64_t j = 0; j < num_cols; ++j) {
89
+ row_abs_sum += abs_fn(host_matrix(i, j));
90
+ }
91
+ if (std::isnan(row_abs_sum)) {
92
+ found_nan = true;
93
+ }
94
+ else {
95
+ inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
96
+ }
97
+ }
98
+
99
+ return {inf_norm, found_nan};
100
+ }
101
+
102
+ // Infinity norm of (X - Y).
103
+ template <typename EngineType, typename LayoutType>
104
+ matrix_inf_norm_result
105
+ matrix_diff_inf_norm(cute::Tensor<EngineType, LayoutType> const& X,
106
+ cute::Tensor<EngineType, LayoutType> const& Y)
107
+ {
108
+ using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
109
+ using element_type = typename EngineType::value_type;
110
+
111
+ auto abs_fn = [] (element_type A_ij) {
112
+ if constexpr (not std::is_unsigned_v<element_type>) {
113
+ using std::abs;
114
+ return abs(A_ij);
115
+ }
116
+ else {
117
+ return A_ij;
118
+ }
119
+ };
120
+
121
+ assert(cute::size<0>(X) == cute::size<0>(Y));
122
+ assert(cute::size<1>(X) == cute::size<1>(Y));
123
+
124
+ // Computing the infinity norm requires that we be able
125
+ // to treat the input as a matrix, with rows and columns.
126
+ const int64_t num_rows = cute::size<0>(X);
127
+ const int64_t num_cols = cute::size<1>(X);
128
+
129
+ error_type inf_norm = 0.0;
130
+ bool found_nan = false;
131
+
132
+ for (int64_t i = 0; i < num_rows; ++i) {
133
+ error_type row_abs_sum = 0.0;
134
+ for (int64_t j = 0; j < num_cols; ++j) {
135
+ row_abs_sum += error_type(abs_fn(element_type(X(i,j)) -
136
+ element_type(Y(i,j))));
137
+ }
138
+ if (std::isnan(row_abs_sum)) {
139
+ found_nan = true;
140
+ }
141
+ else {
142
+ inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm;
143
+ }
144
+ }
145
+
146
+ return {inf_norm, found_nan};
147
+ }
148
+
149
+ template <typename EngineType_A, typename LayoutType_A,
150
+ typename EngineType_B, typename LayoutType_B,
151
+ typename EngineType_C, typename LayoutType_C,
152
+ typename EngineType_C_ref, typename LayoutType_C_ref>
153
+ auto
154
+ print_matrix_multiply_mollified_relative_error(
155
+ char const A_value_type_name[],
156
+ cute::Tensor<EngineType_A, LayoutType_A> const& A,
157
+ char const B_value_type_name[],
158
+ cute::Tensor<EngineType_B, LayoutType_B> const& B,
159
+ char const C_value_type_name[],
160
+ cute::Tensor<EngineType_C, LayoutType_C> const& C,
161
+ cute::Tensor<EngineType_C_ref, LayoutType_C_ref> const& C_ref)
162
+ {
163
+ const auto [A_norm, A_has_nan] = matrix_inf_norm(A);
164
+ const auto [B_norm, B_has_nan] = matrix_inf_norm(B);
165
+ const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref);
166
+ const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref);
167
+
168
+ const auto A_norm_times_B_norm = A_norm * B_norm;
169
+ const auto relative_error = A_norm_times_B_norm == 0.0 ?
170
+ diff_norm : (diff_norm / A_norm_times_B_norm);
171
+
172
+ // For expected error bounds, please refer to the LAPACK Users' Guide,
173
+ // in particular https://netlib.org/lapack/lug/node108.html .
174
+ // Printing the infinity norm of C is a way to check
175
+ // that both the function being tested (C)
176
+ // and the reference implementation (C_ref)
177
+ // don't just do nothing (or fill with zeros).
178
+ using std::cout;
179
+ using cute::shape;
180
+ cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n'
181
+ << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n'
182
+ << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n'
183
+ << std::scientific
184
+ << "Infinity norm of A: " << A_norm << '\n'
185
+ << "Infinity norm of B: " << B_norm << '\n'
186
+ << "Infinity norm of C: " << C_norm << '\n'
187
+ << "Infinity norm of (C - C_ref): " << diff_norm << '\n';
188
+
189
+ if(A_norm_times_B_norm == 0.0) {
190
+ cout << "Mollified relative error: " << relative_error << '\n';
191
+ } else {
192
+ cout << "Relative error: " << relative_error << '\n';
193
+ }
194
+
195
+ if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) {
196
+ cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
197
+ << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
198
+ << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n'
199
+ << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n';
200
+ }
201
+ return relative_error;
202
+ }
203
+
204
+ template <typename EngineType, typename LayoutType>
205
+ auto
206
+ print_matrix_multiply_mollified_relative_error(
207
+ const char value_type_name[],
208
+ const cute::Tensor<EngineType, LayoutType>& A,
209
+ const cute::Tensor<EngineType, LayoutType>& B,
210
+ const cute::Tensor<EngineType, LayoutType>& C_computed,
211
+ const cute::Tensor<EngineType, LayoutType>& C_expected)
212
+ {
213
+ return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B,
214
+ value_type_name, C_computed, C_expected);
215
+ }
216
+
217
+ // Take a CUTLASS HostTensor (or the like) as input,
218
+ // and return a const CuTe Tensor.
219
+ // This is useful for use with the above error printing functions.
220
+ // This implicitly "transposes" if the layout is RowMajor.
221
+ // Note that the HostTensor must be captured by nonconst reference
222
+ // in order for X.host_ref().data() to compile.
223
+ // (CUTLASS is a bit more container-y than CuTe.)
224
+ template<class CutlassHostTensorType>
225
+ auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X)
226
+ {
227
+ // The tensors were created with post-transposed extents.
228
+ const auto extents = X.extent();
229
+ const auto shape = cute::Shape<int, int>{extents[0], extents[1]};
230
+ // Both RowMajor and ColumnMajor only store one stride.
231
+ const int LDX = X.stride(0);
232
+ const auto strides = [&]() {
233
+ using input_layout_type = typename std::decay_t<decltype(X)>::Layout;
234
+ if constexpr (std::is_same_v<input_layout_type, cutlass::layout::ColumnMajor>) {
235
+ return cute::Stride<int, int>{1, LDX};
236
+ }
237
+ else {
238
+ static_assert(std::is_same_v<input_layout_type, cutlass::layout::RowMajor>);
239
+ return cute::Stride<int, int>{LDX, 1};
240
+ }
241
+ }();
242
+ const auto layout = cute::make_layout(shape, strides);
243
+ auto X_data = X.host_ref().data();
244
+ auto X_data_const = const_cast<std::add_const_t< decltype(X_data)> >(X_data);
245
+ return cute::make_tensor(X_data_const, layout);
246
+ };
247
+
248
+
249
+ // Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE.
250
+ // This makes the return value suitable as the return value of main().
251
+ template <typename T1, typename T2>
252
+ int
253
+ print_relative_error(
254
+ std::size_t n,
255
+ T1 const& data,
256
+ T2 const& reference,
257
+ bool print_verbose = false,
258
+ bool print_error = true,
259
+ double error_margin = 0.00001) {
260
+ using std::abs; using std::sqrt;
261
+
262
+ // Use either double or complex<double> for error computation
263
+ using value_type = cute::remove_cvref_t<decltype(reference[0])>;
264
+ using error_type = std::conditional_t<cute::is_complex<value_type>::value,
265
+ cute::complex<double>,
266
+ double>;
267
+
268
+ if (print_verbose) {
269
+ std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl;
270
+ }
271
+
272
+ double eps = 1e-200;
273
+
274
+ double tot_error_sq = 0;
275
+ double tot_norm_sq = 0;
276
+ double tot_ind_rel_err = 0;
277
+ double max_ind_rel_err = 0;
278
+ double max_diff = 0;
279
+ for (std::size_t i = 0; i < n; ++i) {
280
+ error_type val = data[i];
281
+ error_type ref = reference[i];
282
+
283
+ double aref = abs(ref);
284
+ double diff = abs(ref - val);
285
+ double rel_error = diff / (aref + eps);
286
+
287
+ // Individual relative error
288
+ tot_ind_rel_err += rel_error;
289
+
290
+ // Maximum relative error
291
+ max_ind_rel_err = std::max(max_ind_rel_err, rel_error);
292
+
293
+ // Maximum delta in value error
294
+ max_diff = std::max(max_diff, diff);
295
+
296
+ // Total relative error
297
+ tot_error_sq += diff * diff;
298
+ tot_norm_sq += aref * aref;
299
+
300
+ if (print_verbose) {
301
+ std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl;
302
+ }
303
+ }
304
+
305
+ double ave_rel_err = tot_ind_rel_err / double(n);
306
+ if (print_error) {
307
+ printf("Average relative error: %.3e\n", ave_rel_err);
308
+ }
309
+
310
+ if (print_error) {
311
+ printf("Maximum relative error: %.3e\n", max_ind_rel_err);
312
+ }
313
+
314
+ if (print_error) {
315
+ printf("Maximum difference : %.3e\n", max_diff);
316
+ }
317
+
318
+ double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps));
319
+ if (print_error) {
320
+ printf("Vector relative error: %.3e\n", tot_rel_err);
321
+ }
322
+
323
+ printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq));
324
+
325
+ return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE;
326
+ }
327
+
328
+ // Overload for cute::Tensor<>
329
+ template <class Engine, class Layout>
330
+ int
331
+ print_relative_error(
332
+ cute::Tensor<Engine, Layout> data,
333
+ cute::Tensor<Engine, Layout> reference,
334
+ bool print_verbose = false,
335
+ bool print_error = true,
336
+ double error_margin = 0.00001) {
337
+ assert(size(data) == size(reference));
338
+ return print_relative_error(static_cast<std::size_t>(size(data)),
339
+ data, reference,
340
+ print_verbose, print_error, error_margin);
341
+ }
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for GEMM in host-side code.
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "cutlass/array.h"
38
+
39
+ namespace cutlass {
40
+ namespace reference {
41
+ namespace detail {
42
+
43
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ /// Template function to compute an inner product.
46
+ #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a
47
+ // host-only type
48
+ template <typename Atype, typename Btype, typename Ctype>
49
+ CUTLASS_HOST_DEVICE
50
+ Ctype inner_product(Atype a, Btype b, Ctype c) {
51
+ return Ctype(a) * Ctype(b) + c;
52
+ }
53
+
54
+ /// Specialization for matrix multiplication with binary operands
55
+ template <>
56
+ CUTLASS_HOST_DEVICE
57
+ int inner_product<Array<bin1_t, 32>, Array<bin1_t, 32>, int>(
58
+ Array<bin1_t, 32> a,
59
+ Array<bin1_t, 32> b,
60
+ int c) {
61
+
62
+ int accum = 0;
63
+ for (int bit = 0; bit < 32; bit++) {
64
+ accum += a[bit] ^ b[bit];
65
+ }
66
+ return accum + c;
67
+ }
68
+
69
+ /*
70
+ /// Specialization for matrix multiplication with signed 4-bit integer operands
71
+ template <>
72
+ CUTLASS_HOST_DEVICE
73
+ int inner_product<Array<int4b_t, 8>, Array<int4b_t, 8>, int>(
74
+ Array<int4b_t, 8> a,
75
+ Array<int4b_t, 8> b,
76
+ int c) {
77
+
78
+ int accum = 0;
79
+ for (int k = 0; k < 8; k++) {
80
+ accum += a[k] * b[k];
81
+ }
82
+ return accum + c;
83
+ }
84
+
85
+ /// Specialization for matrix multiplication with unsigned 4-bit integer operands
86
+ template <>
87
+ CUTLASS_HOST_DEVICE
88
+ int inner_product<Array<uint4b_t, 8>, Array<uint4b_t, 8>, int>(
89
+ Array<uint4b_t, 8> a,
90
+ Array<uint4b_t, 8> b,
91
+ int c) {
92
+
93
+ int accum = 0;
94
+ for (int k = 0; k < 8; k++) {
95
+ accum += a[k] * b[k];
96
+ }
97
+ return accum + c;
98
+ }
99
+ */
100
+
101
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
102
+
103
+ template <typename SrcType, typename DstType>
104
+ struct Cast {
105
+ // Default behavior: convert to the destination type
106
+ #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
107
+ // host-only type
108
+ CUTLASS_HOST_DEVICE
109
+ static DstType apply(SrcType src) { return static_cast<DstType>(src); };
110
+ };
111
+
112
+ template <>
113
+ struct Cast<float, int8_t> {
114
+ CUTLASS_HOST_DEVICE
115
+ static int8_t apply(float src) {
116
+ // Clamp to the range of signed 8-bit integers.
117
+ return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
118
+ };
119
+ };
120
+
121
+ template <>
122
+ struct Cast<float, uint8_t> {
123
+ CUTLASS_HOST_DEVICE
124
+ static uint8_t apply(float src) {
125
+ // Clamp to the range of signed 8-bit integers.
126
+ return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
127
+ };
128
+ };
129
+
130
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
131
+
132
+ } // namespace detail
133
+ } // namespace reference
134
+ } // namespace cutlass
135
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for GEMM in host-side code.
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "cutlass/coord.h"
38
+
39
+ /////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ namespace cutlass {
42
+ namespace reference {
43
+ namespace detail {
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ template <int Rank, int Index>
48
+ struct LinearToCoordinateHelper {
49
+
50
+ CUTLASS_HOST_DEVICE
51
+ void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
52
+
53
+ int64_t prod = 1;
54
+
55
+ CUTLASS_PRAGMA_UNROLL
56
+ for (int i = Rank - Index; i < Rank; ++i) {
57
+ prod *= int64_t(extent[i]);
58
+ }
59
+
60
+ coord[Rank - Index - 1] = int(idx / prod);
61
+
62
+ int64_t residual = idx % prod;
63
+ LinearToCoordinateHelper<Rank, Index - 1>()(coord, residual, extent);
64
+ }
65
+ };
66
+
67
+ template <int Rank>
68
+ struct LinearToCoordinateHelper<Rank, 0> {
69
+
70
+ CUTLASS_HOST_DEVICE
71
+ void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &) const {
72
+ coord[Rank - 1] = int(idx);
73
+ }
74
+ };
75
+
76
+ /////////////////////////////////////////////////////////////////////////////////////////////////
77
+
78
+ template <int Rank>
79
+ struct LinearToCoordinate {
80
+
81
+ CUTLASS_HOST_DEVICE
82
+ void operator()(Coord<Rank> &coord, int64_t idx, Coord<Rank> const &extent) const {
83
+ LinearToCoordinateHelper<Rank, Rank - 1>()(coord, idx, extent);
84
+ }
85
+ };
86
+
87
+ /////////////////////////////////////////////////////////////////////////////////////////////////
88
+
89
+ } // namespace detail
90
+ } // namespace reference
91
+ } // namespace cutlass
92
+
93
+ /////////////////////////////////////////////////////////////////////////////////////////////////
94
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h ADDED
@@ -0,0 +1,1549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Reference implementation for convolution in device-side code.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/coord.h"
39
+ #include "cutlass/functional.h"
40
+ #include "cutlass/layout/tensor.h"
41
+ #include "cutlass/matrix_shape.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/numeric_types.h"
44
+ #include "cutlass/tensor_ref.h"
45
+ #include "cutlass/conv/convolution.h"
46
+ #include "cutlass/conv/conv2d_problem_size.h"
47
+ #include "cutlass/conv/conv3d_problem_size.h"
48
+
49
+ namespace cutlass {
50
+ namespace reference {
51
+ namespace device {
52
+
53
+ /////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ namespace kernel {
56
+
57
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
58
+ /// Conv2d device reference kernel
59
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ // Conv2d Fprop kernel - y = fprop(x, w)
62
+ template <
63
+ typename ElementA,
64
+ typename LayoutA,
65
+ typename ElementB,
66
+ typename LayoutB,
67
+ typename ElementC,
68
+ typename LayoutC,
69
+ typename ElementCompute,
70
+ typename ElementAccumulator = ElementCompute,
71
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
72
+ typename InnerProductOp = multiply_add<ElementAccumulator>,
73
+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
74
+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
75
+ int kCtaShapeM = 16, // shape of a threadblock in units of threads
76
+ int kCtaShapeN = 8 // shape of a threadblock in units of threads
77
+ >
78
+ __global__ void Conv2dFprop(
79
+ conv::Conv2dProblemSize problem_size,
80
+ TensorRef<ElementA, LayoutA> tensor_x,
81
+ TensorRef<ElementB, LayoutB> tensor_w,
82
+ TensorRef<ElementC, LayoutC> tensor_y_in,
83
+ TensorRef<ElementC, LayoutC> tensor_y_out,
84
+ ElementCompute alpha,
85
+ ElementCompute beta
86
+ ) {
87
+
88
+ ConvertOp convert_op;
89
+ InnerProductOp inner_product_op;
90
+
91
+ ElementAccumulator element_A[kThreadM];
92
+ ElementAccumulator element_B[kThreadN];
93
+ ElementAccumulator accum[kThreadM][kThreadN];
94
+
95
+ int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
96
+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
97
+
98
+ int thread_n[kThreadM];
99
+ int thread_p[kThreadM];
100
+ int thread_q[kThreadM];
101
+
102
+ // Compute N, P, Q coordinates for each row of a thread's tile
103
+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
104
+
105
+ CUTLASS_PRAGMA_UNROLL
106
+ for (int m = 0; m < kThreadM; ++m) {
107
+
108
+ int64_t npq = npq_start + m;
109
+
110
+ thread_n[m] = int(npq / PQ);
111
+
112
+ int64_t residual = npq % PQ;
113
+ thread_p[m] = int(residual / problem_size.Q);
114
+ thread_q[m] = int(residual % problem_size.Q);
115
+ }
116
+
117
+ // Clear accumulators
118
+ CUTLASS_PRAGMA_UNROLL
119
+ for (int m = 0; m < kThreadM; ++m) {
120
+ CUTLASS_PRAGMA_UNROLL
121
+ for (int n = 0; n < kThreadN; ++n) {
122
+ accum[m][n] = ElementAccumulator();
123
+ }
124
+ }
125
+
126
+ int c_per_group = problem_size.C / problem_size.groups;
127
+ int k_per_group = problem_size.K / problem_size.groups;
128
+
129
+ // Compute convolution
130
+ for (int R = 0; R < problem_size.R; ++R) {
131
+ for (int S = 0; S < problem_size.S; ++S) {
132
+ for (int C = 0; C < problem_size.C; ++C) {
133
+
134
+ // Get group id of currnet channel
135
+ int c_group_idx = C / c_per_group;
136
+
137
+ // Load from activations tensor
138
+ int filter_r = R;
139
+ int filter_s = S;
140
+
141
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
142
+ filter_r = problem_size.R - 1 - R;
143
+ filter_s = problem_size.S - 1 - S;
144
+ }
145
+
146
+ CUTLASS_PRAGMA_UNROLL
147
+ for (int m = 0; m < kThreadM; ++m) {
148
+ int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
149
+ int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
150
+
151
+ if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
152
+ element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C}));
153
+ }
154
+ else {
155
+ element_A[m] = ElementAccumulator();
156
+ }
157
+ }
158
+
159
+ // Load from filters tensor
160
+ CUTLASS_PRAGMA_UNROLL
161
+ for (int n = 0; n < kThreadN; ++n) {
162
+ int thread_k = k_start + n;
163
+ int k_group_idx = thread_k / k_per_group;
164
+
165
+ if (thread_k < problem_size.K && k_group_idx == c_group_idx) {
166
+ element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group}));
167
+ }
168
+ else {
169
+ element_B[n] = ElementAccumulator();
170
+ }
171
+ }
172
+
173
+ // Accumulate matrix product
174
+ CUTLASS_PRAGMA_UNROLL
175
+ for (int m = 0; m < kThreadM; ++m) {
176
+ CUTLASS_PRAGMA_UNROLL
177
+ for (int n = 0; n < kThreadN; ++n) {
178
+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
179
+ }
180
+ }
181
+ }
182
+ }
183
+ }
184
+
185
+ // Write out the results
186
+ CUTLASS_PRAGMA_UNROLL
187
+ for (int m = 0; m < kThreadM; ++m) {
188
+ if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
189
+ CUTLASS_PRAGMA_UNROLL
190
+ for (int n = 0; n < kThreadN; ++n) {
191
+ int thread_k = k_start + n;
192
+ if (thread_k < problem_size.K) {
193
+
194
+ ElementCompute c_ref = ElementCompute();
195
+ if (beta != ElementCompute()) {
196
+ c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
197
+ }
198
+
199
+ tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
200
+ alpha * ElementCompute(accum[m][n]) + beta * c_ref);
201
+ }
202
+ }
203
+ }
204
+ }
205
+ }
206
+
207
+ // Conv3d Fprop kernel - y = fprop(x, w)
208
+ template <
209
+ typename ElementA,
210
+ typename LayoutA,
211
+ typename ElementB,
212
+ typename LayoutB,
213
+ typename ElementC,
214
+ typename LayoutC,
215
+ typename ElementCompute,
216
+ typename ElementAccumulator = ElementCompute,
217
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
218
+ typename InnerProductOp = multiply_add<ElementAccumulator>,
219
+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
220
+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
221
+ int kCtaShapeM = 16, // shape of a threadblock in units of threads
222
+ int kCtaShapeN = 8 // shape of a threadblock in units of threads
223
+ >
224
+ __global__ void Conv3dFprop(
225
+ conv::Conv3dProblemSize problem_size,
226
+ TensorRef<ElementA, LayoutA> tensor_x,
227
+ TensorRef<ElementB, LayoutB> tensor_w,
228
+ TensorRef<ElementC, LayoutC> tensor_y_in,
229
+ TensorRef<ElementC, LayoutC> tensor_y_out,
230
+ ElementCompute alpha,
231
+ ElementCompute beta
232
+ ) {
233
+
234
+ ConvertOp convert_op;
235
+ InnerProductOp inner_product_op;
236
+
237
+ ElementAccumulator element_A[kThreadM];
238
+ ElementAccumulator element_B[kThreadN];
239
+ ElementAccumulator accum[kThreadM][kThreadN];
240
+
241
+ int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
242
+ int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
243
+
244
+ int thread_n[kThreadM];
245
+ int thread_z[kThreadM];
246
+ int thread_p[kThreadM];
247
+ int thread_q[kThreadM];
248
+
249
+ // Compute N, Z, P, Q coordinates for each row of a thread's tile
250
+ int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
251
+ int64_t ZPQ = PQ * problem_size.Z;
252
+
253
+ CUTLASS_PRAGMA_UNROLL
254
+ for (int m = 0; m < kThreadM; ++m) {
255
+
256
+ int64_t nzpq = nzpq_start + m;
257
+
258
+ thread_n[m] = int(nzpq / ZPQ);
259
+
260
+ int64_t residual = nzpq % ZPQ;
261
+ thread_z[m] = int(residual / PQ);
262
+
263
+ residual = residual % PQ;
264
+ thread_p[m] = int(residual / problem_size.Q);
265
+ thread_q[m] = int(residual % problem_size.Q);
266
+ }
267
+
268
+ // Clear accumulators
269
+ CUTLASS_PRAGMA_UNROLL
270
+ for (int m = 0; m < kThreadM; ++m) {
271
+ CUTLASS_PRAGMA_UNROLL
272
+ for (int n = 0; n < kThreadN; ++n) {
273
+ accum[m][n] = ElementAccumulator();
274
+ }
275
+ }
276
+
277
+ // Compute convolution
278
+ for (int T = 0; T < problem_size.T; ++T) {
279
+ for (int R = 0; R < problem_size.R; ++R) {
280
+ for (int S = 0; S < problem_size.S; ++S) {
281
+ for (int C = 0; C < problem_size.C; ++C) {
282
+
283
+ // Load from activations tensor
284
+ int filter_t = T;
285
+ int filter_r = R;
286
+ int filter_s = S;
287
+
288
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
289
+ filter_t = problem_size.T - 1 - T;
290
+ filter_r = problem_size.R - 1 - R;
291
+ filter_s = problem_size.S - 1 - S;
292
+ }
293
+
294
+ CUTLASS_PRAGMA_UNROLL
295
+ for (int m = 0; m < kThreadM; ++m) {
296
+ int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
297
+ int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
298
+ int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
299
+
300
+ if (thread_n[m] < problem_size.N &&
301
+ d >= 0 && d < problem_size.D &&
302
+ h >= 0 && h < problem_size.H &&
303
+ w >= 0 && w < problem_size.W) {
304
+
305
+ element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C}));
306
+ }
307
+ else {
308
+ element_A[m] = ElementAccumulator();
309
+ }
310
+ }
311
+
312
+ // Load from filters tensor
313
+ CUTLASS_PRAGMA_UNROLL
314
+ for (int n = 0; n < kThreadN; ++n) {
315
+ int thread_k = k_start + n;
316
+
317
+ if (thread_k < problem_size.K) {
318
+ element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C}));
319
+ }
320
+ else {
321
+ element_B[n] = ElementAccumulator();
322
+ }
323
+ }
324
+
325
+ // Accumulate matrix product
326
+ CUTLASS_PRAGMA_UNROLL
327
+ for (int m = 0; m < kThreadM; ++m) {
328
+ CUTLASS_PRAGMA_UNROLL
329
+ for (int n = 0; n < kThreadN; ++n) {
330
+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
331
+ }
332
+ }
333
+
334
+ } // for (C)
335
+ } // for (S)
336
+ } // for (R)
337
+ } // for (T)
338
+
339
+ // Write out the results
340
+ CUTLASS_PRAGMA_UNROLL
341
+ for (int m = 0; m < kThreadM; ++m) {
342
+
343
+ if (thread_n[m] < problem_size.N &&
344
+ thread_z[m] < problem_size.Z &&
345
+ thread_p[m] < problem_size.P &&
346
+ thread_q[m] < problem_size.Q) {
347
+
348
+ CUTLASS_PRAGMA_UNROLL
349
+ for (int n = 0; n < kThreadN; ++n) {
350
+ int thread_k = k_start + n;
351
+ if (thread_k < problem_size.K) {
352
+
353
+ ElementCompute c_ref = ElementCompute();
354
+ if (beta != ElementCompute()) {
355
+ c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}));
356
+ }
357
+
358
+ tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
359
+ alpha * ElementCompute(accum[m][n]) + beta * c_ref);
360
+ }
361
+ } // for (n)
362
+
363
+ }
364
+ } // for (m)
365
+ }
366
+
367
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
368
+
369
+ // Conv2d dgrad kernel - dx = dgrad(dy, w)
370
+ template <
371
+ typename ElementA,
372
+ typename LayoutA,
373
+ typename ElementB,
374
+ typename LayoutB,
375
+ typename ElementC,
376
+ typename LayoutC,
377
+ typename ElementCompute,
378
+ typename ElementAccumulator = ElementCompute,
379
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
380
+ typename InnerProductOp = multiply_add<ElementAccumulator>,
381
+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
382
+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
383
+ int kCtaShapeM = 16, // shape of a threadblock in units of threads
384
+ int kCtaShapeN = 8 // shape of a threadblock in units of threads
385
+ >
386
+ __global__ void Conv2dDgrad(
387
+ conv::Conv2dProblemSize problem_size,
388
+ TensorRef<ElementA, LayoutA> tensor_dy,
389
+ TensorRef<ElementB, LayoutB> tensor_w,
390
+ TensorRef<ElementC, LayoutC> tensor_dx_in,
391
+ TensorRef<ElementC, LayoutC> tensor_dx_out,
392
+ ElementCompute alpha,
393
+ ElementCompute beta
394
+ ) {
395
+
396
+ ConvertOp convert_op;
397
+ InnerProductOp inner_product_op;
398
+
399
+ ElementAccumulator element_A[kThreadM];
400
+ ElementAccumulator element_B[kThreadN];
401
+ ElementAccumulator accum[kThreadM][kThreadN];
402
+
403
+ int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
404
+ int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
405
+
406
+ int thread_n[kThreadM];
407
+ int thread_h[kThreadM];
408
+ int thread_w[kThreadM];
409
+
410
+ // Compute N, H, W coordinates for each row of a thread's tile
411
+ int64_t HW = int64_t(problem_size.H) * problem_size.W;
412
+
413
+ CUTLASS_PRAGMA_UNROLL
414
+ for (int m = 0; m < kThreadM; ++m) {
415
+
416
+ int64_t nhw = nhw_start + m;
417
+
418
+ thread_n[m] = int(nhw / HW);
419
+
420
+ int64_t residual = nhw % HW;
421
+ thread_h[m] = int(residual / problem_size.W);
422
+ thread_w[m] = int(residual % problem_size.W);
423
+ }
424
+
425
+ // Clear accumulators
426
+ CUTLASS_PRAGMA_UNROLL
427
+ for (int m = 0; m < kThreadM; ++m) {
428
+ CUTLASS_PRAGMA_UNROLL
429
+ for (int n = 0; n < kThreadN; ++n) {
430
+ accum[m][n] = ElementAccumulator();
431
+ }
432
+ }
433
+
434
+ // Compute convolution
435
+ for (int R = 0; R < problem_size.R; ++R) {
436
+ for (int S = 0; S < problem_size.S; ++S) {
437
+ for (int K = 0; K < problem_size.K; ++K) {
438
+
439
+ // Load from activations tensor
440
+ int filter_r = R;
441
+ int filter_s = S;
442
+
443
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
444
+ filter_r = problem_size.R - 1 - R;
445
+ filter_s = problem_size.S - 1 - S;
446
+ }
447
+
448
+ CUTLASS_PRAGMA_UNROLL
449
+ for (int m = 0; m < kThreadM; ++m) {
450
+
451
+ int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
452
+ int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
453
+
454
+ element_A[m] = ElementAccumulator();
455
+
456
+ if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) {
457
+
458
+ p = p / problem_size.stride_h;
459
+ q = q / problem_size.stride_w;
460
+
461
+ if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) {
462
+ element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K}));
463
+ }
464
+ }
465
+ }
466
+
467
+ // Load from filters tensor
468
+ CUTLASS_PRAGMA_UNROLL
469
+ for (int n = 0; n < kThreadN; ++n) {
470
+ int thread_c = c_start + n;
471
+
472
+ if (thread_c < problem_size.C) {
473
+ element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c}));
474
+ }
475
+ else {
476
+ element_B[n] = ElementAccumulator();
477
+ }
478
+ }
479
+
480
+ // Accumulate matrix product
481
+ CUTLASS_PRAGMA_UNROLL
482
+ for (int m = 0; m < kThreadM; ++m) {
483
+ CUTLASS_PRAGMA_UNROLL
484
+ for (int n = 0; n < kThreadN; ++n) {
485
+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
486
+ }
487
+ }
488
+ }
489
+ }
490
+ }
491
+
492
+ // Write out the results
493
+ CUTLASS_PRAGMA_UNROLL
494
+ for (int m = 0; m < kThreadM; ++m) {
495
+
496
+ if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) {
497
+
498
+ CUTLASS_PRAGMA_UNROLL
499
+ for (int n = 0; n < kThreadN; ++n) {
500
+ int thread_c = c_start + n;
501
+ if (thread_c < problem_size.C) {
502
+
503
+ ElementCompute c_ref = ElementCompute();
504
+ if (beta != ElementCompute()) {
505
+ c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c}));
506
+ }
507
+
508
+ tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
509
+ alpha * ElementCompute(accum[m][n]) + beta * c_ref);
510
+ }
511
+ }
512
+ }
513
+ }
514
+ }
515
+
516
+ // Conv3d dgrad kernel - dx = dgrad(dy, w)
517
+ template <
518
+ typename ElementA,
519
+ typename LayoutA,
520
+ typename ElementB,
521
+ typename LayoutB,
522
+ typename ElementC,
523
+ typename LayoutC,
524
+ typename ElementCompute,
525
+ typename ElementAccumulator = ElementCompute,
526
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
527
+ typename InnerProductOp = multiply_add<ElementAccumulator>,
528
+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
529
+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
530
+ int kCtaShapeM = 16, // shape of a threadblock in units of threads
531
+ int kCtaShapeN = 8 // shape of a threadblock in units of threads
532
+ >
533
+ __global__ void Conv3dDgrad(
534
+ conv::Conv3dProblemSize problem_size,
535
+ TensorRef<ElementA, LayoutA> tensor_dy,
536
+ TensorRef<ElementB, LayoutB> tensor_w,
537
+ TensorRef<ElementC, LayoutC> tensor_dx_in,
538
+ TensorRef<ElementC, LayoutC> tensor_dx_out,
539
+ ElementCompute alpha,
540
+ ElementCompute beta
541
+ ) {
542
+
543
+ ConvertOp convert_op;
544
+ InnerProductOp inner_product_op;
545
+
546
+ ElementAccumulator element_A[kThreadM];
547
+ ElementAccumulator element_B[kThreadN];
548
+ ElementAccumulator accum[kThreadM][kThreadN];
549
+
550
+ int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
551
+ int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
552
+
553
+ int thread_n[kThreadM];
554
+ int thread_d[kThreadM];
555
+ int thread_h[kThreadM];
556
+ int thread_w[kThreadM];
557
+
558
+ // Compute N, H, W coordinates for each row of a thread's tile
559
+ int64_t HW = int64_t(problem_size.H) * problem_size.W;
560
+ int64_t DHW = HW * problem_size.D;
561
+
562
+ CUTLASS_PRAGMA_UNROLL
563
+ for (int m = 0; m < kThreadM; ++m) {
564
+
565
+ int64_t ndhw = ndhw_start + m;
566
+
567
+ thread_n[m] = int(ndhw / DHW);
568
+
569
+ int64_t residual = ndhw % DHW;
570
+ thread_d[m] = int(residual / HW);
571
+
572
+ residual = residual % HW;
573
+ thread_h[m] = int(residual / problem_size.W);
574
+ thread_w[m] = int(residual % problem_size.W);
575
+ }
576
+
577
+ // Clear accumulators
578
+ CUTLASS_PRAGMA_UNROLL
579
+ for (int m = 0; m < kThreadM; ++m) {
580
+ CUTLASS_PRAGMA_UNROLL
581
+ for (int n = 0; n < kThreadN; ++n) {
582
+ accum[m][n] = ElementAccumulator();
583
+ }
584
+ }
585
+
586
+ // Compute convolution
587
+ for (int T = 0; T < problem_size.T; ++T) {
588
+ for (int R = 0; R < problem_size.R; ++R) {
589
+ for (int S = 0; S < problem_size.S; ++S) {
590
+ for (int K = 0; K < problem_size.K; ++K) {
591
+
592
+ // Load from activations tensor
593
+ int filter_t = T;
594
+ int filter_r = R;
595
+ int filter_s = S;
596
+
597
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
598
+ filter_t = problem_size.T - 1 - T;
599
+ filter_r = problem_size.R - 1 - R;
600
+ filter_s = problem_size.S - 1 - S;
601
+ }
602
+
603
+ CUTLASS_PRAGMA_UNROLL
604
+ for (int m = 0; m < kThreadM; ++m) {
605
+
606
+ int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d;
607
+ int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h;
608
+ int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w;
609
+
610
+ element_A[m] = ElementAccumulator();
611
+
612
+ if (z >= 0 && !(z % problem_size.stride_d) &&
613
+ p >= 0 && !(p % problem_size.stride_h) &&
614
+ q >= 0 && !(q % problem_size.stride_w)) {
615
+
616
+ z = z / problem_size.stride_d;
617
+ p = p / problem_size.stride_h;
618
+ q = q / problem_size.stride_w;
619
+
620
+ if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
621
+ element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K}));
622
+ }
623
+ }
624
+ }
625
+
626
+ // Load from filters tensor
627
+ CUTLASS_PRAGMA_UNROLL
628
+ for (int n = 0; n < kThreadN; ++n) {
629
+ int thread_c = c_start + n;
630
+
631
+ if (thread_c < problem_size.C) {
632
+ element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c}));
633
+ }
634
+ else {
635
+ element_B[n] = ElementAccumulator();
636
+ }
637
+ }
638
+
639
+ // Accumulate matrix product
640
+ CUTLASS_PRAGMA_UNROLL
641
+ for (int m = 0; m < kThreadM; ++m) {
642
+ CUTLASS_PRAGMA_UNROLL
643
+ for (int n = 0; n < kThreadN; ++n) {
644
+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
645
+ }
646
+ }
647
+
648
+ } // for (C)
649
+ } // for (S)
650
+ } // for (R)
651
+ } // for (T)
652
+
653
+ // Write out the results
654
+ CUTLASS_PRAGMA_UNROLL
655
+ for (int m = 0; m < kThreadM; ++m) {
656
+
657
+ if (thread_n[m] < problem_size.N &&
658
+ thread_d[m] < problem_size.D &&
659
+ thread_h[m] < problem_size.H &&
660
+ thread_w[m] < problem_size.W) {
661
+
662
+ CUTLASS_PRAGMA_UNROLL
663
+ for (int n = 0; n < kThreadN; ++n) {
664
+ int thread_c = c_start + n;
665
+ if (thread_c < problem_size.C) {
666
+
667
+ ElementCompute c_ref = ElementCompute();
668
+ if (beta != ElementCompute()) {
669
+ c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}));
670
+ }
671
+
672
+ tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op(
673
+ alpha * ElementCompute(accum[m][n]) + beta * c_ref);
674
+ }
675
+ }
676
+ }
677
+ }
678
+ }
679
+
680
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
681
+
682
+ // Conv2d wgrad kernel - dw = wgrad(dy, x)
683
+ template <
684
+ typename ElementA,
685
+ typename LayoutA,
686
+ typename ElementB,
687
+ typename LayoutB,
688
+ typename ElementC,
689
+ typename LayoutC,
690
+ typename ElementCompute,
691
+ typename ElementAccumulator = ElementCompute,
692
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
693
+ typename InnerProductOp = multiply_add<ElementAccumulator>,
694
+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
695
+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
696
+ int kCtaShapeM = 8, // shape of a threadblock in units of threads
697
+ int kCtaShapeN = 16 // shape of a threadblock in units of threads
698
+ >
699
+ __global__ void Conv2dWgrad(
700
+ conv::Conv2dProblemSize problem_size,
701
+ TensorRef<ElementA, LayoutA> tensor_dy,
702
+ TensorRef<ElementB, LayoutB> tensor_x,
703
+ TensorRef<ElementC, LayoutC> tensor_dw_in,
704
+ TensorRef<ElementC, LayoutC> tensor_dw_out,
705
+ ElementCompute alpha,
706
+ ElementCompute beta
707
+ ) {
708
+
709
+ ConvertOp convert_op;
710
+ InnerProductOp inner_product_op;
711
+
712
+ ElementAccumulator element_A[kThreadM];
713
+ ElementAccumulator element_B[kThreadN];
714
+ ElementAccumulator accum[kThreadM][kThreadN];
715
+
716
+ int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
717
+ int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
718
+
719
+ int thread_r[kThreadN];
720
+ int thread_s[kThreadN];
721
+ int thread_c[kThreadN];
722
+
723
+ // Compute R, S, C coordinates for each row of a thread's tile
724
+ int64_t SC = int64_t(problem_size.S) * problem_size.C;
725
+
726
+ CUTLASS_PRAGMA_UNROLL
727
+ for (int n = 0; n < kThreadN; ++n) {
728
+
729
+ int64_t rsc = rsc_start + n;
730
+ int64_t residual = rsc % SC;
731
+
732
+ thread_r[n] = int(rsc / SC);
733
+ thread_s[n] = int(residual / problem_size.C);
734
+ thread_c[n] = int(residual % problem_size.C);
735
+ }
736
+
737
+ // Clear accumulators
738
+ CUTLASS_PRAGMA_UNROLL
739
+ for (int m = 0; m < kThreadM; ++m) {
740
+ CUTLASS_PRAGMA_UNROLL
741
+ for (int n = 0; n < kThreadN; ++n) {
742
+ accum[m][n] = ElementAccumulator();
743
+ }
744
+ }
745
+
746
+ // Compute convolution
747
+ for (int N = 0; N < problem_size.N; ++N) {
748
+ for (int P = 0; P < problem_size.P; ++P) {
749
+ for (int Q = 0; Q < problem_size.Q; ++Q) {
750
+
751
+ CUTLASS_PRAGMA_UNROLL
752
+ for (int m = 0; m < kThreadM; ++m) {
753
+ int thread_k = k_start + m;
754
+
755
+ element_A[m] = ElementAccumulator();
756
+
757
+ if (thread_k < problem_size.K) {
758
+ element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k}));
759
+ }
760
+ }
761
+
762
+ // Load from filters tensor
763
+ CUTLASS_PRAGMA_UNROLL
764
+ for (int n = 0; n < kThreadN; ++n) {
765
+
766
+ // Load from activations tensor
767
+ int filter_r = thread_r[n];
768
+ int filter_s = thread_s[n];
769
+
770
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
771
+ filter_r = problem_size.R - 1 - filter_r;
772
+ filter_s = problem_size.S - 1 - filter_s;
773
+ }
774
+
775
+ int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
776
+ int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
777
+
778
+ element_B[n] = ElementAccumulator();
779
+
780
+ if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) {
781
+ element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]}));
782
+ }
783
+ }
784
+
785
+ // Accumulate matrix product
786
+ CUTLASS_PRAGMA_UNROLL
787
+ for (int m = 0; m < kThreadM; ++m) {
788
+ CUTLASS_PRAGMA_UNROLL
789
+ for (int n = 0; n < kThreadN; ++n) {
790
+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
791
+ }
792
+ }
793
+ }
794
+ }
795
+ }
796
+
797
+ // Write out the results
798
+ CUTLASS_PRAGMA_UNROLL
799
+ for (int m = 0; m < kThreadM; ++m) {
800
+ int thread_k = k_start + m;
801
+
802
+ if (thread_k < problem_size.K) {
803
+
804
+ CUTLASS_PRAGMA_UNROLL
805
+ for (int n = 0; n < kThreadN; ++n) {
806
+
807
+ if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) {
808
+
809
+ ElementCompute c_ref = ElementCompute();
810
+
811
+ if (beta != ElementCompute()) {
812
+ c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}));
813
+ }
814
+
815
+ tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
816
+ alpha * ElementCompute(accum[m][n]) + beta * c_ref);
817
+ }
818
+ }
819
+ }
820
+ }
821
+ }
822
+
823
+ // Conv3d wgrad kernel - dw = wgrad(dy, x)
824
+ template <
825
+ typename ElementA,
826
+ typename LayoutA,
827
+ typename ElementB,
828
+ typename LayoutB,
829
+ typename ElementC,
830
+ typename LayoutC,
831
+ typename ElementCompute,
832
+ typename ElementAccumulator = ElementCompute,
833
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
834
+ typename InnerProductOp = multiply_add<ElementAccumulator>,
835
+ int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension
836
+ int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
837
+ int kCtaShapeM = 8, // shape of a threadblock in units of threads
838
+ int kCtaShapeN = 16 // shape of a threadblock in units of threads
839
+ >
840
+ __global__ void Conv3dWgrad(
841
+ conv::Conv3dProblemSize problem_size,
842
+ TensorRef<ElementA, LayoutA> tensor_dy,
843
+ TensorRef<ElementB, LayoutB> tensor_x,
844
+ TensorRef<ElementC, LayoutC> tensor_dw_in,
845
+ TensorRef<ElementC, LayoutC> tensor_dw_out,
846
+ ElementCompute alpha,
847
+ ElementCompute beta
848
+ ) {
849
+
850
+ ConvertOp convert_op;
851
+ InnerProductOp inner_product_op;
852
+
853
+ ElementAccumulator element_A[kThreadM];
854
+ ElementAccumulator element_B[kThreadN];
855
+ ElementAccumulator accum[kThreadM][kThreadN];
856
+
857
+ int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
858
+ int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
859
+
860
+ int thread_t[kThreadN];
861
+ int thread_r[kThreadN];
862
+ int thread_s[kThreadN];
863
+ int thread_c[kThreadN];
864
+
865
+ // Compute R, S, C coordinates for each row of a thread's tile
866
+ int64_t SC = int64_t(problem_size.S) * problem_size.C;
867
+ int64_t RSC = SC * problem_size.R;
868
+
869
+ CUTLASS_PRAGMA_UNROLL
870
+ for (int n = 0; n < kThreadN; ++n) {
871
+
872
+ int64_t trsc = trsc_start + n;
873
+
874
+ thread_t[n] = int(trsc / RSC);
875
+
876
+ int64_t residual = trsc % RSC;
877
+ thread_r[n] = int(residual / SC);
878
+
879
+ residual = residual % SC;
880
+ thread_s[n] = int(residual / problem_size.C);
881
+ thread_c[n] = int(residual % problem_size.C);
882
+ }
883
+
884
+ // Clear accumulators
885
+ CUTLASS_PRAGMA_UNROLL
886
+ for (int m = 0; m < kThreadM; ++m) {
887
+ CUTLASS_PRAGMA_UNROLL
888
+ for (int n = 0; n < kThreadN; ++n) {
889
+ accum[m][n] = ElementAccumulator();
890
+ }
891
+ }
892
+
893
+ // Compute convolution
894
+ for (int N = 0; N < problem_size.N; ++N) {
895
+ for (int Z = 0; Z < problem_size.Z; ++Z) {
896
+ for (int P = 0; P < problem_size.P; ++P) {
897
+ for (int Q = 0; Q < problem_size.Q; ++Q) {
898
+
899
+ CUTLASS_PRAGMA_UNROLL
900
+ for (int m = 0; m < kThreadM; ++m) {
901
+ int thread_k = k_start + m;
902
+
903
+ element_A[m] = ElementAccumulator();
904
+
905
+ if (thread_k < problem_size.K) {
906
+ element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k}));
907
+ }
908
+ }
909
+
910
+ // Load from filters tensor
911
+ CUTLASS_PRAGMA_UNROLL
912
+ for (int n = 0; n < kThreadN; ++n) {
913
+
914
+ // Load from activations tensor
915
+ int filter_t = thread_t[n];
916
+ int filter_r = thread_r[n];
917
+ int filter_s = thread_s[n];
918
+
919
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
920
+ filter_t = problem_size.T - 1 - filter_t;
921
+ filter_r = problem_size.R - 1 - filter_r;
922
+ filter_s = problem_size.S - 1 - filter_s;
923
+ }
924
+
925
+ int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
926
+ int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
927
+ int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
928
+
929
+ element_B[n] = ElementAccumulator();
930
+
931
+ if (d >= 0 && d < problem_size.D &&
932
+ h >= 0 && h < problem_size.H &&
933
+ w >= 0 && w < problem_size.W &&
934
+ thread_c[n] < problem_size.C) {
935
+
936
+ element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]}));
937
+ }
938
+ }
939
+
940
+ // Accumulate matrix product
941
+ CUTLASS_PRAGMA_UNROLL
942
+ for (int m = 0; m < kThreadM; ++m) {
943
+ CUTLASS_PRAGMA_UNROLL
944
+ for (int n = 0; n < kThreadN; ++n) {
945
+ accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]);
946
+ }
947
+ }
948
+
949
+ } // for (Q)
950
+ } // for (P)
951
+ } // for (Z)
952
+ } // for (N)
953
+
954
+ // Write out the results
955
+ CUTLASS_PRAGMA_UNROLL
956
+ for (int m = 0; m < kThreadM; ++m) {
957
+ int thread_k = k_start + m;
958
+
959
+ if (thread_k < problem_size.K) {
960
+
961
+ CUTLASS_PRAGMA_UNROLL
962
+ for (int n = 0; n < kThreadN; ++n) {
963
+
964
+ if (thread_t[n] < problem_size.T &&
965
+ thread_r[n] < problem_size.R &&
966
+ thread_s[n] < problem_size.S &&
967
+ thread_c[n] < problem_size.C) {
968
+
969
+ ElementCompute c_ref = ElementCompute();
970
+
971
+ if (beta != ElementCompute()) {
972
+ c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}));
973
+ }
974
+
975
+ tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op(
976
+ alpha * ElementCompute(accum[m][n]) + beta * c_ref);
977
+ }
978
+ }
979
+ }
980
+ }
981
+ }
982
+
983
+ /////////////////////////////////////////////////////////////////////////////////////////////////
984
+
985
+ } // namespace kernel
986
+
987
+ /////////////////////////////////////////////////////////////////////////////////////////////////
988
+
989
+ /// Conv2d Fprop dispatcher - y = fprop(x, w)
990
+ template <
991
+ typename ElementA,
992
+ typename LayoutA,
993
+ typename ElementB,
994
+ typename LayoutB,
995
+ typename ElementC,
996
+ typename LayoutC,
997
+ typename ElementCompute,
998
+ typename ElementAccumulator = ElementCompute,
999
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1000
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1001
+ >
1002
+ Status Conv2dFprop(
1003
+ conv::Conv2dProblemSize problem_size,
1004
+ TensorRef<ElementA, LayoutA> tensor_x,
1005
+ TensorRef<ElementB, LayoutB> tensor_w,
1006
+ TensorRef<ElementC, LayoutC> tensor_y_in,
1007
+ TensorRef<ElementC, LayoutC> tensor_y_out,
1008
+ ElementCompute alpha,
1009
+ ElementCompute beta,
1010
+ cudaStream_t stream = nullptr) {
1011
+
1012
+ //
1013
+ // Blocking factors improve performance of reference implementation
1014
+ //
1015
+
1016
+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
1017
+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1018
+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1019
+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1020
+
1021
+ int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
1022
+ int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1023
+
1024
+ dim3 block(kCtaShapeM, kCtaShapeN);
1025
+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1026
+
1027
+ kernel::Conv2dFprop<
1028
+ ElementA,
1029
+ LayoutA,
1030
+ ElementB,
1031
+ LayoutB,
1032
+ ElementC,
1033
+ LayoutC,
1034
+ ElementCompute,
1035
+ ElementAccumulator,
1036
+ ConvertOp,
1037
+ InnerProductOp,
1038
+ kThreadM,
1039
+ kThreadN,
1040
+ kCtaShapeM,
1041
+ kCtaShapeN
1042
+ ><<< grid, block, 0, stream >>>(
1043
+ problem_size,
1044
+ tensor_x,
1045
+ tensor_w,
1046
+ tensor_y_in,
1047
+ tensor_y_out,
1048
+ alpha,
1049
+ beta
1050
+ );
1051
+
1052
+ cudaError_t result = cudaPeekAtLastError();
1053
+ if (result != cudaSuccess) {
1054
+ return Status::kErrorInternal;
1055
+ }
1056
+
1057
+ return Status::kSuccess;
1058
+ }
1059
+
1060
+ /// Conv3d Fprop dispatcher - y = fprop(x, w)
1061
+ template <
1062
+ typename ElementA,
1063
+ typename LayoutA,
1064
+ typename ElementB,
1065
+ typename LayoutB,
1066
+ typename ElementC,
1067
+ typename LayoutC,
1068
+ typename ElementCompute,
1069
+ typename ElementAccumulator = ElementCompute,
1070
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1071
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1072
+ >
1073
+ Status Conv3dFprop(
1074
+ conv::Conv3dProblemSize problem_size,
1075
+ TensorRef<ElementA, LayoutA> tensor_x,
1076
+ TensorRef<ElementB, LayoutB> tensor_w,
1077
+ TensorRef<ElementC, LayoutC> tensor_y_in,
1078
+ TensorRef<ElementC, LayoutC> tensor_y_out,
1079
+ ElementCompute alpha,
1080
+ ElementCompute beta,
1081
+ cudaStream_t stream = nullptr) {
1082
+
1083
+ //
1084
+ // Blocking factors improve performance of reference implementation
1085
+ //
1086
+
1087
+ int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
1088
+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1089
+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1090
+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1091
+
1092
+ int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q;
1093
+ int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1094
+
1095
+ dim3 block(kCtaShapeM, kCtaShapeN);
1096
+ dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1097
+
1098
+ kernel::Conv3dFprop<
1099
+ ElementA,
1100
+ LayoutA,
1101
+ ElementB,
1102
+ LayoutB,
1103
+ ElementC,
1104
+ LayoutC,
1105
+ ElementCompute,
1106
+ ElementAccumulator,
1107
+ ConvertOp,
1108
+ InnerProductOp,
1109
+ kThreadM,
1110
+ kThreadN,
1111
+ kCtaShapeM,
1112
+ kCtaShapeN
1113
+ ><<< grid, block, 0, stream >>>(
1114
+ problem_size,
1115
+ tensor_x,
1116
+ tensor_w,
1117
+ tensor_y_in,
1118
+ tensor_y_out,
1119
+ alpha,
1120
+ beta
1121
+ );
1122
+
1123
+ cudaError_t result = cudaPeekAtLastError();
1124
+ if (result != cudaSuccess) {
1125
+ return Status::kErrorInternal;
1126
+ }
1127
+
1128
+ return Status::kSuccess;
1129
+ }
1130
+
1131
+ /// Conv2d Dgrad dispatcher - dx = dgrad(dy, w)
1132
+ template <
1133
+ typename ElementA,
1134
+ typename LayoutA,
1135
+ typename ElementB,
1136
+ typename LayoutB,
1137
+ typename ElementC,
1138
+ typename LayoutC,
1139
+ typename ElementCompute,
1140
+ typename ElementAccumulator = ElementCompute,
1141
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1142
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1143
+ >
1144
+ Status Conv2dDgrad(
1145
+ conv::Conv2dProblemSize problem_size,
1146
+ TensorRef<ElementA, LayoutA> tensor_dy,
1147
+ TensorRef<ElementB, LayoutB> tensor_w,
1148
+ TensorRef<ElementC, LayoutC> tensor_dx_in,
1149
+ TensorRef<ElementC, LayoutC> tensor_dx_out,
1150
+ ElementCompute alpha,
1151
+ ElementCompute beta,
1152
+ cudaStream_t stream = nullptr) {
1153
+
1154
+ //
1155
+ // Blocking factors improve performance of reference implementation
1156
+ //
1157
+
1158
+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1159
+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1160
+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1161
+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1162
+
1163
+ int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W;
1164
+ int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1165
+
1166
+ dim3 block(kCtaShapeM, kCtaShapeN);
1167
+ dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1168
+
1169
+ kernel::Conv2dDgrad<
1170
+ ElementA,
1171
+ LayoutA,
1172
+ ElementB,
1173
+ LayoutB,
1174
+ ElementC,
1175
+ LayoutC,
1176
+ ElementCompute,
1177
+ ElementAccumulator,
1178
+ ConvertOp,
1179
+ InnerProductOp,
1180
+ kThreadM,
1181
+ kThreadN,
1182
+ kCtaShapeM,
1183
+ kCtaShapeN
1184
+ ><<< grid, block, 0, stream >>>(
1185
+ problem_size,
1186
+ tensor_dy,
1187
+ tensor_w,
1188
+ tensor_dx_in,
1189
+ tensor_dx_out,
1190
+ alpha,
1191
+ beta
1192
+ );
1193
+
1194
+ cudaError_t result = cudaPeekAtLastError();
1195
+ if (result != cudaSuccess) {
1196
+ return Status::kErrorInternal;
1197
+ }
1198
+
1199
+ return Status::kSuccess;
1200
+ }
1201
+
1202
+ /// Conv3d Dgrad dispatcher - dx = dgrad(dy, w)
1203
+ template <
1204
+ typename ElementA,
1205
+ typename LayoutA,
1206
+ typename ElementB,
1207
+ typename LayoutB,
1208
+ typename ElementC,
1209
+ typename LayoutC,
1210
+ typename ElementCompute,
1211
+ typename ElementAccumulator = ElementCompute,
1212
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1213
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1214
+ >
1215
+ Status Conv3dDgrad(
1216
+ conv::Conv3dProblemSize problem_size,
1217
+ TensorRef<ElementA, LayoutA> tensor_dy,
1218
+ TensorRef<ElementB, LayoutB> tensor_w,
1219
+ TensorRef<ElementC, LayoutC> tensor_dx_in,
1220
+ TensorRef<ElementC, LayoutC> tensor_dx_out,
1221
+ ElementCompute alpha,
1222
+ ElementCompute beta,
1223
+ cudaStream_t stream = nullptr) {
1224
+
1225
+ //
1226
+ // Blocking factors improve performance of reference implementation
1227
+ //
1228
+
1229
+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1230
+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1231
+ int const kCtaShapeM = 16; // shape of a threadblock in units of threads
1232
+ int const kCtaShapeN = 8; // shape of a threadblock in units of threads
1233
+
1234
+ int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W;
1235
+ int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
1236
+
1237
+ dim3 block(kCtaShapeM, kCtaShapeN);
1238
+ dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
1239
+
1240
+ kernel::Conv3dDgrad<
1241
+ ElementA,
1242
+ LayoutA,
1243
+ ElementB,
1244
+ LayoutB,
1245
+ ElementC,
1246
+ LayoutC,
1247
+ ElementCompute,
1248
+ ElementAccumulator,
1249
+ ConvertOp,
1250
+ InnerProductOp,
1251
+ kThreadM,
1252
+ kThreadN,
1253
+ kCtaShapeM,
1254
+ kCtaShapeN
1255
+ ><<< grid, block, 0, stream >>>(
1256
+ problem_size,
1257
+ tensor_dy,
1258
+ tensor_w,
1259
+ tensor_dx_in,
1260
+ tensor_dx_out,
1261
+ alpha,
1262
+ beta
1263
+ );
1264
+
1265
+ cudaError_t result = cudaPeekAtLastError();
1266
+ if (result != cudaSuccess) {
1267
+ return Status::kErrorInternal;
1268
+ }
1269
+
1270
+ return Status::kSuccess;
1271
+ }
1272
+
1273
+ /// Conv2d Wgrad dispatcher - dw = wgrad(dy, x)
1274
+ template <
1275
+ typename ElementA,
1276
+ typename LayoutA,
1277
+ typename ElementB,
1278
+ typename LayoutB,
1279
+ typename ElementC,
1280
+ typename LayoutC,
1281
+ typename ElementCompute,
1282
+ typename ElementAccumulator = ElementCompute,
1283
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1284
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1285
+ >
1286
+ Status Conv2dWgrad(
1287
+ conv::Conv2dProblemSize problem_size,
1288
+ TensorRef<ElementA, LayoutA> tensor_dy,
1289
+ TensorRef<ElementB, LayoutB> tensor_x,
1290
+ TensorRef<ElementC, LayoutC> tensor_dw_in,
1291
+ TensorRef<ElementC, LayoutC> tensor_dw_out,
1292
+ ElementCompute alpha,
1293
+ ElementCompute beta,
1294
+ cudaStream_t stream = nullptr) {
1295
+
1296
+ //
1297
+ // Blocking factors improve performance of reference implementation
1298
+ //
1299
+
1300
+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1301
+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1302
+ int const kCtaShapeM = 8; // shape of a threadblock in units of threads
1303
+ int const kCtaShapeN = 16; // shape of a threadblock in units of threads
1304
+
1305
+ int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C;
1306
+ int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
1307
+
1308
+ dim3 block(kCtaShapeM, kCtaShapeN);
1309
+ dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
1310
+
1311
+ kernel::Conv2dWgrad<
1312
+ ElementA,
1313
+ LayoutA,
1314
+ ElementB,
1315
+ LayoutB,
1316
+ ElementC,
1317
+ LayoutC,
1318
+ ElementCompute,
1319
+ ElementAccumulator,
1320
+ ConvertOp,
1321
+ InnerProductOp,
1322
+ kThreadM,
1323
+ kThreadN,
1324
+ kCtaShapeM,
1325
+ kCtaShapeN
1326
+ ><<< grid, block, 0, stream >>>(
1327
+ problem_size,
1328
+ tensor_dy,
1329
+ tensor_x,
1330
+ tensor_dw_in,
1331
+ tensor_dw_out,
1332
+ alpha,
1333
+ beta
1334
+ );
1335
+
1336
+ cudaError_t result = cudaPeekAtLastError();
1337
+ if (result != cudaSuccess) {
1338
+ return Status::kErrorInternal;
1339
+ }
1340
+
1341
+ return Status::kSuccess;
1342
+ }
1343
+
1344
+ /// Conv3d Wgrad dispatcher - dw = wgrad(dy, x)
1345
+ template <
1346
+ typename ElementA,
1347
+ typename LayoutA,
1348
+ typename ElementB,
1349
+ typename LayoutB,
1350
+ typename ElementC,
1351
+ typename LayoutC,
1352
+ typename ElementCompute,
1353
+ typename ElementAccumulator = ElementCompute,
1354
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1355
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1356
+ >
1357
+ Status Conv3dWgrad(
1358
+ conv::Conv3dProblemSize problem_size,
1359
+ TensorRef<ElementA, LayoutA> tensor_dy,
1360
+ TensorRef<ElementB, LayoutB> tensor_x,
1361
+ TensorRef<ElementC, LayoutC> tensor_dw_in,
1362
+ TensorRef<ElementC, LayoutC> tensor_dw_out,
1363
+ ElementCompute alpha,
1364
+ ElementCompute beta,
1365
+ cudaStream_t stream = nullptr) {
1366
+
1367
+ //
1368
+ // Blocking factors improve performance of reference implementation
1369
+ //
1370
+
1371
+ int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension
1372
+ int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
1373
+ int const kCtaShapeM = 8; // shape of a threadblock in units of threads
1374
+ int const kCtaShapeN = 16; // shape of a threadblock in units of threads
1375
+
1376
+ int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C;
1377
+ int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);
1378
+
1379
+ dim3 block(kCtaShapeM, kCtaShapeN);
1380
+ dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n));
1381
+
1382
+ kernel::Conv3dWgrad<
1383
+ ElementA,
1384
+ LayoutA,
1385
+ ElementB,
1386
+ LayoutB,
1387
+ ElementC,
1388
+ LayoutC,
1389
+ ElementCompute,
1390
+ ElementAccumulator,
1391
+ ConvertOp,
1392
+ InnerProductOp,
1393
+ kThreadM,
1394
+ kThreadN,
1395
+ kCtaShapeM,
1396
+ kCtaShapeN
1397
+ ><<< grid, block, 0, stream >>>(
1398
+ problem_size,
1399
+ tensor_dy,
1400
+ tensor_x,
1401
+ tensor_dw_in,
1402
+ tensor_dw_out,
1403
+ alpha,
1404
+ beta
1405
+ );
1406
+
1407
+ cudaError_t result = cudaPeekAtLastError();
1408
+ if (result != cudaSuccess) {
1409
+ return Status::kErrorInternal;
1410
+ }
1411
+
1412
+ return Status::kSuccess;
1413
+ }
1414
+
1415
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1416
+
1417
+ /// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
1418
+ template <
1419
+ typename ElementA,
1420
+ typename LayoutA,
1421
+ typename ElementB,
1422
+ typename LayoutB,
1423
+ typename ElementC,
1424
+ typename LayoutC,
1425
+ typename ElementCompute,
1426
+ typename ElementAccumulator = ElementCompute,
1427
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1428
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1429
+ >
1430
+ Status Conv2d(
1431
+ conv::Operator convolutional_operator,
1432
+ conv::Conv2dProblemSize problem_size,
1433
+ TensorRef<ElementA, LayoutA> tensor_A,
1434
+ TensorRef<ElementB, LayoutB> tensor_B,
1435
+ TensorRef<ElementC, LayoutC> tensor_C,
1436
+ TensorRef<ElementC, LayoutC> tensor_D,
1437
+ ElementCompute alpha,
1438
+ ElementCompute beta,
1439
+ cudaStream_t stream = nullptr) {
1440
+
1441
+ switch (convolutional_operator) {
1442
+ case conv::Operator::kFprop:
1443
+ return Conv2dFprop<
1444
+ ElementA, LayoutA,
1445
+ ElementB, LayoutB,
1446
+ ElementC, LayoutC,
1447
+ ElementCompute,
1448
+ ElementAccumulator,
1449
+ ConvertOp, InnerProductOp
1450
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1451
+ break;
1452
+
1453
+ case conv::Operator::kDgrad:
1454
+ return Conv2dDgrad<
1455
+ ElementA, LayoutA,
1456
+ ElementB, LayoutB,
1457
+ ElementC, LayoutC,
1458
+ ElementCompute,
1459
+ ElementAccumulator,
1460
+ ConvertOp, InnerProductOp
1461
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1462
+ break;
1463
+
1464
+ case conv::Operator::kWgrad:
1465
+ return Conv2dWgrad<
1466
+ ElementA, LayoutA,
1467
+ ElementB, LayoutB,
1468
+ ElementC, LayoutC,
1469
+ ElementCompute,
1470
+ ElementAccumulator,
1471
+ ConvertOp, InnerProductOp
1472
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1473
+ break;
1474
+
1475
+ default: break;
1476
+ }
1477
+
1478
+ return Status::kErrorNotSupported;
1479
+ }
1480
+
1481
+ /// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad.
1482
+ template <
1483
+ typename ElementA,
1484
+ typename LayoutA,
1485
+ typename ElementB,
1486
+ typename LayoutB,
1487
+ typename ElementC,
1488
+ typename LayoutC,
1489
+ typename ElementCompute,
1490
+ typename ElementAccumulator = ElementCompute,
1491
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
1492
+ typename InnerProductOp = multiply_add<ElementAccumulator>
1493
+ >
1494
+ Status Conv3d(
1495
+ conv::Operator convolutional_operator,
1496
+ conv::Conv3dProblemSize problem_size,
1497
+ TensorRef<ElementA, LayoutA> tensor_A,
1498
+ TensorRef<ElementB, LayoutB> tensor_B,
1499
+ TensorRef<ElementC, LayoutC> tensor_C,
1500
+ TensorRef<ElementC, LayoutC> tensor_D,
1501
+ ElementCompute alpha,
1502
+ ElementCompute beta,
1503
+ cudaStream_t stream = nullptr) {
1504
+
1505
+ switch (convolutional_operator) {
1506
+ case conv::Operator::kFprop:
1507
+ return Conv3dFprop<
1508
+ ElementA, LayoutA,
1509
+ ElementB, LayoutB,
1510
+ ElementC, LayoutC,
1511
+ ElementCompute,
1512
+ ElementAccumulator,
1513
+ ConvertOp, InnerProductOp
1514
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1515
+
1516
+ case conv::Operator::kDgrad:
1517
+ return Conv3dDgrad<
1518
+ ElementA, LayoutA,
1519
+ ElementB, LayoutB,
1520
+ ElementC, LayoutC,
1521
+ ElementCompute,
1522
+ ElementAccumulator,
1523
+ ConvertOp, InnerProductOp
1524
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1525
+
1526
+ case conv::Operator::kWgrad:
1527
+ return Conv3dWgrad<
1528
+ ElementA, LayoutA,
1529
+ ElementB, LayoutB,
1530
+ ElementC, LayoutC,
1531
+ ElementCompute,
1532
+ ElementAccumulator,
1533
+ ConvertOp, InnerProductOp
1534
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
1535
+
1536
+ default: break;
1537
+ }
1538
+
1539
+ return Status::kErrorNotSupported;
1540
+ }
1541
+
1542
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1543
+
1544
+ } // namespace device
1545
+ } // namespace reference
1546
+ } // namespace cutlass
1547
+
1548
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1549
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for GEMM in device-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/functional.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+
43
+ #include "cutlass/tensor_view.h"
44
+ #include "cutlass/gemm/gemm.h"
45
+
46
+ #include "cutlass/util/reference/device/kernel/gemm.h"
47
+
48
+ namespace cutlass {
49
+ namespace reference {
50
+ namespace device {
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
55
+ /// objects.
56
+ ///
57
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
58
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
59
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
60
+ /// arguments explicitly.
61
+ template <
62
+ typename ElementA,
63
+ typename LayoutA,
64
+ typename ElementB,
65
+ typename LayoutB,
66
+ typename ElementC,
67
+ typename LayoutC,
68
+ typename ScalarType,
69
+ typename AccumulatorType,
70
+ typename InnerProductOp = multiply_add<AccumulatorType>,
71
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
72
+ >
73
+ void compute_gemm(
74
+ gemm::GemmCoord problem_size,
75
+ ScalarType alpha,
76
+ TensorRef<ElementA, LayoutA> tensor_a,
77
+ TensorRef<ElementB, LayoutB> tensor_b,
78
+ ScalarType beta,
79
+ TensorRef<ElementC, LayoutC> tensor_c,
80
+ TensorRef<ElementC, LayoutC> tensor_d,
81
+ AccumulatorType initial_accum) {
82
+
83
+ static_assert(
84
+ LayoutA::kRank == 2 &&
85
+ LayoutB::kRank == 2 &&
86
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
87
+
88
+ // Blocking structure potentially improves performance of reference implementation
89
+ // with a minor increase in complexity.
90
+ //
91
+ // Note, this reference implementation is NOT expected to approach peak performance.
92
+ using OutputTile = MatrixShape<4, 4>;
93
+
94
+ dim3 block(16, 8);
95
+
96
+ dim3 grid(
97
+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
98
+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
99
+ );
100
+
101
+ // Launch a GEMM kernel
102
+ kernel::Gemm<
103
+ TensorRef<ElementA, LayoutA>,
104
+ TensorRef<ElementB, LayoutB>,
105
+ TensorRef<ElementC, LayoutC>,
106
+ ScalarType,
107
+ AccumulatorType,
108
+ OutputTile,
109
+ InnerProductOp,
110
+ ConvertOp
111
+ ><<< grid, block >>>(
112
+ problem_size,
113
+ alpha,
114
+ tensor_a,
115
+ tensor_b,
116
+ beta,
117
+ tensor_c,
118
+ tensor_d,
119
+ initial_accum
120
+ );
121
+ }
122
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
123
+
124
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
125
+ /// objects.
126
+ ///
127
+ /// This assumes the accumulator type is the same type as the scalars.
128
+ template <
129
+ typename ElementA,
130
+ typename LayoutA,
131
+ typename ElementB,
132
+ typename LayoutB,
133
+ typename ElementC,
134
+ typename LayoutC,
135
+ typename ScalarType,
136
+ typename AccumulatorType,
137
+ typename InnerProductOp = multiply_add<AccumulatorType>,
138
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
139
+ >
140
+ void compute_gemm(
141
+ gemm::GemmCoord problem_size,
142
+ ScalarType alpha,
143
+ TensorRef<ElementA, LayoutA> tensor_a,
144
+ TensorRef<ElementB, LayoutB> tensor_b,
145
+ ScalarType beta,
146
+ TensorRef<ElementC, LayoutC> tensor_c,
147
+ AccumulatorType initial_accum) {
148
+
149
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
150
+ ScalarType, AccumulatorType, InnerProductOp, ConvertOp>(
151
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
152
+ initial_accum);
153
+ }
154
+
155
+ template <
156
+ typename ElementA,
157
+ typename LayoutA,
158
+ typename ElementB,
159
+ typename LayoutB,
160
+ typename ElementC,
161
+ typename LayoutC,
162
+ typename ScalarType,
163
+ typename AccumulatorType,
164
+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd
165
+ >
166
+ struct Gemm;
167
+
168
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
169
+
170
+ /// Partial specialization for multiply-add
171
+ template <typename ElementA, typename LayoutA, typename ElementB,
172
+ typename LayoutB, typename ElementC, typename LayoutC,
173
+ typename ScalarType, typename AccumulatorType>
174
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
175
+ ScalarType, AccumulatorType, arch::OpMultiplyAdd> {
176
+
177
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
178
+ TensorRef<ElementA, LayoutA> tensor_a,
179
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
180
+ TensorRef<ElementC, LayoutC> tensor_c,
181
+ AccumulatorType initial_accum = AccumulatorType(0)) {
182
+
183
+ static_assert(
184
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
185
+ "Tensors must be of rank 2");
186
+
187
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
188
+ ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
189
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
190
+ }
191
+
192
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
193
+ TensorRef<ElementA, LayoutA> tensor_a,
194
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
195
+ TensorRef<ElementC, LayoutC> tensor_c,
196
+ TensorRef<ElementC, LayoutC> tensor_d,
197
+ AccumulatorType initial_accum = AccumulatorType(0)) {
198
+ static_assert(
199
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
200
+ "Tensors must be of rank 2");
201
+
202
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
203
+ ScalarType, AccumulatorType, multiply_add<AccumulatorType>>(
204
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
205
+ }
206
+ };
207
+
208
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
209
+
210
+ /// Partial specialization for multiply-add-saturate
211
+ template <typename ElementA, typename LayoutA, typename ElementB,
212
+ typename LayoutB, typename ElementC, typename LayoutC,
213
+ typename ScalarType, typename AccumulatorType>
214
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
215
+ AccumulatorType, arch::OpMultiplyAddSaturate> {
216
+
217
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
218
+ TensorRef<ElementA, LayoutA> tensor_a,
219
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
220
+ TensorRef<ElementC, LayoutC> tensor_c,
221
+ AccumulatorType initial_accum = AccumulatorType(0)) {
222
+ static_assert(
223
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
224
+ "Tensors must be of rank 2");
225
+
226
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
227
+ ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
228
+ NumericConverterClamp<ElementC, ScalarType>>(
229
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
230
+ }
231
+
232
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
233
+ TensorRef<ElementA, LayoutA> tensor_a,
234
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
235
+ TensorRef<ElementC, LayoutC> tensor_c,
236
+ TensorRef<ElementC, LayoutC> tensor_d,
237
+ AccumulatorType initial_accum = AccumulatorType(0)) {
238
+ static_assert(
239
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
240
+ "Tensors must be of rank 2");
241
+
242
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
243
+ ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
244
+ NumericConverterClamp<ElementC, ScalarType>>(
245
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
246
+ }
247
+ };
248
+
249
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
250
+
251
+ /// Partial specialization for XOR-popc
252
+ template <typename ElementA, typename LayoutA, typename ElementB,
253
+ typename LayoutB, typename ElementC, typename LayoutC,
254
+ typename ScalarType, typename AccumulatorType>
255
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
256
+ AccumulatorType, arch::OpXorPopc> {
257
+
258
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
259
+ TensorRef<ElementA, LayoutA> tensor_a,
260
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
261
+ TensorRef<ElementC, LayoutC> tensor_c,
262
+ AccumulatorType initial_accum = AccumulatorType(0)) {
263
+ static_assert(
264
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
265
+ "Tensors must be of rank 2");
266
+
267
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
268
+ ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
269
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
270
+ }
271
+
272
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
273
+ TensorRef<ElementA, LayoutA> tensor_a,
274
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
275
+ TensorRef<ElementC, LayoutC> tensor_c,
276
+ TensorRef<ElementC, LayoutC> tensor_d,
277
+ AccumulatorType initial_accum = AccumulatorType(0)) {
278
+ static_assert(
279
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
280
+ "Tensors must be of rank 2");
281
+
282
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
283
+ ScalarType, AccumulatorType, xor_add<AccumulatorType>>(
284
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
285
+ }
286
+ };
287
+
288
+
289
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
290
+ //
291
+ // Batched GEMM
292
+ //
293
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
294
+
295
+ /// Computes a batch of GEMMs over a set of matrices of common dimension.
296
+ //
297
+ // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
298
+ //
299
+ template <
300
+ typename TensorRefCollectionA,
301
+ typename TensorRefCollectionB,
302
+ typename TensorRefCollectionC,
303
+ typename ScalarType,
304
+ typename AccumulatorType,
305
+ typename InnerProductOp,
306
+ typename ConvertOp
307
+ >
308
+ void BatchedGemm(
309
+ gemm::GemmCoord problem_size,
310
+ int batch_count,
311
+ ScalarType alpha,
312
+ TensorRefCollectionA const& tensor_a,
313
+ TensorRefCollectionB const& tensor_b,
314
+ ScalarType beta,
315
+ TensorRefCollectionC &tensor_c,
316
+ AccumulatorType initial_accum) {
317
+
318
+ static_assert(
319
+ TensorRefCollectionA::kRank == 2 &&
320
+ TensorRefCollectionB::kRank == 2 &&
321
+ TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2");
322
+
323
+ // Blocking structure potentially improves performance of reference implementation
324
+ // with a minor increase in complexity.
325
+ //
326
+ // Note, this reference implementation is NOT expected to approach peak performance.
327
+ using OutputTile = MatrixShape<4, 4>;
328
+
329
+ dim3 block(16, 8);
330
+ dim3 grid(
331
+ (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
332
+ (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn),
333
+ batch_count
334
+ );
335
+
336
+ // Launch a GEMM kernel
337
+ kernel::BatchedGemm<
338
+ TensorRefCollectionA,
339
+ TensorRefCollectionB,
340
+ TensorRefCollectionC,
341
+ ScalarType,
342
+ AccumulatorType,
343
+ OutputTile,
344
+ InnerProductOp,
345
+ ConvertOp
346
+ ><<< grid, block >>>(
347
+ problem_size,
348
+ alpha,
349
+ tensor_a,
350
+ tensor_b,
351
+ beta,
352
+ tensor_c,
353
+ initial_accum
354
+ );
355
+ }
356
+
357
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
358
+ /// objects.
359
+ //
360
+ // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
361
+ //
362
+ template <
363
+ typename TensorRefCollectionA,
364
+ typename TensorRefCollectionB,
365
+ typename TensorRefCollectionC,
366
+ typename ScalarType,
367
+ typename AccumulatorType
368
+ >
369
+ void BatchedGemm(
370
+ gemm::GemmCoord problem_size,
371
+ int batch_count,
372
+ ScalarType alpha,
373
+ TensorRefCollectionA const& tensor_a,
374
+ TensorRefCollectionB const& tensor_b,
375
+ ScalarType beta,
376
+ TensorRefCollectionC &tensor_c) {
377
+
378
+ BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
379
+ }
380
+
381
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
382
+
383
+ } // namespace device
384
+ } // namespace reference
385
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued GEMM in device-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/functional.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+
43
+ #include "cutlass/tensor_view.h"
44
+ #include "cutlass/gemm/gemm.h"
45
+
46
+ namespace cutlass {
47
+ namespace reference {
48
+ namespace device {
49
+
50
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ namespace kernel {
53
+
54
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
55
+ /// objects.
56
+ ///
57
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
58
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
59
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
60
+ /// arguments explicitly.
61
+ template <
62
+ typename ElementA,
63
+ typename LayoutA,
64
+ typename ElementB,
65
+ typename LayoutB,
66
+ typename ElementC,
67
+ typename LayoutC,
68
+ typename ScalarType,
69
+ typename ComputeType,
70
+ typename ElementD = ElementC,
71
+ typename ConvertOp = NumericConverter<ElementD, ScalarType>,
72
+ typename InnerProductOp = multiply_add<ComputeType>,
73
+ int kMblock = 4,
74
+ int kNblock = 4
75
+ >
76
+ __global__ void GemmComplex(
77
+ gemm::GemmCoord problem_size,
78
+ ScalarType alpha,
79
+ TensorRef<ElementA, LayoutA> tensor_a,
80
+ ComplexTransform transform_a,
81
+ TensorRef<ElementB, LayoutB> tensor_b,
82
+ ComplexTransform transform_b,
83
+ ScalarType beta,
84
+ TensorRef<ElementC, LayoutC> tensor_c,
85
+ TensorRef<ElementD, LayoutC> tensor_d,
86
+ ComputeType initial_accum,
87
+ int batch_count = 1,
88
+ int64_t batch_stride_A = 0,
89
+ int64_t batch_stride_B = 0,
90
+ int64_t batch_stride_C = 0,
91
+ int64_t batch_stride_D = 0) {
92
+
93
+ static_assert(
94
+ LayoutA::kRank == 2 &&
95
+ LayoutB::kRank == 2 &&
96
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
97
+
98
+ int const M = problem_size.m();
99
+ int const N = problem_size.n();
100
+ int const K = problem_size.k();
101
+
102
+ ConvertOp convert_op;
103
+ InnerProductOp inner_product_op;
104
+
105
+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
106
+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
107
+ int batch_idx = blockIdx.z;
108
+
109
+ tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
110
+ tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
111
+ tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
112
+ tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
113
+
114
+ for (; batch_idx < batch_count; batch_idx += gridDim.z) {
115
+
116
+ // Compute matrix product using blocks
117
+ ComputeType accum[kMblock][kNblock];
118
+
119
+ CUTLASS_PRAGMA_UNROLL
120
+ for (int j = 0; j < kNblock; j++) {
121
+ CUTLASS_PRAGMA_UNROLL
122
+ for (int i = 0; i < kMblock; i++) {
123
+ accum[i][j] = initial_accum;
124
+ }
125
+ }
126
+
127
+ for (int k_block = 0; k_block < K; ++k_block) {
128
+ CUTLASS_PRAGMA_UNROLL
129
+ for (int j = 0; j < kNblock; j++) {
130
+ CUTLASS_PRAGMA_UNROLL
131
+ for (int i = 0; i < kMblock; i++) {
132
+ int row = row_block + i;
133
+ int col = col_block + j;
134
+
135
+ if (row < M && col < N) {
136
+ ElementA a = tensor_a.at(MatrixCoord(row, k_block));
137
+ ElementB b = tensor_b.at(MatrixCoord(k_block, col));
138
+
139
+ ComputeType a_ik = ComputeType(a);
140
+ ComputeType b_kj = ComputeType(b);
141
+
142
+ if (transform_a == ComplexTransform::kConjugate) {
143
+ a_ik = conj(a_ik);
144
+ }
145
+
146
+ if (transform_b == ComplexTransform::kConjugate) {
147
+ b_kj = conj(b_kj);
148
+ }
149
+
150
+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
151
+ }
152
+ }
153
+ }
154
+ }
155
+
156
+ CUTLASS_PRAGMA_UNROLL
157
+ for (int j = 0; j < kNblock; j++) {
158
+ CUTLASS_PRAGMA_UNROLL
159
+ for (int i = 0; i < kMblock; i++) {
160
+ int row = row_block + i;
161
+ int col = col_block + j;
162
+
163
+ MatrixCoord coord = MatrixCoord(row, col);
164
+
165
+ if (row < M && col < N) {
166
+
167
+ tensor_d.at(coord) = convert_op(
168
+ alpha * ScalarType(accum[i][j]) +
169
+ beta * ScalarType(tensor_c.at(coord)));
170
+ }
171
+ }
172
+ }
173
+
174
+ tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
175
+ tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
176
+ tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
177
+ tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
178
+
179
+ } // for (batch_idx)
180
+ }
181
+
182
+ } // namespace kernel
183
+
184
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
185
+
186
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
187
+ /// objects.
188
+ ///
189
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
190
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
191
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
192
+ /// arguments explicitly.
193
+ template <
194
+ typename ElementA,
195
+ typename LayoutA,
196
+ typename ElementB,
197
+ typename LayoutB,
198
+ typename ElementC,
199
+ typename LayoutC,
200
+ typename ScalarType,
201
+ typename ComputeType,
202
+ typename ElementD = ElementC,
203
+ typename ConvertOp = NumericConverter<ElementD, ScalarType>,
204
+ typename InnerProductOp = multiply_add<ComputeType>
205
+ >
206
+ void GemmComplex(
207
+ gemm::GemmCoord problem_size,
208
+ ScalarType alpha,
209
+ TensorRef<ElementA, LayoutA> tensor_a,
210
+ ComplexTransform transform_a,
211
+ TensorRef<ElementB, LayoutB> tensor_b,
212
+ ComplexTransform transform_b,
213
+ ScalarType beta,
214
+ TensorRef<ElementC, LayoutC> tensor_c,
215
+ TensorRef<ElementD, LayoutC> tensor_d,
216
+ ComputeType initial_accum,
217
+ int batch_count = 1,
218
+ int64_t batch_stride_A = 0,
219
+ int64_t batch_stride_B = 0,
220
+ int64_t batch_stride_C = 0,
221
+ int64_t batch_stride_D = 0) {
222
+
223
+ static_assert(
224
+ LayoutA::kRank == 2 &&
225
+ LayoutB::kRank == 2 &&
226
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
227
+
228
+ int const kMblock = 4;
229
+ int const kNblock = 4;
230
+
231
+ dim3 block(16, 8);
232
+ dim3 grid(
233
+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
234
+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
235
+ batch_count % std::numeric_limits<uint16_t>::max()
236
+ );
237
+
238
+ if (grid.y <= std::numeric_limits<uint16_t>::max()) {
239
+ kernel::GemmComplex<
240
+ ElementA,
241
+ LayoutA,
242
+ ElementB,
243
+ LayoutB,
244
+ ElementC,
245
+ LayoutC,
246
+ ScalarType,
247
+ ComputeType,
248
+ ElementD,
249
+ ConvertOp,
250
+ InnerProductOp,
251
+ kMblock,
252
+ kNblock
253
+ ><<< grid, block >>>(
254
+ problem_size,
255
+ alpha,
256
+ tensor_a,
257
+ transform_a,
258
+ tensor_b,
259
+ transform_b,
260
+ beta,
261
+ tensor_c,
262
+ tensor_d,
263
+ initial_accum,
264
+ batch_count,
265
+ batch_stride_A,
266
+ batch_stride_B,
267
+ batch_stride_C,
268
+ batch_stride_D
269
+ );
270
+ } else {
271
+ // Using bigger thread tile size
272
+ int const kBigMblock = 4;
273
+ int const kBigNblock = 16;
274
+
275
+ dim3 Bigblock(16, 8);
276
+ dim3 Biggrid(
277
+ (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock),
278
+ (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock),
279
+ batch_count % std::numeric_limits<uint16_t>::max()
280
+ );
281
+
282
+ kernel::GemmComplex<
283
+ ElementA,
284
+ LayoutA,
285
+ ElementB,
286
+ LayoutB,
287
+ ElementC,
288
+ LayoutC,
289
+ ScalarType,
290
+ ComputeType,
291
+ ElementD,
292
+ ConvertOp,
293
+ InnerProductOp,
294
+ kBigMblock,
295
+ kBigNblock
296
+ ><<< Biggrid, Bigblock >>>(
297
+ problem_size,
298
+ alpha,
299
+ tensor_a,
300
+ transform_a,
301
+ tensor_b,
302
+ transform_b,
303
+ beta,
304
+ tensor_c,
305
+ tensor_d,
306
+ initial_accum,
307
+ batch_count,
308
+ batch_stride_A,
309
+ batch_stride_B,
310
+ batch_stride_C,
311
+ batch_stride_D
312
+ );
313
+ }
314
+ }
315
+
316
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
317
+
318
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
319
+ /// objects.
320
+ ///
321
+ /// This assumes the accumulator type is the same type as the scalars.
322
+ template <
323
+ typename ElementA,
324
+ typename LayoutA,
325
+ typename ElementB,
326
+ typename LayoutB,
327
+ typename ElementC,
328
+ typename LayoutC,
329
+ typename ScalarType,
330
+ typename ElementD = ElementC
331
+ >
332
+ void GemmComplex(
333
+ gemm::GemmCoord problem_size,
334
+ ScalarType alpha,
335
+ TensorRef<ElementA, LayoutA> tensor_a,
336
+ ComplexTransform transform_a,
337
+ TensorRef<ElementB, LayoutB> tensor_b,
338
+ ComplexTransform transform_b,
339
+ ScalarType beta,
340
+ TensorRef<ElementC, LayoutC> tensor_c,
341
+ TensorRef<ElementD, LayoutC> tensor_d) {
342
+
343
+ GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
344
+ }
345
+
346
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
347
+
348
+ } // namespace device
349
+ } // namespace reference
350
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued GEMM in device code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/matrix_coord.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/tensor_ref_planar_complex.h"
44
+
45
+ #include "cutlass/tensor_view.h"
46
+ #include "cutlass/gemm/gemm.h"
47
+
48
+ namespace cutlass {
49
+ namespace reference {
50
+ namespace device {
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace kernel {
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ static int const kGemmPlanarComplexBlockSize = 4;
59
+
60
+ template <
61
+ typename ElementA,
62
+ typename LayoutA,
63
+ typename ElementB,
64
+ typename LayoutB,
65
+ typename ElementC,
66
+ typename LayoutC,
67
+ typename ScalarType,
68
+ typename ComputeType,
69
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>,
70
+ typename InnerProductOp = multiply_add<complex<ComputeType>>
71
+ >
72
+ __global__ void GemmPlanarComplex(
73
+ gemm::GemmCoord problem_size,
74
+ complex<ScalarType> alpha,
75
+ TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
76
+ ComplexTransform transform_a,
77
+ TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
78
+ ComplexTransform transform_b,
79
+ complex<ScalarType> beta,
80
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
81
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
82
+ complex<ComputeType> initial_accum) {
83
+
84
+ int const kMblock = kGemmPlanarComplexBlockSize;
85
+ int const kNblock = kGemmPlanarComplexBlockSize;
86
+
87
+ using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
88
+ using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
89
+ using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
90
+
91
+ // Note: batch is ignored.
92
+ int const M = problem_size.m();
93
+ int const N = problem_size.n();
94
+ int const K = problem_size.k();
95
+
96
+ ConvertOp convert_op;
97
+ InnerProductOp inner_product_op;
98
+
99
+ complex<ComputeType> accum[kMblock][kNblock];
100
+
101
+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
102
+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
103
+
104
+ CUTLASS_PRAGMA_UNROLL
105
+ for (int j = 0; j < kNblock; j++) {
106
+ CUTLASS_PRAGMA_UNROLL
107
+ for (int i = 0; i < kMblock; i++) {
108
+ accum[i][j] = initial_accum;
109
+ }
110
+ }
111
+
112
+ CUTLASS_PRAGMA_NO_UNROLL
113
+ for (int k_block = 0; k_block < K; ++k_block) {
114
+
115
+ CUTLASS_PRAGMA_UNROLL
116
+ for (int j = 0; j < kNblock; j++) {
117
+
118
+ CUTLASS_PRAGMA_UNROLL
119
+ for (int i = 0; i < kMblock; i++) {
120
+
121
+ int row = row_block + i;
122
+ int col = col_block + j;
123
+
124
+ if (row < M && col < N) {
125
+
126
+ ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
127
+ ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
128
+
129
+ complex<ComputeType> a = complex<ComputeType>{
130
+ ComputeType(a_ik.real()),
131
+ ComputeType(a_ik.imag())
132
+ };
133
+
134
+ complex<ComputeType> b = complex<ComputeType>{
135
+ ComputeType(b_kj.real()),
136
+ ComputeType(b_kj.imag())
137
+ };
138
+
139
+ if (transform_a == ComplexTransform::kConjugate) {
140
+ a = conj(a);
141
+ }
142
+
143
+ if (transform_b == ComplexTransform::kConjugate) {
144
+ b = conj(b);
145
+ }
146
+
147
+ accum[i][j] = inner_product_op(a, b, accum[i][j]);
148
+ }
149
+ }
150
+ }
151
+ }
152
+
153
+ CUTLASS_PRAGMA_UNROLL
154
+ for (int j = 0; j < kNblock; j++) {
155
+ CUTLASS_PRAGMA_UNROLL
156
+ for (int i = 0; i < kMblock; i++) {
157
+
158
+ int row = row_block + i;
159
+ int col = col_block + j;
160
+
161
+ MatrixCoord coord = MatrixCoord(row, col);
162
+
163
+ if (row < M && col < N) {
164
+
165
+ complex<ScalarType> acc{
166
+ ScalarType(accum[i][j].real()),
167
+ ScalarType(accum[i][j].imag())
168
+ };
169
+
170
+ ComplexC c_ij = ComplexC();
171
+
172
+ if (beta.real() != ScalarType() || beta.imag() != ScalarType()) {
173
+ c_ij = tensor_c.at(coord);
174
+ }
175
+
176
+ complex<ScalarType> src{
177
+ ScalarType(c_ij.real()),
178
+ ScalarType(c_ij.imag())
179
+ };
180
+
181
+ complex<ScalarType> result = alpha * acc + beta * src;
182
+
183
+ ComplexC d_ij;
184
+
185
+ d_ij.real() = convert_op(result.real());
186
+ d_ij.imag() = convert_op(result.imag());
187
+
188
+ tensor_d.at(coord) = d_ij;
189
+ }
190
+ }
191
+ }
192
+ }
193
+
194
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
195
+
196
+ } // namespace kernel
197
+
198
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
199
+
200
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
201
+ /// objects.
202
+ ///
203
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
204
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
205
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
206
+ /// arguments explicitly.
207
+ template <
208
+ typename ElementA,
209
+ typename LayoutA,
210
+ typename ElementB,
211
+ typename LayoutB,
212
+ typename ElementC,
213
+ typename LayoutC,
214
+ typename ScalarType,
215
+ typename ComputeType,
216
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>,
217
+ typename InnerProductOp = multiply_add<complex<ComputeType>>
218
+ >
219
+ void GemmPlanarComplex(
220
+ gemm::GemmCoord problem_size,
221
+ complex<ScalarType> alpha,
222
+ TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
223
+ ComplexTransform transform_a,
224
+ TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
225
+ ComplexTransform transform_b,
226
+ complex<ScalarType> beta,
227
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
228
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
229
+ complex<ComputeType> initial_accum) {
230
+
231
+ static_assert(
232
+ LayoutA::kRank == 2 &&
233
+ LayoutB::kRank == 2 &&
234
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
235
+
236
+ int const kMblock = kernel::kGemmPlanarComplexBlockSize;
237
+ int const kNblock = kernel::kGemmPlanarComplexBlockSize;
238
+
239
+ dim3 block(16, 8);
240
+
241
+ dim3 grid(
242
+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
243
+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
244
+ 1);
245
+
246
+ kernel::GemmPlanarComplex<
247
+ ElementA, LayoutA,
248
+ ElementB, LayoutB,
249
+ ElementC, LayoutC,
250
+ ScalarType,
251
+ ComputeType,
252
+ ConvertOp,
253
+ InnerProductOp
254
+ ><<< grid, block >>>(
255
+ problem_size,
256
+ alpha,
257
+ tensor_a,
258
+ transform_a,
259
+ tensor_b,
260
+ transform_b,
261
+ beta,
262
+ tensor_c,
263
+ tensor_d,
264
+ initial_accum
265
+ );
266
+ }
267
+
268
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
269
+
270
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
271
+ /// objects.
272
+ ///
273
+ /// This assumes the accumulator type is the same type as the scalars.
274
+ template <
275
+ typename ElementA,
276
+ typename LayoutA,
277
+ typename ElementB,
278
+ typename LayoutB,
279
+ typename ElementC,
280
+ typename LayoutC,
281
+ typename ScalarType
282
+ >
283
+ void GemmPlanarComplex(
284
+ gemm::GemmCoord problem_size,
285
+ complex<ScalarType> alpha,
286
+ TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
287
+ ComplexTransform transform_a,
288
+ TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
289
+ ComplexTransform transform_b,
290
+ complex<ScalarType> beta,
291
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
292
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
293
+
294
+ GemmPlanarComplex(
295
+ problem_size,
296
+ alpha,
297
+ tensor_a, transform_a,
298
+ tensor_b, transform_b,
299
+ beta,
300
+ tensor_c,
301
+ tensor_d,
302
+ complex<ScalarType>());
303
+ }
304
+
305
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
306
+
307
+ } // namespace device
308
+ } // namespace reference
309
+ } // namespace cutlass
310
+
311
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief GETT device reference code
33
+ */
34
+ #pragma once
35
+
36
+ #include <cute/tensor.hpp>
37
+
38
+ namespace cutlass::reference::device {
39
+
40
+ template <
41
+ class ATensor,
42
+ class BTensor,
43
+ class CTensor,
44
+ class DTensor,
45
+ class ElementAccumulator,
46
+ class ElementEpilogue>
47
+ __global__ static
48
+ void
49
+ gett_kernel(
50
+ DTensor D,
51
+ ATensor const A,
52
+ BTensor const B,
53
+ CTensor const C,
54
+ ElementEpilogue alpha, ElementEpilogue beta,
55
+ ElementAccumulator acc_init)
56
+ {
57
+ using namespace cute;
58
+
59
+ static_assert(DTensor::rank == 3, "(M,N,L)");
60
+ static_assert(ATensor::rank == 3, "(M,K,L)");
61
+ static_assert(BTensor::rank == 3, "(N,K,L)");
62
+ static_assert(CTensor::rank == 3, "(M,N,L)");
63
+
64
+ assert(size<0>(A) == size<0>(D)); // M
65
+ assert(size<0>(C) == size<0>(D)); // M
66
+ assert(size<0>(B) == size<1>(D)); // N
67
+ assert(size<1>(C) == size<1>(D)); // N
68
+ assert(size<1>(A) == size<1>(B)); // K
69
+ assert(size<2>(A) == size<2>(D)); // L
70
+ assert(size<2>(B) == size<2>(D)); // L
71
+ assert(size<2>(C) == size<2>(D)); // L
72
+
73
+ NumericConverter<ElementAccumulator, typename ATensor::value_type> a_converter;
74
+ NumericConverter<ElementAccumulator, typename BTensor::value_type> b_converter;
75
+ NumericConverter<ElementEpilogue, ElementAccumulator> acc_converter;
76
+ NumericConverter<ElementEpilogue, typename CTensor::value_type> source_converter;
77
+ NumericConverter<typename DTensor::value_type, ElementEpilogue> output_converter;
78
+
79
+ // Thread id to each element of D
80
+ for (int tid = threadIdx.x + blockDim.x * blockIdx.x;
81
+ tid < size(D);
82
+ tid += blockDim.x * gridDim.x) {
83
+ // (m,n,l) coordinate
84
+ auto mnl_coord = idx2crd(tid, product_each(shape(D)));
85
+ auto m = get<0>(mnl_coord);
86
+ auto n = get<1>(mnl_coord);
87
+ auto l = get<2>(mnl_coord);
88
+
89
+ auto A_ml = A(m,_,l);
90
+ auto B_nl = B(n,_,l);
91
+
92
+ ElementAccumulator accum = ElementAccumulator(0);
93
+ for (int k = 0; k < size<1>(A); ++k) {
94
+ ElementAccumulator a = a_converter(A_ml(k));
95
+ ElementAccumulator b = b_converter(B_nl(k));
96
+ accum += a * b;
97
+ }
98
+
99
+ ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l)));
100
+ D(m,n,l) = output_converter(scaled_output);
101
+ }
102
+ }
103
+
104
+ // Most general version
105
+ template <
106
+ class ProblemShapeMNKL,
107
+ class ElementA,
108
+ class StrideA,
109
+ class ElementB,
110
+ class StrideB,
111
+ class ElementAccumulator,
112
+ class ElementC,
113
+ class StrideC,
114
+ class ElementD,
115
+ class StrideD,
116
+ class ElementEpilogue>
117
+ void
118
+ gett(
119
+ ProblemShapeMNKL problem_shape_mnkl,
120
+ ElementA const* ptr_A, StrideA stride_a_mkl,
121
+ ElementB const* ptr_B, StrideB stride_b_nkl,
122
+ ElementAccumulator _,
123
+ ElementC const* ptr_C, StrideC stride_c_mnl,
124
+ ElementD * ptr_D, StrideD stride_d_mnl,
125
+ ElementEpilogue alpha, ElementEpilogue beta,
126
+ cudaStream_t stream = 0) {
127
+ using namespace cute;
128
+
129
+ static_assert(cute::rank(ProblemShapeMNKL{}) == 4);
130
+ auto M = get<0>(problem_shape_mnkl);
131
+ auto N = get<1>(problem_shape_mnkl);
132
+ auto K = get<2>(problem_shape_mnkl);
133
+ auto L = get<3>(problem_shape_mnkl);
134
+
135
+ // Represent the full tensors
136
+ auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L)
137
+ auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L)
138
+ auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L)
139
+ auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L)
140
+
141
+ dim3 dimBlock(256);
142
+ dim3 dimGrid(240);
143
+ gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0));
144
+ }
145
+
146
+ } // namespace cutlass::reference::device
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for GEMM in host-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/tensor_view.h"
39
+ #include "cutlass/gemm/gemm.h"
40
+
41
+ #include "cutlass/util/reference/device/thread/gemm.h"
42
+
43
+ namespace cutlass {
44
+ namespace reference {
45
+ namespace device {
46
+ namespace kernel {
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
51
+ /// objects.
52
+ template <
53
+ typename TensorRefA,
54
+ typename TensorRefB,
55
+ typename TensorRefC,
56
+ typename ScalarType,
57
+ typename AccumulatorType,
58
+ typename OutputTile,
59
+ typename InnerProductOp,
60
+ typename ConvertOp
61
+ >
62
+ __global__ void Gemm(
63
+ gemm::GemmCoord problem_size,
64
+ ScalarType alpha,
65
+ TensorRefA tensor_a,
66
+ TensorRefB tensor_b,
67
+ ScalarType beta,
68
+ TensorRefC tensor_c,
69
+ TensorRefC tensor_d,
70
+ AccumulatorType initial_accum) {
71
+
72
+ // Map each thread to a unique tile of the output matrix
73
+ MatrixCoord output_coord(
74
+ MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow),
75
+ MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn)
76
+ );
77
+
78
+ // Compute the general matrix product
79
+ thread::Gemm<
80
+ TensorRefA,
81
+ TensorRefB,
82
+ TensorRefC,
83
+ ScalarType,
84
+ AccumulatorType,
85
+ OutputTile,
86
+ InnerProductOp,
87
+ ConvertOp
88
+ > gemm(initial_accum);
89
+
90
+ gemm.multiply_add(
91
+ problem_size,
92
+ tensor_a,
93
+ tensor_b,
94
+ output_coord);
95
+
96
+ gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord);
97
+ }
98
+
99
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
100
+
101
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
102
+ /// objects.
103
+ template <
104
+ typename TensorRefCollectionA,
105
+ typename TensorRefCollectionB,
106
+ typename TensorRefCollectionC,
107
+ typename ScalarType,
108
+ typename AccumulatorType,
109
+ typename OutputTile,
110
+ typename InnerProductOp,
111
+ typename ConvertOp
112
+ >
113
+ __global__ void BatchedGemm(
114
+ gemm::GemmCoord problem_size,
115
+ ScalarType alpha,
116
+ TensorRefCollectionA tensor_collection_a,
117
+ TensorRefCollectionB tensor_collection_b,
118
+ ScalarType beta,
119
+ TensorRefCollectionC tensor_collection_c,
120
+ AccumulatorType initial_accum) {
121
+
122
+ // Obtain batch ID
123
+ int batch_id = blockIdx.z;
124
+
125
+ // Dereference based on batch_id
126
+ typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id);
127
+ typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id);
128
+ typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id);
129
+
130
+ // Map each thread to a unique tile of the output matrix
131
+ MatrixCoord output_coord(
132
+ (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn,
133
+ (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow
134
+ );
135
+
136
+ // Compute the general matrix product
137
+ thread::Gemm<
138
+ typename TensorRefCollectionA::TensorRef,
139
+ typename TensorRefCollectionB::TensorRef,
140
+ typename TensorRefCollectionC::TensorRef,
141
+ ScalarType,
142
+ AccumulatorType,
143
+ OutputTile,
144
+ InnerProductOp,
145
+ ConvertOp
146
+ > gemm(initial_accum);
147
+
148
+ gemm.multiply_add(
149
+ problem_size,
150
+ tensor_a,
151
+ tensor_b,
152
+ output_coord);
153
+
154
+ gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
155
+ }
156
+
157
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
158
+
159
+ } // namespace kernel
160
+ } // namespace device
161
+ } // namespace reference
162
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include <curand_kernel.h>
35
+
36
+ #include "cutlass/cutlass.h"
37
+
38
+ namespace cutlass {
39
+ namespace reference {
40
+ namespace device {
41
+ namespace kernel {
42
+
43
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ /// Kernel to initialize tensor to uniform random distribution
46
+ template <typename T>
47
+ __global__ void TensorInitializeUniform(
48
+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
49
+ __shared__ curandState_t rng_state[1024];
50
+
51
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
52
+
53
+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
54
+
55
+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
56
+ int s_idx = blockIdx.y * blockDim.x;
57
+
58
+ tensor += s_idx * ldm + c_idx;
59
+
60
+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
61
+ if (s_idx < dim_strided && c_idx < dim_contiguous) {
62
+ double range = dist.uniform.max - dist.uniform.min;
63
+
64
+ double rnd = curand_uniform(&rng_state[threadIdx.x]);
65
+
66
+ rnd = dist.uniform.min + range * rnd;
67
+
68
+ // Random values are cast to integer after scaling by a power of two to facilitate error
69
+ // testing
70
+ if (dist.int_scale >= 0) {
71
+ rnd = double(int(rnd * double(1 << dist.int_scale)));
72
+ *tensor = T(rnd / double(1 << dist.int_scale));
73
+ } else {
74
+ *tensor = T(rnd);
75
+ }
76
+
77
+ tensor += ldm;
78
+ }
79
+ }
80
+ }
81
+
82
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
83
+
84
+ /// Kernel to initialize tensor to uniform distribution
85
+ template <typename T>
86
+ __global__ void TensorInitializeGaussian(
87
+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
88
+ __shared__ curandState_t rng_state[1024];
89
+
90
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
91
+
92
+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
93
+
94
+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
95
+ int s_idx = blockIdx.y * blockDim.x;
96
+
97
+ tensor += s_idx * ldm + c_idx;
98
+
99
+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
100
+ if (s_idx < dim_strided && c_idx < dim_contiguous) {
101
+ // Random values are cast to integer after scaling by a power of two to facilitate error
102
+ // testing
103
+
104
+ double rnd = curand_normal(&rng_state[threadIdx.x]);
105
+
106
+ rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd;
107
+
108
+ if (dist.int_scale >= 0) {
109
+ rnd = double(int(rnd * double(1 << dist.int_scale)));
110
+ *tensor = T(rnd / double(1 << dist.int_scale));
111
+ } else {
112
+ *tensor = T(rnd);
113
+ }
114
+ }
115
+ }
116
+ }
117
+
118
+ /// Kernel to initialize tensor to an identity matrix
119
+ template <typename T>
120
+ __global__ void TensorInitializeLinear(
121
+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
122
+ __shared__ curandState_t rng_state[1024];
123
+
124
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
125
+
126
+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
127
+
128
+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
129
+ int s_idx = blockIdx.y * blockDim.x;
130
+
131
+ tensor += s_idx * ldm + c_idx;
132
+
133
+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
134
+ if (s_idx < dim_strided && c_idx < dim_contiguous) {
135
+ *tensor =
136
+ dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx;
137
+ }
138
+ }
139
+ }
140
+
141
+ /// Kernel to initialize tensor to an identity matrix
142
+ template <typename T>
143
+ __global__ void TensorInitializeIdentity(
144
+ Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
145
+ __shared__ curandState_t rng_state[1024];
146
+
147
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
148
+
149
+ curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
150
+
151
+ int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
152
+ int s_idx = blockIdx.y * blockDim.x;
153
+
154
+ tensor += s_idx * ldm + c_idx;
155
+
156
+ for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
157
+ if (s_idx < dim_strided && c_idx < dim_contiguous) {
158
+ *tensor = (c_idx == s_idx ? T(1) : T(0));
159
+ }
160
+ }
161
+ }
162
+
163
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
164
+
165
+ } // namespace kernel
166
+ } // namespace device
167
+ } // namespace reference
168
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include "cutlass/cutlass.h"
35
+ #include "cutlass/coord.h"
36
+ #include "cutlass/subbyte_reference.h"
37
+ #include "cutlass/fast_math.h"
38
+
39
+ namespace cutlass {
40
+ namespace reference {
41
+ namespace device {
42
+ namespace kernel {
43
+
44
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ /// Defines several helpers
47
+ namespace detail {
48
+
49
+ /// Helper to perform for-each operation
50
+ template <typename Func, int Rank, int RankRemaining>
51
+ struct TensorForEachHelper {
52
+
53
+ /// Constructor for general rank
54
+ __inline__ __device__
55
+ TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
56
+
57
+ int64_t product = 1;
58
+
59
+ CUTLASS_PRAGMA_UNROLL
60
+ for (int i = Rank - RankRemaining; i < Rank; ++i) {
61
+ product *= size[i];
62
+ }
63
+
64
+ coord[Rank - 1 - RankRemaining] = index / product;
65
+ int64_t remaining = index % product;
66
+
67
+ TensorForEachHelper<Func, Rank, RankRemaining-1>(func, size, coord, remaining);
68
+ }
69
+ };
70
+
71
+ /// Helper to perform for-each operation
72
+ template <typename Func, int Rank>
73
+ struct TensorForEachHelper<Func, Rank, 0> {
74
+
75
+ /// Constructor for fastest changing rank
76
+ __inline__ __device__
77
+ TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
78
+
79
+ coord[Rank - 1] = index;
80
+
81
+ if (coord < size) {
82
+ func(coord);
83
+ }
84
+ }
85
+ };
86
+
87
+ } // namespace detail
88
+
89
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
90
+
91
+ /// Kernel calls a functor for each element in a tensor's index space
92
+ template <typename Func, int Rank, typename Params>
93
+ __global__ void TensorForEach(Coord<Rank> size, Params params = Params()) {
94
+
95
+ Func func(params);
96
+
97
+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
98
+ int64_t max_index = 1;
99
+
100
+ CUTLASS_PRAGMA_UNROLL
101
+ for (int i = 0; i < Rank; ++i) {
102
+ max_index *= size[i];
103
+ }
104
+
105
+ CUTLASS_PRAGMA_NO_UNROLL
106
+ while (index < max_index) {
107
+ Coord<Rank> coord;
108
+
109
+ detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, size, coord, index);
110
+ index += blockDim.x * gridDim.x;
111
+ }
112
+ }
113
+
114
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
115
+
116
+ /// Kernel calls a functor for each element along a tensor's diagonal
117
+ template <typename Func, int Rank, typename Params>
118
+ __global__ void TensorDiagonalForEach(Coord<Rank> size, Params params, int start, int end) {
119
+
120
+ Func func(params);
121
+
122
+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start;
123
+
124
+ if (index < end) {
125
+ Coord<Rank> coord;
126
+
127
+ CUTLASS_PRAGMA_UNROLL
128
+ for (int i = 0; i < Rank; ++i) {
129
+ coord[i] = index;
130
+ }
131
+
132
+ func(coord);
133
+ }
134
+ }
135
+
136
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
137
+
138
+ template <typename Element, typename Func>
139
+ __global__ void BlockForEach(
140
+ Element *ptr,
141
+ size_t capacity,
142
+ typename Func::Params params) {
143
+
144
+ Func func(params);
145
+
146
+ size_t index = threadIdx.x + blockIdx.x * blockDim.x;
147
+
148
+ for (; index < capacity; index += blockDim.x * gridDim.x) {
149
+ ReferenceFactory<Element>::get(ptr, index) = func();
150
+ }
151
+ }
152
+
153
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
154
+
155
+ } // namespace kernel
156
+ } // namespace device
157
+ } // namespace reference
158
+ } // namespace cutlass
159
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued GEMM in device-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/blas3.h"
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/numeric_conversion.h"
40
+ #include "cutlass/tensor_view.h"
41
+ #include "cutlass/gemm/gemm.h"
42
+
43
+ namespace cutlass {
44
+ namespace reference {
45
+ namespace device {
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace kernel {
50
+
51
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
52
+ /// objects.
53
+ ///
54
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
55
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
56
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
57
+ /// arguments explicitly.
58
+ template <
59
+ typename ElementA,
60
+ typename LayoutA,
61
+ typename ElementB,
62
+ typename LayoutB,
63
+ typename ElementC,
64
+ typename LayoutC,
65
+ typename ScalarType,
66
+ typename ComputeType,
67
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>,
68
+ typename InnerProductOp = multiply_add<ComputeType>,
69
+ int kMblock = 4,
70
+ int kNblock = 4
71
+ >
72
+ __global__ void Rank2KComplex(
73
+ gemm::GemmCoord problem_size,
74
+ ScalarType alpha,
75
+ TensorRef<ElementA, LayoutA> tensor_a,
76
+ ComplexTransform transform_a,
77
+ TensorRef<ElementB, LayoutB> tensor_b,
78
+ ComplexTransform transform_b,
79
+ ScalarType beta,
80
+ TensorRef<ElementC, LayoutC> tensor_c,
81
+ TensorRef<ElementC, LayoutC> tensor_d,
82
+ ComputeType initial_accum,
83
+ FillMode fill_mode_c,
84
+ BlasMode blas_mode,
85
+ int batch_count = 1,
86
+ int64_t batch_stride_A = 0,
87
+ int64_t batch_stride_B = 0,
88
+ int64_t batch_stride_C = 0,
89
+ int64_t batch_stride_D = 0) {
90
+
91
+ static_assert(
92
+ LayoutA::kRank == 2 &&
93
+ LayoutB::kRank == 2 &&
94
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
95
+
96
+ int const M = problem_size.m();
97
+ int const N = problem_size.n();
98
+ int const K = problem_size.k();
99
+
100
+ assert(M=N);
101
+
102
+ ConvertOp convert_op;
103
+ InnerProductOp inner_product_op;
104
+
105
+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
106
+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
107
+ int batch_idx = blockIdx.z;
108
+
109
+ tensor_a.add_pointer_offset(batch_idx * batch_stride_A);
110
+ tensor_b.add_pointer_offset(batch_idx * batch_stride_B);
111
+ tensor_c.add_pointer_offset(batch_idx * batch_stride_C);
112
+ tensor_d.add_pointer_offset(batch_idx * batch_stride_D);
113
+
114
+ for (; batch_idx < batch_count; batch_idx += gridDim.z) {
115
+
116
+ // Compute matrix product using blocks
117
+ ComputeType accum[kMblock][kNblock];
118
+
119
+ CUTLASS_PRAGMA_UNROLL
120
+ for (int j = 0; j < kNblock; j++) {
121
+ CUTLASS_PRAGMA_UNROLL
122
+ for (int i = 0; i < kMblock; i++) {
123
+ accum[i][j] = initial_accum;
124
+ }
125
+ }
126
+
127
+ for (int k_block = 0; k_block < K; ++k_block) {
128
+ CUTLASS_PRAGMA_UNROLL
129
+ for (int j = 0; j < kNblock; j++) {
130
+ CUTLASS_PRAGMA_UNROLL
131
+ for (int i = 0; i < kMblock; i++) {
132
+ int row = row_block + i;
133
+ int col = col_block + j;
134
+
135
+ if (row < M && col < N &&
136
+ ( (fill_mode_c == FillMode::kLower && row >= col) ||
137
+ (fill_mode_c == FillMode::kUpper && row <= col) )
138
+ ) {
139
+
140
+ // A x B^T (Symmetric) or A x B^H (Hermitian)
141
+ // complex conjugation on operandB (b_t) is function of blas3 computation
142
+ ElementA a = tensor_a.at(MatrixCoord(row, k_block));
143
+ ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
144
+ conj(tensor_b.at(MatrixCoord(col, k_block))) :
145
+ tensor_b.at(MatrixCoord(col, k_block));
146
+
147
+ ComputeType a_ik = ComputeType(a);
148
+ ComputeType b_jk = ComputeType(b_t);
149
+
150
+ // complex conjugation is a function of operand layouts
151
+ if (transform_a == ComplexTransform::kConjugate) {
152
+ a_ik = conj(a_ik);
153
+ }
154
+ // complex conjugation is a function of operand layouts
155
+ if (transform_b == ComplexTransform::kConjugate) {
156
+ b_jk = conj(b_jk);
157
+ }
158
+
159
+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
160
+
161
+ // B x A^T (Symmetric) or B x A^H (Hermitian)
162
+ // complex conjugation on operandB (a_t) is function of blas3 computation
163
+ ElementB b = tensor_b.at(MatrixCoord(row, k_block));
164
+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
165
+ conj(tensor_a.at(MatrixCoord(col, k_block))):
166
+ tensor_a.at(MatrixCoord(col, k_block));
167
+
168
+ ComputeType b_ik = ComputeType(b);
169
+ ComputeType a_jk = ComputeType(a_t);
170
+
171
+ // complex conjugation here is a function of operand layouts
172
+ if (transform_b == ComplexTransform::kConjugate) {
173
+ b_ik = conj(b_ik);
174
+ }
175
+ // complex conjugation here is a function of operand layouts
176
+ if (transform_a == ComplexTransform::kConjugate) {
177
+ a_jk = conj(a_jk);
178
+ }
179
+
180
+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
181
+ }
182
+ }
183
+ }
184
+ }
185
+
186
+ CUTLASS_PRAGMA_UNROLL
187
+ for (int j = 0; j < kNblock; j++) {
188
+ CUTLASS_PRAGMA_UNROLL
189
+ for (int i = 0; i < kMblock; i++) {
190
+ int row = row_block + i;
191
+ int col = col_block + j;
192
+
193
+ MatrixCoord coord = MatrixCoord(row, col);
194
+
195
+ if (row < M && col < N &&
196
+ ((fill_mode_c == FillMode::kLower && row >= col) ||
197
+ (fill_mode_c == FillMode::kUpper && row <= col))
198
+ ) {
199
+
200
+ ScalarType c = tensor_c.at(coord);
201
+ // The imaginary parts of the diagonal elements of
202
+ // a complex data type are assumed and set to zero
203
+ if (blas_mode == BlasMode::kHermitian) {
204
+ c = (row == col) ? real(c) : c;
205
+ }
206
+
207
+ tensor_d.at(coord) = convert_op(
208
+ alpha * ScalarType(accum[i][j]) +
209
+ beta * c);
210
+ }
211
+ }
212
+ }
213
+
214
+ tensor_a.add_pointer_offset(batch_stride_A * gridDim.z);
215
+ tensor_b.add_pointer_offset(batch_stride_B * gridDim.z);
216
+ tensor_c.add_pointer_offset(batch_stride_C * gridDim.z);
217
+ tensor_d.add_pointer_offset(batch_stride_D * gridDim.z);
218
+
219
+ } // for (batch_idx)
220
+ }
221
+
222
+ } // namespace kernel
223
+
224
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
225
+
226
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
227
+ /// objects.
228
+ ///
229
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
230
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
231
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
232
+ /// arguments explicitly.
233
+ template <
234
+ typename ElementA,
235
+ typename LayoutA,
236
+ typename ElementB,
237
+ typename LayoutB,
238
+ typename ElementC,
239
+ typename LayoutC,
240
+ typename ScalarType,
241
+ typename ComputeType,
242
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>,
243
+ typename InnerProductOp = multiply_add<ComputeType>
244
+ >
245
+ void Rank2KComplex(
246
+ gemm::GemmCoord problem_size,
247
+ ScalarType alpha,
248
+ TensorRef<ElementA, LayoutA> tensor_a,
249
+ ComplexTransform transform_a,
250
+ TensorRef<ElementB, LayoutB> tensor_b,
251
+ ComplexTransform transform_b,
252
+ ScalarType beta,
253
+ TensorRef<ElementC, LayoutC> tensor_c,
254
+ TensorRef<ElementC, LayoutC> tensor_d,
255
+ ComputeType initial_accum,
256
+ FillMode fill_mode_c,
257
+ BlasMode blas_mode,
258
+ int batch_count = 1,
259
+ int64_t batch_stride_A = 0,
260
+ int64_t batch_stride_B = 0,
261
+ int64_t batch_stride_C = 0,
262
+ int64_t batch_stride_D = 0) {
263
+
264
+ static_assert(
265
+ LayoutA::kRank == 2 &&
266
+ LayoutB::kRank == 2 &&
267
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
268
+
269
+ int const kMblock = 4;
270
+ int const kNblock = 4;
271
+
272
+ dim3 block(16, 8);
273
+ dim3 grid(
274
+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
275
+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
276
+ batch_count % std::numeric_limits<uint16_t>::max()
277
+ );
278
+
279
+ kernel::Rank2KComplex<
280
+ ElementA,
281
+ LayoutA,
282
+ ElementB,
283
+ LayoutB,
284
+ ElementC,
285
+ LayoutC,
286
+ ScalarType,
287
+ ComputeType,
288
+ ConvertOp,
289
+ InnerProductOp,
290
+ kMblock,
291
+ kNblock
292
+ ><<< grid, block >>>(
293
+ problem_size,
294
+ alpha,
295
+ tensor_a,
296
+ transform_a,
297
+ tensor_b,
298
+ transform_b,
299
+ beta,
300
+ tensor_c,
301
+ tensor_d,
302
+ initial_accum,
303
+ fill_mode_c,
304
+ blas_mode,
305
+ batch_count,
306
+ batch_stride_A,
307
+ batch_stride_B,
308
+ batch_stride_C,
309
+ batch_stride_D
310
+ );
311
+ }
312
+
313
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
314
+
315
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
316
+ /// objects.
317
+ ///
318
+ /// This assumes the accumulator type is the same type as the scalars.
319
+ template <
320
+ typename ElementA,
321
+ typename LayoutA,
322
+ typename ElementB,
323
+ typename LayoutB,
324
+ typename ElementC,
325
+ typename LayoutC,
326
+ typename ScalarType
327
+ >
328
+ void Rank2KComplex(
329
+ gemm::GemmCoord problem_size,
330
+ ScalarType alpha,
331
+ TensorRef<ElementA, LayoutA> tensor_a,
332
+ ComplexTransform transform_a,
333
+ TensorRef<ElementB, LayoutB> tensor_b,
334
+ ComplexTransform transform_b,
335
+ ScalarType beta,
336
+ TensorRef<ElementC, LayoutC> tensor_c,
337
+ TensorRef<ElementC, LayoutC> tensor_d,
338
+ FillMode fill_mode_c,
339
+ BlasMode blas_mode) {
340
+
341
+ Rank2KComplex(
342
+ problem_size, alpha,
343
+ tensor_a, transform_a,
344
+ tensor_b, transform_b,
345
+ beta, tensor_c, tensor_d,
346
+ ScalarType(0),
347
+ fill_mode_c,
348
+ blas_mode);
349
+ }
350
+
351
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
352
+
353
+ } // namespace device
354
+ } // namespace reference
355
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Defines host-side elementwise operations on TensorView.
33
+ */
34
+
35
+ #pragma once
36
+ // Standard Library includes
37
+ #include <utility>
38
+
39
+ // Cutlass includes
40
+ #include "cutlass/cutlass.h"
41
+ #include "cutlass/relatively_equal.h"
42
+
43
+ #include "cutlass/util/distribution.h"
44
+
45
+ #include "tensor_foreach.h"
46
+
47
+ namespace cutlass {
48
+ namespace reference {
49
+ namespace device {
50
+
51
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ namespace kernel {
54
+
55
+ template <typename Element>
56
+ __global__ void BlockCompareEqual(
57
+ int *equal,
58
+ Element const *ptr_A,
59
+ Element const *ptr_B,
60
+ size_t capacity) {
61
+
62
+ size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
63
+
64
+ for (; idx < capacity; idx += gridDim.x * blockDim.x) {
65
+
66
+ Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
67
+ Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
68
+
69
+ if (a != b) {
70
+ *equal = 0;
71
+
72
+ return;
73
+ }
74
+ }
75
+ }
76
+
77
+ template <typename Element>
78
+ __global__ void BlockCompareRelativelyEqual(
79
+ int *equal,
80
+ Element const *ptr_A,
81
+ Element const *ptr_B,
82
+ size_t capacity,
83
+ Element epsilon,
84
+ Element nonzero_floor) {
85
+
86
+ size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
87
+
88
+ for (; idx < capacity; idx += gridDim.x * blockDim.x) {
89
+
90
+ Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
91
+ Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);
92
+
93
+ if (!relatively_equal(a, b, epsilon, nonzero_floor)) {
94
+ *equal = 0;
95
+ return;
96
+ }
97
+ }
98
+ }
99
+
100
+ } // namespace kernel
101
+
102
+
103
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
104
+
105
+ /// Performs a bit-level equality check between two blocks
106
+ template <typename Element>
107
+ bool BlockCompareEqual(
108
+ Element const *ptr_A,
109
+ Element const *ptr_B,
110
+ size_t capacity,
111
+ int grid_size = 0,
112
+ int block_size = 0,
113
+ cudaStream_t stream = nullptr) {
114
+
115
+ int equal_flag = 1;
116
+ int *device_equal_flag = nullptr;
117
+
118
+ if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
119
+ throw std::runtime_error("Failed to allocate device flag.");
120
+ }
121
+
122
+ if (cudaMemcpy(
123
+ device_equal_flag,
124
+ &equal_flag,
125
+ sizeof(int),
126
+ cudaMemcpyHostToDevice) != cudaSuccess) {
127
+
128
+ throw std::runtime_error("Failed to copy equality flag to device.");
129
+ }
130
+
131
+ if (!grid_size || !block_size) {
132
+
133
+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
134
+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
135
+ &grid_size,
136
+ &block_size,
137
+ reinterpret_cast<void const *>(kernel::BlockCompareEqual<Element>));
138
+
139
+ if (result != cudaSuccess) {
140
+ throw std::runtime_error("Failed to query occupancy.");
141
+ }
142
+ // Limit block size. This has the effect of increasing the number of items processed by a
143
+ // single thread and reduces the impact of initialization overhead.
144
+ block_size = (block_size < 128 ? block_size : 128);
145
+ }
146
+
147
+ dim3 grid(grid_size, 1, 1);
148
+ dim3 block(block_size, 1, 1);
149
+
150
+ kernel::BlockCompareEqual<Element><<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity);
151
+
152
+ cudaStreamSynchronize(stream);
153
+
154
+ if (cudaMemcpy(
155
+ &equal_flag,
156
+ device_equal_flag,
157
+ sizeof(int),
158
+ cudaMemcpyDeviceToHost) != cudaSuccess) {
159
+
160
+ cudaFree(device_equal_flag);
161
+
162
+ throw std::runtime_error("Failed to copy equality flag from device.");
163
+ }
164
+
165
+ cudaFree(device_equal_flag);
166
+
167
+ return equal_flag;
168
+ }
169
+
170
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
171
+
172
+ /// Performs a bit-level equality check between two blocks
173
+ template <typename Element>
174
+ bool BlockCompareRelativelyEqual(
175
+ Element const *ptr_A,
176
+ Element const *ptr_B,
177
+ size_t capacity,
178
+ Element epsilon,
179
+ Element nonzero_floor,
180
+ int grid_size = 0,
181
+ int block_size = 0,
182
+ cudaStream_t stream = nullptr) {
183
+
184
+ int equal_flag = 1;
185
+ int *device_equal_flag = nullptr;
186
+
187
+ if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) {
188
+ throw std::runtime_error("Failed to allocate device flag.");
189
+ }
190
+
191
+ if (cudaMemcpy(
192
+ device_equal_flag,
193
+ &equal_flag,
194
+ sizeof(int),
195
+ cudaMemcpyHostToDevice) != cudaSuccess) {
196
+
197
+ throw std::runtime_error("Failed to copy equality flag to device.");
198
+ }
199
+
200
+ if (!grid_size || !block_size) {
201
+
202
+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
203
+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
204
+ &grid_size,
205
+ &block_size,
206
+ reinterpret_cast<void const *>(kernel::BlockCompareRelativelyEqual<Element>));
207
+
208
+ if (result != cudaSuccess) {
209
+ throw std::runtime_error("Failed to query occupancy.");
210
+ }
211
+ // Limit block size. This has the effect of increasing the number of items processed by a
212
+ // single thread and reduces the impact of initialization overhead.
213
+ block_size = (block_size < 128 ? block_size : 128);
214
+ }
215
+
216
+ dim3 grid(grid_size, 1, 1);
217
+ dim3 block(block_size, 1, 1);
218
+
219
+ kernel::BlockCompareRelativelyEqual<Element><<< grid, block, 0, stream >>>(
220
+ device_equal_flag,
221
+ ptr_A,
222
+ ptr_B,
223
+ capacity,
224
+ epsilon,
225
+ nonzero_floor
226
+ );
227
+
228
+ cudaStreamSynchronize(stream);
229
+
230
+ if (cudaMemcpy(
231
+ &equal_flag,
232
+ device_equal_flag,
233
+ sizeof(int),
234
+ cudaMemcpyDeviceToHost) != cudaSuccess) {
235
+
236
+ cudaFree(device_equal_flag);
237
+
238
+ throw std::runtime_error("Failed to copy equality flag from device.");
239
+ }
240
+
241
+ cudaFree(device_equal_flag);
242
+
243
+ return equal_flag;
244
+ }
245
+
246
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
247
+
248
+ } // device
249
+ } // reference
250
+ } // cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h ADDED
@@ -0,0 +1,2075 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined
33
+ in this header are not specialized for any particular data layout and are therefore not
34
+ intended to offer the best possible performance. Rather, they are intended to be generic
35
+ reference implementations to support the CUTLASS unit tests.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #if !defined(__CUDACC_RTC__)
41
+
42
+ // Standard Library includes
43
+ #include <utility>
44
+ #include <cstdlib>
45
+ #include <cmath>
46
+ #include <type_traits>
47
+ #include <cstdint>
48
+
49
+ #endif
50
+
51
+ // CUDA includes
52
+ #include <curand_kernel.h>
53
+
54
+ // Cutlass includes
55
+ #include "cutlass/cutlass.h"
56
+ #include "cutlass/array.h"
57
+ #include "cutlass/complex.h"
58
+ #include "cutlass/tensor_view.h"
59
+ #include "cutlass/blas3.h"
60
+ #include "cutlass/numeric_types.h"
61
+
62
+ #include "cutlass/layout/vector.h"
63
+
64
+ #include "cutlass/util/reference/device/tensor_foreach.h"
65
+ #include "cutlass/util/distribution.h"
66
+
67
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
68
+
69
+ namespace cutlass {
70
+ namespace reference {
71
+ namespace device {
72
+
73
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
74
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
75
+
76
+ namespace detail {
77
+
78
+ template <typename FloatType>
79
+ CUTLASS_DEVICE
80
+ FloatType random_normal_float(curandState_t *state) {
81
+ return curand_normal(state);
82
+ }
83
+
84
+ template <>
85
+ CUTLASS_DEVICE
86
+ double random_normal_float<double>(curandState_t *state) {
87
+ return curand_normal_double(state);
88
+ }
89
+
90
+ template <typename FloatType>
91
+ CUTLASS_DEVICE
92
+ FloatType random_uniform_float(curandState_t *state) {
93
+ return curand_uniform(state);
94
+ }
95
+
96
+ template <>
97
+ CUTLASS_DEVICE
98
+ double random_uniform_float<double>(curandState_t *state) {
99
+ return curand_uniform_double(state);
100
+ }
101
+
102
+ template <typename Element>
103
+ struct RandomGaussianFunc {
104
+
105
+ using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type;
106
+ using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type;
107
+
108
+ /// Parameters structure
109
+ struct Params {
110
+
111
+ //
112
+ // Data members
113
+ //
114
+
115
+ uint64_t seed;
116
+ FloatType mean;
117
+ FloatType stddev;
118
+ int int_scale;
119
+ FloatType float_scale_up;
120
+ FloatType float_scale_down;
121
+ int exclude_zero; ///< If non-negative, excludes zeros
122
+
123
+ //
124
+ // Methods
125
+ //
126
+
127
+ /// Construction of Gaussian RNG functor.
128
+ Params(
129
+ uint64_t seed_ = 0,
130
+ Element mean_ = 0,
131
+ Element stddev_ = 1,
132
+ int int_scale_ = -1,
133
+ int exclude_zero_ = -1
134
+ ):
135
+ seed(seed_),
136
+ mean(static_cast<FloatType>(mean_)),
137
+ stddev(static_cast<FloatType>(stddev_)),
138
+ int_scale(int_scale_),
139
+ exclude_zero(exclude_zero_) {
140
+
141
+ float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
142
+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
143
+ }
144
+ };
145
+
146
+ //
147
+ // Data members
148
+ //
149
+
150
+ /// Parameters object
151
+ Params params;
152
+
153
+ /// RNG state object
154
+ curandState_t rng_state;
155
+
156
+ //
157
+ // Methods
158
+ //
159
+
160
+ /// Device-side initialization of RNG
161
+ CUTLASS_DEVICE
162
+ RandomGaussianFunc(Params const &params): params(params) {
163
+
164
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
165
+
166
+ curand_init(params.seed, gtid, 0, &rng_state);
167
+ }
168
+
169
+ /// Compute random value and update RNG state
170
+ CUTLASS_DEVICE
171
+ Element operator()() {
172
+
173
+ FloatType rnd = random_normal_float<FloatType>(&rng_state);
174
+ rnd = params.mean + params.stddev * rnd;
175
+
176
+ Element result;
177
+ if (params.int_scale >= 0) {
178
+ rnd = FloatType(std::llround(rnd * params.float_scale_up));
179
+ result = Element(rnd * params.float_scale_down);
180
+ }
181
+ else {
182
+ result = Element(rnd);
183
+ }
184
+
185
+ if (params.exclude_zero >=0 && result == Element(0.0)) {
186
+ if (rnd > FloatType(0)) {
187
+ rnd += FloatType(1);
188
+ } else {
189
+ rnd -= FloatType(1);
190
+ }
191
+ result = Element(rnd);
192
+ }
193
+
194
+ return result;
195
+ }
196
+ };
197
+
198
+
199
+ template <typename Real>
200
+ struct RandomGaussianFunc<complex<Real>> {
201
+
202
+ using Element = complex<Real>;
203
+ using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type;
204
+ using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type;
205
+
206
+ /// Parameters structure
207
+ struct Params {
208
+
209
+ //
210
+ // Data members
211
+ //
212
+
213
+ uint64_t seed;
214
+ FloatType mean;
215
+ FloatType stddev;
216
+ int int_scale;
217
+ FloatType float_scale_up;
218
+ FloatType float_scale_down;
219
+ int exclude_zero; ///< If non-negative, excludes zeros
220
+
221
+ //
222
+ // Methods
223
+ //
224
+
225
+ /// Construction of Gaussian RNG functor.
226
+ Params(
227
+ uint64_t seed_ = 0,
228
+ Real mean_ = 0,
229
+ Real stddev_ = 1,
230
+ int int_scale_ = -1,
231
+ int exclude_zero_ = -1
232
+ ):
233
+ seed(seed_),
234
+ mean(static_cast<FloatType>(mean_)),
235
+ stddev(static_cast<FloatType>(stddev_)),
236
+ int_scale(int_scale_),
237
+ exclude_zero(exclude_zero_) {
238
+
239
+ float_scale_up = FloatType(IntType(1) << int_scale);
240
+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
241
+ }
242
+ };
243
+
244
+ //
245
+ // Data members
246
+ //
247
+
248
+ /// Parameters object
249
+ Params params;
250
+
251
+ /// RNG state object
252
+ curandState_t rng_state;
253
+
254
+ //
255
+ // Methods
256
+ //
257
+
258
+ /// Device-side initialization of RNG
259
+ CUTLASS_DEVICE
260
+ RandomGaussianFunc(Params const &params): params(params) {
261
+
262
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
263
+
264
+ curand_init(params.seed, gtid, 0, &rng_state);
265
+ }
266
+
267
+ /// Compute random value and update RNG state
268
+ CUTLASS_DEVICE
269
+ Element operator()() {
270
+
271
+ FloatType rnd_r = random_normal_float<FloatType>(&rng_state);
272
+ FloatType rnd_i = random_normal_float<FloatType>(&rng_state);
273
+ rnd_r = params.mean + params.stddev * rnd_r;
274
+ rnd_i = params.mean + params.stddev * rnd_i;
275
+
276
+ Element result;
277
+ if (params.int_scale >= 0) {
278
+ rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
279
+ rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
280
+
281
+ result = {
282
+ Real(rnd_r * params.float_scale_down),
283
+ Real(rnd_i * params.float_scale_down)
284
+ };
285
+ }
286
+ else {
287
+ result = Element(Real(rnd_r), Real(rnd_i));
288
+ }
289
+
290
+ if (params.exclude_zero >= 0 &&
291
+ result.real() == Real(0.0) &&
292
+ result.imag() == Real(0.0)) {
293
+
294
+ if (rnd_r > FloatType(0)) {
295
+ rnd_r += FloatType(1);
296
+ } else {
297
+ rnd_r -= FloatType(1);
298
+ }
299
+ result = Element(Real(rnd_r), Real(rnd_i));
300
+ }
301
+
302
+ return result;
303
+ }
304
+ };
305
+
306
+ /// Computes a random Gaussian distribution
307
+ template <
308
+ typename Element, ///< Element type
309
+ typename Layout> ///< Layout function
310
+ struct TensorFillRandomGaussianFunc {
311
+
312
+ /// View type
313
+ using TensorView = TensorView<Element, Layout>;
314
+
315
+ /// Scalar type
316
+ typedef typename TensorView::Element T;
317
+
318
+ /// Coordinate in tensor's index space
319
+ typedef typename TensorView::TensorCoord TensorCoord;
320
+
321
+ using RandomFunc = RandomGaussianFunc<Element>;
322
+
323
+ /// Parameters structure
324
+ struct Params {
325
+
326
+ //
327
+ // Data members
328
+ //
329
+
330
+ TensorView view;
331
+ typename RandomFunc::Params random;
332
+
333
+ //
334
+ // Methods
335
+ //
336
+
337
+ /// Construction of Gaussian RNG functor.
338
+ Params(
339
+ TensorView view_ = TensorView(),
340
+ typename RandomFunc::Params random_ = typename RandomFunc::Params()
341
+ ):
342
+ view(view_), random(random_) {
343
+
344
+ }
345
+ };
346
+
347
+ //
348
+ // Data members
349
+ //
350
+
351
+ Params params;
352
+ RandomFunc random;
353
+
354
+ //
355
+ // Methods
356
+ //
357
+
358
+ /// Device-side initialization of RNG
359
+ CUTLASS_DEVICE
360
+ TensorFillRandomGaussianFunc(Params const &params): params(params), random(params.random) {
361
+
362
+ }
363
+
364
+ /// Compute random value and update RNG state
365
+ CUTLASS_DEVICE
366
+ void operator()(TensorCoord const &coord) {
367
+
368
+ params.view.at(coord) = random();
369
+ }
370
+ };
371
+
372
+ } // namespace detail
373
+
374
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
375
+
376
+ /// Fills a tensor with random values with a Gaussian distribution.
377
+ template <
378
+ typename Element, ///< Element type
379
+ typename Layout> ///< Layout function
380
+ void TensorFillRandomGaussian(
381
+ TensorView<Element, Layout> view, ///< destination tensor
382
+ uint64_t seed, ///< seed for RNG
383
+ typename RealType<Element>::Type mean = Element(0), ///< Gaussian distribution's mean
384
+ typename RealType<Element>::Type stddev = Element(1), ///< Gaussian distribution's standard deviation
385
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
386
+ /// are not truncated to zero. Permits reducing precision of
387
+ /// data.
388
+ int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
389
+ cudaStream_t stream = nullptr) {
390
+
391
+ using RandomFunc = detail::RandomGaussianFunc<Element>;
392
+ using Func = detail::TensorFillRandomGaussianFunc<Element, Layout>;
393
+ using Params = typename Func::Params;
394
+
395
+ TensorForEach<Func, Layout::kRank, Params>(
396
+ view.extent(),
397
+ Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)),
398
+ /*grid_size*/0, /*block_size*/0,
399
+ stream
400
+ );
401
+ }
402
+
403
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
404
+
405
+ /// Fills a tensor with random values with a Gaussian distribution.
406
+ template <typename Element> ///< Element type
407
+ void BlockFillRandomGaussian(
408
+ Element *ptr,
409
+ size_t capacity,
410
+ uint64_t seed, ///< seed for RNG
411
+ typename RealType<Element>::Type mean, ///< Gaussian distribution's mean
412
+ typename RealType<Element>::Type stddev, ///< Gaussian distribution's standard deviation
413
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
414
+ /// are not truncated to zero. Permits reducing precision of
415
+ /// data.
416
+ cudaStream_t stream = nullptr) {
417
+
418
+ using RandomFunc = detail::RandomGaussianFunc<Element>;
419
+
420
+ typename RandomFunc::Params params(seed, mean, stddev, bits);
421
+
422
+ BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
423
+ }
424
+
425
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
426
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
427
+
428
+ namespace detail {
429
+
430
+ /// Computes a random uniform distribution
431
+ template <typename Element> ///< Element type
432
+ struct RandomUniformFunc {
433
+
434
+ using FloatType = typename std::conditional<
435
+ (sizeof(Element) > 4),
436
+ double,
437
+ float>::type;
438
+
439
+ using IntType = typename std::conditional<
440
+ (sizeof(Element) > 4),
441
+ int64_t,
442
+ int>::type;
443
+
444
+ /// Parameters structure
445
+ struct Params {
446
+
447
+ //
448
+ // Data members
449
+ //
450
+
451
+ uint64_t seed;
452
+ FloatType range;
453
+ FloatType max;
454
+ int int_scale;
455
+ double pnan;
456
+ FloatType float_scale_up;
457
+ FloatType float_scale_down;
458
+ int exclude_zero; ///< If non-negative, excludes zeros
459
+
460
+ /// Default ctor
461
+ CUTLASS_HOST_DEVICE
462
+ Params() { }
463
+
464
+ //
465
+ // Methods
466
+ //
467
+
468
+ /// Construction of Gaussian RNG functor.
469
+ Params(
470
+ uint64_t seed_ = 0,
471
+ Element max_ = 1,
472
+ Element min = 0,
473
+ int int_scale_ = -1,
474
+ double pnan_ = 0,
475
+ int exclude_zero_ = -1
476
+ ):
477
+ seed(seed_),
478
+ range(static_cast<FloatType>(max_) - static_cast<FloatType>(min)),
479
+ max(static_cast<FloatType>(max_)),
480
+ int_scale(int_scale_),
481
+ pnan(pnan_),
482
+ exclude_zero(exclude_zero_) {
483
+
484
+ float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits
485
+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
486
+
487
+ // Handle cases where min = 0 or max = 0 for excluding zeros
488
+ if (exclude_zero >= 0) {
489
+ range = (min == Element(0)) ? range - FloatType(1): range;
490
+ max = (max_ == Element(0)) ? max - FloatType(1): max;
491
+ }
492
+ }
493
+ };
494
+
495
+ //
496
+ // Data members
497
+ //
498
+
499
+ /// Parameters object
500
+ Params params;
501
+
502
+ /// RNG state object
503
+ curandState_t rng_state;
504
+
505
+ //
506
+ // Methods
507
+ //
508
+
509
+ /// Device-side initialization of RNG
510
+ CUTLASS_DEVICE
511
+ RandomUniformFunc(Params const &params): params(params) {
512
+
513
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
514
+
515
+ curand_init(params.seed, gtid, 0, &rng_state);
516
+ }
517
+
518
+ /// Compute random value and update RNG state
519
+ CUTLASS_DEVICE
520
+ Element operator()() {
521
+
522
+ // Draw random float in [0.0, 1.0] to determine if element should be NaN.
523
+ if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
524
+ if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
525
+ return Element(NAN);
526
+ }
527
+ }
528
+
529
+ FloatType rnd = random_uniform_float<FloatType>(&rng_state);
530
+ rnd = params.max - params.range * rnd;
531
+
532
+ // Random values are cast to integer after scaling by a power of two to facilitate error
533
+ // testing
534
+ Element result;
535
+
536
+ if (params.int_scale >= 0) {
537
+ rnd = FloatType(std::llround(rnd * params.float_scale_up));
538
+ result = Element(rnd * params.float_scale_down);
539
+ }
540
+ else {
541
+ result = Element(rnd);
542
+ }
543
+
544
+ if (params.exclude_zero >=0 && result == Element(0.0)) {
545
+ if (rnd > FloatType(0)) {
546
+ rnd = std::min(params.max, rnd + FloatType(1));
547
+ } else {
548
+ rnd = std::max((params.max - params.range), rnd - FloatType(1));
549
+ }
550
+ result = Element(rnd);
551
+ }
552
+
553
+ return result;
554
+ }
555
+ };
556
+
557
+ /// Computes a random Gaussian distribution
558
+ template <typename Real>
559
+ struct RandomUniformFunc<complex<Real>> {
560
+
561
+ using Element = complex<Real>;
562
+
563
+ using FloatType = typename std::conditional<
564
+ (sizeof(Real) > 4),
565
+ double,
566
+ float>::type;
567
+
568
+ using IntType = typename std::conditional<
569
+ (sizeof(Real) > 4),
570
+ int64_t,
571
+ int>::type;
572
+
573
+ /// Parameters structure
574
+ struct Params {
575
+
576
+ //
577
+ // Data members
578
+ //
579
+
580
+ uint64_t seed;
581
+ FloatType range;
582
+ FloatType min;
583
+ int int_scale;
584
+ double pnan;
585
+ FloatType float_scale_up;
586
+ FloatType float_scale_down;
587
+ int exclude_zero; ///< If non-negative, excludes zeros
588
+
589
+ /// Default ctor
590
+ CUTLASS_HOST_DEVICE
591
+ Params() { }
592
+
593
+ //
594
+ // Methods
595
+ //
596
+
597
+ /// Construction of Gaussian RNG functor.
598
+ Params(
599
+ uint64_t seed_ = 0,
600
+ FloatType max = 1,
601
+ FloatType min_ = 0,
602
+ int int_scale_ = -1,
603
+ double pnan_ = 0,
604
+ int exclude_zero_ = -1
605
+ ):
606
+ seed(seed_),
607
+ range(static_cast<FloatType>(max - min_)),
608
+ min(static_cast<FloatType>(min_)),
609
+ int_scale(int_scale_),
610
+ pnan(pnan_),
611
+ exclude_zero(exclude_zero_) {
612
+
613
+ float_scale_up = FloatType(IntType(1) << int_scale);
614
+ float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale);
615
+
616
+ // Handle cases where min = 0 or max = 0 for excluding zeros
617
+ if (exclude_zero >= 0) {
618
+ min = (min == FloatType(0)) ? min + FloatType(1): min;
619
+ range = (max == FloatType(0)) ? range - FloatType(1): range;
620
+ }
621
+ }
622
+ };
623
+
624
+ //
625
+ // Data members
626
+ //
627
+
628
+ /// Parameters object
629
+ Params params;
630
+
631
+ /// RNG state object
632
+ curandState_t rng_state;
633
+
634
+ //
635
+ // Methods
636
+ //
637
+
638
+ /// Device-side initialization of RNG
639
+ CUTLASS_DEVICE
640
+ RandomUniformFunc(Params const &params): params(params) {
641
+
642
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
643
+
644
+ curand_init(params.seed, gtid, 0, &rng_state);
645
+ }
646
+
647
+ /// Compute random value and update RNG state
648
+ CUTLASS_DEVICE
649
+ Element operator()() {
650
+
651
+ // Draw random float in [0.0, 1.0] to determine if element should be NaN.
652
+ if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
653
+ if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
654
+ return Element(Real(NAN), Real(NAN));
655
+ }
656
+ }
657
+
658
+ FloatType rnd_r = random_uniform_float<FloatType>(&rng_state);
659
+ FloatType rnd_i = random_uniform_float<FloatType>(&rng_state);
660
+
661
+ rnd_r = params.min + params.range * rnd_r;
662
+ rnd_i = params.min + params.range * rnd_i;
663
+
664
+ // Random values are cast to integer after scaling by a power of two to facilitate error
665
+ // testing
666
+ Element result;
667
+
668
+ if (params.int_scale >= 0) {
669
+ rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up));
670
+ rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up));
671
+
672
+ result = {
673
+ Real(rnd_r * params.float_scale_down),
674
+ Real(rnd_i * params.float_scale_down)
675
+ };
676
+ }
677
+ else {
678
+ result = Element(Real(rnd_r), Real(rnd_i));
679
+ }
680
+
681
+ if (params.exclude_zero >= 0 &&
682
+ result.real() == Real(0.0) &&
683
+ result.imag() == Real(0.0)) {
684
+
685
+ if (rnd_r > FloatType(0)) {
686
+ rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1));
687
+ } else {
688
+ rnd_r = std::max((params.min), rnd_r - FloatType(1));
689
+ }
690
+ result = Element(Real(rnd_r), Real(rnd_i));
691
+ }
692
+
693
+ return result;
694
+ }
695
+ };
696
+
697
+ /// Computes a random uniform distribution
698
+ template <
699
+ typename Element, ///< Element type
700
+ typename Layout> ///< Layout function
701
+ struct TensorFillRandomUniformFunc {
702
+
703
+ /// View type
704
+ using TensorView = TensorView<Element, Layout>;
705
+
706
+ /// Scalar type
707
+ typedef typename TensorView::Element T;
708
+
709
+ /// Coordinate in tensor's index space
710
+ typedef typename TensorView::TensorCoord TensorCoord;
711
+
712
+ using RandomFunc = RandomUniformFunc<Element>;
713
+
714
+ /// Parameters structure
715
+ struct Params {
716
+
717
+ //
718
+ // Data members
719
+ //
720
+
721
+ TensorView view;
722
+ typename RandomFunc::Params random;
723
+
724
+ /// Default ctor
725
+ CUTLASS_HOST_DEVICE
726
+ Params() { }
727
+
728
+ //
729
+ // Methods
730
+ //
731
+
732
+ /// Construction of Gaussian RNG functor.
733
+ Params(
734
+ TensorView view_ = TensorView(),
735
+ typename RandomFunc::Params random_ = RandomFunc::Params()
736
+ ):
737
+ view(view_), random(random_) {
738
+
739
+ }
740
+ };
741
+
742
+ //
743
+ // Data members
744
+ //
745
+
746
+ Params params;
747
+ RandomFunc random;
748
+
749
+ //
750
+ // Methods
751
+ //
752
+
753
+ /// Device-side initialization of RNG
754
+ CUTLASS_DEVICE
755
+ TensorFillRandomUniformFunc(Params const &params): params(params), random(params.random) {
756
+ }
757
+
758
+ /// Compute random value and update RNG state
759
+ CUTLASS_DEVICE
760
+ void operator()(TensorCoord const &coord) {
761
+
762
+ params.view.at(coord) = random();
763
+ }
764
+ };
765
+
766
+ } // namespace detail
767
+
768
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
769
+
770
+ /// Fills a tensor with random values with a uniform random distribution.
771
+ template <
772
+ typename Element, ///< Element type
773
+ typename Layout> ///< Layout function
774
+ void TensorFillRandomUniform(
775
+ TensorView<Element, Layout> view, ///< destination tensor
776
+ uint64_t seed, ///< seed for RNG
777
+ typename RealType<Element>::Type max = Element(1), ///< upper bound of distribution
778
+ typename RealType<Element>::Type min = Element(0), ///< lower bound for distribution
779
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
780
+ /// are not truncated to zero. Permits reducing precision of
781
+ /// data.
782
+ double pnan = 0, ///< Percentage of NaN elements.
783
+ int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init
784
+ cudaStream_t stream = nullptr) {
785
+
786
+ using RandomFunc = detail::RandomUniformFunc<Element>;
787
+ using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
788
+ using Params = typename Func::Params;
789
+
790
+ typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero);
791
+
792
+ TensorForEach<Func, Layout::kRank, Params>(
793
+ view.extent(),
794
+ Params(view, random),
795
+ /*grid_size*/0, /*block_size*/0,
796
+ stream
797
+ );
798
+ }
799
+
800
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
801
+
802
+ /// Fills a tensor with random values with a uniform random distribution.
803
+ template <typename Element>
804
+ void BlockFillRandomUniform(
805
+ Element *ptr,
806
+ size_t capacity,
807
+ uint64_t seed, ///< seed for RNG
808
+ typename RealType<Element>::Type max, ///< upper bound of distribution
809
+ typename RealType<Element>::Type min, ///< lower bound for distribution
810
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
811
+ /// are not truncated to zero. Permits reducing precision of
812
+ /// data.
813
+ double pnan = 0, ///< Percentage of NaN elements.
814
+ cudaStream_t stream = nullptr) {
815
+
816
+ using RandomFunc = detail::RandomUniformFunc<Element>;
817
+
818
+ typename RandomFunc::Params params(seed, max, min, bits, pnan);
819
+
820
+ BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
821
+ }
822
+
823
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
824
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
825
+
826
+ namespace detail {
827
+
828
+ /// Computes a random sparse meta
829
+ template <typename Element> ///< Element type
830
+ struct RandomSparseMetaFunc {
831
+
832
+ using FloatType = float;
833
+
834
+ using IntType = int32_t;
835
+
836
+ /// Parameters structure
837
+ struct Params {
838
+
839
+ //
840
+ // Data members
841
+ //
842
+
843
+ uint64_t seed;
844
+ FloatType range;
845
+ int MetaSizeInBits;
846
+
847
+ /// Default ctor
848
+ CUTLASS_HOST_DEVICE
849
+ Params() { }
850
+
851
+ //
852
+ // Methods
853
+ //
854
+
855
+ /// Construction of Gaussian RNG functor.
856
+ Params(
857
+ uint64_t seed_ = 0,
858
+ int MetaSizeInBits_ = 2
859
+ ):
860
+ seed(seed_),
861
+ MetaSizeInBits(MetaSizeInBits_) {
862
+ if (MetaSizeInBits_ == 2) {
863
+ range = 6;
864
+ }
865
+ else if (MetaSizeInBits_ == 4) {
866
+ range = 2;
867
+ }
868
+ else {
869
+ throw std::invalid_argument("Invalid MetaSizeInBits");
870
+ }
871
+ }
872
+ };
873
+
874
+ //
875
+ // Data members
876
+ //
877
+
878
+ /// Parameters object
879
+ Params params;
880
+
881
+ /// RNG state object
882
+ curandState_t rng_state;
883
+
884
+ //
885
+ // Methods
886
+ //
887
+
888
+ /// Device-side initialization of RNG
889
+ CUTLASS_DEVICE
890
+ RandomSparseMetaFunc(Params const &params): params(params) {
891
+
892
+ uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
893
+
894
+ curand_init(params.seed, gtid, 0, &rng_state);
895
+ }
896
+
897
+ /// Compute random value and update RNG state
898
+ CUTLASS_DEVICE
899
+ Element operator()() {
900
+ Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe};
901
+ Element TwoToOneMeta[2] = {0x4, 0xe};
902
+
903
+ Element *MetaArray =
904
+ (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta;
905
+
906
+ Element result = 0x0;
907
+
908
+ CUTLASS_PRAGMA_UNROLL
909
+ for (int i = 0; i < cutlass::sizeof_bits<Element>::value / 4; ++i) {
910
+ FloatType rnd = random_uniform_float<FloatType>(&rng_state);
911
+ rnd = params.range * rnd;
912
+ Element meta = MetaArray[(int)rnd];
913
+
914
+ result = (Element)(result | ((Element)(meta << (i * 4))));
915
+ }
916
+
917
+ return result;
918
+ }
919
+ };
920
+
921
+ /// Computes a random Gaussian distribution
922
+ template <
923
+ typename Element, ///< Element type
924
+ typename Layout> ///< Layout function
925
+ struct TensorFillRandomSparseMetaFunc {
926
+
927
+ /// View type
928
+ using TensorView = TensorView<Element, Layout>;
929
+
930
+ /// Scalar type
931
+ typedef typename TensorView::Element T;
932
+
933
+ /// Coordinate in tensor's index space
934
+ typedef typename TensorView::TensorCoord TensorCoord;
935
+
936
+ using RandomFunc = RandomSparseMetaFunc<Element>;
937
+
938
+ /// Parameters structure
939
+ struct Params {
940
+
941
+ //
942
+ // Data members
943
+ //
944
+
945
+ TensorView view;
946
+ typename RandomFunc::Params random;
947
+
948
+ /// Default ctor
949
+ CUTLASS_HOST_DEVICE
950
+ Params() { }
951
+
952
+ //
953
+ // Methods
954
+ //
955
+
956
+ /// Construction of Gaussian RNG functor.
957
+ Params(
958
+ TensorView view_ = TensorView(),
959
+ typename RandomFunc::Params random_ = RandomFunc::Params()
960
+ ):
961
+ view(view_), random(random_) {
962
+
963
+ }
964
+ };
965
+
966
+ //
967
+ // Data members
968
+ //
969
+
970
+ Params params;
971
+ RandomFunc random;
972
+
973
+ //
974
+ // Methods
975
+ //
976
+
977
+ /// Device-side initialization of RNG
978
+ CUTLASS_DEVICE
979
+ TensorFillRandomSparseMetaFunc(Params const &params): params(params), random(params.random) {
980
+ }
981
+
982
+ /// Compute random value and update RNG state
983
+ CUTLASS_DEVICE
984
+ void operator()(TensorCoord const &coord) {
985
+
986
+ params.view.at(coord) = random();
987
+ }
988
+ };
989
+
990
+ } // namespace detail
991
+
992
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
993
+
994
+ /// Fills a tensor with random values with a uniform random distribution.
995
+ template <
996
+ typename Element, ///< Element type
997
+ typename Layout> ///< Layout function
998
+ void TensorFillRandomSparseMeta(
999
+ TensorView<Element, Layout> view, ///< destination tensor
1000
+ uint64_t seed, ///< seed for RNG
1001
+ int MetaSizeInBits = 2, ///< meta data size
1002
+ cudaStream_t stream = nullptr) {
1003
+
1004
+ using RandomFunc = detail::RandomSparseMetaFunc<Element>;
1005
+ using Func = detail::TensorFillRandomUniformFunc<Element, Layout>;
1006
+ using Params = typename Func::Params;
1007
+
1008
+ typename RandomFunc::Params random(seed, MetaSizeInBits);
1009
+
1010
+ TensorForEach<Func, Layout::kRank, Params>(
1011
+ view.extent(),
1012
+ Params(view, random),
1013
+ /*grid_size*/0, /*block_size*/0,
1014
+ stream
1015
+ );
1016
+ }
1017
+
1018
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1019
+
1020
+ /// Fills a tensor with random values with a uniform random distribution.
1021
+ template <typename Element>
1022
+ void BlockFillRandomSparseMeta(
1023
+ Element *ptr,
1024
+ size_t capacity,
1025
+ uint64_t seed, ///< seed for RNG
1026
+ int MetaSizeInBits = 2, ///< meta data size
1027
+ cudaStream_t stream = nullptr) {
1028
+
1029
+ using RandomFunc = detail::RandomSparseMetaFunc<Element>;
1030
+
1031
+ typename RandomFunc::Params params(seed, MetaSizeInBits);
1032
+
1033
+ BlockForEach<Element, RandomFunc>(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream);
1034
+ }
1035
+
1036
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1037
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1038
+
1039
+ namespace detail {
1040
+
1041
+ /// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal.
1042
+ template <
1043
+ typename Element, ///< Element type
1044
+ typename Layout> ///< Layout function
1045
+ struct TensorFillDiagonalFunc {
1046
+
1047
+ /// View type
1048
+ using TensorView = TensorView<Element, Layout>;
1049
+
1050
+ /// Scalar type
1051
+ typedef typename TensorView::Element T;
1052
+
1053
+ /// Coordinate in tensor's index space
1054
+ typedef typename TensorView::TensorCoord TensorCoord;
1055
+
1056
+ /// Parameters structure
1057
+ struct Params {
1058
+
1059
+ //
1060
+ // Data members
1061
+ //
1062
+
1063
+ TensorView view;
1064
+ Element diag;
1065
+ Element other;
1066
+
1067
+ /// Default ctor
1068
+ CUTLASS_HOST_DEVICE
1069
+ Params() { }
1070
+
1071
+ //
1072
+ // Methods
1073
+ //
1074
+
1075
+ Params(
1076
+ TensorView view_ = TensorView(),
1077
+ Element diag_ = Element(1),
1078
+ Element other_ = Element(0)
1079
+ ):
1080
+ view(view_), diag(diag_), other(other_) {
1081
+
1082
+ }
1083
+ };
1084
+
1085
+ //
1086
+ // Data members
1087
+ //
1088
+
1089
+ /// Parameters object
1090
+ Params params;
1091
+
1092
+ //
1093
+ // Methods
1094
+ //
1095
+
1096
+ /// Device-side initialization of RNG
1097
+ CUTLASS_DEVICE
1098
+ TensorFillDiagonalFunc(Params const &params): params(params) {
1099
+
1100
+ }
1101
+
1102
+ /// Updates the tensor
1103
+ CUTLASS_DEVICE
1104
+ void operator()(TensorCoord const &coord) {
1105
+
1106
+ bool is_diag = true;
1107
+
1108
+ CUTLASS_PRAGMA_UNROLL
1109
+ for (int i = 1; i < Layout::kRank; ++i) {
1110
+ if (coord[i] != coord[i - 1]) {
1111
+ is_diag = false;
1112
+ break;
1113
+ }
1114
+ }
1115
+
1116
+ params.view.at(coord) = (is_diag ? params.diag : params.other);
1117
+ }
1118
+ };
1119
+
1120
+ // Overwrites the elements of a tensor with a uniform value depending on fill mode
1121
+ template <
1122
+ typename Element, ///< Element type
1123
+ typename Layout> ///< Layout function
1124
+ struct TensorFillPartialFunc {
1125
+
1126
+ /// View type
1127
+ using TensorView = TensorView<Element, Layout>;
1128
+
1129
+ /// Scalar type
1130
+ typedef typename TensorView::Element T;
1131
+
1132
+ /// Coordinate in tensor's index space
1133
+ typedef typename TensorView::TensorCoord TensorCoord;
1134
+
1135
+ /// Parameters structure
1136
+ struct Params {
1137
+
1138
+ //
1139
+ // Data members
1140
+ //
1141
+
1142
+ TensorView view;
1143
+ Element element;
1144
+ FillMode fill_mode;
1145
+
1146
+ /// Default ctor
1147
+ CUTLASS_HOST_DEVICE
1148
+ Params(): fill_mode(FillMode::kNone) { }
1149
+
1150
+ //
1151
+ // Methods
1152
+ //
1153
+
1154
+ /// Construction of Gaussian RNG functor.
1155
+ Params(
1156
+ TensorView view_,
1157
+ Element element_,
1158
+ FillMode fill_mode_
1159
+ ):
1160
+ view(view_), element(element_), fill_mode(fill_mode_) {
1161
+
1162
+ }
1163
+ };
1164
+
1165
+ //
1166
+ // Data members
1167
+ //
1168
+
1169
+ /// Parameters object
1170
+ Params params;
1171
+
1172
+ //
1173
+ // Methods
1174
+ //
1175
+
1176
+ CUTLASS_DEVICE
1177
+ TensorFillPartialFunc(Params const &params): params(params) {
1178
+
1179
+ }
1180
+
1181
+ /// Overwrites the element if it is within the covered region.
1182
+ CUTLASS_DEVICE
1183
+ void operator()(TensorCoord const &coord) {
1184
+
1185
+ bool predicate = true;
1186
+
1187
+ switch (params.fill_mode) {
1188
+ case FillMode::kFull:
1189
+ predicate = true;
1190
+ break;
1191
+
1192
+ case FillMode::kLower:
1193
+ CUTLASS_PRAGMA_UNROLL
1194
+ for (int i = 1; i < Layout::kRank; ++i) {
1195
+ if (coord[i - 1] < coord[i]) {
1196
+ predicate = false;
1197
+ break;
1198
+ }
1199
+ }
1200
+ break;
1201
+
1202
+ case FillMode::kUpper:
1203
+ CUTLASS_PRAGMA_UNROLL
1204
+ for (int i = 1; i < Layout::kRank; ++i) {
1205
+ if (coord[i - 1] > coord[i]) {
1206
+ predicate = false;
1207
+ break;
1208
+ }
1209
+ }
1210
+ break;
1211
+
1212
+ case FillMode::kDiagonal:
1213
+ CUTLASS_PRAGMA_UNROLL
1214
+ for (int i = 1; i < Layout::kRank; ++i) {
1215
+ if (coord[i - 1] != coord[i]) {
1216
+ predicate = false;
1217
+ break;
1218
+ }
1219
+ }
1220
+ break;
1221
+
1222
+ case FillMode::kNone: // fall-through
1223
+
1224
+ default:
1225
+ predicate = false;
1226
+ break;
1227
+ }
1228
+
1229
+ if (predicate) {
1230
+ params.view.at(coord) = params.element;
1231
+ }
1232
+ }
1233
+ };
1234
+
1235
+
1236
+ template <
1237
+ typename Element, ///< Element type
1238
+ typename Layout> ///< Layout function
1239
+ struct TensorClearPartialFunc {
1240
+
1241
+ /// View type
1242
+ using TensorView = TensorView<Element, Layout>;
1243
+
1244
+ /// Scalar type
1245
+ typedef typename TensorView::Element T;
1246
+
1247
+ /// Coordinate in tensor's index space
1248
+ typedef typename TensorView::TensorCoord TensorCoord;
1249
+
1250
+ ///
1251
+ static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices");
1252
+
1253
+ /// Parameters structure
1254
+ struct Params {
1255
+ TensorView view{};
1256
+ Element element{};
1257
+ FillMode fill_mode{FillMode::kNone};
1258
+ int alignment{0};
1259
+ };
1260
+
1261
+ //
1262
+ // Data members
1263
+ //
1264
+
1265
+ /// Parameters object
1266
+ Params params;
1267
+
1268
+ //
1269
+ // Methods
1270
+ //
1271
+
1272
+ CUTLASS_DEVICE
1273
+ TensorClearPartialFunc(Params const &params): params(params) {
1274
+
1275
+ }
1276
+
1277
+ /// Overwrites the element if it is within the covered region.
1278
+ CUTLASS_DEVICE
1279
+ void operator()(TensorCoord const &coord) {
1280
+
1281
+ bool predicate = true;
1282
+
1283
+ switch (params.fill_mode) {
1284
+
1285
+ case FillMode::kLower:
1286
+ if ((coord[0] >= coord[1]) ||
1287
+ ((coord[1] - coord[0]) >= params.alignment)) {
1288
+ predicate = false;
1289
+ break;
1290
+ }
1291
+ break;
1292
+
1293
+ case FillMode::kUpper:
1294
+ if ((coord[0] <= coord[1]) ||
1295
+ ((coord[0] - coord[1]) >= params.alignment)) {
1296
+ predicate = false;
1297
+ break;
1298
+ }
1299
+ break;
1300
+
1301
+ case FillMode::kNone: // fall-through
1302
+
1303
+ default:
1304
+ predicate = false;
1305
+ break;
1306
+ }
1307
+
1308
+ if (predicate) {
1309
+ params.view.at(coord) = params.element;
1310
+ }
1311
+ }
1312
+ };
1313
+
1314
+ } // namespace detail
1315
+
1316
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1317
+
1318
+ /// Fills a tensor everywhere with a unique value for its diagonal.
1319
+ template <
1320
+ typename Element, ///< Element type
1321
+ typename Layout> ///< Layout function
1322
+ void TensorFillDiagonal(
1323
+ TensorView<Element, Layout> view, ///< destination tensor
1324
+ Element diag = Element(1), ///< value to write in the diagonal
1325
+ Element other = Element(0), ///< value to write off the diagonal
1326
+ cudaStream_t stream = nullptr) {
1327
+
1328
+ typedef detail::TensorFillDiagonalFunc<Element, Layout> Func;
1329
+ typedef typename Func::Params Params;
1330
+
1331
+ TensorForEach<Func, Layout::kRank, Params>(
1332
+ view.extent(),
1333
+ Params(view, diag, other),
1334
+ /*grid_size*/0, /*block_size*/0,
1335
+ stream
1336
+ );
1337
+ }
1338
+
1339
+ /// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are
1340
+ /// not written.
1341
+ template <
1342
+ typename Element, ///< Element type
1343
+ typename Layout> ///< Layout function
1344
+ void TensorFillPartial(
1345
+ TensorView<Element, Layout> view, ///< destination tensor
1346
+ Element element,
1347
+ FillMode fill_mode,
1348
+ cudaStream_t stream = nullptr) {
1349
+
1350
+ typedef detail::TensorFillPartialFunc<Element, Layout> Func;
1351
+ typedef typename Func::Params Params;
1352
+
1353
+ TensorForEach<Func, Layout::kRank, Params>(
1354
+ view.extent(),
1355
+ Params(view, element, fill_mode),
1356
+ stream
1357
+ );
1358
+ }
1359
+
1360
+ /// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side
1361
+ /// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros)
1362
+ template <
1363
+ typename Element, ///< Element type
1364
+ typename Layout> ///< Layout function
1365
+ void TensorClearPartial(
1366
+ TensorView<Element, Layout> view, ///< destination tensor
1367
+ Element element,
1368
+ FillMode fill_mode,
1369
+ int alignment,
1370
+ cudaStream_t stream = nullptr) {
1371
+
1372
+ typedef detail::TensorClearPartialFunc<Element, Layout> Func;
1373
+ typedef typename Func::Params Params;
1374
+
1375
+ TensorForEach<Func, Layout::kRank, Params>(
1376
+ view.extent(),
1377
+ Params{view, element, fill_mode, alignment},
1378
+ /*grid_size*/0, /*block_size*/0,
1379
+ stream
1380
+ );
1381
+ }
1382
+
1383
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1384
+
1385
+ /// Fills a tensor with a uniform value
1386
+ template <
1387
+ typename Element, ///< Element type
1388
+ typename Layout> ///< Layout function
1389
+ void TensorFill(
1390
+ TensorView<Element, Layout> view, ///< destination tensor
1391
+ Element val = Element(0), ///< value to uniformly fill it with
1392
+ cudaStream_t stream = nullptr) {
1393
+
1394
+ TensorFillDiagonal(view, val, val, stream);
1395
+ }
1396
+
1397
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1398
+
1399
+ /// Fills a tensor's diagonal with 1 and 0 everywhere else.
1400
+ template <
1401
+ typename Element, ///< Element type
1402
+ typename Layout> ///< Layout function
1403
+ void TensorFillIdentity(
1404
+ TensorView<Element, Layout> view, ///< destination tensor
1405
+ cudaStream_t stream = nullptr) {
1406
+
1407
+ TensorFillDiagonal(view, Element(1), Element(0), stream);
1408
+ }
1409
+
1410
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1411
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1412
+
1413
+ namespace detail {
1414
+
1415
+ /// Computes a random Gaussian distribution
1416
+ template <
1417
+ typename Element, ///< Element type
1418
+ typename Layout> ///< Layout function
1419
+ struct TensorUpdateDiagonalFunc {
1420
+
1421
+ /// View type
1422
+ using TensorView = TensorView<Element, Layout>;
1423
+
1424
+ /// Scalar type
1425
+ typedef typename TensorView::Element T;
1426
+
1427
+ /// Coordinate in tensor's index space
1428
+ typedef typename TensorView::TensorCoord TensorCoord;
1429
+
1430
+ /// Parameters structure
1431
+ struct Params {
1432
+
1433
+ //
1434
+ // Data members
1435
+ //
1436
+
1437
+ TensorView view;
1438
+ Element diag;
1439
+
1440
+ /// Default ctor
1441
+ CUTLASS_HOST_DEVICE
1442
+ Params() { }
1443
+
1444
+ //
1445
+ // Methods
1446
+ //
1447
+
1448
+ /// Construction of Gaussian RNG functor.
1449
+ Params(
1450
+ TensorView view_ = TensorView(),
1451
+ Element diag_ = Element(1)
1452
+ ):
1453
+ view(view_), diag(diag_) {
1454
+
1455
+ }
1456
+ };
1457
+
1458
+ //
1459
+ // Data members
1460
+ //
1461
+
1462
+ /// Parameters object
1463
+ Params params;
1464
+
1465
+ //
1466
+ // Methods
1467
+ //
1468
+
1469
+ /// Device-side initialization of RNG
1470
+ CUTLASS_DEVICE
1471
+ TensorUpdateDiagonalFunc(Params const &params): params(params) {
1472
+
1473
+ }
1474
+
1475
+ /// Compute random value and update RNG state
1476
+ CUTLASS_DEVICE
1477
+ void operator()(TensorCoord const &coord) {
1478
+
1479
+ bool is_diag = true;
1480
+
1481
+ CUTLASS_PRAGMA_UNROLL
1482
+ for (int i = 1; i < Layout::kRank; ++i) {
1483
+ if (coord[i] != coord[i - 1]) {
1484
+ is_diag = false;
1485
+ break;
1486
+ }
1487
+ }
1488
+
1489
+ if (is_diag) {
1490
+ params.view.at(coord) = params.diag;
1491
+ }
1492
+ }
1493
+ };
1494
+
1495
+ } // namespace detail
1496
+
1497
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1498
+
1499
+ /// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements.
1500
+ template <
1501
+ typename Element, ///< Element type
1502
+ typename Layout> ///< Layout function
1503
+ void TensorUpdateDiagonal(
1504
+ TensorView<Element, Layout> view, ///< destination tensor
1505
+ Element diag = Element(1),
1506
+ cudaStream_t stream = nullptr) {
1507
+
1508
+ typedef detail::TensorUpdateDiagonalFunc<Element, Layout> Func;
1509
+ typedef typename Func::Params Params;
1510
+
1511
+ TensorForEach<Func, Layout::kRank, Params>(
1512
+ view.extent(),
1513
+ Params(view, diag),
1514
+ /*grid_size*/0, /*block_size*/0,
1515
+ stream
1516
+ );
1517
+ }
1518
+
1519
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1520
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1521
+
1522
+ namespace detail {
1523
+
1524
+ /// Computes a random Gaussian distribution
1525
+ template <
1526
+ typename Element, ///< Element type
1527
+ typename Layout> ///< Layout function
1528
+ struct TensorUpdateOffDiagonalFunc {
1529
+
1530
+ /// View type
1531
+ using TensorView = TensorView<Element, Layout>;
1532
+
1533
+ /// Scalar type
1534
+ typedef typename TensorView::Element T;
1535
+
1536
+ /// Coordinate in tensor's index space
1537
+ typedef typename TensorView::TensorCoord TensorCoord;
1538
+
1539
+ /// Parameters structure
1540
+ struct Params {
1541
+
1542
+ //
1543
+ // Data members
1544
+ //
1545
+
1546
+ TensorView view;
1547
+ Element other;
1548
+
1549
+ /// Default ctor
1550
+ CUTLASS_HOST_DEVICE
1551
+ Params() { }
1552
+
1553
+ //
1554
+ // Methods
1555
+ //
1556
+
1557
+ /// Construction of Gaussian RNG functor.
1558
+ Params(
1559
+ TensorView view_ = TensorView(),
1560
+ Element other_ = Element(0)
1561
+ ):
1562
+ view(view_), other(other_) {
1563
+
1564
+ }
1565
+ };
1566
+
1567
+ //
1568
+ // Data members
1569
+ //
1570
+
1571
+ /// Parameters object
1572
+ Params params;
1573
+
1574
+ //
1575
+ // Methods
1576
+ //
1577
+
1578
+ /// Device-side initialization of RNG
1579
+ CUTLASS_DEVICE
1580
+ TensorUpdateOffDiagonalFunc(Params const &params): params(params) {
1581
+
1582
+ }
1583
+
1584
+ /// Compute random value and update RNG state
1585
+ CUTLASS_DEVICE
1586
+ void operator()(TensorCoord const &coord) {
1587
+
1588
+ bool is_diag = true;
1589
+
1590
+ CUTLASS_PRAGMA_UNROLL
1591
+ for (int i = 1; i < Layout::kRank; ++i) {
1592
+ if (coord[i] != coord[i - 1]) {
1593
+ is_diag = false;
1594
+ break;
1595
+ }
1596
+ }
1597
+
1598
+ if (!is_diag) {
1599
+ params.view.at(coord) = params.other;
1600
+ }
1601
+ }
1602
+ };
1603
+
1604
+ } // namespace detail
1605
+
1606
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1607
+
1608
+ /// Writes a uniform value to all elements in the tensor without modifying diagonal elements.
1609
+ template <
1610
+ typename Element, ///< Element type
1611
+ typename Layout> ///< Layout function
1612
+ void TensorUpdateOffDiagonal(
1613
+ TensorView<Element, Layout> view, ///< destination tensor
1614
+ Element other = Element(1),
1615
+ cudaStream_t stream = nullptr) {
1616
+
1617
+ typedef detail::TensorUpdateOffDiagonalFunc<Element, Layout> Func;
1618
+ typedef typename Func::Params Params;
1619
+
1620
+ TensorForEach<Func, Layout::kRank, Params>(
1621
+ view.extent(),
1622
+ Params(view, other),
1623
+ /*grid_size*/0, /*block_size*/0,
1624
+ stream
1625
+ );
1626
+ }
1627
+
1628
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1629
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1630
+
1631
+ namespace detail {
1632
+
1633
+ /// Computes a random Gaussian distribution
1634
+ template <
1635
+ typename Element, ///< Element type
1636
+ typename Layout> ///< Layout function
1637
+ struct TensorFillLinearFunc {
1638
+
1639
+ /// View type
1640
+ using TensorView = TensorView<Element, Layout>;
1641
+
1642
+ /// Scalar type
1643
+ typedef typename TensorView::Element T;
1644
+
1645
+ /// Coordinate in tensor's index space
1646
+ typedef typename TensorView::TensorCoord TensorCoord;
1647
+
1648
+ /// Parameters structure
1649
+ struct Params {
1650
+
1651
+ //
1652
+ // Data members
1653
+ //
1654
+
1655
+ TensorView view;
1656
+ Array<Element, Layout::kRank> v;
1657
+ Element s;
1658
+
1659
+ /// Default ctor
1660
+ CUTLASS_HOST_DEVICE
1661
+ Params() { }
1662
+
1663
+ //
1664
+ // Methods
1665
+ //
1666
+
1667
+ /// Construction of Gaussian RNG functor.
1668
+ Params(
1669
+ TensorView view_, ///< destination tensor
1670
+ Array<Element, Layout::kRank> const & v_,
1671
+ Element s_ = Element(0)
1672
+ ):
1673
+ view(view_), v(v_), s(s_) {
1674
+
1675
+ }
1676
+ };
1677
+
1678
+ //
1679
+ // Data members
1680
+ //
1681
+
1682
+ /// Parameters object
1683
+ Params params;
1684
+
1685
+ //
1686
+ // Methods
1687
+ //
1688
+
1689
+ /// Device-side initialization of RNG
1690
+ CUTLASS_DEVICE
1691
+ TensorFillLinearFunc(Params const &params): params(params) {
1692
+
1693
+ }
1694
+
1695
+ /// Compute random value and update RNG state
1696
+ CUTLASS_DEVICE
1697
+ void operator()(TensorCoord const &coord) {
1698
+
1699
+ Element sum = params.s;
1700
+
1701
+ CUTLASS_PRAGMA_UNROLL
1702
+ for (int i = 0; i < Layout::kRank; ++i) {
1703
+ if constexpr (is_complex<Element>::value) {
1704
+ if constexpr (sizeof_bits<Element>::value <= 32) {
1705
+ sum = Element(static_cast<complex<float>>(sum) +
1706
+ static_cast<complex<float>>(params.v[i]) * static_cast<complex<float>>(coord[i]));
1707
+ }
1708
+ }
1709
+ else if constexpr (sizeof_bits<Element>::value <= 32) {
1710
+ if constexpr (std::numeric_limits<Element>::is_integer) {
1711
+ sum = Element(static_cast<int32_t>(sum) +
1712
+ static_cast<int32_t>(params.v[i]) * static_cast<int32_t>(coord[i]));
1713
+ }
1714
+ else {
1715
+ sum = Element(static_cast<float>(sum) +
1716
+ static_cast<float>(params.v[i]) * static_cast<float>(coord[i]));
1717
+ }
1718
+ }
1719
+ else {
1720
+ sum += params.v[i] * coord[i];
1721
+ }
1722
+ }
1723
+
1724
+ params.view.at(coord) = sum;
1725
+ }
1726
+ };
1727
+
1728
+ } // namespace detail
1729
+
1730
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1731
+
1732
+ /// Fills tensor with a linear combination of its coordinate and another vector
1733
+ template <
1734
+ typename Element, ///< Element type
1735
+ typename Layout> ///< Layout function
1736
+ void TensorFillLinear(
1737
+ TensorView<Element, Layout> view, ///< destination tensor
1738
+ Array<Element, Layout::kRank> const & v,
1739
+ Element s = Element(0),
1740
+ cudaStream_t stream = nullptr) {
1741
+
1742
+ using Func = detail::TensorFillLinearFunc<Element, Layout>;
1743
+ using Params = typename Func::Params;
1744
+
1745
+ TensorForEach<Func, Layout::kRank, Params>(
1746
+ view.extent(),
1747
+ Params(view, v, s),
1748
+ /*grid_size*/0, /*block_size*/0,
1749
+ stream
1750
+ );
1751
+ }
1752
+
1753
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1754
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1755
+
1756
+ /// Fills a tensor with random values from a distribution.
1757
+ template <
1758
+ typename Element, ///< Element type
1759
+ typename Layout> ///< Layout function
1760
+ void TensorFillRandom(
1761
+ TensorView<Element, Layout> view, ///< destination tensor
1762
+ uint64_t seed,
1763
+ Distribution dist,
1764
+ cudaStream_t stream = nullptr,
1765
+ int exclude_zero = -1 ///< If non-negative, excludes 0.
1766
+ /// Note that setting this flag will result in more 1's,
1767
+ /// as we use a simple mechanism to replace 0's by adding/subtracting 1's.
1768
+ ) {
1769
+
1770
+ using Real = typename RealType<Element>::Type;
1771
+
1772
+ if (dist.kind == Distribution::Gaussian) {
1773
+ TensorFillRandomGaussian<Element, Layout>(
1774
+ view,
1775
+ seed,
1776
+ static_cast<Real>(dist.gaussian.mean),
1777
+ static_cast<Real>(dist.gaussian.stddev),
1778
+ dist.int_scale,
1779
+ exclude_zero,
1780
+ stream);
1781
+ } else if (dist.kind == Distribution::Uniform) {
1782
+ TensorFillRandomUniform<Element, Layout>(
1783
+ view,
1784
+ seed,
1785
+ static_cast<Real>(dist.uniform.max),
1786
+ static_cast<Real>(dist.uniform.min),
1787
+ dist.int_scale,
1788
+ dist.uniform.pnan,
1789
+ exclude_zero,
1790
+ stream);
1791
+ }
1792
+ }
1793
+
1794
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1795
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1796
+
1797
+ /// Fills a block of data with sequential elements
1798
+ template <
1799
+ typename Element
1800
+ >
1801
+ void BlockFillSequential(
1802
+ Element *ptr,
1803
+ int64_t capacity,
1804
+ Element v = Element(1),
1805
+ Element s = Element(0)) {
1806
+
1807
+ using Layout = layout::PackedVectorLayout;
1808
+ Layout::TensorCoord size(static_cast<Layout::Index>(capacity)); // -Wconversion
1809
+ Layout layout = Layout::packed(size);
1810
+ TensorView<Element, Layout> view(ptr, layout, size);
1811
+
1812
+ Array<Element, Layout::kRank> c{};
1813
+ c[0] = v;
1814
+
1815
+ TensorFillLinear(view, c, s);
1816
+ }
1817
+
1818
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1819
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1820
+
1821
+ /// Fills a block of data with sequential elements
1822
+ template <
1823
+ typename Element
1824
+ >
1825
+ void BlockFillRandom(
1826
+ Element *ptr,
1827
+ size_t capacity,
1828
+ uint64_t seed,
1829
+ Distribution dist,
1830
+ cudaStream_t stream = nullptr) {
1831
+
1832
+ using Real = typename RealType<Element>::Type;
1833
+
1834
+ if (dist.kind == Distribution::Gaussian) {
1835
+ BlockFillRandomGaussian<Element>(
1836
+ ptr,
1837
+ capacity,
1838
+ seed,
1839
+ static_cast<Real>(dist.gaussian.mean),
1840
+ static_cast<Real>(dist.gaussian.stddev),
1841
+ dist.int_scale,
1842
+ stream);
1843
+ }
1844
+ else if (dist.kind == Distribution::Uniform) {
1845
+ BlockFillRandomUniform<Element>(
1846
+ ptr,
1847
+ capacity,
1848
+ seed,
1849
+ static_cast<Real>(dist.uniform.max),
1850
+ static_cast<Real>(dist.uniform.min),
1851
+ dist.int_scale,
1852
+ dist.uniform.pnan,
1853
+ stream);
1854
+ }
1855
+ }
1856
+
1857
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1858
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1859
+
1860
+ namespace detail {
1861
+
1862
+ /// Computes a random Gaussian distribution
1863
+ template <
1864
+ typename Element, ///< Element type
1865
+ typename Layout> ///< Layout function
1866
+ struct TensorCopyDiagonalInFunc {
1867
+
1868
+ /// View type
1869
+ using TensorView = TensorView<Element, Layout>;
1870
+
1871
+ /// Scalar type
1872
+ typedef typename TensorView::Element T;
1873
+
1874
+ /// Coordinate in tensor's index space
1875
+ typedef typename TensorView::TensorCoord TensorCoord;
1876
+
1877
+ /// Parameters structure
1878
+ struct Params {
1879
+
1880
+ //
1881
+ // Data members
1882
+ //
1883
+
1884
+ TensorView view;
1885
+ Element const *ptr;
1886
+
1887
+ /// Default ctor
1888
+ CUTLASS_HOST_DEVICE
1889
+ Params() { }
1890
+
1891
+ //
1892
+ // Methods
1893
+ //
1894
+
1895
+ /// Construction of Gaussian RNG functor.
1896
+ Params(
1897
+ TensorView view_, ///< destination tensor
1898
+ Element const *ptr_
1899
+ ):
1900
+ view(view_), ptr(ptr_) {
1901
+
1902
+ }
1903
+ };
1904
+
1905
+ //
1906
+ // Data members
1907
+ //
1908
+
1909
+ /// Parameters object
1910
+ Params params;
1911
+
1912
+ //
1913
+ // Methods
1914
+ //
1915
+
1916
+ /// Device-side initialization of RNG
1917
+ CUTLASS_DEVICE
1918
+ TensorCopyDiagonalInFunc(Params const &params): params(params) {
1919
+
1920
+ }
1921
+
1922
+ /// Only update the diagonal element
1923
+ CUTLASS_DEVICE
1924
+ void operator()(TensorCoord const &coord) {
1925
+ bool is_diagonal = true;
1926
+
1927
+ CUTLASS_PRAGMA_UNROLL
1928
+ for (int i = 1; i < Layout::kRank; ++i) {
1929
+ if (coord[i] != coord[0]) {
1930
+ is_diagonal = false;
1931
+ }
1932
+ }
1933
+ if (is_diagonal) {
1934
+ params.view.at(coord) = params.ptr[coord[0]];
1935
+ }
1936
+ }
1937
+ };
1938
+
1939
+ } // namespace detail
1940
+
1941
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1942
+
1943
+ /// Copies a diagonal in from host memory without modifying off-diagonal elements.
1944
+ template <
1945
+ typename Element, ///< Element type
1946
+ typename Layout> ///< Layout function
1947
+ void TensorCopyDiagonalIn(
1948
+ TensorView<Element, Layout> view, ///< destination tensor
1949
+ Element const *ptr, ///< dense buffer of elements
1950
+ cudaStream_t stream = nullptr) {
1951
+
1952
+ using Func = detail::TensorCopyDiagonalInFunc<Element, Layout>;
1953
+ using Params = typename Func::Params;
1954
+
1955
+ TensorForEach<Func, Layout::kRank, Params>(
1956
+ view.extent(),
1957
+ Params(view, ptr),
1958
+ /*grid_size*/0, /*block_size*/0,
1959
+ stream
1960
+ );
1961
+ }
1962
+
1963
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1964
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1965
+
1966
+
1967
+ namespace detail {
1968
+
1969
+ /// Computes a random Gaussian distribution
1970
+ template <
1971
+ typename Element, ///< Element type
1972
+ typename Layout> ///< Layout function
1973
+ struct TensorCopyDiagonalOutFunc {
1974
+
1975
+ /// View type
1976
+ using TensorView = TensorView<Element, Layout>;
1977
+
1978
+ /// Scalar type
1979
+ typedef typename TensorView::Element T;
1980
+
1981
+ /// Coordinate in tensor's index space
1982
+ typedef typename TensorView::TensorCoord TensorCoord;
1983
+
1984
+ /// Parameters structure
1985
+ struct Params {
1986
+
1987
+ //
1988
+ // Data members
1989
+ //
1990
+
1991
+ TensorView view;
1992
+ Element *ptr;
1993
+
1994
+ /// Default ctor
1995
+ CUTLASS_HOST_DEVICE
1996
+ Params() { }
1997
+
1998
+ //
1999
+ // Methods
2000
+ //
2001
+
2002
+ /// Construction of Gaussian RNG functor.
2003
+ Params(
2004
+ TensorView view_, ///< destination tensor
2005
+ Element *ptr_
2006
+ ):
2007
+ view(view_), ptr(ptr_) {
2008
+
2009
+ }
2010
+ };
2011
+
2012
+ //
2013
+ // Data members
2014
+ //
2015
+
2016
+ /// Parameters object
2017
+ Params params;
2018
+
2019
+ //
2020
+ // Methods
2021
+ //
2022
+
2023
+ /// Device-side initialization of RNG
2024
+ CUTLASS_DEVICE
2025
+ TensorCopyDiagonalOutFunc(Params const &params): params(params) {
2026
+
2027
+ }
2028
+
2029
+ /// Compute random value and update RNG state
2030
+ CUTLASS_DEVICE
2031
+ void operator()(TensorCoord const &coord) {
2032
+ bool is_diagonal = true;
2033
+
2034
+ CUTLASS_PRAGMA_UNROLL
2035
+ for (int i = 1; i < Layout::kRank; ++i) {
2036
+ if (coord[i] != coord[0]) {
2037
+ is_diagonal = false;
2038
+ }
2039
+ }
2040
+ if (is_diagonal) {
2041
+ params.ptr[coord[0]] = params.view.at(coord);
2042
+ }
2043
+ }
2044
+ };
2045
+
2046
+ } // namespace detail
2047
+
2048
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
2049
+
2050
+ /// Copies the diagonal of a tensor into a dense buffer in host memory.
2051
+ template <
2052
+ typename Element, ///< Element type
2053
+ typename Layout> ///< Layout function
2054
+ void TensorCopyDiagonalOut(
2055
+ Element *ptr, ///< dense buffer of elements
2056
+ TensorView<Element, Layout> view, ///< source tensor
2057
+ cudaStream_t stream = nullptr) {
2058
+
2059
+ using Func = detail::TensorCopyDiagonalOutFunc<Element, Layout>;
2060
+ using Params = typename Func::Params;
2061
+
2062
+ TensorForEach<Func, Layout::kRank, Params>(
2063
+ view.extent(),
2064
+ Params(view, ptr),
2065
+ /*grid_size*/0, /*block_size*/0,
2066
+ stream
2067
+ );
2068
+ }
2069
+
2070
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
2071
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
2072
+
2073
+ } // namespace device
2074
+ } // namespace reference
2075
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include <stdexcept>
34
+ #include "cutlass/cutlass.h"
35
+ #include "cutlass/util/reference/device/kernel/tensor_foreach.h"
36
+
37
+ namespace cutlass {
38
+ namespace reference {
39
+ namespace device {
40
+
41
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ /// Launches a kernel calling a functor for each element in a tensor's index space.
44
+ template <typename Func, int Rank, typename Params>
45
+ struct TensorForEach {
46
+
47
+ /// Constructor performs the operation.
48
+ TensorForEach(
49
+ Coord<Rank> size, Params params = Params(),
50
+ int grid_size = 0, int block_size = 0,
51
+ cudaStream_t stream = nullptr) {
52
+
53
+ if (!grid_size || !block_size) {
54
+
55
+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
56
+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
57
+ &grid_size,
58
+ &block_size,
59
+ reinterpret_cast<void const *>(kernel::TensorForEach<Func, Rank, Params>));
60
+
61
+ if (result != cudaSuccess) {
62
+ throw std::runtime_error("Failed to query occupancy.");
63
+ }
64
+ // Limit block size. This has the effect of increasing the number of items processed by a
65
+ // single thread and reduces the impact of initialization overhead.
66
+ block_size = (block_size < 128 ? block_size : 128);
67
+ }
68
+
69
+ dim3 grid(grid_size, 1, 1);
70
+ dim3 block(block_size, 1, 1);
71
+
72
+ kernel::TensorForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(size, params);
73
+ }
74
+ };
75
+
76
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
77
+
78
+ /// Launches a kernel calling a functor for each element along a tensor's diagonal
79
+ template <typename Func, int Rank, typename Params>
80
+ struct TensorDiagonalForEach {
81
+
82
+ /// Constructor performs the operation
83
+ TensorDiagonalForEach(
84
+ Coord<Rank> size, Params params = Params(),
85
+ int start = 0, int end = -1,
86
+ int block_size = 128, cudaStream_t stream = nullptr) {
87
+
88
+ if (end < 0) {
89
+ end = size.min();
90
+ }
91
+
92
+ dim3 block(block_size, 1, 1);
93
+ dim3 grid((end - start + block_size - 1) / block_size, 1, 1);
94
+
95
+ kernel::TensorDiagonalForEach<Func, Rank, Params><<< grid, block, 0, stream >>>(
96
+ size, params, start, end);
97
+ }
98
+ };
99
+
100
+
101
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
102
+
103
+ template <typename Element, typename Func>
104
+ struct BlockForEach {
105
+
106
+ /// Constructor performs the operation.
107
+ BlockForEach(
108
+ Element *ptr,
109
+ size_t capacity,
110
+ typename Func::Params params = typename Func::Params(),
111
+ int grid_size = 0,
112
+ int block_size = 0,
113
+ cudaStream_t stream = nullptr) {
114
+
115
+ if (!grid_size || !block_size) {
116
+
117
+ // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
118
+ cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
119
+ &grid_size,
120
+ &block_size,
121
+ reinterpret_cast<void const *>(kernel::BlockForEach<Element, Func>));
122
+
123
+ if (result != cudaSuccess) {
124
+ throw std::runtime_error("Failed to query occupancy.");
125
+ }
126
+ // Limit block size. This has the effect of increasing the number of items processed by a
127
+ // single thread and reduces the impact of initialization overhead.
128
+ block_size = (block_size < 128 ? block_size : 128);
129
+ }
130
+
131
+ dim3 grid(grid_size, 1, 1);
132
+ dim3 block(block_size, 1, 1);
133
+
134
+ kernel::BlockForEach<Element, Func><<< grid, block, 0, stream >>>(ptr, capacity, params);
135
+ }
136
+ };
137
+
138
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
139
+
140
+ } // namespace device
141
+ } // namespace reference
142
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include <cmath>
34
+
35
+ #include "cutlass/cutlass.h"
36
+ #include "cutlass/complex.h"
37
+ #include "cutlass/functional.h"
38
+ #include "cutlass/numeric_conversion.h"
39
+ #include "cutlass/tensor_view.h"
40
+ #include "cutlass/util/device_memory.h"
41
+ #include "cutlass/util/reference/detail/linear_to_coordinate.h"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace reference {
47
+ namespace device {
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace kernel {
52
+
53
+ template <
54
+ typename Element,
55
+ typename Layout,
56
+ typename ComputeType,
57
+ typename ReduceOp,
58
+ typename TransformOp,
59
+ int kBlockSize = 128
60
+ >
61
+ __global__ void TensorTransformReducePartial(
62
+ TensorView<Element, Layout> view, /// View of the tensor to reduce over
63
+ ComputeType identity, /// Identity element of the reduction operation
64
+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
65
+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
66
+ ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
67
+
68
+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
69
+ int64_t size = view.size();
70
+
71
+ __shared__ ComputeType scratchpad[kBlockSize];
72
+
73
+ for (; idx < size; idx += blockDim.x * gridDim.x) {
74
+
75
+ // Map linear thread ID onto tensor coordinate
76
+ typename Layout::TensorCoord coord;
77
+
78
+ cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view.extent());
79
+
80
+ if (view.contains(coord)) {
81
+
82
+ // Fetch element
83
+ Element x = view.at(coord);
84
+
85
+ // Transform
86
+ identity = reduce(identity, transform(x));
87
+ }
88
+ }
89
+
90
+ scratchpad[threadIdx.x] = identity;
91
+
92
+ __syncthreads();
93
+
94
+ // One thread performs the final reduction and stores out. This could be enhanced via
95
+ // a tree reduction and pipelining.
96
+ if (threadIdx.x == 0) {
97
+
98
+ for (int i = 1; i < kBlockSize; ++i) {
99
+ identity = reduce(identity, scratchpad[i]);
100
+ }
101
+
102
+ workspace[blockIdx.x] = identity;
103
+ }
104
+ }
105
+
106
+ template <
107
+ typename Element,
108
+ typename Layout,
109
+ typename ComputeType,
110
+ typename ReduceOp,
111
+ typename TransformOp,
112
+ int kBlockSize = 128
113
+ >
114
+ __global__ void TensorTransformReducePartial(
115
+ TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
116
+ TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
117
+ ComputeType identity, /// Identity element of the reduction operation
118
+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
119
+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
120
+ ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
121
+
122
+ int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
123
+ auto size = static_cast<int64_t>(view_A.size());
124
+
125
+ __shared__ ComputeType scratchpad[kBlockSize];
126
+
127
+ for (; idx < size; idx += blockDim.x * gridDim.x) {
128
+
129
+ // Map linear thread ID onto tensor coordinate
130
+ typename Layout::TensorCoord coord;
131
+
132
+ cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view_A.extent());
133
+
134
+ if (view_A.contains(coord)) {
135
+
136
+ // Fetch element
137
+ Element a = view_A.at(coord);
138
+ Element b = view_B.at(coord);
139
+
140
+ // Transform
141
+ identity = reduce(identity, transform(a, b));
142
+ }
143
+ }
144
+
145
+ scratchpad[threadIdx.x] = identity;
146
+
147
+ __syncthreads();
148
+
149
+ // One thread performs the final reduction and stores out. This could be enhanced via
150
+ // a tree reduction and pipelining.
151
+ if (threadIdx.x == 0) {
152
+
153
+ for (int i = 1; i < kBlockSize; ++i) {
154
+ identity = reduce(identity, scratchpad[i]);
155
+ }
156
+
157
+ workspace[blockIdx.x] = identity;
158
+ }
159
+ }
160
+
161
+
162
+ template <
163
+ typename ComputeType,
164
+ typename ReduceOp,
165
+ int kBlockSize = 32
166
+ >
167
+ __global__ void TensorTransformReduceFinalize(
168
+ ComputeType *workspace,
169
+ ComputeType identity,
170
+ int workspace_size,
171
+ ReduceOp reduce) {
172
+
173
+ __shared__ ComputeType scratchpad[kBlockSize];
174
+
175
+ for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) {
176
+ identity = reduce(identity, workspace[idx]);
177
+ }
178
+
179
+ scratchpad[threadIdx.x] = identity;
180
+
181
+ __syncthreads();
182
+
183
+ if (threadIdx.x == 0) {
184
+
185
+ for (int i = 1; i < kBlockSize; ++i) {
186
+ identity = reduce(identity, scratchpad[i]);
187
+ }
188
+
189
+ workspace[0] = identity;
190
+ }
191
+ }
192
+
193
+ } // namespace kernel
194
+
195
+ /////////////////////////////////////////////////////////////////////////////////////////////////
196
+
197
+ /// Transform-reduce operation over the elements of a tensor
198
+ template <
199
+ typename Element,
200
+ typename Layout,
201
+ typename ComputeType,
202
+ typename ReduceOp,
203
+ typename TransformOp
204
+ >
205
+ ComputeType TensorTransformReduce(
206
+ TensorView<Element, Layout> view, /// View of the tensor to reduce over
207
+ ComputeType identity, /// Identity element of the reduction operation
208
+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
209
+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
210
+ ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
211
+ int workspace_size, /// Number of elements in workspace
212
+ cudaStream_t stream = nullptr, /// CUDA stream to launch into
213
+ bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
214
+ ) {
215
+
216
+ int const kBlockSize = 128;
217
+
218
+ dim3 block(kBlockSize, 1);
219
+ dim3 grid(workspace_size, 1);
220
+
221
+ kernel::TensorTransformReducePartial<
222
+ Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
223
+ ><<< grid, block, 0, stream >>>(
224
+ view, identity, reduce, transform, workspace
225
+ );
226
+
227
+ int const kFinalizeBlockSize = 32;
228
+
229
+ kernel::TensorTransformReduceFinalize<
230
+ ComputeType, ReduceOp, kFinalizeBlockSize
231
+ ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
232
+ workspace, identity, workspace_size, reduce
233
+ );
234
+
235
+ cudaStreamSynchronize(stream);
236
+
237
+ if (copy_out) {
238
+ cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
239
+ if (result != cudaSuccess) {
240
+ throw std::runtime_error("cudaMemcpy() failed");
241
+ }
242
+ }
243
+
244
+ return identity;
245
+ }
246
+
247
+ /// Transform-reduce operation over the elements of two tensors, zipped together
248
+ template <
249
+ typename Element,
250
+ typename Layout,
251
+ typename ComputeType,
252
+ typename ReduceOp,
253
+ typename TransformOp
254
+ >
255
+ ComputeType TensorTransformReduce(
256
+ TensorView<Element, Layout> view_A, /// View of the tensor to reduce over
257
+ TensorView<Element, Layout> view_B, /// View of the tensor to reduce over
258
+ ComputeType identity, /// Identity element of the reduction operation
259
+ ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType
260
+ TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType
261
+ ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0]
262
+ int workspace_size, /// Number of elements in workspace
263
+ cudaStream_t stream = nullptr, /// CUDA stream to launch into
264
+ bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned.
265
+ ) {
266
+
267
+ if (view_A.extent() != view_B.extent()) {
268
+ throw std::runtime_error("Extents must be equal.");
269
+ }
270
+
271
+ int const kBlockSize = 128;
272
+
273
+ dim3 block(kBlockSize, 1);
274
+ dim3 grid(workspace_size, 1);
275
+
276
+ kernel::TensorTransformReducePartial<
277
+ Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize
278
+ ><<< grid, block, 0, stream >>>(
279
+ view_A, view_B, identity, reduce, transform, workspace
280
+ );
281
+
282
+ int const kFinalizeBlockSize = 32;
283
+
284
+ kernel::TensorTransformReduceFinalize<
285
+ ComputeType, ReduceOp, kFinalizeBlockSize
286
+ ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>(
287
+ workspace, identity, workspace_size, reduce
288
+ );
289
+
290
+ cudaStreamSynchronize(stream);
291
+
292
+ if (copy_out) {
293
+ cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost);
294
+ if (result != cudaSuccess) {
295
+ throw std::runtime_error("cudaMemcpy() failed");
296
+ }
297
+ }
298
+
299
+ return identity;
300
+ }
301
+
302
+ /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
303
+ /// workspace
304
+ template <
305
+ typename Element,
306
+ typename Layout,
307
+ typename ComputeType,
308
+ typename ReduceOp,
309
+ typename TransformOp
310
+ >
311
+ ComputeType TensorTransformReduce(
312
+ TensorView<Element, Layout> view,
313
+ ComputeType identity,
314
+ ReduceOp reduce,
315
+ TransformOp transform,
316
+ cudaStream_t stream = nullptr,
317
+ int workspace_size = 0
318
+ ) {
319
+
320
+ // Optionally query for the SM count to size the workspace.
321
+ if (!workspace_size) {
322
+
323
+ int device_idx = 0;
324
+ cudaDeviceProp prop;
325
+
326
+ cudaError_t result = cudaGetDevice(&device_idx);
327
+ if (result != cudaSuccess) {
328
+ throw std::runtime_error("cudaGetDevice() failed");
329
+ }
330
+
331
+ result = cudaGetDeviceProperties(&prop, device_idx);
332
+ if (result != cudaSuccess) {
333
+ throw std::runtime_error("cudaGetDeviceProp() failed");
334
+ }
335
+
336
+ workspace_size = int(prop.multiProcessorCount);
337
+ }
338
+
339
+ DeviceAllocation<ComputeType> workspace(workspace_size);
340
+
341
+ ComputeType output = TensorTransformReduce(
342
+ view,
343
+ identity,
344
+ reduce,
345
+ transform,
346
+ workspace.get(),
347
+ workspace_size,
348
+ stream,
349
+ true);
350
+
351
+ return output;
352
+ }
353
+
354
+
355
+ /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
356
+ /// workspace
357
+ template <
358
+ typename Element,
359
+ typename Layout,
360
+ typename ComputeType,
361
+ typename ReduceOp,
362
+ typename TransformOp
363
+ >
364
+ ComputeType TensorTransformReduce(
365
+ TensorView<Element, Layout> view_A,
366
+ TensorView<Element, Layout> view_B,
367
+ ComputeType identity,
368
+ ReduceOp reduce,
369
+ TransformOp transform,
370
+ cudaStream_t stream = nullptr,
371
+ int workspace_size = 0
372
+ ) {
373
+
374
+ // Optionally query for the SM count to size the workspace.
375
+ if (!workspace_size) {
376
+
377
+ int device_idx = 0;
378
+ cudaDeviceProp prop;
379
+
380
+ cudaError_t result = cudaGetDevice(&device_idx);
381
+ if (result != cudaSuccess) {
382
+ throw std::runtime_error("cudaGetDevice() failed");
383
+ }
384
+
385
+ result = cudaGetDeviceProperties(&prop, device_idx);
386
+ if (result != cudaSuccess) {
387
+ throw std::runtime_error("cudaGetDeviceProp() failed");
388
+ }
389
+
390
+ workspace_size = int(prop.multiProcessorCount);
391
+ }
392
+
393
+ DeviceAllocation<ComputeType> workspace(workspace_size);
394
+
395
+ ComputeType output = TensorTransformReduce(
396
+ view_A,
397
+ view_B,
398
+ identity,
399
+ reduce,
400
+ transform,
401
+ workspace.get(),
402
+ workspace_size,
403
+ stream,
404
+ true);
405
+
406
+ return output;
407
+ }
408
+
409
+ /////////////////////////////////////////////////////////////////////////////////////////////////
410
+
411
+ /// Helper to compute the sum of the elements of a tensor
412
+ template <
413
+ typename Element,
414
+ typename Layout,
415
+ typename ComputeType = Element
416
+ >
417
+ ComputeType TensorSum(
418
+ TensorView<Element, Layout> view,
419
+ ComputeType identity = ComputeType(),
420
+ cudaStream_t stream = nullptr,
421
+ int workspace_size = 0
422
+ ) {
423
+
424
+ plus<ComputeType> reduce;
425
+ NumericConverter<ComputeType, Element> transform;
426
+
427
+ return TensorTransformReduce(
428
+ view, identity, reduce, transform, stream, workspace_size);
429
+ }
430
+
431
+ /// Helper to compute the sum of the squares of the elements of a tensor
432
+ template <
433
+ typename Element,
434
+ typename Layout,
435
+ typename ComputeType = Element
436
+ >
437
+ ComputeType TensorSumSq(
438
+ TensorView<Element, Layout> view,
439
+ ComputeType identity = ComputeType(),
440
+ cudaStream_t stream = nullptr,
441
+ int workspace_size = 0
442
+ ) {
443
+
444
+ plus<ComputeType> reduce;
445
+ magnitude_squared<Element, ComputeType> transform;
446
+
447
+ return TensorTransformReduce(
448
+ view, identity, reduce, transform, stream, workspace_size);
449
+ }
450
+
451
+ /// Helper to compute the norm of the elements of a tensor.
452
+ template <
453
+ typename Element,
454
+ typename Layout,
455
+ typename ComputeType = double
456
+ >
457
+ ComputeType TensorNorm(
458
+ TensorView<Element, Layout> view,
459
+ ComputeType identity = ComputeType(),
460
+ cudaStream_t stream = nullptr,
461
+ int workspace_size = 0
462
+ ) {
463
+
464
+ return std::sqrt(TensorSumSq(view, identity, stream, workspace_size));
465
+ }
466
+
467
+ /////////////////////////////////////////////////////////////////////////////////////////////////
468
+
469
+ /// Helper to compute the sum of the squares of the differences of two tensors
470
+ template <
471
+ typename Element,
472
+ typename Layout,
473
+ typename ComputeType = double
474
+ >
475
+ ComputeType TensorSumSqDiff(
476
+ TensorView<Element, Layout> view_A,
477
+ TensorView<Element, Layout> view_B,
478
+ ComputeType identity = ComputeType(),
479
+ cudaStream_t stream = nullptr,
480
+ int workspace_size = 0
481
+ ) {
482
+
483
+ plus<ComputeType> reduce;
484
+ magnitude_squared_difference<Element, ComputeType> transform;
485
+
486
+ return TensorTransformReduce(
487
+ view_A, view_B, identity, reduce, transform, stream, workspace_size);
488
+ }
489
+
490
+
491
+ /// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
492
+ template <
493
+ typename Element,
494
+ typename Layout,
495
+ typename ComputeType = double
496
+ >
497
+ ComputeType TensorNormDiff(
498
+ TensorView<Element, Layout> view_A,
499
+ TensorView<Element, Layout> view_B,
500
+ ComputeType identity = ComputeType(),
501
+ cudaStream_t stream = nullptr,
502
+ int workspace_size = 0
503
+ ) {
504
+
505
+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size));
506
+ }
507
+
508
+ /////////////////////////////////////////////////////////////////////////////////////////////////
509
+
510
+ } // namespace device
511
+ } // namespace reference
512
+ } // namespace cutlass
513
+
514
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Defines device-side elementwise operations on TensorView. Note, the operations defined
33
+ in this header are not specialized for any particular data layout and are therefore not
34
+ intended to offer the best possible performance. Rather, they are intended to be generic
35
+ reference implementations to support the CUTLASS unit tests.
36
+ */
37
+
38
+ #pragma once
39
+
40
+ // Cutlass includes
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/tensor_view.h"
43
+
44
+ #include "cutlass/util/reference/device/tensor_foreach.h"
45
+
46
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace reference {
50
+ namespace device {
51
+
52
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
53
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ namespace detail {
56
+
57
+ template <
58
+ typename Element, ///< Element type
59
+ typename Layout> ///< Layout function
60
+ struct TensorReLuFunc {
61
+
62
+ /// View type
63
+ using TensorView = TensorView<Element, Layout>;
64
+
65
+ /// Coordinate in tensor's index space
66
+ using TensorCoord = typename TensorView::TensorCoord;
67
+
68
+ /// Parameters structure
69
+ struct Params {
70
+
71
+ //
72
+ // Data members
73
+ //
74
+
75
+ TensorView view;
76
+ Element threshold;
77
+
78
+
79
+ //
80
+ // Methods
81
+ //
82
+
83
+ Params(
84
+ TensorView view_ = TensorView(),
85
+ Element threshold_ = Element(0)
86
+ ):
87
+ view(view_), threshold(threshold_) {
88
+
89
+ }
90
+ };
91
+
92
+ //
93
+ // Data members
94
+ //
95
+
96
+ Params params;
97
+
98
+ //
99
+ // Methods
100
+ //
101
+
102
+ CUTLASS_DEVICE
103
+ TensorReLuFunc(Params const &params): params(params) {
104
+
105
+ }
106
+
107
+ CUTLASS_DEVICE
108
+ void operator()(TensorCoord const &coord) {
109
+
110
+ Element const & value = params.view.at(coord);
111
+ params.view.at(coord) = (value < params.threshold) ? params.threshold : value;
112
+ }
113
+ };
114
+
115
+ } // namespace detail
116
+
117
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
118
+
119
+ /// Apply ReLu on a tensor
120
+ template <
121
+ typename Element, ///< Element type
122
+ typename Layout> ///< Layout function
123
+ void TensorReLu(
124
+ TensorView<Element, Layout> view, ///< destination tensor
125
+ Element threshold = Element(0)) { ///< ReLu threshold
126
+
127
+ using Func = detail::TensorReLuFunc<Element, Layout>;
128
+ using Params = typename Func::Params;
129
+
130
+ TensorForEach<Func, Layout::kRank, Params>(
131
+ view.extent(),
132
+ Params(view, threshold)
133
+ );
134
+ }
135
+
136
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
137
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
138
+
139
+ } // namespace device
140
+ } // namespace reference
141
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for GEMM in host-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/tensor_view.h"
39
+ #include "cutlass/gemm/gemm.h"
40
+
41
+ namespace cutlass {
42
+ namespace reference {
43
+ namespace device {
44
+ namespace thread {
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ /// Thread-level blocked general matrix product.
49
+ //
50
+ // Note, this is a reference implementation. Performance is not expected to approach peak.
51
+ //
52
+ template <
53
+ typename TensorRefA,
54
+ typename TensorRefB,
55
+ typename TensorRefC,
56
+ typename ScalarType,
57
+ typename AccumulatorType,
58
+ typename OutputTile,
59
+ typename InnerProductOp = multiply_add<AccumulatorType>,
60
+ typename ConvertOp = NumericConverter<typename TensorRefC::Element, ScalarType>
61
+ >
62
+ struct Gemm {
63
+
64
+ using ElementA = typename TensorRefA::Element;
65
+ using ElementB = typename TensorRefB::Element;
66
+ using ElementC = typename TensorRefC::Element;
67
+
68
+ //
69
+ // Data members
70
+ //
71
+
72
+ /// Tile for A operand
73
+ ElementA A_tile[OutputTile::kColumn];
74
+
75
+ /// Tile for B operand
76
+ ElementB B_tile[OutputTile::kRow];
77
+
78
+ /// Tile for Accumulator
79
+ AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow];
80
+
81
+ //
82
+ // Methods
83
+ //
84
+
85
+ /// Constructor
86
+ CUTLASS_HOST_DEVICE
87
+ Gemm(AccumulatorType initial_accum = AccumulatorType(0)) {
88
+
89
+ // Clear fetch registers
90
+ for (int i = 0; i < OutputTile::kColumn; ++i) {
91
+ A_tile[i] = ElementA(0);
92
+ }
93
+
94
+ for (int j = 0; j < OutputTile::kRow; ++j) {
95
+ B_tile[j] = ElementB(0);
96
+ }
97
+
98
+ // Clear accumulators
99
+ CUTLASS_PRAGMA_UNROLL
100
+ for (int j = 0; j < OutputTile::kColumn; ++j) {
101
+ CUTLASS_PRAGMA_UNROLL
102
+ for (int i = 0; i < OutputTile::kRow; ++i) {
103
+ accum[j][i] = initial_accum;
104
+ }
105
+ }
106
+ }
107
+
108
+ /// Computes a matrix product
109
+ CUTLASS_HOST_DEVICE
110
+ Gemm & multiply_add(
111
+ gemm::GemmCoord problem_size,
112
+ TensorRefA tensor_a,
113
+ TensorRefB tensor_b,
114
+ MatrixCoord output_coord = MatrixCoord()) {
115
+
116
+ InnerProductOp inner_product_op;
117
+
118
+ // Loop over the GEMM K dimension
119
+ CUTLASS_PRAGMA_NO_UNROLL
120
+ for (int k = 0; k < problem_size.k(); ++k) {
121
+
122
+ // Fetch a slice of the A matrix
123
+ CUTLASS_PRAGMA_UNROLL
124
+ for (int i = 0; i < OutputTile::kColumn; ++i) {
125
+ if (output_coord.row() + i < problem_size.m()) {
126
+ A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k));
127
+ }
128
+ }
129
+
130
+ // Fetch a slice of the B matrix
131
+ CUTLASS_PRAGMA_UNROLL
132
+ for (int j = 0; j < OutputTile::kRow; ++j) {
133
+ if (output_coord.column() + j < problem_size.n()) {
134
+ B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j));
135
+ }
136
+ }
137
+
138
+ // Compute an accumulated matrix product
139
+ CUTLASS_PRAGMA_UNROLL
140
+ for (int j = 0; j < OutputTile::kRow; ++j) {
141
+ CUTLASS_PRAGMA_UNROLL
142
+ for (int i = 0; i < OutputTile::kColumn; ++i) {
143
+ accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]);
144
+ }
145
+ }
146
+ }
147
+
148
+ return *this;
149
+ }
150
+
151
+ /// Performs linear scaling of matrix product and updates output tensor
152
+ CUTLASS_HOST_DEVICE
153
+ Gemm & epilogue(
154
+ gemm::GemmCoord problem_size,
155
+ ScalarType alpha,
156
+ ScalarType beta,
157
+ TensorRefC tensor_c,
158
+ TensorRefC tensor_d,
159
+ MatrixCoord output_coord = MatrixCoord()) {
160
+
161
+ ConvertOp convert_op;
162
+
163
+ // Update the output tensor
164
+ for (int j = 0; j < OutputTile::kRow; ++j) {
165
+ for (int i = 0; i < OutputTile::kColumn; ++i) {
166
+ MatrixCoord coord = output_coord + MatrixCoord(i, j);
167
+ if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
168
+
169
+ tensor_d.at(coord) = convert_op(
170
+ alpha * ScalarType(accum[j][i]) +
171
+ beta * ScalarType(tensor_c.at(coord))
172
+ );
173
+ }
174
+ }
175
+ }
176
+
177
+ return *this;
178
+ }
179
+ };
180
+
181
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
182
+
183
+ } // namespace thread
184
+ } // namespace device
185
+ } // namespace reference
186
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for CONV in host-side code.
33
+ */
34
+ #pragma once
35
+
36
+ /////////////////////////////////////////////////////////////////////////////////////////////////
37
+
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/numeric_conversion.h"
40
+ #include "cutlass/epilogue/thread/activation.h"
41
+
42
+ #include "cute/tensor.hpp"
43
+
44
+ #include <cuda_runtime.h>
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass::reference::host {
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ namespace detail {
53
+
54
+ template<class EngineAct, class LayoutAct>
55
+ bool
56
+ is_activation_in_bounds(
57
+ cute::Tensor<EngineAct, LayoutAct> const& activation,
58
+ int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
59
+ return ((g_ >= 0 && g_ < size<5>(activation)) &&
60
+ (n_ >= 0 && n_ < size<4>(activation)) &&
61
+ (d_ >= 0 && d_ < size<3>(activation)) &&
62
+ (h_ >= 0 && h_ < size<2>(activation)) &&
63
+ (w_ >= 0 && w_ < size<1>(activation)) &&
64
+ (c_ >= 0 && c_ < size<0>(activation)));
65
+ }
66
+
67
+ template<class EngineAct, class LayoutAct>
68
+ bool
69
+ is_activation_in_bounds(
70
+ cute::Tensor<EngineAct, LayoutAct> const& activation,
71
+ int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) {
72
+ return ((g_ >= 0 && g_ < size<4>(activation)) &&
73
+ (n_ >= 0 && n_ < size<3>(activation)) &&
74
+ (h_ >= 0 && h_ < size<2>(activation)) &&
75
+ (w_ >= 0 && w_ < size<1>(activation)) &&
76
+ (c_ >= 0 && c_ < size<0>(activation)));
77
+ }
78
+
79
+ template<class EngineAct, class LayoutAct>
80
+ bool
81
+ is_activation_in_bounds(
82
+ cute::Tensor<EngineAct, LayoutAct> const& activation,
83
+ int32_t n_, int32_t w_, int32_t c_, int32_t g_) {
84
+ return ((g_ >= 0 && g_ < size<3>(activation)) &&
85
+ (n_ >= 0 && n_ < size<2>(activation)) &&
86
+ (w_ >= 0 && w_ < size<1>(activation)) &&
87
+ (c_ >= 0 && c_ < size<0>(activation)));
88
+ }
89
+
90
+ } // namespace detail
91
+
92
+ template<
93
+ class ElementAcc_,
94
+ class ElementScalar_,
95
+ class ElementCompute_,
96
+ class ElementC_,
97
+ class ElementOut_,
98
+ bool ResidualAdd_,
99
+ class TensorAlpha_,
100
+ class TensorBeta_,
101
+ class TensorBias_,
102
+ class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>
103
+ >
104
+ struct ConvEpilogueFusionParams {
105
+ using ElementAcc = ElementAcc_;
106
+ using ElementScalar = ElementScalar_;
107
+ using ElementCompute = ElementCompute_;
108
+ using ElementC = ElementC_;
109
+ using ElementOut = ElementOut_;
110
+ using TensorAlpha = TensorAlpha_;
111
+ using TensorBeta = TensorBeta_;
112
+ using TensorBias = TensorBias_;
113
+ using ActivationFunctor = ActivationFunctor_;
114
+ static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation
115
+
116
+ ElementScalar alpha = ElementScalar(1);
117
+ ElementScalar beta = ElementScalar(0);
118
+
119
+ TensorAlpha tensor_alpha{};
120
+ TensorBeta tensor_beta{};
121
+ TensorBias tensor_bias{};
122
+ };
123
+
124
+ template<
125
+ cutlass::conv::Operator ConvOp,
126
+ int NumSpatialDims,
127
+ class TensorA,
128
+ class TensorB,
129
+ class TensorC,
130
+ class TensorD,
131
+ class ShapePadding,
132
+ class StrideTraversal,
133
+ class ShapeDilation,
134
+ class EpilogueFusionParams
135
+ >
136
+ struct ConvReferenceImpl {
137
+ // Hard code accumlulator type to float to avoid data lost in accumulating add.
138
+ using ElementAcc = cutlass::platform::conditional_t<cutlass::platform::is_same_v<typename EpilogueFusionParams::ElementAcc, double>, double, float>;
139
+ using ElementC = typename EpilogueFusionParams::ElementC;
140
+ using ElementOut = typename EpilogueFusionParams::ElementOut;
141
+ using ElementScalar = typename EpilogueFusionParams::ElementScalar;
142
+ using ElementCompute = typename EpilogueFusionParams::ElementCompute;
143
+ using ElementBias = typename EpilogueFusionParams::TensorBias::value_type;
144
+ using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor;
145
+
146
+ // Input related converter
147
+ NumericConverter<ElementCompute, ElementAcc> acc_converter;
148
+ NumericConverter<ElementCompute, ElementC> residual_converter;
149
+ NumericConverter<ElementCompute, ElementBias> bias_converter;
150
+ // Scale related converter
151
+ NumericConverter<ElementCompute, ElementScalar> scale_converter;
152
+ // Output related converter
153
+ NumericConverter<ElementOut, ElementCompute> output_converter;
154
+
155
+ EpilogueFusionParams& epi_fusion_params_;
156
+ TensorA const& tensor_a_;
157
+ TensorB const& tensor_b_;
158
+ TensorC const& tensor_c_;
159
+ TensorD& tensor_d_;
160
+
161
+ ShapePadding const& padding_;
162
+ StrideTraversal const& tstride_;
163
+ ShapeDilation const& dilation_;
164
+
165
+ // Epilogue activation operation
166
+ ActivationFunctor epi_activation;
167
+
168
+ ConvReferenceImpl(
169
+ TensorA const& tensor_a,
170
+ TensorB const& tensor_b,
171
+ TensorC const& tensor_c,
172
+ TensorD& tensor_d,
173
+ ShapePadding const& padding,
174
+ StrideTraversal const& tstride,
175
+ ShapeDilation const& dilation,
176
+ EpilogueFusionParams& epi_fusion_params)
177
+ : tensor_a_(tensor_a),
178
+ tensor_b_(tensor_b),
179
+ tensor_c_(tensor_c),
180
+ tensor_d_(tensor_d),
181
+ padding_(padding),
182
+ tstride_(tstride),
183
+ dilation_(dilation),
184
+ epi_fusion_params_(epi_fusion_params)
185
+ {
186
+ static_assert(rank(ShapePadding{}) == rank(ShapeDilation{}));
187
+ static_assert(rank(ShapePadding{}) == rank(StrideTraversal{}));
188
+ }
189
+
190
+ void compute_reference() {
191
+ if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
192
+ fprop_reference(cute::Int<NumSpatialDims>{});
193
+ }
194
+ else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
195
+ dgrad_reference(cute::Int<NumSpatialDims>{});
196
+ }
197
+ else {
198
+ wgrad_reference(cute::Int<NumSpatialDims>{});
199
+ }
200
+ }
201
+
202
+ private:
203
+ // Specialization for 1D fprop kernel
204
+ void fprop_reference(cute::Int<1> spatial_dims) {
205
+ int32_t G = size<3>(tensor_d_);
206
+ int32_t N = size<2>(tensor_d_);
207
+ int32_t Q = size<1>(tensor_d_);
208
+ int32_t K = size<0>(tensor_d_);
209
+ int32_t S = size<1>(tensor_b_);
210
+ int32_t C = size<0>(tensor_b_);
211
+
212
+ #if defined(_OPENMP)
213
+ #pragma omp parallel for collapse(2)
214
+ #endif
215
+ for (int32_t g = 0; g < G; ++g) {
216
+ for (int32_t n = 0; n < N; ++n) {
217
+ for (int32_t q = 0; q < Q; ++q) {
218
+ for (int32_t k = 0; k < K; ++k) {
219
+ auto accumulator = ElementAcc(0);
220
+ for (int32_t s = 0; s < S; ++s) {
221
+ for (int32_t c = 0; c < C; ++c) {
222
+ int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
223
+ if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) {
224
+ auto a = tensor_a_(c, w, n, g);
225
+ auto b = tensor_b_(c, s, k, g);
226
+ accumulator += ElementAcc(a * b);
227
+ }
228
+ }
229
+ }
230
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
231
+ epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
232
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
233
+ epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
234
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
235
+ if (not EpilogueFusionParams::ResidualAdd) {
236
+ output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
237
+ }
238
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
239
+ output += bias_converter(epi_fusion_params_.tensor_bias[k]);
240
+ }
241
+ output = epi_activation(output);
242
+ if (EpilogueFusionParams::ResidualAdd) {
243
+ output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g));
244
+ }
245
+ tensor_d_(k, q, n, g) = output_converter(output);
246
+ }
247
+ }
248
+ }
249
+ }
250
+
251
+ }
252
+
253
+ // Specialization for 2D fprop kernel
254
+ void fprop_reference(cute::Int<2> spatial_dims) {
255
+ int32_t G = size<4>(tensor_d_);
256
+ int32_t N = size<3>(tensor_d_);
257
+ int32_t P = size<2>(tensor_d_);
258
+ int32_t Q = size<1>(tensor_d_);
259
+ int32_t K = size<0>(tensor_d_);
260
+ int32_t R = size<2>(tensor_b_);
261
+ int32_t S = size<1>(tensor_b_);
262
+ int32_t C = size<0>(tensor_b_);
263
+
264
+ #if defined(_OPENMP)
265
+ #pragma omp parallel for collapse(3)
266
+ #endif
267
+ for (int32_t g = 0; g < G; ++g) {
268
+ for (int32_t n = 0; n < N; ++n) {
269
+ for (int32_t p = 0; p < P; ++p) {
270
+ for (int32_t q = 0; q < Q; ++q) {
271
+ for (int32_t k = 0; k < K; ++k) {
272
+ auto accumulator = ElementAcc(0);
273
+ for (int32_t r = 0; r < R; ++r) {
274
+ for (int32_t s = 0; s < S; ++s) {
275
+ for (int32_t c = 0; c < C; ++c) {
276
+ int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
277
+ int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
278
+ if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) {
279
+ auto a = tensor_a_(c, w, h, n, g);
280
+ auto b = tensor_b_(c, s, r, k, g);
281
+ accumulator += ElementAcc(a * b);
282
+ }
283
+ }
284
+ }
285
+ }
286
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
287
+ epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
288
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
289
+ epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
290
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
291
+ if (not EpilogueFusionParams::ResidualAdd) {
292
+ output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
293
+ }
294
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
295
+ output += bias_converter(epi_fusion_params_.tensor_bias[k]);
296
+ }
297
+ output = epi_activation(output);
298
+ if (EpilogueFusionParams::ResidualAdd) {
299
+ output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g));
300
+ }
301
+ tensor_d_(k, q, p, n, g) = output_converter(output);
302
+ }
303
+ }
304
+ }
305
+ }
306
+ }
307
+
308
+ }
309
+
310
+ // Specialization for 3D fprop kernel
311
+ void fprop_reference(cute::Int<3> spatial_dims) {
312
+ int32_t G = size<5>(tensor_d_);
313
+ int32_t N = size<4>(tensor_d_);
314
+ int32_t Z = size<3>(tensor_d_);
315
+ int32_t P = size<2>(tensor_d_);
316
+ int32_t Q = size<1>(tensor_d_);
317
+ int32_t K = size<0>(tensor_d_);
318
+ int32_t T = size<3>(tensor_b_);
319
+ int32_t R = size<2>(tensor_b_);
320
+ int32_t S = size<1>(tensor_b_);
321
+ int32_t C = size<0>(tensor_b_);
322
+
323
+ #if defined(_OPENMP)
324
+ #pragma omp parallel for collapse(3)
325
+ #endif
326
+ for (int32_t g = 0; g < G; ++g) {
327
+ for (int32_t n = 0; n < N; ++n) {
328
+ for (int32_t z = 0; z < Z; ++z) {
329
+ for (int32_t p = 0; p < P; ++p) {
330
+ for (int32_t q = 0; q < Q; ++q) {
331
+ for (int32_t k = 0; k < K; ++k) {
332
+ auto accumulator = ElementAcc(0);
333
+ for (int32_t t = 0; t < T; ++t) {
334
+ for (int32_t r = 0; r < R; ++r) {
335
+ for (int32_t s = 0; s < S; ++s) {
336
+ for (int32_t c = 0; c < C; ++c) {
337
+ int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
338
+ int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
339
+ int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
340
+ if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) {
341
+ auto a = tensor_a_(c, w, h, d, n, g);
342
+ auto b = tensor_b_(c, s, r, t, k, g);
343
+ accumulator += ElementAcc(a * b);
344
+ }
345
+ }
346
+ }
347
+ }
348
+ }
349
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
350
+ epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha;
351
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
352
+ epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta;
353
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
354
+ if (not EpilogueFusionParams::ResidualAdd) {
355
+ output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
356
+ }
357
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
358
+ output += bias_converter(epi_fusion_params_.tensor_bias[k]);
359
+ }
360
+ output = epi_activation(output);
361
+ if (EpilogueFusionParams::ResidualAdd) {
362
+ output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g));
363
+ }
364
+ tensor_d_(k, q, p, z, n, g) = output_converter(output);
365
+ }
366
+ }
367
+ }
368
+ }
369
+ }
370
+ }
371
+
372
+ }
373
+
374
+ // Specialization for 1D dgrad kernel
375
+ void dgrad_reference(cute::Int<1> spatial_dims) {
376
+ int32_t G = size<3>(tensor_d_);
377
+ int32_t N = size<2>(tensor_d_);
378
+ int32_t W = size<1>(tensor_d_);
379
+ int32_t C = size<0>(tensor_d_);
380
+ int32_t K = size<2>(tensor_b_);
381
+ int32_t S = size<1>(tensor_b_);
382
+
383
+ #if defined(_OPENMP)
384
+ #pragma omp parallel for collapse(2)
385
+ #endif
386
+ for (int32_t g = 0; g < G; ++g) {
387
+ for (int32_t n = 0; n < N; ++n) {
388
+ for (int32_t w = 0; w < W; ++w) {
389
+ for (int32_t c = 0; c < C; ++c) {
390
+ auto accumulator = ElementAcc(0);
391
+ for (int32_t k = 0; k < K; ++k) {
392
+ for (int32_t s = 0; s < S; ++s) {
393
+ int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
394
+
395
+ if (q % cute::get<0>(tstride_) == 0) {
396
+ q /= cute::get<0>(tstride_);
397
+ } else {
398
+ continue;
399
+ }
400
+
401
+ if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) {
402
+ accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g));
403
+ }
404
+ }
405
+ }
406
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
407
+ ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
408
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
409
+ ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
410
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
411
+ if (not EpilogueFusionParams::ResidualAdd) {
412
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
413
+ }
414
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
415
+ output += bias_converter(epi_fusion_params_.tensor_bias[c]);
416
+ }
417
+ output = epi_activation(output);
418
+ if (EpilogueFusionParams::ResidualAdd) {
419
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g));
420
+ }
421
+ tensor_d_(c, w, n, g) = output_converter(output);
422
+ }
423
+ }
424
+ }
425
+ }
426
+
427
+ }
428
+
429
+ // Specialization for 2D dgrad kernel
430
+ void dgrad_reference(cute::Int<2> spatial_dims) {
431
+ int32_t G = size<4>(tensor_d_);
432
+ int32_t N = size<3>(tensor_d_);
433
+ int32_t H = size<2>(tensor_d_);
434
+ int32_t W = size<1>(tensor_d_);
435
+ int32_t C = size<0>(tensor_d_);
436
+ int32_t K = size<3>(tensor_b_);
437
+ int32_t R = size<2>(tensor_b_);
438
+ int32_t S = size<1>(tensor_b_);
439
+
440
+ #if defined(_OPENMP)
441
+ #pragma omp parallel for collapse(3)
442
+ #endif
443
+ for (int32_t g = 0; g < G; ++g) {
444
+ for (int32_t n = 0; n < N; ++n) {
445
+ for (int32_t h = 0; h < H; ++h) {
446
+ for (int32_t w = 0; w < W; ++w) {
447
+ for (int32_t c = 0; c < C; ++c) {
448
+ auto accumulator = ElementAcc(0);
449
+ for (int32_t k = 0; k < K; ++k) {
450
+ for (int32_t r = 0; r < R; ++r) {
451
+ for (int32_t s = 0; s < S; ++s) {
452
+ int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
453
+ int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
454
+
455
+ if (q % cute::get<0>(tstride_) == 0) {
456
+ q /= cute::get<0>(tstride_);
457
+ } else {
458
+ continue;
459
+ }
460
+
461
+ if (p % cute::get<1>(tstride_) == 0) {
462
+ p /= cute::get<1>(tstride_);
463
+ } else {
464
+ continue;
465
+ }
466
+
467
+ if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) {
468
+ accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g));
469
+ }
470
+ }
471
+ }
472
+ }
473
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
474
+ ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
475
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
476
+ ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
477
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
478
+ if (not EpilogueFusionParams::ResidualAdd) {
479
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
480
+ }
481
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
482
+ output += bias_converter(epi_fusion_params_.tensor_bias[c]);
483
+ }
484
+ output = epi_activation(output);
485
+ if (EpilogueFusionParams::ResidualAdd) {
486
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g));
487
+ }
488
+
489
+ tensor_d_(c, w, h, n, g) = output_converter(output);
490
+ }
491
+ }
492
+ }
493
+ }
494
+ }
495
+
496
+ }
497
+
498
+ // Specialization for 3D dgrad kernel
499
+ void dgrad_reference(cute::Int<3> spatial_dims) {
500
+ int32_t G = size<5>(tensor_d_);
501
+ int32_t N = size<4>(tensor_d_);
502
+ int32_t D = size<3>(tensor_d_);
503
+ int32_t H = size<2>(tensor_d_);
504
+ int32_t W = size<1>(tensor_d_);
505
+ int32_t C = size<0>(tensor_d_);
506
+ int32_t K = size<4>(tensor_b_);
507
+ int32_t T = size<3>(tensor_b_);
508
+ int32_t R = size<2>(tensor_b_);
509
+ int32_t S = size<1>(tensor_b_);
510
+
511
+ #if defined(_OPENMP)
512
+ #pragma omp parallel for collapse(3)
513
+ #endif
514
+ for (int32_t g = 0; g < G; ++g) {
515
+ for (int32_t n = 0; n < N; ++n) {
516
+ for (int32_t d = 0; d < D; ++d) {
517
+ for (int32_t h = 0; h < H; ++h) {
518
+ for (int32_t w = 0; w < W; ++w) {
519
+ for (int32_t c = 0; c < C; ++c) {
520
+ auto accumulator = ElementAcc(0);
521
+ for (int32_t k = 0; k < K; ++k) {
522
+ for (int32_t t = 0; t < T; ++t) {
523
+ for (int32_t r = 0; r < R; ++r) {
524
+ for (int32_t s = 0; s < S; ++s) {
525
+ int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_);
526
+ int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_);
527
+ int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_);
528
+
529
+ if (q % cute::get<0>(tstride_) == 0) {
530
+ q /= cute::get<0>(tstride_);
531
+ } else {
532
+ continue;
533
+ }
534
+
535
+ if (p % cute::get<1>(tstride_) == 0) {
536
+ p /= cute::get<1>(tstride_);
537
+ } else {
538
+ continue;
539
+ }
540
+
541
+ if (z % cute::get<2>(tstride_) == 0) {
542
+ z /= cute::get<2>(tstride_);
543
+ } else {
544
+ continue;
545
+ }
546
+
547
+ if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) {
548
+ accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g));
549
+ }
550
+ }
551
+ }
552
+ }
553
+ }
554
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data())
555
+ ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
556
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data())
557
+ ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
558
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
559
+ if (not EpilogueFusionParams::ResidualAdd) {
560
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
561
+ }
562
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
563
+ output += bias_converter(epi_fusion_params_.tensor_bias[c]);
564
+ }
565
+ output = epi_activation(output);
566
+ if (EpilogueFusionParams::ResidualAdd) {
567
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g));
568
+ }
569
+ tensor_d_(c, w, h, d, n, g) = output_converter(output);
570
+ }
571
+ }
572
+ }
573
+ }
574
+ }
575
+ }
576
+
577
+ }
578
+
579
+ // Specialization for 1D wgrad kernel
580
+ void wgrad_reference(cute::Int<1> spatial_dims) {
581
+ int32_t G = size<3>(tensor_d_);
582
+ int32_t N =
583
+ size<2>(tensor_a_);
584
+ int32_t Q =
585
+ size<1>(tensor_a_);
586
+ int32_t K =
587
+ size<0>(tensor_a_);
588
+ int32_t S = size<1>(tensor_d_);
589
+ int32_t C = size<0>(tensor_d_);
590
+
591
+ #if defined(_OPENMP)
592
+ #pragma omp parallel for collapse(2)
593
+ #endif
594
+ for (int32_t g = 0; g < G; ++g) {
595
+ for (int32_t k = 0; k < K; ++k) {
596
+ for (int32_t s = 0; s < S; ++s) {
597
+ for (int32_t c = 0; c < C; ++c) {
598
+ auto accumulator = ElementAcc(0);
599
+ for (int32_t n = 0; n < N; ++n) {
600
+ for (int32_t q = 0; q < Q; ++q) {
601
+ int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
602
+ bool is_in_bounds =
603
+ detail::is_activation_in_bounds(tensor_b_, n, w, c, g);
604
+ if (is_in_bounds) {
605
+ auto act =
606
+ tensor_b_(c, w, n, g);
607
+ auto xformed_act =
608
+ tensor_a_(k, q, n, g);
609
+ accumulator += ElementAcc(act * xformed_act);
610
+ }
611
+ }
612
+ }
613
+
614
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
615
+ epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
616
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
617
+ epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
618
+
619
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
620
+ if (not EpilogueFusionParams::ResidualAdd) {
621
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
622
+ }
623
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
624
+ output += bias_converter(epi_fusion_params_.tensor_bias[c]);
625
+ }
626
+ output = epi_activation(output);
627
+ if (EpilogueFusionParams::ResidualAdd) {
628
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g));
629
+ }
630
+ tensor_d_(c, s, k, g) = output_converter(output);
631
+ }
632
+ }
633
+ }
634
+ }
635
+ }
636
+
637
+ // Specialization for 2D wgrad kernel
638
+ void wgrad_reference(cute::Int<2> spatial_dims) {
639
+ int32_t G = size<4>(tensor_d_);
640
+ int32_t N =
641
+ size<3>(tensor_a_);
642
+ int32_t P =
643
+ size<2>(tensor_a_);
644
+ int32_t Q =
645
+ size<1>(tensor_a_);
646
+ int32_t K =
647
+ size<0>(tensor_a_);
648
+ int32_t R = size<2>(tensor_d_);
649
+ int32_t S = size<1>(tensor_d_);
650
+ int32_t C = size<0>(tensor_d_);
651
+
652
+ #if defined(_OPENMP)
653
+ #pragma omp parallel for collapse(3)
654
+ #endif
655
+ for (int32_t g = 0; g < G; ++g) {
656
+ for (int32_t k = 0; k < K; ++k) {
657
+ for (int32_t r = 0; r < R; ++r) {
658
+ for (int32_t s = 0; s < S; ++s) {
659
+ for (int32_t c = 0; c < C; ++c) {
660
+ auto accumulator = ElementAcc(0);
661
+ for (int32_t n = 0; n < N; ++n) {
662
+ for (int32_t p = 0; p < P; ++p) {
663
+ for (int32_t q = 0; q < Q; ++q) {
664
+ int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
665
+ int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
666
+ bool is_in_bounds =
667
+ detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g);
668
+ if (is_in_bounds) {
669
+ auto act =
670
+ tensor_b_(c, w, h, n, g);
671
+ auto xformed_act =
672
+ tensor_a_(k, q, p, n, g);
673
+ accumulator += ElementAcc(act * xformed_act);
674
+ }
675
+ }
676
+ }
677
+ }
678
+
679
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
680
+ epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
681
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
682
+ epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
683
+
684
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
685
+ if (not EpilogueFusionParams::ResidualAdd) {
686
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
687
+ }
688
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
689
+ output += bias_converter(epi_fusion_params_.tensor_bias[c]);
690
+ }
691
+ output = epi_activation(output);
692
+ if (EpilogueFusionParams::ResidualAdd) {
693
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g));
694
+ }
695
+ tensor_d_(c, s, r, k, g) = output_converter(output);
696
+ }
697
+ }
698
+ }
699
+ }
700
+ }
701
+ }
702
+
703
+ // Specialization for 3D wgrad kernel
704
+ void wgrad_reference(cute::Int<3> spatial_dims) {
705
+ int32_t G = size<5>(tensor_d_);
706
+ int32_t N =
707
+ size<4>(tensor_a_);
708
+ int32_t Z =
709
+ size<3>(tensor_a_);
710
+ int32_t P =
711
+ size<2>(tensor_a_);
712
+ int32_t Q =
713
+ size<1>(tensor_a_);
714
+ int32_t K =
715
+ size<0>(tensor_a_);
716
+ int32_t T = size<3>(tensor_d_);
717
+ int32_t R = size<2>(tensor_d_);
718
+ int32_t S = size<1>(tensor_d_);
719
+ int32_t C = size<0>(tensor_d_);
720
+
721
+ #if defined(_OPENMP)
722
+ #pragma omp parallel for collapse(3)
723
+ #endif
724
+ for (int32_t g = 0 ; g < G; ++g) {
725
+ for (int32_t k = 0; k < K; ++k) {
726
+ for (int32_t t = 0; t < T; ++t) {
727
+ for (int32_t r = 0; r < R; ++r) {
728
+ for (int32_t s = 0; s < S; ++s) {
729
+ for (int32_t c = 0; c < C; ++c) {
730
+ auto accumulator = ElementAcc(0);
731
+ for (int32_t n = 0; n < N; ++n) {
732
+ for (int32_t z = 0; z < Z; ++z) {
733
+ for (int32_t p = 0; p < P; ++p) {
734
+ for (int32_t q = 0; q < Q; ++q) {
735
+ int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_);
736
+ int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_);
737
+ int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_);
738
+ bool is_in_bounds =
739
+ detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g);
740
+ if (is_in_bounds) {
741
+ auto act =
742
+ tensor_b_(c, w, h, d, n, g);
743
+ auto xformed_act =
744
+ tensor_a_(k, q, p, z, n, g);
745
+ accumulator += ElementAcc(act * xformed_act);
746
+ }
747
+ }
748
+ }
749
+ }
750
+ }
751
+
752
+ ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ?
753
+ epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha;
754
+ ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ?
755
+ epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta;
756
+
757
+ ElementCompute output = scale_converter(alpha) * acc_converter(accumulator);
758
+ if (not EpilogueFusionParams::ResidualAdd) {
759
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
760
+ }
761
+ if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) {
762
+ output += bias_converter(epi_fusion_params_.tensor_bias[c]);
763
+ }
764
+ output = epi_activation(output);
765
+ if (EpilogueFusionParams::ResidualAdd) {
766
+ output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g));
767
+ }
768
+ tensor_d_(c, s, r, t, k, g) = output_converter(output);
769
+ }
770
+ }
771
+ }
772
+ }
773
+ }
774
+ }
775
+ }
776
+ };
777
+
778
+ /////////////////////////////////////////////////////////////////////////////////////////////////
779
+
780
+ } // cutlass::reference::host
781
+
782
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Reference implementation for convolution in host-side code.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/coord.h"
39
+ #include "cutlass/functional.h"
40
+ #include "cutlass/layout/tensor.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/numeric_types.h"
43
+ #include "cutlass/tensor_ref.h"
44
+ #include "cutlass/tensor_view.h"
45
+ #include "cutlass/conv/convolution.h"
46
+ #include "cutlass/conv/conv2d_problem_size.h"
47
+ #include "cutlass/conv/conv3d_problem_size.h"
48
+ #include <iostream>
49
+
50
+ namespace cutlass {
51
+ namespace reference {
52
+ namespace host {
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
55
+ /// Forward propagation
56
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ /// y = conv2d(x, w)
59
+ template <
60
+ typename ElementA,
61
+ typename LayoutA,
62
+ typename ElementB,
63
+ typename LayoutB,
64
+ typename ElementC,
65
+ typename LayoutC,
66
+ typename ElementCompute,
67
+ typename ElementAccumulator = ElementCompute,
68
+ typename ElementD = ElementC,
69
+ typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
70
+ typename InnerProductOp = multiply_add<ElementAccumulator>
71
+ >
72
+ void Conv2dFprop(
73
+ conv::Conv2dProblemSize problem_size,
74
+ TensorRef<ElementA, LayoutA> tensor_x,
75
+ TensorRef<ElementB, LayoutB> tensor_w,
76
+ TensorRef<ElementC, LayoutC> tensor_y_in,
77
+ TensorRef<ElementD, LayoutC> tensor_y_out,
78
+ ElementCompute alpha,
79
+ ElementCompute beta) {
80
+
81
+ ConvertOp convert_op;
82
+ InnerProductOp inner_product_op;
83
+
84
+ // Apply MMA and accumulate ElementAccumulator
85
+ for (int n = 0; n < problem_size.N; ++n) {
86
+ for (int p = 0; p < problem_size.P; ++p) {
87
+ for (int q = 0; q < problem_size.Q; ++q) {
88
+ for (int k = 0; k < problem_size.K; ++k) {
89
+
90
+ int group_idx = k / (problem_size.K / problem_size.groups);
91
+ int channels_per_group = problem_size.C / problem_size.groups;
92
+
93
+ ElementAccumulator acc = ElementAccumulator();
94
+
95
+ for (int r = 0; r < problem_size.R; ++r) {
96
+ for (int s = 0; s < problem_size.S; ++s) {
97
+ for (int c = 0; c < channels_per_group; ++c) {
98
+
99
+ int filter_r = r;
100
+ int filter_s = s;
101
+
102
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
103
+ filter_r = problem_size.R - 1 - r;
104
+ filter_s = problem_size.S - 1 - s;
105
+ }
106
+
107
+ int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
108
+ int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
109
+
110
+ if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) {
111
+
112
+ ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group});
113
+ ElementB b = tensor_w.at({k, r, s, c});
114
+
115
+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
116
+
117
+ }
118
+ }
119
+ }
120
+ }
121
+
122
+ // Apply Epilogue, compute ElementCompute, convert and store ElementC
123
+ ElementC c_ref = ElementC();
124
+
125
+ if (beta != ElementCompute()) {
126
+ c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k));
127
+ }
128
+
129
+ tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) =
130
+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
131
+ }
132
+ }
133
+ }
134
+ }
135
+ }
136
+
137
+ /// Depthwise-separable convolution
138
+ template <typename ElementA,
139
+ typename LayoutA,
140
+ typename ElementB,
141
+ typename LayoutB,
142
+ typename ElementC,
143
+ typename LayoutC,
144
+ typename ElementCompute,
145
+ typename ElementAccumulator = ElementCompute,
146
+ typename ElementD = ElementC,
147
+ typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
148
+ typename InnerProductOp = multiply_add<ElementAccumulator>>
149
+ void Depsep_Fprop(cutlass::TensorView<ElementA, LayoutA> tensor_A,
150
+ cutlass::TensorView<ElementB, LayoutB> tensor_B,
151
+ cutlass::TensorView<ElementC, LayoutC> tensor_C,
152
+ cutlass::TensorView<ElementD, LayoutC> tensor_D,
153
+ ElementCompute alpha,
154
+ ElementCompute beta,
155
+ cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(),
156
+ cutlass::Coord<2> conv_stride = cutlass::Coord<2>(),
157
+ cutlass::Coord<2> dilation = cutlass::Coord<2>(),
158
+ cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) {
159
+
160
+ ConvertOp convert_op;
161
+ InnerProductOp inner_product_op;
162
+
163
+ // Apply MMA and accumulate ElementAccumulator
164
+ for (int n = 0; n < tensor_C.extent().n(); ++n) {
165
+ for (int p = 0; p < tensor_C.extent().h(); ++p) {
166
+ for (int q = 0; q < tensor_C.extent().w(); ++q) {
167
+ for (int g = 0; g < tensor_C.extent().c(); ++g) {
168
+ ElementAccumulator acc = ElementAccumulator();
169
+ for (int r = 0; r < tensor_B.extent().h(); ++r) {
170
+ for (int s = 0; s < tensor_B.extent().w(); ++s) {
171
+
172
+ // input activation H and W
173
+ int h = p * conv_stride[0] - padding[0] + r * dilation[0];
174
+ int w = q * conv_stride[1] - padding[2] + s * dilation[1];
175
+
176
+ if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) {
177
+ ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g));
178
+
179
+ ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation)
180
+ ? tensor_B.at(cutlass::make_Coord(g, r, s, 0))
181
+ : tensor_B.at(cutlass::make_Coord(
182
+ g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0));
183
+
184
+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
185
+ }
186
+ }
187
+ }
188
+
189
+ // Apply Epilogue, compute ElementCompute, convert and store ElementC
190
+ ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g));
191
+ tensor_D.at(cutlass::make_Coord(n, p, q, g)) =
192
+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
193
+ }
194
+ }
195
+ }
196
+ }
197
+ }
198
+
199
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
200
+ /// Dgrad / Deconv
201
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
202
+
203
+ /// dx = dgrad(dy, w)
204
+ template <
205
+ typename ElementA,
206
+ typename LayoutA,
207
+ typename ElementB,
208
+ typename LayoutB,
209
+ typename ElementC,
210
+ typename LayoutC,
211
+ typename ElementCompute,
212
+ typename ElementAccumulator = ElementCompute,
213
+ typename ElementD = ElementC,
214
+ typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
215
+ typename InnerProductOp = multiply_add<ElementAccumulator>
216
+ >
217
+ void Conv2dDgrad(
218
+ cutlass::conv::Conv2dProblemSize problem_size,
219
+ TensorRef<ElementA, LayoutA> tensor_dy,
220
+ TensorRef<ElementB, LayoutB> tensor_w,
221
+ TensorRef<ElementC, LayoutC> tensor_dx_in,
222
+ TensorRef<ElementD, LayoutC> tensor_dx_out,
223
+ ElementCompute alpha,
224
+ ElementCompute beta,
225
+ bool is_deconv = false) {
226
+
227
+ ConvertOp convert_op;
228
+ InnerProductOp inner_product_op;
229
+
230
+ // Apply MMA and accumulate ElementAccumulator
231
+ for (int n = 0; n < problem_size.N; ++n) {
232
+ for (int h = 0; h < problem_size.H; ++h) {
233
+ for (int w = 0; w < problem_size.W; ++w) {
234
+ for (int c = 0; c < problem_size.C; ++c) {
235
+
236
+ ElementAccumulator acc = ElementAccumulator();
237
+
238
+ for (int r = 0; r < problem_size.R; ++r) {
239
+ for (int s = 0; s < problem_size.S; ++s) {
240
+ for (int k = 0; k < problem_size.K; ++k) {
241
+
242
+ int filter_r = r;
243
+ int filter_s = s;
244
+
245
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
246
+ filter_r = problem_size.R - 1 - r;
247
+ filter_s = problem_size.S - 1 - s;
248
+ }
249
+
250
+ int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
251
+ int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
252
+
253
+ if (p >= 0 && (p % problem_size.stride_h) == 0 &&
254
+ q >= 0 && (q % problem_size.stride_w) == 0) {
255
+
256
+ p = p / problem_size.stride_h;
257
+ q = q / problem_size.stride_w;
258
+ #if 0
259
+ std::cout << "row:"
260
+ << n * problem_size.H * problem_size.W +
261
+ h * problem_size.W +
262
+ w << " "
263
+ << "n, p, q: ("
264
+ << n << ", "
265
+ << p << ", "
266
+ << q << ") * "
267
+ << "r, s: ("
268
+ << r << ", "
269
+ << s << ") ["
270
+ << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]"
271
+ << std::endl;
272
+ #endif
273
+ if (p < problem_size.P && q < problem_size.Q) {
274
+
275
+ ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k));
276
+ ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k))
277
+ : tensor_w.at(cutlass::make_Coord(k, r, s, c));
278
+
279
+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
280
+ }
281
+ }
282
+
283
+ } // for (K)
284
+ } // for (S)
285
+ } // for (R)
286
+
287
+ // Apply Epilogue, compute ElementCompute, convert and store ElementC
288
+ ElementC c_ref = ElementC();
289
+
290
+ if (beta != ElementCompute()) {
291
+ c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c));
292
+ }
293
+
294
+ tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) =
295
+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
296
+
297
+ } // for (C)
298
+ } // for (W)
299
+ } // for (H)
300
+ } // for (N)
301
+ }
302
+
303
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
304
+ /// Wgrad
305
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
306
+
307
+ /// dw = wgrad(dy, x)
308
+ template <
309
+ typename ElementA,
310
+ typename LayoutA,
311
+ typename ElementB,
312
+ typename LayoutB,
313
+ typename ElementC,
314
+ typename LayoutC,
315
+ typename ElementCompute,
316
+ typename ElementAccumulator = ElementCompute,
317
+ typename ElementD = ElementC,
318
+ typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
319
+ typename InnerProductOp = multiply_add<ElementAccumulator>
320
+ >
321
+ void Conv2dWgrad(
322
+ cutlass::conv::Conv2dProblemSize problem_size,
323
+ TensorRef<ElementA, LayoutA> tensor_dy,
324
+ TensorRef<ElementB, LayoutB> tensor_x,
325
+ TensorRef<ElementC, LayoutC> tensor_dw_in,
326
+ TensorRef<ElementD, LayoutC> tensor_dw_out,
327
+ ElementCompute alpha,
328
+ ElementCompute beta) {
329
+
330
+ InnerProductOp inner_product_op;
331
+ ConvertOp convert_op;
332
+
333
+ // Apply MMA and accumulate ElementAccumulator
334
+ for (int k = 0; k < problem_size.K; ++k) {
335
+ for (int r = 0; r < problem_size.R; ++r) {
336
+ for (int s = 0; s < problem_size.S; ++s) {
337
+ for (int c = 0; c < problem_size.C; ++c) {
338
+
339
+ ElementAccumulator acc = ElementAccumulator();
340
+
341
+ for (int n = 0; n < problem_size.N; ++n) {
342
+ for (int p = 0; p < problem_size.P; ++p) {
343
+ for (int q = 0; q < problem_size.Q; ++q) {
344
+
345
+ cutlass::Tensor4DCoord b_coord;
346
+
347
+ int filter_r = r;
348
+ int filter_s = s;
349
+
350
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
351
+ filter_r = problem_size.R - 1 - r;
352
+ filter_s = problem_size.S - 1 - s;
353
+ }
354
+
355
+ b_coord = make_Coord(
356
+ n,
357
+ p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
358
+ q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
359
+ c);
360
+
361
+ if (b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
362
+ b_coord.w() < problem_size.W && b_coord.w() >= 0) {
363
+
364
+ ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k)));
365
+ ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
366
+ acc = inner_product_op(a, b, acc);
367
+ }
368
+ }
369
+ }
370
+ }
371
+
372
+ // Apply Epilogue, compute ElementCompute, convert and store ElementC
373
+ ElementC c_ref = ElementC();
374
+
375
+ if (beta != ElementCompute()) {
376
+ c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c));
377
+ }
378
+
379
+ tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) =
380
+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
381
+
382
+ } // for (C)
383
+ } // for (S)
384
+ } // for (R)
385
+ } // for (K)
386
+ }
387
+
388
+ /// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
389
+ template <
390
+ typename ElementA,
391
+ typename LayoutA,
392
+ typename ElementB,
393
+ typename LayoutB,
394
+ typename ElementC,
395
+ typename LayoutC,
396
+ typename ElementCompute,
397
+ typename ElementAccumulator = ElementCompute,
398
+ typename ElementD = ElementC,
399
+ typename ConvertOp = NumericConverter<ElementD, ElementCompute>,
400
+ typename InnerProductOp = multiply_add<ElementAccumulator>
401
+ >
402
+ void Conv2d(
403
+ conv::Operator convolutional_operator,
404
+ conv::Conv2dProblemSize problem_size,
405
+ TensorRef<ElementA, LayoutA> tensor_A,
406
+ TensorRef<ElementB, LayoutB> tensor_B,
407
+ TensorRef<ElementC, LayoutC> tensor_C,
408
+ TensorRef<ElementD, LayoutC> tensor_D,
409
+ ElementCompute alpha,
410
+ ElementCompute beta) {
411
+
412
+ switch (convolutional_operator) {
413
+ case conv::Operator::kFprop:
414
+ Conv2dFprop<
415
+ ElementA, LayoutA,
416
+ ElementB, LayoutB,
417
+ ElementC, LayoutC,
418
+ ElementCompute,
419
+ ElementAccumulator,
420
+ ElementD,
421
+ ConvertOp, InnerProductOp
422
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
423
+ break;
424
+
425
+ case conv::Operator::kDeconv:
426
+ case conv::Operator::kDgrad:
427
+ Conv2dDgrad<
428
+ ElementA, LayoutA,
429
+ ElementB, LayoutB,
430
+ ElementC, LayoutC,
431
+ ElementCompute,
432
+ ElementAccumulator,
433
+ ElementD,
434
+ ConvertOp, InnerProductOp
435
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
436
+ break;
437
+
438
+ case conv::Operator::kWgrad:
439
+ Conv2dWgrad<
440
+ ElementA, LayoutA,
441
+ ElementB, LayoutB,
442
+ ElementC, LayoutC,
443
+ ElementCompute,
444
+ ElementAccumulator,
445
+ ElementD,
446
+ ConvertOp, InnerProductOp
447
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
448
+ break;
449
+
450
+ default:
451
+ break;
452
+ }
453
+ }
454
+
455
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
456
+ /// 3D convolution
457
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
458
+
459
+ /// y = conv3d(x, w)
460
+ template <
461
+ typename ElementA,
462
+ typename LayoutA,
463
+ typename ElementB,
464
+ typename LayoutB,
465
+ typename ElementC,
466
+ typename LayoutC,
467
+ typename ElementCompute,
468
+ typename ElementAccumulator = ElementCompute,
469
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
470
+ typename InnerProductOp = multiply_add<ElementAccumulator>
471
+ >
472
+ void Conv3dFprop(
473
+ conv::Conv3dProblemSize problem_size,
474
+ TensorRef<ElementA, LayoutA> tensor_x,
475
+ TensorRef<ElementB, LayoutB> tensor_w,
476
+ TensorRef<ElementC, LayoutC> tensor_y_in,
477
+ TensorRef<ElementC, LayoutC> tensor_y_out,
478
+ ElementCompute alpha,
479
+ ElementCompute beta) {
480
+
481
+ ConvertOp convert_op;
482
+ InnerProductOp inner_product_op;
483
+
484
+ // Apply MMA and accumulate ElementAccumulator
485
+ for (int n = 0; n < problem_size.N; ++n) {
486
+ for (int z = 0; z < problem_size.Z; ++z) {
487
+ for (int p = 0; p < problem_size.P; ++p) {
488
+ for (int q = 0; q < problem_size.Q; ++q) {
489
+ for (int k = 0; k < problem_size.K; ++k) {
490
+
491
+ ElementAccumulator acc = ElementAccumulator();
492
+
493
+ for (int t = 0; t < problem_size.T; ++t) {
494
+ for (int r = 0; r < problem_size.R; ++r) {
495
+ for (int s = 0; s < problem_size.S; ++s) {
496
+ for (int c = 0; c < problem_size.C; ++c) {
497
+
498
+ int filter_t = t;
499
+ int filter_r = r;
500
+ int filter_s = s;
501
+
502
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
503
+ filter_t = problem_size.T - 1 - t;
504
+ filter_r = problem_size.R - 1 - r;
505
+ filter_s = problem_size.S - 1 - s;
506
+ }
507
+
508
+ int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d;
509
+ int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h;
510
+ int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w;
511
+
512
+ if (d >= 0 && d < problem_size.D &&
513
+ h >=0 && h < problem_size.H &&
514
+ w >= 0 && w < problem_size.W) {
515
+
516
+ ElementA a = tensor_x.at({n, d, h, w, c});
517
+ ElementB b = tensor_w.at({k, t, r, s, c});
518
+
519
+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
520
+ }
521
+ }
522
+ }
523
+ }
524
+ }
525
+
526
+ // Apply Epilogue, compute ElementCompute, convert and store ElementC
527
+ ElementC c_ref = ElementC();
528
+
529
+ if (beta != ElementCompute()) {
530
+ c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k));
531
+ }
532
+
533
+ tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) =
534
+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
535
+ }
536
+ }
537
+ }
538
+ }
539
+ }
540
+ }
541
+
542
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
543
+ /// Dgrad / Deconv
544
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
545
+
546
+ /// dx = dgrad(dy, w)
547
+ template <
548
+ typename ElementA,
549
+ typename LayoutA,
550
+ typename ElementB,
551
+ typename LayoutB,
552
+ typename ElementC,
553
+ typename LayoutC,
554
+ typename ElementCompute,
555
+ typename ElementAccumulator = ElementCompute,
556
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
557
+ typename InnerProductOp = multiply_add<ElementAccumulator>
558
+ >
559
+ void Conv3dDgrad(
560
+ cutlass::conv::Conv3dProblemSize problem_size,
561
+ TensorRef<ElementA, LayoutA> tensor_dy,
562
+ TensorRef<ElementB, LayoutB> tensor_w,
563
+ TensorRef<ElementC, LayoutC> tensor_dx_in,
564
+ TensorRef<ElementC, LayoutC> tensor_dx_out,
565
+ ElementCompute alpha,
566
+ ElementCompute beta,
567
+ bool is_deconv = false) {
568
+
569
+ ConvertOp convert_op;
570
+ InnerProductOp inner_product_op;
571
+
572
+ // Apply MMA and accumulate ElementAccumulator
573
+ for (int n = 0; n < problem_size.N; ++n) {
574
+ for (int d = 0; d < problem_size.D; ++d) {
575
+ for (int h = 0; h < problem_size.H; ++h) {
576
+ for (int w = 0; w < problem_size.W; ++w) {
577
+ for (int c = 0; c < problem_size.C; ++c) {
578
+
579
+ ElementAccumulator acc = ElementAccumulator();
580
+
581
+ for (int t = 0; t < problem_size.T; ++t) {
582
+ for (int r = 0; r < problem_size.R; ++r) {
583
+ for (int s = 0; s < problem_size.S; ++s) {
584
+ for (int k = 0; k < problem_size.K; ++k) {
585
+
586
+ int filter_t = t;
587
+ int filter_r = r;
588
+ int filter_s = s;
589
+
590
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
591
+ filter_t = problem_size.T - 1 - t;
592
+ filter_r = problem_size.R - 1 - r;
593
+ filter_s = problem_size.S - 1 - s;
594
+ }
595
+
596
+ int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d;
597
+ int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h;
598
+ int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w;
599
+
600
+ if (z >= 0 && (z % problem_size.stride_d) == 0 &&
601
+ p >= 0 && (p % problem_size.stride_h) == 0 &&
602
+ q >= 0 && (q % problem_size.stride_w) == 0) {
603
+
604
+ z = z / problem_size.stride_d;
605
+ p = p / problem_size.stride_h;
606
+ q = q / problem_size.stride_w;
607
+
608
+ if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) {
609
+
610
+ ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k));
611
+ ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k))
612
+ : tensor_w.at(cutlass::make_Coord(k, t, r, s, c));
613
+ acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc);
614
+ }
615
+ }
616
+
617
+ } // for (K)
618
+ } // for (S)
619
+ } // for (R)
620
+ } // for (T)
621
+
622
+ // Apply Epilogue, compute ElementCompute, convert and store ElementC
623
+ ElementC c_ref = ElementC();
624
+
625
+ if (beta != ElementCompute()) {
626
+ c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c));
627
+ }
628
+
629
+ tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) =
630
+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
631
+
632
+ } // for (C)
633
+ } // for (W)
634
+ } // for (H)
635
+ } // for (D)
636
+ } // for (N)
637
+ }
638
+
639
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
640
+ /// Wgrad
641
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
642
+
643
+ /// dw = wgrad(dy, x)
644
+ template <
645
+ typename ElementA,
646
+ typename LayoutA,
647
+ typename ElementB,
648
+ typename LayoutB,
649
+ typename ElementC,
650
+ typename LayoutC,
651
+ typename ElementCompute,
652
+ typename ElementAccumulator = ElementCompute,
653
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
654
+ typename InnerProductOp = multiply_add<ElementAccumulator>
655
+ >
656
+ void Conv3dWgrad(
657
+ cutlass::conv::Conv3dProblemSize problem_size,
658
+ TensorRef<ElementA, LayoutA> tensor_dy,
659
+ TensorRef<ElementB, LayoutB> tensor_x,
660
+ TensorRef<ElementC, LayoutC> tensor_dw_in,
661
+ TensorRef<ElementC, LayoutC> tensor_dw_out,
662
+ ElementCompute alpha,
663
+ ElementCompute beta) {
664
+
665
+ InnerProductOp inner_product_op;
666
+ ConvertOp convert_op;
667
+
668
+ // Apply MMA and accumulate ElementAccumulator
669
+ for (int k = 0; k < problem_size.K; ++k) {
670
+ for (int t = 0; t < problem_size.T; ++t) {
671
+ for (int r = 0; r < problem_size.R; ++r) {
672
+ for (int s = 0; s < problem_size.S; ++s) {
673
+ for (int c = 0; c < problem_size.C; ++c) {
674
+
675
+ ElementAccumulator acc = ElementAccumulator();
676
+
677
+ for (int n = 0; n < problem_size.N; ++n) {
678
+ for (int z = 0; z < problem_size.Z; ++z) {
679
+ for (int p = 0; p < problem_size.P; ++p) {
680
+ for (int q = 0; q < problem_size.Q; ++q) {
681
+
682
+ int filter_t = t;
683
+ int filter_r = r;
684
+ int filter_s = s;
685
+
686
+ if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
687
+ filter_t = problem_size.T - 1 - t;
688
+ filter_r = problem_size.R - 1 - r;
689
+ filter_s = problem_size.S - 1 - s;
690
+ }
691
+
692
+ Tensor5DCoord b_coord = make_Coord(
693
+ n,
694
+ z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d,
695
+ p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h,
696
+ q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w,
697
+ c);
698
+
699
+ if (b_coord.d() < problem_size.D && b_coord.d() >= 0 &&
700
+ b_coord.h() < problem_size.H && b_coord.h() >= 0 &&
701
+ b_coord.w() < problem_size.W && b_coord.w() >= 0) {
702
+
703
+ ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)));
704
+ ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord));
705
+
706
+ acc = inner_product_op(a, b, acc);
707
+ }
708
+ }
709
+ }
710
+ }
711
+ }
712
+
713
+ // Apply Epilogue, compute ElementCompute, convert and store ElementC
714
+ ElementC c_ref = ElementC();
715
+
716
+ if (beta != ElementCompute()) {
717
+ c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c));
718
+ }
719
+
720
+ tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) =
721
+ convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref));
722
+
723
+ } // for (C)
724
+ } // for (S)
725
+ } // for (R)
726
+ } // for (T)
727
+ } // for (K)
728
+ }
729
+
730
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
731
+
732
+ /// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
733
+ template <
734
+ typename ElementA,
735
+ typename LayoutA,
736
+ typename ElementB,
737
+ typename LayoutB,
738
+ typename ElementC,
739
+ typename LayoutC,
740
+ typename ElementCompute,
741
+ typename ElementAccumulator = ElementCompute,
742
+ typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
743
+ typename InnerProductOp = multiply_add<ElementAccumulator>
744
+ >
745
+ void Conv3d(
746
+ conv::Operator convolutional_operator,
747
+ conv::Conv3dProblemSize problem_size,
748
+ TensorRef<ElementA, LayoutA> tensor_A,
749
+ TensorRef<ElementB, LayoutB> tensor_B,
750
+ TensorRef<ElementC, LayoutC> tensor_C,
751
+ TensorRef<ElementC, LayoutC> tensor_D,
752
+ ElementCompute alpha,
753
+ ElementCompute beta) {
754
+
755
+ switch (convolutional_operator) {
756
+ case conv::Operator::kFprop:
757
+ Conv3dFprop<
758
+ ElementA, LayoutA,
759
+ ElementB, LayoutB,
760
+ ElementC, LayoutC,
761
+ ElementCompute,
762
+ ElementAccumulator,
763
+ ConvertOp, InnerProductOp
764
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
765
+ break;
766
+
767
+ case conv::Operator::kDeconv:
768
+ case conv::Operator::kDgrad:
769
+ Conv3dDgrad<
770
+ ElementA, LayoutA,
771
+ ElementB, LayoutB,
772
+ ElementC, LayoutC,
773
+ ElementCompute,
774
+ ElementAccumulator,
775
+ ConvertOp, InnerProductOp
776
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv));
777
+ break;
778
+
779
+ case conv::Operator::kWgrad:
780
+ Conv3dWgrad<
781
+ ElementA, LayoutA,
782
+ ElementB, LayoutB,
783
+ ElementC, LayoutC,
784
+ ElementCompute,
785
+ ElementAccumulator,
786
+ ConvertOp, InnerProductOp
787
+ >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta);
788
+ break;
789
+
790
+ default:
791
+ break;
792
+ }
793
+ }
794
+
795
+ /////////////////////////////////////////////////////////////////////////////////////////////////
796
+
797
+ } // namespace host
798
+ } // namespace reference
799
+ } // namespace cutlass
800
+
801
+ /////////////////////////////////////////////////////////////////////////////////////////////////
802
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /***************************************************************************************************
3
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ *
6
+ * Redistribution and use in source and binary forms, with or without
7
+ * modification, are permitted provided that the following conditions are met:
8
+ *
9
+ * 1. Redistributions of source code must retain the above copyright notice, this
10
+ * list of conditions and the following disclaimer.
11
+ *
12
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ * this list of conditions and the following disclaimer in the documentation
14
+ * and/or other materials provided with the distribution.
15
+ *
16
+ * 3. Neither the name of the copyright holder nor the names of its
17
+ * contributors may be used to endorse or promote products derived from
18
+ * this software without specific prior written permission.
19
+ *
20
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ *
31
+ **************************************************************************************************/
32
+ #pragma once
33
+
34
+ #include <cmath>
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "cutlass/complex.h"
38
+ #include "cutlass/util/reference/host/tensor_reduce.h"
39
+ #include "cutlass/core_io.h"
40
+
41
+ namespace cutlass {
42
+ namespace reference {
43
+ namespace host {
44
+
45
+ /// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference
46
+ template <
47
+ typename Element,
48
+ typename Layout,
49
+ typename ComputeType = double
50
+ >
51
+ ComputeType TensorRelativeErrorMetric(
52
+ TensorView<Element, Layout> view_A_computed,
53
+ TensorView<Element, Layout> view_B_reference,
54
+ ComputeType identity = ComputeType()
55
+ ) {
56
+
57
+ return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) /
58
+ cutlass::reference::host::TensorNorm(view_B_reference, identity);
59
+ }
60
+
61
+
62
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
63
+
64
+ } // namespace host
65
+ } // namespace reference
66
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for GEMM in host-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/numeric_types.h"
39
+ #include "cutlass/functional.h"
40
+ #include "cutlass/numeric_conversion.h"
41
+
42
+ #include "cutlass/tensor_view.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+ #include "cutlass/arch/mma.h"
45
+ #include "cutlass/util/host_tensor.h"
46
+
47
+ namespace cutlass {
48
+ namespace reference {
49
+ namespace host {
50
+
51
+ template<typename Out, typename In>
52
+ struct CastIfScalar {
53
+ static Out cast(In in) {
54
+ return Out(in);
55
+ }
56
+ };
57
+
58
+ template<typename OutScalar, typename In>
59
+ struct CastIfScalar<cutlass::complex<OutScalar>, In> {
60
+ typedef cutlass::complex<OutScalar> Out;
61
+ static Out cast(In in) {
62
+ return Out(static_cast<OutScalar>(in));
63
+ }
64
+ };
65
+
66
+ template<typename OutScalar, typename InScalar>
67
+ struct CastIfScalar<cutlass::complex<OutScalar>, cutlass::complex<InScalar>> {
68
+ typedef cutlass::complex<OutScalar> Out;
69
+ typedef cutlass::complex<InScalar> In;
70
+ static Out cast(In in) {
71
+ return Out(in);
72
+ }
73
+ };
74
+
75
+ template<typename Out, typename In>
76
+ Out cast_if_scalar(In in) {
77
+ return CastIfScalar<Out, In>::cast(in);
78
+ }
79
+
80
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
81
+
82
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
83
+ /// objects.
84
+ template <
85
+ typename ElementA,
86
+ typename LayoutA,
87
+ typename ElementB,
88
+ typename LayoutB,
89
+ typename ElementC,
90
+ typename LayoutC,
91
+ typename ScalarType,
92
+ typename ComputeType,
93
+ typename InnerProductOp = multiply_add<ComputeType>,
94
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
95
+ >
96
+ void compute_gemm(
97
+ gemm::GemmCoord problem_size,
98
+ ScalarType alpha,
99
+ TensorRef<ElementA, LayoutA> tensor_a,
100
+ TensorRef<ElementB, LayoutB> tensor_b,
101
+ ScalarType beta,
102
+ TensorRef<ElementC, LayoutC> tensor_c,
103
+ TensorRef<ElementC, LayoutC> tensor_d,
104
+ ComputeType initial_accum) {
105
+
106
+ static_assert(
107
+ LayoutA::kRank == 2 &&
108
+ LayoutB::kRank == 2 &&
109
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
110
+
111
+
112
+ // Note: batch is ignored.
113
+ int const M = problem_size.m();
114
+ int const N = problem_size.n();
115
+ int const K = problem_size.k();
116
+
117
+ // Blocking necessary to speedup reference implementation
118
+ int const Mblock = 16;
119
+ int const Nblock = 16;
120
+
121
+ ConvertOp convert_op;
122
+ InnerProductOp inner_product_op;
123
+
124
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
125
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
126
+
127
+ ComputeType accum[Mblock][Nblock];
128
+
129
+ for (int j = 0; j < Nblock; j++) {
130
+ for (int i = 0; i < Mblock; i++) {
131
+ accum[i][j] = initial_accum;
132
+ }
133
+ }
134
+
135
+ for (int k_block = 0; k_block < K; ++k_block) {
136
+ for (int j = 0; j < Nblock; j++) {
137
+ for (int i = 0; i < Mblock; i++) {
138
+ int row = row_block + i;
139
+ int col = col_block + j;
140
+
141
+ if (row < M && col < N) {
142
+ ElementA a = tensor_a.at(MatrixCoord(row, k_block));
143
+ ElementB b = tensor_b.at(MatrixCoord(k_block, col));
144
+
145
+ ComputeType compute_a(cast_if_scalar<ComputeType>(a));
146
+ ComputeType compute_b(cast_if_scalar<ComputeType>(b));
147
+
148
+ accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]);
149
+ }
150
+ }
151
+ }
152
+ }
153
+
154
+ for (int j = 0; j < Nblock; j++) {
155
+ for (int i = 0; i < Mblock; i++) {
156
+ int row = row_block + i;
157
+ int col = col_block + j;
158
+
159
+ MatrixCoord coord = MatrixCoord(row, col);
160
+
161
+ if (row < M && col < N) {
162
+ tensor_d.at(coord) = convert_op(
163
+ alpha * ScalarType(accum[i][j]) +
164
+ beta * ScalarType(tensor_c.at(coord)));
165
+ }
166
+ }
167
+ }
168
+ }
169
+ }
170
+ }
171
+
172
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
173
+
174
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
175
+ /// objects.
176
+ template <
177
+ typename ElementA,
178
+ typename LayoutA,
179
+ typename ElementB,
180
+ typename LayoutB,
181
+ typename ElementC,
182
+ typename LayoutC,
183
+ typename ScalarType,
184
+ typename ComputeType,
185
+ typename InnerProductOp = multiply_add<ComputeType>,
186
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
187
+ >
188
+ void compute_gemm(
189
+ gemm::GemmCoord problem_size,
190
+ ScalarType alpha,
191
+ TensorRef<ElementA, LayoutA> tensor_a,
192
+ TensorRef<ElementB, LayoutB> tensor_b,
193
+ ScalarType beta,
194
+ TensorRef<ElementC, LayoutC> tensor_c,
195
+ ComputeType initial_accum) {
196
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
197
+ ScalarType, ComputeType, InnerProductOp, ConvertOp>(
198
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
199
+ initial_accum);
200
+ }
201
+
202
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
203
+
204
+ template <
205
+ typename ElementA,
206
+ typename LayoutA,
207
+ typename ElementB,
208
+ typename LayoutB,
209
+ typename ElementC,
210
+ typename LayoutC,
211
+ typename ScalarType,
212
+ typename ComputeType,
213
+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd
214
+ >
215
+ struct Gemm;
216
+
217
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
218
+
219
+ /// Partial specialization for multiply-add
220
+ template <typename ElementA, typename LayoutA, typename ElementB,
221
+ typename LayoutB, typename ElementC, typename LayoutC,
222
+ typename ScalarType, typename ComputeType>
223
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
224
+ ComputeType, arch::OpMultiplyAdd> {
225
+
226
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
227
+ TensorRef<ElementA, LayoutA> tensor_a,
228
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
229
+ TensorRef<ElementC, LayoutC> tensor_c,
230
+ ComputeType initial_accum = ComputeType(0)) {
231
+ static_assert(
232
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
233
+ "Tensors must be of rank 2");
234
+
235
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
236
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
237
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
238
+ }
239
+
240
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
241
+ TensorRef<ElementA, LayoutA> tensor_a,
242
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
243
+ TensorRef<ElementC, LayoutC> tensor_c,
244
+ TensorRef<ElementC, LayoutC> tensor_d,
245
+ ComputeType initial_accum = ComputeType(0)) {
246
+ static_assert(
247
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
248
+ "Tensors must be of rank 2");
249
+
250
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
251
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
252
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
253
+ }
254
+ };
255
+
256
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
257
+
258
+ /// Partial specialization for multiply-add
259
+ template <typename ElementA, typename LayoutA, typename ElementB,
260
+ typename LayoutB, typename ElementC, typename LayoutC,
261
+ typename ScalarType, typename ComputeType>
262
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
263
+ ComputeType, arch::OpMultiplyAddFastBF16> {
264
+
265
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
266
+ TensorRef<ElementA, LayoutA> tensor_a,
267
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
268
+ TensorRef<ElementC, LayoutC> tensor_c,
269
+ ComputeType initial_accum = ComputeType(0)) {
270
+ static_assert(
271
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
272
+ "Tensors must be of rank 2");
273
+
274
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
275
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
276
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
277
+ }
278
+
279
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
280
+ TensorRef<ElementA, LayoutA> tensor_a,
281
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
282
+ TensorRef<ElementC, LayoutC> tensor_c,
283
+ TensorRef<ElementC, LayoutC> tensor_d,
284
+ ComputeType initial_accum = ComputeType(0)) {
285
+ static_assert(
286
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
287
+ "Tensors must be of rank 2");
288
+
289
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
290
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
291
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
292
+ }
293
+ };
294
+
295
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
296
+
297
+ /// Partial specialization for multiply-add-saturate
298
+ template <typename ElementA, typename LayoutA, typename ElementB,
299
+ typename LayoutB, typename ElementC, typename LayoutC,
300
+ typename ScalarType, typename ComputeType>
301
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
302
+ ComputeType, arch::OpMultiplyAddSaturate> {
303
+
304
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
305
+ TensorRef<ElementA, LayoutA> tensor_a,
306
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
307
+ TensorRef<ElementC, LayoutC> tensor_c,
308
+ ComputeType initial_accum = ComputeType(0)) {
309
+ static_assert(
310
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
311
+ "Tensors must be of rank 2");
312
+
313
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
314
+ ScalarType, ComputeType, multiply_add<ComputeType>,
315
+ NumericConverterClamp<ElementC, ScalarType>>(
316
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
317
+ }
318
+
319
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
320
+ TensorRef<ElementA, LayoutA> tensor_a,
321
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
322
+ TensorRef<ElementC, LayoutC> tensor_c,
323
+ TensorRef<ElementC, LayoutC> tensor_d,
324
+ ComputeType initial_accum = ComputeType(0)) {
325
+ static_assert(
326
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
327
+ "Tensors must be of rank 2");
328
+
329
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
330
+ ScalarType, ComputeType, multiply_add<ComputeType>,
331
+ NumericConverterClamp<ElementC, ScalarType>>(
332
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
333
+ }
334
+ };
335
+
336
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
337
+
338
+ /// Partial specialization for XOR-popc
339
+ template <typename ElementA, typename LayoutA, typename ElementB,
340
+ typename LayoutB, typename ElementC, typename LayoutC,
341
+ typename ScalarType, typename ComputeType>
342
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
343
+ ComputeType, arch::OpXorPopc> {
344
+
345
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
346
+ TensorRef<ElementA, LayoutA> tensor_a,
347
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
348
+ TensorRef<ElementC, LayoutC> tensor_c,
349
+ ComputeType initial_accum = ComputeType(0)) {
350
+ static_assert(
351
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
352
+ "Tensors must be of rank 2");
353
+
354
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
355
+ ScalarType, ComputeType, xor_popc_add<ComputeType>>(
356
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
357
+ }
358
+
359
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
360
+ TensorRef<ElementA, LayoutA> tensor_a,
361
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
362
+ TensorRef<ElementC, LayoutC> tensor_c,
363
+ TensorRef<ElementC, LayoutC> tensor_d,
364
+ ComputeType initial_accum = ComputeType(0)) {
365
+ static_assert(
366
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
367
+ "Tensors must be of rank 2");
368
+
369
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
370
+ ScalarType, ComputeType, xor_popc_add<ComputeType>>(
371
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
372
+ }
373
+ };
374
+
375
+ /// Partial specialization for AND-popc
376
+ template <typename ElementA, typename LayoutA, typename ElementB,
377
+ typename LayoutB, typename ElementC, typename LayoutC,
378
+ typename ScalarType, typename ComputeType>
379
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
380
+ ComputeType, arch::OpAndPopc> {
381
+
382
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
383
+ TensorRef<ElementA, LayoutA> tensor_a,
384
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
385
+ TensorRef<ElementC, LayoutC> tensor_c,
386
+ ComputeType initial_accum = ComputeType(0)) {
387
+ static_assert(
388
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
389
+ "Tensors must be of rank 2");
390
+
391
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
392
+ ScalarType, ComputeType, and_popc_add<ComputeType>>(
393
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
394
+ }
395
+
396
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
397
+ TensorRef<ElementA, LayoutA> tensor_a,
398
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
399
+ TensorRef<ElementC, LayoutC> tensor_c,
400
+ TensorRef<ElementC, LayoutC> tensor_d,
401
+ ComputeType initial_accum = ComputeType(0)) {
402
+ static_assert(
403
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
404
+ "Tensors must be of rank 2");
405
+
406
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
407
+ ScalarType, ComputeType, and_popc_add<ComputeType>>(
408
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
409
+ }
410
+ };
411
+
412
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
413
+
414
+ /// Partial specialization for multiply-add
415
+ template <typename ElementA, typename LayoutA, typename ElementB,
416
+ typename LayoutB, typename ElementC, typename LayoutC,
417
+ typename ScalarType, typename ComputeType>
418
+ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
419
+ ComputeType, arch::OpMultiplyAddFastF32> {
420
+
421
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
422
+ TensorRef<ElementA, LayoutA> tensor_a,
423
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
424
+ TensorRef<ElementC, LayoutC> tensor_c,
425
+ ComputeType initial_accum = ComputeType(0)) {
426
+ static_assert(
427
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
428
+ "Tensors must be of rank 2");
429
+
430
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
431
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
432
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
433
+ }
434
+
435
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
436
+ TensorRef<ElementA, LayoutA> tensor_a,
437
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
438
+ TensorRef<ElementC, LayoutC> tensor_c,
439
+ TensorRef<ElementC, LayoutC> tensor_d,
440
+ ComputeType initial_accum = ComputeType(0)) {
441
+ static_assert(
442
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
443
+ "Tensors must be of rank 2");
444
+
445
+ compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
446
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
447
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
448
+ }
449
+ };
450
+
451
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
452
+
453
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
454
+ //
455
+ // Batched GEMM
456
+ //
457
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
458
+
459
+ /// Computes a batch of GEMMs over a set of matrices of common dimension.
460
+ //
461
+ // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
462
+ //
463
+ template <
464
+ typename TensorRefCollectionA,
465
+ typename TensorRefCollectionB,
466
+ typename TensorRefCollectionC,
467
+ typename ScalarType,
468
+ typename AccumulatorType
469
+ >
470
+ void BatchedGemm(
471
+ gemm::GemmCoord problem_size,
472
+ int batch_count,
473
+ ScalarType alpha,
474
+ TensorRefCollectionA const& tensor_a,
475
+ TensorRefCollectionB const& tensor_b,
476
+ ScalarType beta,
477
+ TensorRefCollectionC &tensor_c,
478
+ AccumulatorType initial_accum) {
479
+
480
+ typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
481
+ typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
482
+ typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin();
483
+
484
+ for (int batch = 0;
485
+ batch < batch_count;
486
+ ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) {
487
+
488
+ Gemm<typename TensorRefCollectionA::Element,
489
+ typename TensorRefCollectionA::Layout,
490
+ typename TensorRefCollectionB::Element,
491
+ typename TensorRefCollectionB::Layout,
492
+ typename TensorRefCollectionC::Element,
493
+ typename TensorRefCollectionC::Layout,
494
+ typename TensorRefCollectionC::Element,
495
+ typename TensorRefCollectionC::Element>
496
+ gemm;
497
+
498
+ gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it,
499
+ initial_accum);
500
+ }
501
+ }
502
+
503
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
504
+ /// objects.
505
+ //
506
+ // TensorRefCollection* is a type satisfying the TensorRefCollection concept.
507
+ //
508
+ template <
509
+ typename TensorRefCollectionA,
510
+ typename TensorRefCollectionB,
511
+ typename TensorRefCollectionC,
512
+ typename ScalarType,
513
+ typename AccumulatorType
514
+ >
515
+ void BatchedGemm(
516
+ gemm::GemmCoord problem_size,
517
+ int batch_count,
518
+ ScalarType alpha,
519
+ TensorRefCollectionA const& tensor_a,
520
+ TensorRefCollectionB const& tensor_b,
521
+ ScalarType beta,
522
+ TensorRefCollectionC &tensor_c) {
523
+
524
+ BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
525
+ }
526
+
527
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
528
+
529
+ } // namespace host
530
+ } // namespace reference
531
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued GEMM in host-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/functional.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/matrix_coord.h"
43
+
44
+ #include "cutlass/tensor_view.h"
45
+
46
+ #include "cutlass/gemm/gemm.h"
47
+
48
+ namespace cutlass {
49
+ namespace reference {
50
+ namespace host {
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
55
+ /// objects.
56
+ ///
57
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
58
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
59
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
60
+ /// arguments explicitly.
61
+ template <
62
+ typename ElementA,
63
+ typename LayoutA,
64
+ typename ElementB,
65
+ typename LayoutB,
66
+ typename ElementC,
67
+ typename LayoutC,
68
+ typename ScalarType,
69
+ typename ComputeType,
70
+ typename ElementD = ElementC,
71
+ typename ConvertOp = NumericConverter<ElementD, ScalarType>,
72
+ typename InnerProductOp = multiply_add<ComputeType>
73
+ >
74
+ void GemmComplex(
75
+ gemm::GemmCoord problem_size,
76
+ ScalarType alpha,
77
+ TensorRef<ElementA, LayoutA> tensor_a,
78
+ ComplexTransform transform_a,
79
+ TensorRef<ElementB, LayoutB> tensor_b,
80
+ ComplexTransform transform_b,
81
+ ScalarType beta,
82
+ TensorRef<ElementC, LayoutC> tensor_c,
83
+ TensorRef<ElementD, LayoutC> tensor_d,
84
+ ComputeType initial_accum,
85
+ int batch_count = 1,
86
+ int64_t batch_stride_A = 0,
87
+ int64_t batch_stride_B = 0,
88
+ int64_t batch_stride_C = 0,
89
+ int64_t batch_stride_D = 0) {
90
+
91
+ static_assert(
92
+ LayoutA::kRank == 2 &&
93
+ LayoutB::kRank == 2 &&
94
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
95
+
96
+ // Note: batch is ignored.
97
+ int const M = problem_size.m();
98
+ int const N = problem_size.n();
99
+ int const K = problem_size.k();
100
+
101
+ // Blocking necessary to speedup reference implementation
102
+ int const Mblock = 16;
103
+ int const Nblock = 16;
104
+
105
+ ConvertOp convert_op;
106
+ InnerProductOp inner_product_op;
107
+
108
+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
109
+
110
+ // Compute matrix product using blocks
111
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
112
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
113
+
114
+ ComputeType accum[Mblock][Nblock];
115
+
116
+ for (int j = 0; j < Nblock; j++) {
117
+ for (int i = 0; i < Mblock; i++) {
118
+ accum[i][j] = initial_accum;
119
+ }
120
+ }
121
+
122
+ for (int k_block = 0; k_block < K; ++k_block) {
123
+ for (int j = 0; j < Nblock; j++) {
124
+ for (int i = 0; i < Mblock; i++) {
125
+ int row = row_block + i;
126
+ int col = col_block + j;
127
+
128
+ if (row < M && col < N) {
129
+ ElementA a = tensor_a.at(MatrixCoord(row, k_block));
130
+ ElementB b = tensor_b.at(MatrixCoord(k_block, col));
131
+
132
+ ComputeType a_ik = ComputeType(a);
133
+ ComputeType b_kj = ComputeType(b);
134
+
135
+ if (transform_a == ComplexTransform::kConjugate) {
136
+ a_ik = conj(a_ik);
137
+ }
138
+
139
+ if (transform_b == ComplexTransform::kConjugate) {
140
+ b_kj = conj(b_kj);
141
+ }
142
+
143
+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
144
+ }
145
+ }
146
+ }
147
+ }
148
+
149
+ for (int j = 0; j < Nblock; j++) {
150
+ for (int i = 0; i < Mblock; i++) {
151
+ int row = row_block + i;
152
+ int col = col_block + j;
153
+
154
+ MatrixCoord coord = MatrixCoord(row, col);
155
+
156
+ if (row < M && col < N) {
157
+
158
+ tensor_d.at(coord) = convert_op(
159
+ alpha * ScalarType(accum[i][j]) +
160
+ beta * ScalarType(tensor_c.at(coord)));
161
+ }
162
+ }
163
+ }
164
+
165
+ } // for (col_block)
166
+ } // for (row_block)
167
+
168
+ tensor_a.add_pointer_offset(batch_stride_A);
169
+ tensor_b.add_pointer_offset(batch_stride_B);
170
+ tensor_c.add_pointer_offset(batch_stride_C);
171
+ tensor_d.add_pointer_offset(batch_stride_D);
172
+
173
+ } // for (batch_idx)
174
+ }
175
+
176
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
177
+
178
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
179
+ /// objects.
180
+ ///
181
+ /// This assumes the accumulator type is the same type as the scalars.
182
+ template <
183
+ typename ElementA,
184
+ typename LayoutA,
185
+ typename ElementB,
186
+ typename LayoutB,
187
+ typename ElementC,
188
+ typename LayoutC,
189
+ typename ScalarType,
190
+ typename ElementD = ElementC
191
+ >
192
+ void GemmComplex(
193
+ gemm::GemmCoord problem_size,
194
+ ScalarType alpha,
195
+ TensorRef<ElementA, LayoutA> tensor_a,
196
+ ComplexTransform transform_a,
197
+ TensorRef<ElementB, LayoutB> tensor_b,
198
+ ComplexTransform transform_b,
199
+ ScalarType beta,
200
+ TensorRef<ElementC, LayoutC> tensor_c,
201
+ TensorRef<ElementD, LayoutC> tensor_d) {
202
+
203
+ GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0));
204
+ }
205
+
206
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
207
+
208
+ } // namespace host
209
+ } // namespace reference
210
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued GEMM in host-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/coord.h"
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/functional.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/tensor_ref_planar_complex.h"
43
+
44
+ #include "cutlass/tensor_view.h"
45
+ #include "cutlass/gemm/gemm.h"
46
+
47
+ namespace cutlass {
48
+ namespace reference {
49
+ namespace host {
50
+
51
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
54
+ /// objects.
55
+ ///
56
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
57
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
58
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
59
+ /// arguments explicitly.
60
+ template <
61
+ typename ElementA,
62
+ typename LayoutA,
63
+ typename ElementB,
64
+ typename LayoutB,
65
+ typename ElementC,
66
+ typename LayoutC,
67
+ typename ScalarType,
68
+ typename ComputeType,
69
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>,
70
+ typename InnerProductOp = multiply_add<complex<ComputeType>>
71
+ >
72
+ void GemmPlanarComplex(
73
+ gemm::GemmCoord problem_size,
74
+ complex<ScalarType> alpha,
75
+ TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
76
+ ComplexTransform transform_a,
77
+ TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
78
+ ComplexTransform transform_b,
79
+ complex<ScalarType> beta,
80
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
81
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_d,
82
+ complex<ComputeType> initial_accum) {
83
+
84
+ static_assert(
85
+ LayoutA::kRank == 2 &&
86
+ LayoutB::kRank == 2 &&
87
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
88
+
89
+ using ComplexA = typename TensorRefPlanarComplex<ElementA, LayoutA>::ComplexElement;
90
+ using ComplexB = typename TensorRefPlanarComplex<ElementB, LayoutB>::ComplexElement;
91
+ using ComplexC = typename TensorRefPlanarComplex<ElementC, LayoutC>::ComplexElement;
92
+
93
+ // Note: batch is ignored.
94
+ int const M = problem_size.m();
95
+ int const N = problem_size.n();
96
+ int const K = problem_size.k();
97
+
98
+ // Blocking necessary to speedup reference implementation
99
+ int const Mblock = 16;
100
+ int const Nblock = 16;
101
+
102
+ ConvertOp convert_op;
103
+ InnerProductOp inner_product_op;
104
+
105
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
106
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
107
+
108
+ complex<ComputeType> accum[Mblock][Nblock];
109
+
110
+ for (int j = 0; j < Nblock; j++) {
111
+ for (int i = 0; i < Mblock; i++) {
112
+ accum[i][j] = initial_accum;
113
+ }
114
+ }
115
+
116
+ for (int k_block = 0; k_block < K; ++k_block) {
117
+ for (int j = 0; j < Nblock; j++) {
118
+ for (int i = 0; i < Mblock; i++) {
119
+ int row = row_block + i;
120
+ int col = col_block + j;
121
+
122
+ if (row < M && col < N) {
123
+
124
+ ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block));
125
+ ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col));
126
+
127
+ complex<ComputeType> a = complex<ComputeType>{
128
+ ComputeType(a_ik.real()),
129
+ ComputeType(a_ik.imag())
130
+ };
131
+
132
+ complex<ComputeType> b = complex<ComputeType>{
133
+ ComputeType(b_kj.real()),
134
+ ComputeType(b_kj.imag())
135
+ };
136
+
137
+ if (transform_a == ComplexTransform::kConjugate) {
138
+ a = conj(a);
139
+ }
140
+
141
+ if (transform_b == ComplexTransform::kConjugate) {
142
+ b = conj(b);
143
+ }
144
+
145
+ accum[i][j] = inner_product_op(a, b, accum[i][j]);
146
+ }
147
+ }
148
+ }
149
+ }
150
+
151
+ for (int j = 0; j < Nblock; j++) {
152
+ for (int i = 0; i < Mblock; i++) {
153
+ int row = row_block + i;
154
+ int col = col_block + j;
155
+
156
+ MatrixCoord coord = MatrixCoord(row, col);
157
+
158
+ if (row < M && col < N) {
159
+
160
+ complex<ScalarType> acc{
161
+ ScalarType(accum[i][j].real()),
162
+ ScalarType(accum[i][j].imag())
163
+ };
164
+
165
+ ComplexC d_ij = tensor_c.at(coord);
166
+
167
+ complex<ScalarType> src{
168
+ ScalarType(d_ij.real()),
169
+ ScalarType(d_ij.imag())
170
+ };
171
+
172
+ complex<ScalarType> result = alpha * acc + beta * src;
173
+
174
+ d_ij.real() = convert_op(result.real());
175
+ d_ij.imag() = convert_op(result.imag());
176
+
177
+ tensor_d.at(coord) = d_ij;
178
+ }
179
+ }
180
+ }
181
+ }
182
+ }
183
+ }
184
+
185
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
186
+
187
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
188
+ /// objects.
189
+ ///
190
+ /// This assumes the accumulator type is the same type as the scalars.
191
+ template <
192
+ typename ElementA,
193
+ typename LayoutA,
194
+ typename ElementB,
195
+ typename LayoutB,
196
+ typename ElementC,
197
+ typename LayoutC,
198
+ typename ScalarType
199
+ >
200
+ void GemmPlanarComplex(
201
+ gemm::GemmCoord problem_size,
202
+ complex<ScalarType> alpha,
203
+ TensorRefPlanarComplex<ElementA, LayoutA> tensor_a,
204
+ ComplexTransform transform_a,
205
+ TensorRefPlanarComplex<ElementB, LayoutB> tensor_b,
206
+ ComplexTransform transform_b,
207
+ complex<ScalarType> beta,
208
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_c,
209
+ TensorRefPlanarComplex<ElementC, LayoutC> tensor_d) {
210
+
211
+ GemmPlanarComplex(
212
+ problem_size,
213
+ alpha,
214
+ tensor_a, transform_a,
215
+ tensor_b, transform_b,
216
+ beta,
217
+ tensor_c,
218
+ tensor_d,
219
+ complex<ScalarType>());
220
+ }
221
+
222
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
223
+
224
+ } // namespace host
225
+ } // namespace reference
226
+ } // namespace cutlass
227
+
228
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for GETT in host-side code.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ /////////////////////////////////////////////////////////////////////////////////////////////////
38
+ #include "cutlass/gemm/gemm.h"
39
+ #include "cutlass/complex.h"
40
+ #include "cutlass/numeric_conversion.h"
41
+ #include "cutlass/epilogue/thread/activation.h"
42
+ #include "cutlass/relatively_equal.h"
43
+
44
+ #include "cute/tensor.hpp"
45
+ #include "cute/pointer.hpp"
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass::reference::host {
50
+
51
+ template<class T, class = void>
52
+ struct ElementTraits {
53
+ using type = T;
54
+ };
55
+
56
+ template<class T>
57
+ struct ElementTraits<T, std::enable_if_t<!std::is_same_v<decltype(std::declval<T>().get()), void> > > {
58
+ using type = decltype(std::declval<T>().get());
59
+ };
60
+
61
+ /////////////////////////////////////////////////////////////////////////////////////////////////
62
+
63
+ ///////////////////////////////////////////////////////////
64
+ //
65
+ // Gett Mainloop Parameters
66
+ //
67
+ ///////////////////////////////////////////////////////////
68
+
69
+ template<
70
+ class ElementAccumulator_,
71
+ class TensorA_, // (M, K, L)
72
+ class TensorB_ // (N, K, L)
73
+
74
+ , class TensorSfA_ = TensorA_,
75
+ class TensorSfB_ = TensorB_
76
+
77
+ >
78
+ struct GettMainloopParams {
79
+ using ElementAccumulator = ElementAccumulator_;
80
+ using TensorA = TensorA_;
81
+ using TensorB = TensorB_;
82
+ using EngineA = typename TensorA::engine_type;
83
+ using LayoutA = typename TensorA::layout_type;
84
+ using EngineB = typename TensorB::engine_type;
85
+ using LayoutB = typename TensorB::layout_type;
86
+
87
+ TensorA A{};
88
+ TensorB B{};
89
+
90
+ ComplexTransform transform_A = ComplexTransform::kNone;
91
+ ComplexTransform transform_B = ComplexTransform::kNone;
92
+
93
+
94
+ using TensorSfA = TensorSfA_;
95
+ using TensorSfB = TensorSfB_;
96
+ using EngineSfA = typename TensorSfA::engine_type;
97
+ using LayoutSfA = typename TensorSfA::layout_type;
98
+ using EngineSfB = typename TensorSfB::engine_type;
99
+ using LayoutSfB = typename TensorSfB::layout_type;
100
+ TensorSfA_ SfA{};
101
+ TensorSfB_ SfB{};
102
+
103
+
104
+ GettMainloopParams() {}
105
+
106
+ GettMainloopParams(TensorA tensor_A, TensorB tensor_B)
107
+ : A(tensor_A), B(tensor_B) {}
108
+
109
+
110
+ GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
111
+ : A(tensor_A), SfA(tensor_SfA),
112
+ B(tensor_B), SfB(tensor_SfB) {}
113
+
114
+
115
+ };
116
+
117
+
118
+
119
+ ////////////////////////////////////////////////////////////////////////
120
+ //
121
+ // Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels
122
+ //
123
+ ////////////////////////////////////////////////////////////////////////
124
+
125
+ template<
126
+ class ElementAccumulator_,
127
+ class TensorA_, // (M, K, L)
128
+ class TensorSfA_, // (M, K, L)
129
+ class TensorB_, // (N, K, L)
130
+ class TensorSfB_ // (N, K, L)
131
+ >
132
+ struct GettBlockScalingMainloopParams : public GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_> {
133
+ using Base = GettMainloopParams<ElementAccumulator_, TensorA_, TensorB_, TensorSfA_, TensorSfB_>;
134
+ using ElementAccumulator = typename Base::ElementAccumulator;
135
+ using TensorA = typename Base::TensorA;
136
+ using TensorB = typename Base::TensorB;
137
+ using EngineA = typename Base::EngineA;
138
+ using LayoutA = typename Base::LayoutA;
139
+ using EngineB = typename Base::EngineB;
140
+ using LayoutB = typename Base::LayoutB;
141
+ ComplexTransform transform_A = Base::transform_A;
142
+ ComplexTransform transform_B = Base::transform_B;
143
+
144
+ using TensorSfA = typename Base::TensorSfA;
145
+ using TensorSfB = typename Base::TensorSfB;
146
+ using EngineSfA = typename Base::EngineSfA;
147
+ using LayoutSfA = typename Base::LayoutSfA;
148
+ using EngineSfB = typename Base::EngineSfB;
149
+ using LayoutSfB = typename Base::LayoutSfB;
150
+
151
+ GettBlockScalingMainloopParams() {}
152
+
153
+ GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB)
154
+ : Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {}
155
+
156
+
157
+ };
158
+
159
+
160
+ /////////////////////////////////////////////////////////////////////////////////////////////////
161
+
162
+ enum class SfStrategy {
163
+ None = 0,
164
+ SfDGen = 1
165
+ };
166
+
167
+
168
+ ///////////////////////////////////////////////////////////
169
+ //
170
+ // Gett Epilogue Parameters
171
+ //
172
+ ///////////////////////////////////////////////////////////
173
+
174
+ template<
175
+ class ElementScalar_,
176
+ class ElementScalingFactor_,
177
+ class ElementAccumulator_,
178
+ class ElementCompute_,
179
+ class TensorC_, // (M, N, L)
180
+ class TensorD_, // (M, N, L)
181
+ class VectorBias_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
182
+ class TensorAux_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, N, L)
183
+ class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // (M, 1)
184
+ class VectorBeta_ = VectorAlpha_, // (M, 1)
185
+ class ActivationFunctor_ = cutlass::epilogue::thread::Identity<ElementCompute_>,
186
+ class TensorSFD_ = TensorD_,
187
+ class SFD_VectorSize_ = cute::Int<0>,
188
+ class BiasBinaryOp_ = cutlass::plus<ElementCompute_>,
189
+ bool PerColumnBias_ = false
190
+ ,
191
+ SfStrategy SfGenStrategy_ = SfStrategy::None
192
+ >
193
+ struct GettEpilogueParams {
194
+ using ElementScalar = ElementScalar_;
195
+ using ElementScalingFactor = ElementScalingFactor_;
196
+ using ElementAccumulator = ElementAccumulator_;
197
+ using ElementCompute = ElementCompute_;
198
+ using TensorC = TensorC_;
199
+ using TensorD = TensorD_;
200
+ using TensorAux = TensorAux_;
201
+ using VectorBias = VectorBias_;
202
+ using VectorAlpha = VectorAlpha_;
203
+ using VectorBeta = VectorBeta_;
204
+ using TensorSFD = TensorSFD_;
205
+ using SFD_VectorSize = SFD_VectorSize_;
206
+ using ActivationFunctor = ActivationFunctor_;
207
+ using BiasBinaryOp = BiasBinaryOp_;
208
+
209
+ using EngineC = typename TensorC::engine_type;
210
+ using LayoutC = typename TensorC::layout_type;
211
+ using EngineD = typename TensorD::engine_type;
212
+ using LayoutD = typename TensorD::layout_type;
213
+ using EngineSfD = typename TensorSFD::engine_type;
214
+ using LayoutSfD = typename TensorSFD::layout_type;
215
+ static constexpr bool PerColumnBias = PerColumnBias_;
216
+ static constexpr SfStrategy SfGenStrategy = SfGenStrategy_;
217
+
218
+ ElementScalar alpha = ElementScalar(1);
219
+ ElementScalar beta = ElementScalar(0);
220
+
221
+ TensorC C{};
222
+ TensorD D{};
223
+ VectorBias Bias{};
224
+ TensorAux Aux{};
225
+ VectorAlpha Valpha{};
226
+ VectorBeta Vbeta{};
227
+ TensorSFD SfD{};
228
+ ElementCompute st = ElementCompute(1);
229
+
230
+ ElementAccumulator* abs_max_D = nullptr;
231
+ ElementAccumulator* abs_max_Aux = nullptr;
232
+
233
+ ElementScalingFactor scale_a = ElementScalingFactor(1);
234
+ ElementScalingFactor scale_b = ElementScalingFactor(1);
235
+ ElementScalingFactor scale_c = ElementScalingFactor(1);
236
+ ElementScalingFactor scale_d = ElementScalingFactor(1);
237
+ ElementScalingFactor scale_aux = ElementScalingFactor(1);
238
+
239
+ bool beta_per_channel_scaling = false;
240
+ GettEpilogueParams() {}
241
+
242
+ GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
243
+ : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {}
244
+
245
+
246
+ GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
247
+ : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {}
248
+
249
+
250
+ GettEpilogueParams(
251
+ ElementScalar alpha, ElementScalar beta,
252
+ TensorC tensor_C, TensorD tensor_D,
253
+ VectorBias bias, TensorAux tensor_aux,
254
+ VectorAlpha vector_alpha, VectorBeta vector_beta)
255
+ : alpha(alpha), beta(beta),
256
+ C(tensor_C), D(tensor_D),
257
+ Bias(bias), Aux(tensor_aux),
258
+ Valpha(vector_alpha), Vbeta(vector_beta) {}
259
+ };
260
+
261
+
262
+
263
+ ////////////////////////////////////////////////////////////////////////
264
+ //
265
+ // Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels
266
+ //
267
+ ////////////////////////////////////////////////////////////////////////
268
+
269
+ template<
270
+ class ElementScalar_,
271
+ class ElementAccumulator_,
272
+ class ElementCompute_,
273
+ class TensorC_,
274
+ class TensorD_,
275
+ class TensorSfD_ = TensorD_,
276
+ class SFD_VectorSize_ = cute::Int<0>,
277
+ SfStrategy SfGenStrategy_ = SfStrategy::None
278
+ >
279
+ struct GettBlockScalingEpilogueParams : public GettEpilogueParams<
280
+ ElementScalar_, // ElementScalar
281
+ ElementScalar_, // ElementScalingFactor
282
+ ElementAccumulator_, // ElementAccumulator
283
+ ElementCompute_, // ElementCompute
284
+ TensorC_, // TensorC (M, N, L)
285
+ TensorD_, // TensorD (M, N, L)
286
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
287
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
288
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
289
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
290
+ cutlass::epilogue::thread::Identity<ElementCompute_>, //
291
+ TensorSfD_, // TensorSfD
292
+ SFD_VectorSize_, // SFD_VectorSize
293
+ cutlass::plus<ElementCompute_>, // class BiasBinaryOp_ =
294
+ false, //PerColumnBias_
295
+ SfGenStrategy_ // SfGenStrategy
296
+ > {
297
+ using Base = GettEpilogueParams<
298
+ ElementScalar_, // ElementScalar
299
+ ElementScalar_, // ElementScalingFactor
300
+ ElementAccumulator_, // ElementAccumulator
301
+ ElementCompute_, // ElementCompute
302
+ TensorC_, // TensorC (M, N, L)
303
+ TensorD_, // TensorD (M, N, L)
304
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1)
305
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L)
306
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1)
307
+ decltype(make_tensor(cute::recast_ptr<ElementCompute_>(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1)
308
+ cutlass::epilogue::thread::Identity<ElementCompute_>, //
309
+ TensorSfD_, // TensorSfD
310
+ SFD_VectorSize_, // SFD_VectorSize
311
+ cutlass::plus<ElementCompute_>, // BiasBinaryOp
312
+ false, // PerColumnBias
313
+ SfGenStrategy_ // SfGenStrategy
314
+ >;
315
+ using ElementScalar = typename Base::ElementScalar;
316
+ using ElementScalingFactor = typename Base::ElementScalingFactor;
317
+ using ElementAccumulator = typename Base::ElementAccumulator;
318
+ using ElementCompute = typename Base::ElementCompute;
319
+ using TensorC = typename Base::TensorC;
320
+ using TensorD = typename Base::TensorD;
321
+ using TensorAux = typename Base::TensorAux;
322
+ using VectorBias = typename Base::VectorBias;
323
+ using VectorAlpha = typename Base::VectorAlpha;
324
+ using VectorBeta = typename Base::VectorBeta;
325
+ using TensorSFD = typename Base::TensorSFD;
326
+ using SFD_VectorSize = typename Base::SFD_VectorSize;
327
+ using ActivationFunctor = typename Base::ActivationFunctor;
328
+ using BiasBinaryOp = typename Base::BiasBinaryOp;
329
+
330
+ using EngineC = typename Base::EngineC;
331
+ using LayoutC = typename Base::LayoutC;
332
+ using EngineD = typename Base::EngineD;
333
+ using LayoutD = typename Base::LayoutD;
334
+ using EngineSfD = typename Base::EngineSfD;
335
+ using LayoutSfD = typename Base::LayoutSfD;
336
+ static constexpr bool PerColumnBias = Base::PerColumnBias;
337
+ static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy;
338
+
339
+ GettBlockScalingEpilogueParams() {}
340
+
341
+ GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D)
342
+ : Base(alpha, beta, tensor_C, tensor_D) {}
343
+
344
+ GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD)
345
+ : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {}
346
+
347
+ GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st)
348
+ : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {}
349
+ };
350
+
351
+
352
+
353
+
354
+
355
+ ///////////////////////////////////////////////////////////
356
+ //
357
+ // Generic Gett 3x Implementation
358
+ //
359
+ ///////////////////////////////////////////////////////////
360
+
361
+
362
+ /////////////////////////////////////////////////////////////////////////////////////////////////
363
+ template <int kVectorSize, class EpilogueParams, class TensorD, class TensorSFD, class ElementCompute, int kBlockM, int kBlockN>
364
+ void compute_1d_scaling_factor_and_quantized_output(
365
+ EpilogueParams const& epilogue_params,
366
+ TensorD &tensor_D,
367
+ TensorSFD &tensor_SfD,
368
+ int64_t m,
369
+ int64_t n,
370
+ int64_t l,
371
+ ElementCompute (&acc)[kBlockM][kBlockN])
372
+ {
373
+ using ElementD = typename ElementTraits<typename EpilogueParams::EngineD::value_type>::type;
374
+ using ElementSfD = typename ElementTraits<typename EpilogueParams::EngineSfD::value_type>::type;
375
+
376
+ int const M = cute::size<0>(tensor_D.layout());
377
+ int const N = cute::size<1>(tensor_D.layout());
378
+ int const L = cute::size<2>(tensor_D.layout());
379
+
380
+ auto mul = cutlass::multiplies<ElementCompute>{};
381
+ auto div = divides<ElementCompute>{};
382
+ // Get FP max
383
+ ElementCompute fp_max = ElementCompute(std::numeric_limits<ElementD>::max());
384
+ float scale_down_factor = div(1.0f, fp_max);
385
+ // Get st' = st / FP max
386
+ ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor);
387
+
388
+ absolute_value_op<ElementCompute> abs_op;
389
+ maximum_with_nan_propogation<ElementCompute> max_op;
390
+
391
+ if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) {
392
+ // MN major output
393
+ int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize);
394
+ // Col major output
395
+ for (int n_b = 0; n_b < kBlockN; ++n_b) {
396
+ for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
397
+ int64_t col = n + n_b;
398
+
399
+ /// Step1: get max across a vector
400
+ ElementCompute accum_max = ElementCompute(0);
401
+ for (int v = 0; v < kVectorSize; v++) {
402
+ int accum_row = v_b * kVectorSize + v;
403
+ int64_t output_row = accum_row + m;
404
+ if (output_row < M && col < N) {
405
+ accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b]));
406
+ }
407
+ }
408
+
409
+ /// Step2: Compute Scale
410
+ ElementCompute pvscale = mul(accum_max, st_scaled_down);
411
+ ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
412
+ // Store the Scaling Factors
413
+ int64_t sf_row = m + kVectorSize * v_b;
414
+ if (sf_row < M && col < N) {
415
+ tensor_SfD(sf_row, col, l) = qpvscale;
416
+ }
417
+
418
+ /// Step3: Compute quantized output values
419
+ ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
420
+ // Get float reciprocal
421
+ ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
422
+ ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
423
+ // Map INF to fp32::max
424
+ acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
425
+ // Store the intermediate_accum
426
+ for (int v = 0; v < kVectorSize; v++) {
427
+ int accum_row = v_b * kVectorSize + v;
428
+ int64_t output_row = accum_row + m;
429
+ if (output_row < M && col < N) {
430
+ acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale);
431
+ }
432
+ }
433
+ }
434
+ }
435
+ }
436
+ else {
437
+ int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize);
438
+ // row major output
439
+ for (int m_b = 0; m_b < kBlockM; ++m_b) {
440
+ for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) {
441
+ int64_t row = m + m_b;
442
+
443
+ /// Step1: get max across a vector
444
+ ElementCompute accum_max = ElementCompute(0);
445
+ for (int v = 0; v < kVectorSize; v++) {
446
+ int accum_col = v_b * kVectorSize + v;
447
+ int64_t output_col = accum_col + n;
448
+ if (row < M && output_col < N) {
449
+ accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col]));
450
+ }
451
+ }
452
+
453
+ /// Step2: Compute Scale
454
+ ElementCompute pvscale = mul(accum_max, st_scaled_down);
455
+ ElementSfD qpvscale = static_cast<ElementSfD>(pvscale);
456
+ // Store the Scaling Factors
457
+ int64_t sf_col = n + kVectorSize * v_b;
458
+
459
+ if (row < M && sf_col < N) {
460
+ tensor_SfD(row, sf_col, l) = qpvscale;
461
+ }
462
+
463
+ /// Step3: Compute quantized output values
464
+ ElementCompute qpvscale_up = NumericConverter<ElementCompute, ElementSfD>{}(qpvscale);
465
+ // Get float reciprocal
466
+ ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up);
467
+ ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp);
468
+ // Map INF to fp32::max
469
+ acc_scale = cutlass::minimum_with_nan_propagation<ElementCompute>{}(acc_scale, cutlass::platform::numeric_limits<ElementCompute>::max());
470
+ // Store the intermediate_accum
471
+ for (int v = 0; v < kVectorSize; v++) {
472
+ int accum_col = v_b * kVectorSize + v;
473
+ int64_t output_col = accum_col + n;
474
+ if (row < M && output_col < N) {
475
+ acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale);
476
+ }
477
+ }
478
+ }
479
+ }
480
+ }
481
+ }
482
+
483
+
484
+ /////////////////////////////////////////////////////////////////////////////////////////////////
485
+
486
+ /// GETT - General Tensor-Tensor contraction reference kernel
487
+ template <
488
+ class MainloopParams,
489
+ class EpilogueParams
490
+ >
491
+ void Gett(
492
+ MainloopParams const& mainloop_params,
493
+ EpilogueParams const& epilogue_params)
494
+ {
495
+
496
+ static int constexpr kBlockM = 64;
497
+ static int constexpr kBlockN = 64;
498
+
499
+ #if defined(_OPENMP)
500
+ #pragma omp parallel for collapse(3)
501
+ #endif
502
+ for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
503
+ for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
504
+ for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
505
+ typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN];
506
+ gett_mainloop(mainloop_params, m, n, l, acc);
507
+ gett_epilogue(epilogue_params, m, n, l, acc);
508
+ }
509
+ }
510
+ }
511
+ }
512
+
513
+ /////////////////////////////////////////////////////////////////////////////////////////////////
514
+
515
+ /// GETT - Mainloop
516
+ template <class MainloopParams, class ElementAccumulator, int kBlockM, int kBlockN>
517
+ void gett_mainloop(
518
+ MainloopParams const& mainloop_params,
519
+ int64_t m,
520
+ int64_t n,
521
+ int64_t l,
522
+ ElementAccumulator (&acc)[kBlockM][kBlockN])
523
+ {
524
+
525
+ static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B");
526
+ static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B");
527
+
528
+ using cute::raw_pointer_cast;
529
+
530
+ using ElementA = typename ElementTraits<typename MainloopParams::EngineA::value_type>::type;
531
+ using ElementB = typename ElementTraits<typename MainloopParams::EngineB::value_type>::type;
532
+
533
+
534
+ using ElementSFA = typename ElementTraits<typename MainloopParams::EngineSfA::value_type>::type;
535
+ using ElementSFB = typename ElementTraits<typename MainloopParams::EngineSfB::value_type>::type;
536
+
537
+
538
+ using RingOp = multiply_add<ElementAccumulator, ElementAccumulator, ElementAccumulator>;
539
+ RingOp fma_op;
540
+
541
+ // Zero out accumulators
542
+ for (int m_b = 0; m_b < kBlockM; ++m_b) {
543
+ for (int n_b = 0; n_b < kBlockN; ++n_b) {
544
+ acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
545
+ }
546
+ }
547
+
548
+ // Compute on this k-block
549
+ for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
550
+ // Load A
551
+ ElementAccumulator a_frag[kBlockM];
552
+ for (int m_b = 0; m_b < kBlockM; ++m_b) {
553
+ if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
554
+ // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
555
+ a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
556
+
557
+
558
+ if constexpr (not cute::is_same_v<ElementSFA, ElementA>){
559
+ // Load SFA
560
+ auto sfa = static_cast<ElementAccumulator>(mainloop_params.SfA(m + m_b, k, l));
561
+ a_frag[m_b] *= sfa;
562
+ }
563
+
564
+
565
+ if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
566
+ a_frag[m_b] = conj(a_frag[m_b]);
567
+ }
568
+ } else {
569
+ a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
570
+ }
571
+ }
572
+
573
+ // Load B
574
+ ElementAccumulator b_frag[kBlockN];
575
+ for (int n_b = 0; n_b < kBlockN; ++n_b) {
576
+ if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
577
+ // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
578
+ b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
579
+
580
+
581
+ if constexpr (not cute::is_same_v<ElementSFB, ElementB>){
582
+ // Load SFB
583
+ auto sfb = static_cast<ElementAccumulator>(mainloop_params.SfB(n + n_b, k, l));
584
+ b_frag[n_b] *= sfb;
585
+ }
586
+
587
+
588
+ if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
589
+ b_frag[n_b] = conj(b_frag[n_b]);
590
+ }
591
+ } else {
592
+ b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity
593
+ }
594
+ }
595
+
596
+ // do compute
597
+ for (int m_b = 0; m_b < kBlockM; ++m_b) {
598
+ for (int n_b = 0; n_b < kBlockN; ++n_b) {
599
+ acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]);
600
+ }
601
+ }
602
+
603
+ }
604
+ }
605
+
606
+ /////////////////////////////////////////////////////////////////////////////////////////////////
607
+
608
+ /// GETT - Epilogue
609
+ template <class EpilogueParams, class ElementAccumulator, int kBlockM, int kBlockN>
610
+ void gett_epilogue(
611
+ EpilogueParams const& epilogue_params,
612
+ int64_t m,
613
+ int64_t n,
614
+ int64_t l,
615
+ ElementAccumulator (&acc)[kBlockM][kBlockN])
616
+ {
617
+ static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B");
618
+ static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B");
619
+
620
+ using cute::raw_pointer_cast;
621
+
622
+ using ElementCompute = typename EpilogueParams::ElementCompute;
623
+ using ElementC = typename EpilogueParams::TensorC::value_type;
624
+ using ElementD = typename EpilogueParams::TensorD::value_type;
625
+ using ElementSfD = typename EpilogueParams::TensorSFD::value_type;
626
+ using ElementAux = typename EpilogueParams::TensorAux::value_type;
627
+ using ElementBias = typename EpilogueParams::VectorBias::value_type;
628
+ using ElementScalar = typename EpilogueParams::ElementScalar;
629
+ using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor;
630
+ using ActivationFunctor = typename EpilogueParams::ActivationFunctor;
631
+ using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp;
632
+
633
+ constexpr bool PerColBias = EpilogueParams::PerColumnBias;
634
+ constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy;
635
+
636
+ constexpr bool IsScalingAndAmaxOutputNeeded =
637
+ cute::is_same_v<ElementD, cutlass::float_e4m3_t> or
638
+ cute::is_same_v<ElementD, cutlass::float_e5m2_t>;
639
+
640
+ constexpr bool IsScalingAndAmaxAuxOutputNeeded =
641
+ cute::is_same_v<ElementAux, cutlass::float_e4m3_t> or
642
+ cute::is_same_v<ElementAux, cutlass::float_e5m2_t>;
643
+
644
+ constexpr bool IsReLUAuxNeeded =
645
+ (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ReLu<ElementCompute>> or
646
+ cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) and
647
+ cute::is_same_v<ElementAux, cutlass::uint1b_t>;
648
+ constexpr bool UseReLU =
649
+ cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>; // Treat Clamp as ReLU
650
+
651
+ constexpr bool IsBackpropFusion =
652
+ cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dGELU<ElementCompute>> or
653
+ cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::dReLU<ElementCompute>>;
654
+
655
+ // Input related converter
656
+ NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
657
+ NumericConverter<ElementCompute, ElementC> source_converter;
658
+ NumericConverter<ElementCompute, ElementBias> bias_converter;
659
+ [[maybe_unused]] NumericConverter<ElementCompute, ElementAux> aux_source_converter;
660
+
661
+ // Scale related converter
662
+ NumericConverter<ElementCompute, ElementScalar> scale_converter;
663
+ NumericConverter<ElementCompute, ElementScalingFactor> scaling_factor_converter;
664
+
665
+ // Abs max converter
666
+ [[maybe_unused]] NumericConverter<ElementAccumulator, ElementCompute> abs_max_output_converter;
667
+
668
+ // Output related converter
669
+ NumericConverter<ElementD, ElementCompute> destination_converter;
670
+ [[maybe_unused]] NumericConverter<ElementAux, ElementCompute> aux_destination_converter;
671
+ NumericConverter<ElementBias, ElementCompute> dBias_converter;
672
+
673
+ // Epilogue operations
674
+ multiply_add<ElementCompute, ElementCompute, ElementCompute> epilogue_fma;
675
+ multiplies<ElementCompute> mul;
676
+ plus<ElementCompute> add;
677
+
678
+ // Activation operation
679
+ ActivationFunctor activation;
680
+
681
+ // Bias binary operation
682
+ BiasBinaryOp bias_op;
683
+
684
+ // Do conversion
685
+ ElementCompute converted_alpha = scale_converter(epilogue_params.alpha);
686
+ ElementCompute converted_beta = scale_converter(epilogue_params.beta);
687
+ ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a);
688
+ ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b);
689
+ ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c);
690
+ ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d);
691
+ ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux);
692
+
693
+ // Init local var
694
+ [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0);
695
+ [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0);
696
+
697
+ converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
698
+ converted_beta = mul(converted_beta, converted_scale_c);
699
+
700
+ ElementCompute inter_accum[kBlockM][kBlockN];
701
+
702
+ for (int m_b = 0; m_b < kBlockM; ++m_b) {
703
+ ElementCompute local_dBias = ElementCompute(0);
704
+
705
+ for (int n_b = 0; n_b < kBlockN; ++n_b) {
706
+ if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
707
+ // Convert every type to ElementCompute first, do compute, convert to output type, write it out
708
+ ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
709
+ // vector alpha
710
+ if (raw_pointer_cast(epilogue_params.Valpha.data())) {
711
+ converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l));
712
+ converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
713
+ }
714
+ ElementCompute output = mul(converted_alpha, converted_acc);
715
+
716
+ if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) {
717
+ ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b));
718
+ output = bias_op(output, converted_bias);
719
+ }
720
+
721
+ if (raw_pointer_cast(epilogue_params.C.data())) {
722
+ ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
723
+ // vector beta
724
+ if (epilogue_params.Vbeta.data()) {
725
+ converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l));
726
+ converted_beta = mul(converted_beta, converted_scale_c);
727
+ }
728
+ output = epilogue_fma(converted_beta, converted_src, output);
729
+ }
730
+
731
+ if constexpr (IsBackpropFusion) {
732
+ ElementAux aux_input = ElementAux(0);
733
+ if (raw_pointer_cast(epilogue_params.Aux.data())) {
734
+ aux_input = epilogue_params.Aux(m + m_b, n + n_b, l);
735
+ }
736
+
737
+ output = activation(output, aux_source_converter(aux_input));
738
+ local_dBias = add(local_dBias, output);
739
+ }
740
+ else {
741
+ if (raw_pointer_cast(epilogue_params.Aux.data())) {
742
+ auto aux_output = output;
743
+ if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
744
+ maximum_absolute_value_reduction<ElementCompute, true> amax_op;
745
+ local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output);
746
+ aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0));
747
+ }
748
+
749
+ if constexpr (IsReLUAuxNeeded) {
750
+ epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0);
751
+ } else {
752
+ epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output);
753
+ }
754
+ }
755
+
756
+ if constexpr (UseReLU) {
757
+ cutlass::epilogue::thread::ReLU<ElementCompute> relu;
758
+ output = relu(output);
759
+ }
760
+ else {
761
+ output = activation(output);
762
+ }
763
+ }
764
+
765
+ if constexpr (IsScalingAndAmaxOutputNeeded) {
766
+ maximum_absolute_value_reduction<ElementCompute, true> amax_op;
767
+ local_abs_max_output = amax_op(local_abs_max_output, output);
768
+ output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
769
+ }
770
+
771
+ inter_accum[m_b][n_b] = ElementCompute(output);
772
+ }
773
+ } // n_b
774
+
775
+ if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) {
776
+ if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) {
777
+ ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b));
778
+ local_dBias = add(local_dBias, converted_dBias);
779
+ epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias);
780
+ }
781
+ }
782
+ } // m_b
783
+
784
+ if constexpr (
785
+ SfGenStrategy == SfStrategy::SfDGen
786
+ ) {
787
+ // 1d scale factor generation
788
+ constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{};
789
+ if (epilogue_params.SfD.data() != nullptr) {
790
+ compute_1d_scaling_factor_and_quantized_output<kVectorSize>(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum);
791
+ }
792
+ }
793
+
794
+ for (int m_b = 0; m_b < kBlockM; ++m_b) {
795
+ for (int n_b = 0; n_b < kBlockN; ++n_b) {
796
+ if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
797
+ epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
798
+ }
799
+ }
800
+ }
801
+
802
+ #if defined(_OPENMP)
803
+ #pragma omp critical(Abs_Max_Data_Update)
804
+ #endif
805
+ {
806
+ if constexpr (IsScalingAndAmaxOutputNeeded) {
807
+ if (epilogue_params.abs_max_D) {
808
+ *epilogue_params.abs_max_D = maximum_with_nan_propogation<ElementAccumulator>{}(
809
+ *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output));
810
+ }
811
+ }
812
+
813
+ if constexpr (IsScalingAndAmaxAuxOutputNeeded) {
814
+ if (epilogue_params.abs_max_Aux) {
815
+ *epilogue_params.abs_max_Aux = maximum_with_nan_propogation<ElementAccumulator>{}(
816
+ *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output));
817
+ }
818
+ }
819
+ }
820
+ }
821
+
822
+ /////////////////////////////////////////////////////////////////////////////////////////////////
823
+
824
+ template <class TensorType>
825
+ auto make_layout_rank3(const TensorType& tensor) {
826
+ // append a batch mode of size 1 if we do not have tensors that are rank 3
827
+ return make_layout(
828
+ make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}),
829
+ make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout()))));
830
+ }
831
+
832
+ /// GEMM - General Matrix-Matrix contraction without conjugation options
833
+ template <
834
+ class MainloopParams,
835
+ class EpilogueParams
836
+ >
837
+ void Gemm3x(
838
+ MainloopParams const& mainloop_params,
839
+ EpilogueParams const& epilogue_params)
840
+ {
841
+ using namespace cute;
842
+
843
+ static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
844
+ static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
845
+ static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
846
+
847
+ if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) {
848
+ cute::Layout layout_A = make_layout_rank3(mainloop_params.A);
849
+ cute::Layout layout_B = make_layout_rank3(mainloop_params.B);
850
+ cute::Layout layout_C = make_layout_rank3(epilogue_params.C);
851
+ cute::Layout layout_D = make_layout_rank3(epilogue_params.D);
852
+ cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux);
853
+ cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias);
854
+ cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha);
855
+ cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta);
856
+
857
+ auto TensorA = make_tensor(mainloop_params.A.data(), layout_A);
858
+ auto TensorB = make_tensor(mainloop_params.B.data(), layout_B);
859
+ auto TensorC = make_tensor(epilogue_params.C.data(), layout_C);
860
+ auto TensorD = make_tensor(epilogue_params.D.data(), layout_D);
861
+ auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux);
862
+ auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias);
863
+ auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha);
864
+ auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta);
865
+
866
+ // Reconstruct mainloop params
867
+ GettMainloopParams<typename MainloopParams::ElementAccumulator,
868
+ decltype(TensorA),
869
+ decltype(TensorB)>
870
+ mainloop_params_converted{TensorA,
871
+ TensorB,
872
+ mainloop_params.transform_A,
873
+ mainloop_params.transform_B};
874
+
875
+ // Reconstruct epilogue params
876
+ GettEpilogueParams<typename EpilogueParams::ElementScalar,
877
+ typename EpilogueParams::ElementScalingFactor,
878
+ typename EpilogueParams::ElementAccumulator,
879
+ typename EpilogueParams::ElementCompute,
880
+ decltype(TensorC),
881
+ decltype(TensorD),
882
+ decltype(VectorBias),
883
+ decltype(TensorAux),
884
+ decltype(VectorAlpha),
885
+ decltype(VectorBeta)
886
+ >
887
+ epilogue_params_converted{epilogue_params.alpha,
888
+ epilogue_params.beta,
889
+ TensorC,
890
+ TensorD,
891
+ VectorBias,
892
+ TensorAux,
893
+ VectorAlpha,
894
+ VectorBeta,
895
+ epilogue_params.abs_amax_D,
896
+ epilogue_params.abs_amax_Aux,
897
+ epilogue_params.scale_a,
898
+ epilogue_params.scale_b,
899
+ epilogue_params.scale_c,
900
+ epilogue_params.scale_d,
901
+ epilogue_params.scale_aux
902
+ };
903
+
904
+ Gett(mainloop_params_converted, epilogue_params_converted);
905
+ }
906
+ else {
907
+ // if we already have a batch mode, just pass it through
908
+ Gett(mainloop_params, epilogue_params);
909
+ }
910
+ }
911
+
912
+ /////////////////////////////////////////////////////////////////////////////////////////////////
913
+
914
+ } // cutlass::reference::host
915
+
916
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for Rank 2k update in host-side code.
33
+
34
+
35
+
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/blas3.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/tensor_view.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+ #include "cutlass/arch/mma.h"
45
+ #include "cutlass/util/host_tensor.h"
46
+ #include "cutlass/util/reference/host/gemm.h"
47
+
48
+ namespace cutlass {
49
+ namespace reference {
50
+ namespace host {
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
55
+ /// objects.
56
+ template <
57
+ typename ElementA,
58
+ typename LayoutA,
59
+ typename ElementB,
60
+ typename LayoutB,
61
+ typename ElementC,
62
+ typename LayoutC,
63
+ FillMode FillModeC,
64
+ typename ScalarType,
65
+ typename ComputeType,
66
+ typename InnerProductOp = multiply_add<ComputeType>,
67
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
68
+ >
69
+ void compute_rank2k(
70
+ gemm::GemmCoord problem_size,
71
+ ScalarType alpha,
72
+ TensorRef<ElementA, LayoutA> tensor_a,
73
+ TensorRef<ElementB, LayoutB> tensor_b,
74
+ ScalarType beta,
75
+ TensorRef<ElementC, LayoutC> tensor_c,
76
+ TensorRef<ElementC, LayoutC> tensor_d,
77
+ ComputeType initial_accum) {
78
+
79
+ static_assert(
80
+ LayoutA::kRank == 2 &&
81
+ LayoutB::kRank == 2 &&
82
+ LayoutC::kRank == 2,
83
+ "Tensors must be of rank 2");
84
+
85
+ static_assert(
86
+ FillModeC == FillMode::kLower ||
87
+ FillModeC == FillMode::kUpper,
88
+ "Fill Mode can either be Lower or Upper.");
89
+
90
+ using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower),
91
+ std::greater_equal<int>,
92
+ std::less_equal<int>>::type;
93
+
94
+ // Note: batch is ignored.
95
+ // Note: M is same as N for Rank 2k update
96
+ int const N = problem_size.n();
97
+ int const K = problem_size.k();
98
+
99
+ // Blocking necessary to speedup reference implementation
100
+ int const Nblock = 16;
101
+
102
+ ConvertOp convert_op;
103
+ InnerProductOp inner_product_op;
104
+ CompareOp compare_op;
105
+
106
+ for (int row_block = 0; row_block < N; row_block += Nblock) {
107
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
108
+
109
+ ComputeType accum[Nblock][Nblock];
110
+
111
+ for (int j = 0; j < Nblock; j++) {
112
+ for (int i = 0; i < Nblock; i++) {
113
+ accum[i][j] = initial_accum;
114
+ }
115
+ }
116
+
117
+ for (int k_block = 0; k_block < K; ++k_block) {
118
+ for (int j = 0; j < Nblock; j++) {
119
+ for (int i = 0; i < Nblock; i++) {
120
+ int row = row_block + i;
121
+ int col = col_block + j;
122
+
123
+ if (row < N && col < N && compare_op(row, col))
124
+ {
125
+
126
+ // A x B^T
127
+ ElementA a = tensor_a.at(MatrixCoord(row, k_block));
128
+ ElementB b_t = tensor_b.at(MatrixCoord(col, k_block));
129
+
130
+ ComputeType compute_a(cast_if_scalar<ComputeType>(a));
131
+ ComputeType compute_b_t(cast_if_scalar<ComputeType>(b_t));
132
+
133
+ accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]);
134
+
135
+ // B x A^T
136
+ ElementB b = tensor_b.at(MatrixCoord(row, k_block));
137
+ ElementA a_t = tensor_a.at(MatrixCoord(col, k_block));
138
+
139
+ ComputeType compute_b(cast_if_scalar<ComputeType>(b));
140
+ ComputeType compute_a_t(cast_if_scalar<ComputeType>(a_t));
141
+
142
+ accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]);
143
+ }
144
+ }
145
+ }
146
+ }
147
+
148
+ for (int j = 0; j < Nblock; j++) {
149
+ for (int i = 0; i < Nblock; i++) {
150
+ int row = row_block + i;
151
+ int col = col_block + j;
152
+
153
+ MatrixCoord coord = MatrixCoord(row, col);
154
+
155
+ if (row < N && col < N &&
156
+ ( (FillModeC == FillMode::kLower && row >= col) ||
157
+ (FillModeC == FillMode::kUpper && row <= col) )
158
+ ) {
159
+ tensor_d.at(coord) = convert_op(
160
+ alpha * ScalarType(accum[i][j]) +
161
+ beta * ScalarType(tensor_c.at(coord)));
162
+ }
163
+ }
164
+ }
165
+ }
166
+ }
167
+ }
168
+
169
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
170
+
171
+ /// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef
172
+ /// objects.
173
+ template <
174
+ typename ElementA,
175
+ typename LayoutA,
176
+ typename ElementB,
177
+ typename LayoutB,
178
+ typename ElementC,
179
+ typename LayoutC,
180
+ FillMode FillModeC,
181
+ typename ScalarType,
182
+ typename ComputeType,
183
+ typename InnerProductOp = multiply_add<ComputeType>,
184
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
185
+ >
186
+ void compute_rank2k(
187
+ gemm::GemmCoord problem_size,
188
+ ScalarType alpha,
189
+ TensorRef<ElementA, LayoutA> tensor_a,
190
+ TensorRef<ElementB, LayoutB> tensor_b,
191
+ ScalarType beta,
192
+ TensorRef<ElementC, LayoutC> tensor_c,
193
+ ComputeType initial_accum) {
194
+ compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
195
+ ScalarType, ComputeType, InnerProductOp, ConvertOp>(
196
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
197
+ initial_accum);
198
+ }
199
+
200
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
201
+
202
+ template <
203
+ typename ElementA,
204
+ typename LayoutA,
205
+ typename ElementB,
206
+ typename LayoutB,
207
+ typename ElementC,
208
+ typename LayoutC,
209
+ FillMode FillModeC,
210
+ typename ScalarType,
211
+ typename ComputeType,
212
+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd
213
+ >
214
+ struct Rank2K;
215
+
216
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
217
+
218
+ /// Partial specialization for multiply-add
219
+ template <typename ElementA, typename LayoutA,
220
+ typename ElementB, typename LayoutB,
221
+ typename ElementC, typename LayoutC, FillMode FillModeC,
222
+ typename ScalarType, typename ComputeType>
223
+ struct Rank2K<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC, ScalarType,
224
+ ComputeType, arch::OpMultiplyAdd> {
225
+
226
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
227
+ TensorRef<ElementA, LayoutA> tensor_a,
228
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
229
+ TensorRef<ElementC, LayoutC> tensor_c,
230
+ ComputeType initial_accum = ComputeType(0)) {
231
+ static_assert(
232
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
233
+ "Tensors must be of rank 2");
234
+
235
+ compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
236
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
237
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
238
+ }
239
+
240
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
241
+ TensorRef<ElementA, LayoutA> tensor_a,
242
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
243
+ TensorRef<ElementC, LayoutC> tensor_c,
244
+ TensorRef<ElementC, LayoutC> tensor_d,
245
+ ComputeType initial_accum = ComputeType(0)) {
246
+ static_assert(
247
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
248
+ "Tensors must be of rank 2");
249
+
250
+ compute_rank2k<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, FillModeC,
251
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
252
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
253
+ }
254
+ };
255
+
256
+
257
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
258
+
259
+ } // namespace host
260
+ } // namespace reference
261
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued Rank 2K update in host-side code.
33
+
34
+
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/blas3.h"
40
+ #include "cutlass/complex.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/tensor_view.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+ #include <cassert>
45
+
46
+ namespace cutlass {
47
+ namespace reference {
48
+ namespace host {
49
+
50
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
53
+ /// objects.
54
+ ///
55
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
56
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
57
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
58
+ /// arguments explicitly.
59
+ template <
60
+ typename ElementA,
61
+ typename LayoutA,
62
+ typename ElementB,
63
+ typename LayoutB,
64
+ typename ElementC,
65
+ typename LayoutC,
66
+ typename ScalarType,
67
+ typename ComputeType,
68
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>,
69
+ typename InnerProductOp = multiply_add<ComputeType>
70
+ >
71
+ void Rank2KComplex(
72
+ gemm::GemmCoord problem_size,
73
+ ScalarType alpha,
74
+ TensorRef<ElementA, LayoutA> tensor_a,
75
+ ComplexTransform transform_a,
76
+ TensorRef<ElementB, LayoutB> tensor_b,
77
+ ComplexTransform transform_b,
78
+ ScalarType beta,
79
+ TensorRef<ElementC, LayoutC> tensor_c,
80
+ TensorRef<ElementC, LayoutC> tensor_d,
81
+ ComputeType initial_accum,
82
+ FillMode fill_mode_c,
83
+ BlasMode blas_mode,
84
+ int batch_count = 1,
85
+ int64_t batch_stride_A = 0,
86
+ int64_t batch_stride_B = 0,
87
+ int64_t batch_stride_C = 0,
88
+ int64_t batch_stride_D = 0) {
89
+
90
+ static_assert(
91
+ LayoutA::kRank == 2 &&
92
+ LayoutB::kRank == 2 &&
93
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
94
+
95
+ // Note: batch is ignored.
96
+ int const M = problem_size.m();
97
+ int const N = problem_size.n();
98
+ int const K = problem_size.k();
99
+
100
+ // Rank2K update operates on A=NxK, B=NxK, and C=NxN
101
+ assert(M==N);
102
+
103
+ // Blocking necessary to speedup reference implementation
104
+ int const Mblock = 16;
105
+ int const Nblock = 16;
106
+
107
+ ConvertOp convert_op;
108
+ InnerProductOp inner_product_op;
109
+
110
+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
111
+
112
+ // Compute matrix product using blocks
113
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
114
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
115
+
116
+ ComputeType accum[Mblock][Nblock];
117
+
118
+ for (int j = 0; j < Nblock; j++) {
119
+ for (int i = 0; i < Mblock; i++) {
120
+ accum[i][j] = initial_accum;
121
+ }
122
+ }
123
+
124
+ for (int k_block = 0; k_block < K; ++k_block) {
125
+ for (int j = 0; j < Nblock; j++) {
126
+ for (int i = 0; i < Mblock; i++) {
127
+ int row = row_block + i;
128
+ int col = col_block + j;
129
+
130
+ if (row < M && col < N &&
131
+ ( (fill_mode_c == FillMode::kLower && row >= col) ||
132
+ (fill_mode_c == FillMode::kUpper && row <= col) )
133
+ ) {
134
+
135
+ // A x B^T (Symmetric) or A x B^H (Hermitian)
136
+ // complex conjugation on operandB (b_t) is function of blas3 computation
137
+ ElementA a = tensor_a.at(MatrixCoord(row, k_block));
138
+ ElementB b_t = (blas_mode == BlasMode::kHermitian) ?
139
+ conj(tensor_b.at(MatrixCoord(col, k_block))) :
140
+ tensor_b.at(MatrixCoord(col, k_block));
141
+
142
+ ComputeType a_ik = ComputeType(a);
143
+ ComputeType b_jk = ComputeType(b_t);
144
+
145
+ // complex conjugation is a function of operand layouts
146
+ if (transform_a == ComplexTransform::kConjugate) {
147
+ a_ik = conj(a_ik);
148
+ }
149
+ // complex conjugation is a function of operand layouts
150
+ if (transform_b == ComplexTransform::kConjugate) {
151
+ b_jk = conj(b_jk);
152
+ }
153
+
154
+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
155
+ }
156
+ }
157
+ }
158
+ }
159
+
160
+ /* HER2K need two epilogues to handle complex alpha value */
161
+ if ( blas_mode == BlasMode::kHermitian ) {
162
+ for (int j = 0; j < Nblock; j++) {
163
+ for (int i = 0; i < Mblock; i++) {
164
+ int row = row_block + i;
165
+ int col = col_block + j;
166
+
167
+ MatrixCoord coord = MatrixCoord(row, col);
168
+
169
+ if (row < M && col < N &&
170
+ ((fill_mode_c == FillMode::kLower && row >= col) ||
171
+ (fill_mode_c == FillMode::kUpper && row <= col))
172
+ ) {
173
+
174
+ ScalarType c = tensor_c.at(coord);
175
+ // The imaginary parts of the diagonal elements of
176
+ // a complex data type are assumed and set to zero
177
+ if (blas_mode == BlasMode::kHermitian) {
178
+ c = (row == col) ? real(c) : c;
179
+ }
180
+
181
+ tensor_d.at(coord) = convert_op(alpha *
182
+ ScalarType(accum[i][j]) +
183
+ beta * c);
184
+ }
185
+ }
186
+ }
187
+
188
+ /* Zeoring out accum for second HERK */
189
+ for (int j = 0; j < Nblock; j++) {
190
+ for (int i = 0; i < Mblock; i++) {
191
+ accum[i][j] = initial_accum;
192
+ }
193
+ }
194
+ }
195
+
196
+ for (int k_block = 0; k_block < K; ++k_block) {
197
+ for (int j = 0; j < Nblock; j++) {
198
+ for (int i = 0; i < Mblock; i++) {
199
+ int row = row_block + i;
200
+ int col = col_block + j;
201
+
202
+ if (row < M && col < N &&
203
+ ( (fill_mode_c == FillMode::kLower && row >= col) ||
204
+ (fill_mode_c == FillMode::kUpper && row <= col) )
205
+ ) {
206
+
207
+ // B x A^T (Symmetric) or B x A^H (Hermitian)
208
+ // complex conjugation on operandB (a_t) is function of blas3 computation
209
+ ElementB b = tensor_b.at(MatrixCoord(row, k_block));
210
+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
211
+ conj(tensor_a.at(MatrixCoord(col, k_block))):
212
+ tensor_a.at(MatrixCoord(col, k_block));
213
+
214
+ ComputeType b_ik = ComputeType(b);
215
+ ComputeType a_jk = ComputeType(a_t);
216
+
217
+ // complex conjugation here is a function of operand layouts
218
+ if (transform_b == ComplexTransform::kConjugate) {
219
+ b_ik = conj(b_ik);
220
+ }
221
+ // complex conjugation here is a function of operand layouts
222
+ if (transform_a == ComplexTransform::kConjugate) {
223
+ a_jk = conj(a_jk);
224
+ }
225
+
226
+ accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]);
227
+ }
228
+ }
229
+ }
230
+ }
231
+
232
+ ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ?
233
+ conj(alpha) : alpha;
234
+ ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ?
235
+ 1 : beta;
236
+
237
+ for (int j = 0; j < Nblock; j++) {
238
+ for (int i = 0; i < Mblock; i++) {
239
+ int row = row_block + i;
240
+ int col = col_block + j;
241
+
242
+ MatrixCoord coord = MatrixCoord(row, col);
243
+
244
+ if (row < M && col < N &&
245
+ ((fill_mode_c == FillMode::kLower && row >= col) ||
246
+ (fill_mode_c == FillMode::kUpper && row <= col))
247
+ ) {
248
+
249
+ ScalarType d = (blas_mode == BlasMode::kHermitian) ?
250
+ tensor_d.at(coord) : tensor_c.at(coord);
251
+
252
+ ScalarType tmp_d = convert_op(
253
+ alpha_hermitian * ScalarType(accum[i][j]) +
254
+ beta_hermitian * d);
255
+
256
+ if (blas_mode == BlasMode::kHermitian && row == col ) {
257
+ tensor_d.at(coord) = real(tmp_d);
258
+ } else {
259
+ tensor_d.at(coord) = tmp_d;
260
+ }
261
+ }
262
+ }
263
+ }
264
+
265
+ } // for (col_block)
266
+ } // for (row_block)
267
+
268
+ tensor_a.add_pointer_offset(batch_stride_A);
269
+ tensor_b.add_pointer_offset(batch_stride_B);
270
+ tensor_c.add_pointer_offset(batch_stride_C);
271
+ tensor_d.add_pointer_offset(batch_stride_D);
272
+
273
+ } // for (batch_idx)
274
+ }
275
+
276
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
277
+
278
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
279
+ /// objects.
280
+ ///
281
+ /// This assumes the accumulator type is the same type as the scalars.
282
+ template <
283
+ typename ElementA,
284
+ typename LayoutA,
285
+ typename ElementB,
286
+ typename LayoutB,
287
+ typename ElementC,
288
+ typename LayoutC,
289
+ typename ScalarType
290
+ >
291
+ void Rank2KComplex(
292
+ gemm::GemmCoord problem_size,
293
+ ScalarType alpha,
294
+ TensorRef<ElementA, LayoutA> tensor_a,
295
+ ComplexTransform transform_a,
296
+ TensorRef<ElementB, LayoutB> tensor_b,
297
+ ComplexTransform transform_b,
298
+ ScalarType beta,
299
+ TensorRef<ElementC, LayoutC> tensor_c,
300
+ TensorRef<ElementC, LayoutC> tensor_d,
301
+ FillMode fill_mode_c,
302
+ BlasMode blas_mode) {
303
+
304
+ Rank2KComplex(
305
+ problem_size, alpha,
306
+ tensor_a, transform_a,
307
+ tensor_b, transform_b,
308
+ beta, tensor_c, tensor_d,
309
+ ScalarType(0),
310
+ fill_mode_c,
311
+ blas_mode);
312
+ }
313
+
314
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
315
+
316
+ } // namespace host
317
+ } // namespace reference
318
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued Rank 2K update in host-side code.
33
+
34
+
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/blas3.h"
40
+ #include "cutlass/complex.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/tensor_view.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+ #include <cassert>
45
+
46
+ namespace cutlass {
47
+ namespace reference {
48
+ namespace host {
49
+
50
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
53
+ /// objects.
54
+ ///
55
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
56
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
57
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
58
+ /// arguments explicitly.
59
+ template <
60
+ typename ElementA,
61
+ typename LayoutA,
62
+ typename ElementC,
63
+ typename LayoutC,
64
+ typename ScalarType,
65
+ typename ComputeType,
66
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>,
67
+ typename InnerProductOp = multiply_add<ComputeType>
68
+ >
69
+ void Rank2KComplex(
70
+ gemm::GemmCoord problem_size,
71
+ ScalarType alpha,
72
+ TensorRef<ElementA, LayoutA> tensor_a,
73
+ ComplexTransform transform_a,
74
+ ScalarType beta,
75
+ TensorRef<ElementC, LayoutC> tensor_c,
76
+ TensorRef<ElementC, LayoutC> tensor_d,
77
+ ComputeType initial_accum,
78
+ FillMode fill_mode_c,
79
+ BlasMode blas_mode,
80
+ int batch_count = 1,
81
+ int64_t batch_stride_A = 0,
82
+ int64_t batch_stride_C = 0,
83
+ int64_t batch_stride_D = 0) {
84
+
85
+ static_assert(
86
+ LayoutA::kRank == 2 &&
87
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
88
+
89
+ // Note: batch is ignored.
90
+ int const M = problem_size.m();
91
+ int const N = problem_size.n();
92
+ int const K = problem_size.k();
93
+
94
+ // Rank2K update operates on A=NxK, B=NxK, and C=NxN
95
+ assert(M==N);
96
+
97
+ // Blocking necessary to speedup reference implementation
98
+ int const Mblock = 16;
99
+ int const Nblock = 16;
100
+
101
+ ConvertOp convert_op;
102
+ InnerProductOp inner_product_op;
103
+
104
+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
105
+
106
+ // Compute matrix product using blocks
107
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
108
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
109
+
110
+ ComputeType accum[Mblock][Nblock];
111
+
112
+ for (int j = 0; j < Nblock; j++) {
113
+ for (int i = 0; i < Mblock; i++) {
114
+ accum[i][j] = initial_accum;
115
+ }
116
+ }
117
+
118
+ for (int k_block = 0; k_block < K; ++k_block) {
119
+ for (int j = 0; j < Nblock; j++) {
120
+ for (int i = 0; i < Mblock; i++) {
121
+ int row = row_block + i;
122
+ int col = col_block + j;
123
+
124
+ if (row < M && col < N &&
125
+ ( (fill_mode_c == FillMode::kLower && row >= col) ||
126
+ (fill_mode_c == FillMode::kUpper && row <= col) )
127
+ ) {
128
+
129
+ // A x A^T (Symmetric) or A x A^H (Hermitian)
130
+ // complex conjugation on operandB (a_t) (function of blas3 computation)
131
+ ElementA a = tensor_a.at(MatrixCoord(row, k_block));
132
+ ElementA a_t = (blas_mode == BlasMode::kHermitian) ?
133
+ conj(tensor_a.at(MatrixCoord(col, k_block))) :
134
+ tensor_a.at(MatrixCoord(col, k_block));
135
+
136
+ ComputeType a_ik = ComputeType(a);
137
+ ComputeType b_jk = ComputeType(a_t);
138
+
139
+ // complex conjugation (function of input layouts)
140
+ if (transform_a == ComplexTransform::kConjugate) {
141
+ a_ik = conj(a_ik);
142
+ }
143
+ // complex conjugation (function of input layouts)
144
+ if (transform_a == ComplexTransform::kConjugate) {
145
+ b_jk = conj(b_jk);
146
+ }
147
+
148
+ accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]);
149
+
150
+ }
151
+ }
152
+ }
153
+ }
154
+
155
+ for (int j = 0; j < Nblock; j++) {
156
+ for (int i = 0; i < Mblock; i++) {
157
+ int row = row_block + i;
158
+ int col = col_block + j;
159
+
160
+ MatrixCoord coord = MatrixCoord(row, col);
161
+
162
+ if (row < M && col < N &&
163
+ ((fill_mode_c == FillMode::kLower && row >= col) ||
164
+ (fill_mode_c == FillMode::kUpper && row <= col))
165
+ ) {
166
+
167
+ ScalarType c = tensor_c.at(coord);
168
+ // The imaginary parts of the diagonal elements of
169
+ // a complex data type are assumed and set to zero
170
+ if (blas_mode == BlasMode::kHermitian) {
171
+ c = (row == col) ? real(c) : c;
172
+ }
173
+
174
+ ScalarType tmp_d = convert_op(
175
+ alpha * ScalarType(accum[i][j]) +
176
+ beta * c);
177
+
178
+ if (blas_mode == BlasMode::kHermitian && row == col ) {
179
+ tensor_d.at(coord) = real(tmp_d);
180
+ } else {
181
+ tensor_d.at(coord) = tmp_d;
182
+ }
183
+ }
184
+ }
185
+ }
186
+
187
+ } // for (col_block)
188
+ } // for (row_block)
189
+
190
+ tensor_a.add_pointer_offset(batch_stride_A);
191
+ tensor_c.add_pointer_offset(batch_stride_C);
192
+ tensor_d.add_pointer_offset(batch_stride_D);
193
+
194
+ } // for (batch_idx)
195
+ }
196
+
197
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
198
+
199
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
200
+ /// objects.
201
+ ///
202
+ /// This assumes the accumulator type is the same type as the scalars.
203
+ template <
204
+ typename ElementA,
205
+ typename LayoutA,
206
+ typename ElementC,
207
+ typename LayoutC,
208
+ typename ScalarType
209
+ >
210
+ void RankKComplex(
211
+ gemm::GemmCoord problem_size,
212
+ ScalarType alpha,
213
+ TensorRef<ElementA, LayoutA> tensor_a,
214
+ ComplexTransform transform_a,
215
+ ScalarType beta,
216
+ TensorRef<ElementC, LayoutC> tensor_c,
217
+ TensorRef<ElementC, LayoutC> tensor_d,
218
+ FillMode fill_mode_c,
219
+ BlasMode blas_mode) {
220
+
221
+ Rank2KComplex(
222
+ problem_size, alpha,
223
+ tensor_a, transform_a,
224
+ beta, tensor_c, tensor_d,
225
+ ScalarType(0),
226
+ fill_mode_c,
227
+ blas_mode);
228
+ }
229
+
230
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
231
+
232
+ } // namespace host
233
+ } // namespace reference
234
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for SYMM update in host-side code.
33
+
34
+
35
+
36
+ */
37
+
38
+ #pragma once
39
+
40
+ #include "cutlass/blas3.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+
43
+ #include "cutlass/tensor_view.h"
44
+ #include "cutlass/gemm/gemm.h"
45
+ #include "cutlass/arch/mma.h"
46
+ #include "cutlass/util/host_tensor.h"
47
+ #include "cutlass/util/reference/host/gemm.h"
48
+
49
+ namespace cutlass {
50
+ namespace reference {
51
+ namespace host {
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
56
+ /// objects.
57
+ template <
58
+ typename ElementA,
59
+ typename LayoutA,
60
+ SideMode SideModeA,
61
+ FillMode FillModeA,
62
+ typename ElementB,
63
+ typename LayoutB,
64
+ typename ElementC,
65
+ typename LayoutC,
66
+ typename ScalarType,
67
+ typename ComputeType,
68
+ typename InnerProductOp = multiply_add<ComputeType>,
69
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
70
+ >
71
+ void compute_symm(
72
+ gemm::GemmCoord problem_size,
73
+ ScalarType alpha,
74
+ TensorRef<ElementA, LayoutA> tensor_a,
75
+ TensorRef<ElementB, LayoutB> tensor_b,
76
+ ScalarType beta,
77
+ TensorRef<ElementC, LayoutC> tensor_c,
78
+ TensorRef<ElementC, LayoutC> tensor_d,
79
+ ComputeType initial_accum) {
80
+
81
+ static_assert(
82
+ LayoutA::kRank == 2 &&
83
+ LayoutB::kRank == 2 &&
84
+ LayoutC::kRank == 2,
85
+ "Tensors must be of rank 2");
86
+
87
+ static_assert(SideModeA != SideMode::kInvalid
88
+ , "Side Mode can either be Left or Right.");
89
+
90
+ static_assert(
91
+ FillModeA == FillMode::kLower ||
92
+ FillModeA == FillMode::kUpper,
93
+ "Fill Mode can either be Lower or Upper.");
94
+
95
+ using CompareOp_w_diag = typename TrMatrixCompareOp<FillModeA, DiagType::kNonUnit>::Type;
96
+ using CompareOp_wo_diag = typename TrMatrixCompareOp<FillModeA, DiagType::kZero>::Type;
97
+
98
+ // Note: batch is ignored.
99
+ int const M = problem_size.m();
100
+ int const N = problem_size.n();
101
+ // Assuming correct k-dimension value is passed
102
+ int const K = problem_size.k();
103
+
104
+ // Blocking necessary to speedup reference implementation
105
+ int const Mblock = 16;
106
+ int const Nblock = 16;
107
+
108
+ ConvertOp convert_op;
109
+ InnerProductOp inner_product_op;
110
+ CompareOp_w_diag compare_op_1;
111
+ CompareOp_wo_diag compare_op_2;
112
+
113
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
114
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
115
+
116
+ ComputeType accum[Mblock][Nblock];
117
+
118
+ for (int j = 0; j < Nblock; j++) {
119
+ for (int i = 0; i < Mblock; i++) {
120
+ accum[i][j] = initial_accum;
121
+ }
122
+ }
123
+
124
+ for (int k_block = 0; k_block < K; ++k_block) {
125
+ for (int j = 0; j < Nblock; j++) {
126
+ for (int i = 0; i < Mblock; i++) {
127
+ int row = row_block + i;
128
+ int col = col_block + j;
129
+
130
+ if (row < M && col < N) {
131
+ ElementA a_1 = ElementA();
132
+ ElementB b_1 = ElementB();
133
+ ElementA a_2 = ElementA();
134
+ ElementB b_2 = ElementB();
135
+
136
+ // A x B or B x A (with diagonal)
137
+ if (SideModeA == SideMode::kLeft) {
138
+ a_1 = (compare_op_1(row, k_block)) ?
139
+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA();
140
+ b_1 = tensor_b.at(MatrixCoord(k_block, col));
141
+ } else if (SideModeA == SideMode::kRight) {
142
+ a_1 = tensor_b.at(MatrixCoord(row, k_block));
143
+ b_1 = (compare_op_1(k_block, col)) ?
144
+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA();
145
+ }
146
+
147
+ ComputeType compute_a_1(cast_if_scalar<ComputeType>(a_1));
148
+ ComputeType compute_b_1(cast_if_scalar<ComputeType>(b_1));
149
+
150
+ accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]);
151
+
152
+ // A^T x B or B x A^T (without diagonal)
153
+ if (SideModeA == SideMode::kLeft) {
154
+ a_2 = (compare_op_2(k_block, row)) ?
155
+ (tensor_a.at(MatrixCoord(k_block, row))) : ElementA();
156
+ b_2 = tensor_b.at(MatrixCoord(k_block, col));
157
+ } else if (SideModeA == SideMode::kRight) {
158
+ a_2 = tensor_b.at(MatrixCoord(row, k_block));
159
+ b_2 = (compare_op_2(col, k_block)) ?
160
+ tensor_a.at(MatrixCoord(col, k_block)) : ElementA();
161
+ }
162
+
163
+ ComputeType compute_a_2(cast_if_scalar<ComputeType>(a_2));
164
+ ComputeType compute_b_2(cast_if_scalar<ComputeType>(b_2));
165
+
166
+ accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]);
167
+ }
168
+ }
169
+ }
170
+ }
171
+
172
+ for (int j = 0; j < Nblock; j++) {
173
+ for (int i = 0; i < Mblock; i++) {
174
+ int row = row_block + i;
175
+ int col = col_block + j;
176
+
177
+ MatrixCoord coord = MatrixCoord(row, col);
178
+
179
+ if (row < M && col < N) {
180
+ tensor_d.at(coord) = convert_op(
181
+ alpha * ScalarType(accum[i][j]) +
182
+ beta * ScalarType(tensor_c.at(coord)));
183
+ }
184
+ }
185
+ }
186
+ }
187
+ }
188
+ }
189
+
190
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
191
+
192
+ /// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef
193
+ /// objects.
194
+ template <
195
+ typename ElementA,
196
+ typename LayoutA,
197
+ SideMode SideModeA,
198
+ FillMode FillModeA,
199
+ typename ElementB,
200
+ typename LayoutB,
201
+ typename ElementC,
202
+ typename LayoutC,
203
+ typename ScalarType,
204
+ typename ComputeType,
205
+ typename InnerProductOp = multiply_add<ComputeType>,
206
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
207
+ >
208
+ void compute_symm(
209
+ gemm::GemmCoord problem_size,
210
+ ScalarType alpha,
211
+ TensorRef<ElementA, LayoutA> tensor_a,
212
+ TensorRef<ElementB, LayoutB> tensor_b,
213
+ ScalarType beta,
214
+ TensorRef<ElementC, LayoutC> tensor_c,
215
+ ComputeType initial_accum) {
216
+ compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
217
+ ScalarType, ComputeType, InnerProductOp, ConvertOp>(
218
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
219
+ initial_accum);
220
+ }
221
+
222
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
223
+
224
+ template <
225
+ typename ElementA,
226
+ typename LayoutA,
227
+ SideMode SideModeA,
228
+ FillMode FillModeA,
229
+ typename ElementB,
230
+ typename LayoutB,
231
+ typename ElementC,
232
+ typename LayoutC,
233
+ typename ScalarType,
234
+ typename ComputeType,
235
+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd
236
+ >
237
+ struct Symm;
238
+
239
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
240
+
241
+ /// Partial specialization for multiply-add
242
+ template <typename ElementA, typename LayoutA,
243
+ SideMode SideModeA, FillMode FillModeA,
244
+ typename ElementB, typename LayoutB,
245
+ typename ElementC, typename LayoutC,
246
+ typename ScalarType, typename ComputeType>
247
+ struct Symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
248
+ ComputeType, arch::OpMultiplyAdd> {
249
+
250
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
251
+ TensorRef<ElementA, LayoutA> tensor_a,
252
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
253
+ TensorRef<ElementC, LayoutC> tensor_c,
254
+ ComputeType initial_accum = ComputeType(0)) {
255
+ static_assert(
256
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
257
+ "Tensors must be of rank 2");
258
+
259
+ compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
260
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
261
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
262
+ }
263
+
264
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
265
+ TensorRef<ElementA, LayoutA> tensor_a,
266
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
267
+ TensorRef<ElementC, LayoutC> tensor_c,
268
+ TensorRef<ElementC, LayoutC> tensor_d,
269
+ ComputeType initial_accum = ComputeType(0)) {
270
+ static_assert(
271
+ LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
272
+ "Tensors must be of rank 2");
273
+
274
+ compute_symm<ElementA, LayoutA, SideModeA, FillModeA, ElementB, LayoutB, ElementC, LayoutC,
275
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
276
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
277
+ }
278
+ };
279
+
280
+
281
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
282
+
283
+ } // namespace host
284
+ } // namespace reference
285
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued SYMM update in host-side code.
33
+
34
+
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/blas3.h"
40
+ #include "cutlass/complex.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/tensor_view.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+ #include <cassert>
45
+
46
+ namespace cutlass {
47
+ namespace reference {
48
+ namespace host {
49
+
50
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef
53
+ /// objects.
54
+ ///
55
+ /// Explicitly naming types needed by this template can be cumbersome, particularly for the
56
+ /// accumulator type, so a function argument 'initial_accum' is exposed. Passing
57
+ /// AccumulatorType(0) as the last function argument can be easier than naming all template
58
+ /// arguments explicitly.
59
+ template <
60
+ typename ElementA,
61
+ typename LayoutA,
62
+ SideMode SideModeA,
63
+ FillMode FillModeA,
64
+ typename ElementB,
65
+ typename LayoutB,
66
+ typename ElementC,
67
+ typename LayoutC,
68
+ typename ScalarType,
69
+ typename ComputeType,
70
+ BlasMode BlasMode_ = BlasMode::kSymmetric,
71
+ typename InnerProductOp = multiply_add<ComputeType>,
72
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
73
+ >
74
+ void compute_symm_complex(
75
+ gemm::GemmCoord problem_size,
76
+ ScalarType alpha,
77
+ TensorRef<ElementA, LayoutA> tensor_a,
78
+ TensorRef<ElementB, LayoutB> tensor_b,
79
+ ScalarType beta,
80
+ TensorRef<ElementC, LayoutC> tensor_c,
81
+ TensorRef<ElementC, LayoutC> tensor_d,
82
+ ComputeType initial_accum,
83
+ int batch_count = 1,
84
+ int64_t batch_stride_A = 0,
85
+ int64_t batch_stride_B = 0,
86
+ int64_t batch_stride_C = 0,
87
+ int64_t batch_stride_D = 0) {
88
+
89
+ static SideMode const kSideModeA = SideModeA;
90
+ static FillMode const kFillModeA = FillModeA;
91
+ static BlasMode const kBlasMode = BlasMode_;
92
+
93
+ static_assert(
94
+ LayoutA::kRank == 2 &&
95
+ LayoutB::kRank == 2 &&
96
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
97
+
98
+ static_assert(kSideModeA != SideMode::kInvalid
99
+ , "Side Mode can either be Left or Right.");
100
+
101
+ static_assert(
102
+ kFillModeA == FillMode::kLower ||
103
+ kFillModeA == FillMode::kUpper,
104
+ "Fill Mode can either be Lower or Upper.");
105
+
106
+ using CompareOp_w_diag = typename TrMatrixCompareOp<kFillModeA, DiagType::kNonUnit>::Type;
107
+ using CompareOp_wo_diag = typename TrMatrixCompareOp<kFillModeA, DiagType::kZero>::Type;
108
+
109
+ // Note: batch is ignored.
110
+ int const M = problem_size.m();
111
+ int const N = problem_size.n();
112
+ // Assuming correct k-dimension value is passed
113
+ int const K = problem_size.k();
114
+
115
+ // Blocking necessary to speedup reference implementation
116
+ int const Mblock = 16;
117
+ int const Nblock = 16;
118
+
119
+ ConvertOp convert_op;
120
+ InnerProductOp inner_product_op;
121
+ CompareOp_w_diag compare_op_1;
122
+ CompareOp_wo_diag compare_op_2;
123
+
124
+ for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) {
125
+
126
+ // Compute matrix product using blocks
127
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
128
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
129
+
130
+ ComputeType accum[Mblock][Nblock];
131
+
132
+ for (int j = 0; j < Nblock; j++) {
133
+ for (int i = 0; i < Mblock; i++) {
134
+ accum[i][j] = initial_accum;
135
+ }
136
+ }
137
+
138
+ for (int k_block = 0; k_block < K; ++k_block) {
139
+ for (int j = 0; j < Nblock; j++) {
140
+ for (int i = 0; i < Mblock; i++) {
141
+ int row = row_block + i;
142
+ int col = col_block + j;
143
+
144
+ if (row < M && col < N)
145
+ {
146
+ ElementA a_1 = ElementA();
147
+ ElementB b_1 = ElementB();
148
+ ElementA a_2 = ElementA();
149
+ ElementB b_2 = ElementB();
150
+
151
+ // A x B or B x A (with diagonal)
152
+ if (kSideModeA == SideMode::kLeft) {
153
+ a_1 = (compare_op_1(row, k_block)) ?
154
+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA();
155
+ b_1 = tensor_b.at(MatrixCoord(k_block, col));
156
+ } else if (kSideModeA == SideMode::kRight) {
157
+ a_1 = tensor_b.at(MatrixCoord(row, k_block));
158
+ b_1 = (compare_op_1(k_block, col)) ?
159
+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA();
160
+ }
161
+ ComputeType compute_a_1 = ComputeType(a_1);
162
+ ComputeType compute_b_1 = ComputeType(b_1);
163
+
164
+ // The imaginary parts of the diagonal elements of
165
+ // a complex data type are assumed and set to zero
166
+ if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) {
167
+ compute_a_1 = real(compute_a_1);
168
+ } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) {
169
+ compute_b_1 = real(compute_b_1);
170
+ }
171
+
172
+ accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]);
173
+
174
+ // A^T x B or B x A^T (without diagonal)
175
+ if (kSideModeA == SideMode::kLeft) {
176
+ a_2 = (compare_op_2(k_block, row)) ?
177
+ (tensor_a.at(MatrixCoord(k_block, row))) : ElementA();
178
+ b_2 = tensor_b.at(MatrixCoord(k_block, col));
179
+ if (kBlasMode == BlasMode::kHermitian)
180
+ a_2 = conj(a_2);
181
+ } else if (kSideModeA == SideMode::kRight) {
182
+ a_2 = tensor_b.at(MatrixCoord(row, k_block));
183
+ b_2 = (compare_op_2(col, k_block)) ?
184
+ tensor_a.at(MatrixCoord(col, k_block)) : ElementA();
185
+ if (kBlasMode == BlasMode::kHermitian)
186
+ b_2 = conj(b_2);
187
+ }
188
+
189
+ ComputeType compute_a_2 = ComputeType(a_2);
190
+ ComputeType compute_b_2 = ComputeType(b_2);
191
+
192
+ accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]);
193
+ }
194
+ }
195
+ }
196
+ }
197
+
198
+ for (int j = 0; j < Nblock; j++) {
199
+ for (int i = 0; i < Mblock; i++) {
200
+ int row = row_block + i;
201
+ int col = col_block + j;
202
+
203
+ MatrixCoord coord = MatrixCoord(row, col);
204
+
205
+ if (row < M && col < N) {
206
+
207
+ ScalarType c = tensor_c.at(coord);
208
+
209
+ tensor_d.at(coord) = convert_op(
210
+ alpha * ScalarType(accum[i][j]) +
211
+ beta * c);
212
+ }
213
+ }
214
+ }
215
+
216
+ } // for (col_block)
217
+ } // for (row_block)
218
+
219
+ tensor_a.add_pointer_offset(batch_stride_A);
220
+ tensor_b.add_pointer_offset(batch_stride_B);
221
+ tensor_c.add_pointer_offset(batch_stride_C);
222
+ tensor_d.add_pointer_offset(batch_stride_D);
223
+
224
+ } // for (batch_idx)
225
+ }
226
+
227
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
228
+
229
+ template <
230
+ typename ElementA,
231
+ typename LayoutA,
232
+ SideMode SideModeA,
233
+ FillMode FillModeA,
234
+ typename ElementB,
235
+ typename LayoutB,
236
+ typename ElementC,
237
+ typename LayoutC,
238
+ typename ScalarType,
239
+ typename ComputeType,
240
+ BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric,
241
+ typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex
242
+ >
243
+ struct SymmComplex;
244
+
245
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
246
+
247
+ /// Partial specialization for multiply-add
248
+ template <typename ElementA, typename LayoutA,
249
+ SideMode SideModeA, FillMode FillModeA,
250
+ typename ElementB, typename LayoutB,
251
+ typename ElementC, typename LayoutC,
252
+ typename ScalarType, typename ComputeType,
253
+ BlasMode BlasMode_>
254
+ struct SymmComplex<ElementA, LayoutA,
255
+ SideModeA, FillModeA,
256
+ ElementB, LayoutB,
257
+ ElementC, LayoutC, ScalarType,
258
+ ComputeType, BlasMode_,
259
+ arch::OpMultiplyAddComplex> {
260
+
261
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
262
+ TensorRef<ElementA, LayoutA> tensor_a,
263
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
264
+ TensorRef<ElementC, LayoutC> tensor_c,
265
+ TensorRef<ElementC, LayoutC> tensor_d,
266
+ ComputeType initial_accum = ComputeType(0)) {
267
+ static_assert(
268
+ LayoutA::kRank == 2 && LayoutC::kRank == 2,
269
+ "Tensors must be of rank 2");
270
+
271
+ compute_symm_complex<ElementA, LayoutA,
272
+ SideModeA, FillModeA,
273
+ ElementB, LayoutB,
274
+ ElementC, LayoutC,
275
+ ScalarType, ComputeType, BlasMode_, multiply_add<ComputeType>>(
276
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
277
+ }
278
+ };
279
+
280
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
281
+
282
+ /// Partial specialization for gaussian multiply-add
283
+ template <typename ElementA, typename LayoutA,
284
+ SideMode SideModeA, FillMode FillModeA,
285
+ typename ElementB, typename LayoutB,
286
+ typename ElementC, typename LayoutC,
287
+ typename ScalarType, typename ComputeType,
288
+ BlasMode BlasMode_>
289
+ struct SymmComplex<ElementA, LayoutA,
290
+ SideModeA, FillModeA,
291
+ ElementB, LayoutB,
292
+ ElementC, LayoutC, ScalarType,
293
+ ComputeType, BlasMode_,
294
+ arch::OpMultiplyAddGaussianComplex> {
295
+
296
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
297
+ TensorRef<ElementA, LayoutA> tensor_a,
298
+ TensorRef<ElementB, LayoutB> tensor_b, ScalarType beta,
299
+ TensorRef<ElementC, LayoutC> tensor_c,
300
+ TensorRef<ElementC, LayoutC> tensor_d,
301
+ ComputeType initial_accum = ComputeType(0)) {
302
+ static_assert(
303
+ LayoutA::kRank == 2 && LayoutC::kRank == 2,
304
+ "Tensors must be of rank 2");
305
+
306
+ compute_symm_complex<ElementA, LayoutA,
307
+ SideModeA, FillModeA,
308
+ ElementB, LayoutB,
309
+ ElementC, LayoutC,
310
+ ScalarType, ComputeType, BlasMode_, multiply_add<ComputeType>>(
311
+ problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
312
+ }
313
+ };
314
+
315
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
316
+
317
+ } // namespace host
318
+ } // namespace reference
319
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Defines host-side elementwise operations on TensorView.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // Standard Library includes
38
+ #include <utility>
39
+
40
+ // Cutlass includes
41
+ #include "cutlass/cutlass.h"
42
+ #include "cutlass/relatively_equal.h"
43
+ #include "cutlass/tensor_view.h"
44
+ #include "cutlass/tensor_view_planar_complex.h"
45
+
46
+ #include "cutlass/util/distribution.h"
47
+ #include "tensor_foreach.h"
48
+
49
+ namespace cutlass {
50
+ namespace reference {
51
+ namespace host {
52
+
53
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
54
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ namespace detail {
57
+
58
+ template <
59
+ typename Element, ///< Element type
60
+ typename Layout> ///< Layout function
61
+ struct TensorGreatestErrorFunc {
62
+
63
+ //
64
+ // Data members
65
+ //
66
+
67
+ TensorView<Element, Layout> lhs;
68
+ TensorView<Element, Layout> rhs;
69
+ double result;
70
+
71
+ /// Ctor
72
+ TensorGreatestErrorFunc(
73
+ TensorView<Element, Layout> const &lhs_,
74
+ TensorView<Element, Layout> const &rhs_
75
+ ) :
76
+ lhs(lhs_),
77
+ rhs(rhs_),
78
+ result(0.0) { }
79
+
80
+ /// Visits a coordinate
81
+ void operator()(Coord<Layout::kRank> const &coord) {
82
+
83
+ Element lhs_ = lhs.at(coord);
84
+ Element rhs_ = rhs.at(coord);
85
+
86
+ result = std::max(result, std::abs(double(lhs_) - double(rhs_)));
87
+ }
88
+
89
+ /// Returns true if equal
90
+ operator double() const {
91
+ return result;
92
+ }
93
+ };
94
+
95
+ template <
96
+ typename Element, ///< Element type
97
+ typename Layout> ///< Layout function
98
+ struct TensorMREFunc {
99
+
100
+ //
101
+ // Data members
102
+ //
103
+
104
+ TensorView<Element, Layout> lhs;
105
+ TensorView<Element, Layout> rhs;
106
+ double sum;
107
+ uint64_t count;
108
+ static constexpr double epsilon = 1e-6;
109
+
110
+ /// Ctor
111
+ TensorMREFunc(
112
+ TensorView<Element, Layout> const &lhs_,
113
+ TensorView<Element, Layout> const &rhs_
114
+ ) :
115
+ lhs(lhs_),
116
+ rhs(rhs_),
117
+ sum(0.0),
118
+ count(0) { }
119
+
120
+ /// Visits a coordinate
121
+ void operator()(Coord<Layout::kRank> const &coord) {
122
+
123
+ Element lhs_ = lhs.at(coord);
124
+ Element rhs_ = rhs.at(coord);
125
+
126
+ sum += std::abs(double(lhs_) - double(rhs_) / (double(rhs_) + epsilon));
127
+ ++count;
128
+ }
129
+
130
+ /// Returns true if equal
131
+ operator double() const {
132
+ return sum / double(count);
133
+ }
134
+ };
135
+
136
+ template <
137
+ typename Element, ///< Element type
138
+ typename Layout> ///< Layout function
139
+ struct TensorMSEFunc {
140
+
141
+ //
142
+ // Data members
143
+ //
144
+
145
+ TensorView<Element, Layout> lhs;
146
+ TensorView<Element, Layout> rhs;
147
+ double sum;
148
+ uint64_t count;
149
+
150
+ /// Ctor
151
+ TensorMSEFunc(
152
+ TensorView<Element, Layout> const &lhs_,
153
+ TensorView<Element, Layout> const &rhs_
154
+ ) :
155
+ lhs(lhs_),
156
+ rhs(rhs_),
157
+ sum(0.0),
158
+ count(0) { }
159
+
160
+ /// Visits a coordinate
161
+ void operator()(Coord<Layout::kRank> const &coord) {
162
+
163
+ Element lhs_ = lhs.at(coord);
164
+ Element rhs_ = rhs.at(coord);
165
+
166
+ sum += std::pow((double(lhs_) - double(rhs_)), 2);
167
+ ++count;
168
+ }
169
+
170
+ /// Returns true if equal
171
+ operator double() const {
172
+ return sum / double(count);
173
+ }
174
+ };
175
+
176
+ template <
177
+ typename Element, ///< Element type
178
+ typename Layout> ///< Layout function
179
+ struct TensorEqualsFunc {
180
+
181
+ //
182
+ // Data members
183
+ //
184
+
185
+ TensorView<Element, Layout> lhs;
186
+ TensorView<Element, Layout> rhs;
187
+ bool result;
188
+
189
+ /// Ctor
190
+ TensorEqualsFunc(): result(true) { }
191
+
192
+ /// Ctor
193
+ TensorEqualsFunc(
194
+ TensorView<Element, Layout> const &lhs_,
195
+ TensorView<Element, Layout> const &rhs_
196
+ ) :
197
+ lhs(lhs_), rhs(rhs_), result(true) { }
198
+
199
+ /// Visits a coordinate
200
+ void operator()(Coord<Layout::kRank> const &coord) {
201
+
202
+ Element lhs_ = lhs.at(coord);
203
+ Element rhs_ = rhs.at(coord);
204
+
205
+ if (lhs_ != rhs_) {
206
+ result = false;
207
+ }
208
+ }
209
+
210
+ /// Returns true if equal
211
+ operator bool() const {
212
+ return result;
213
+ }
214
+ };
215
+
216
+ template <
217
+ typename Element, ///< Element type
218
+ typename Layout> ///< Layout function
219
+ struct TensorRelativelyEqualsFunc {
220
+
221
+ //
222
+ // Data members
223
+ //
224
+
225
+ TensorView<Element, Layout> lhs;
226
+ TensorView<Element, Layout> rhs;
227
+ Element epsilon;
228
+ Element nonzero_floor;
229
+ bool result;
230
+
231
+ /// Ctor
232
+ TensorRelativelyEqualsFunc(
233
+ TensorView<Element, Layout> const &lhs_,
234
+ TensorView<Element, Layout> const &rhs_,
235
+ Element epsilon_,
236
+ Element nonzero_floor_
237
+ ) :
238
+ lhs(lhs_),
239
+ rhs(rhs_),
240
+ epsilon(epsilon_),
241
+ nonzero_floor(nonzero_floor_),
242
+ result(true) { }
243
+
244
+ /// Visits a coordinate
245
+ void operator()(Coord<Layout::kRank> const &coord) {
246
+
247
+ Element lhs_ = lhs.at(coord);
248
+ Element rhs_ = rhs.at(coord);
249
+
250
+ if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) {
251
+ result = false;
252
+ }
253
+ }
254
+
255
+ /// Returns true if equal
256
+ operator bool() const {
257
+ return result;
258
+ }
259
+ };
260
+
261
+ } // namespace detail
262
+
263
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
264
+
265
+ /// Returns the Mean Squared Error between two tensors.
266
+ template <
267
+ typename Element, ///< Element type
268
+ typename Layout> ///< Layout function
269
+ double TensorMSE(
270
+ TensorView<Element, Layout> const &lhs,
271
+ TensorView<Element, Layout> const &rhs) {
272
+
273
+ // Extents must be identical
274
+ if (lhs.extent() != rhs.extent()) {
275
+ return -1;
276
+ }
277
+
278
+ detail::TensorMSEFunc<Element, Layout> func(lhs, rhs);
279
+ TensorForEach(
280
+ lhs.extent(),
281
+ func
282
+ );
283
+
284
+ return double(func);
285
+ }
286
+
287
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
288
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
289
+
290
+ /// Returns the Mean Relative Error between two tensors.
291
+ template <
292
+ typename Element, ///< Element type
293
+ typename Layout> ///< Layout function
294
+ double TensorMRE(
295
+ TensorView<Element, Layout> const &lhs,
296
+ TensorView<Element, Layout> const &rhs) {
297
+
298
+ // Extents must be identical
299
+ if (lhs.extent() != rhs.extent()) {
300
+ return -1;
301
+ }
302
+
303
+ detail::TensorMREFunc<Element, Layout> func(lhs, rhs);
304
+ TensorForEach(
305
+ lhs.extent(),
306
+ func
307
+ );
308
+
309
+ return double(func);
310
+ }
311
+
312
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
313
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
314
+
315
+ /// Returns the greatest error between two tensors.
316
+ template <
317
+ typename Element, ///< Element type
318
+ typename Layout> ///< Layout function
319
+ double TensorGreatestError(
320
+ TensorView<Element, Layout> const &lhs,
321
+ TensorView<Element, Layout> const &rhs) {
322
+
323
+ // Extents must be identical
324
+ if (lhs.extent() != rhs.extent()) {
325
+ return -1;
326
+ }
327
+
328
+ detail::TensorGreatestErrorFunc<Element, Layout> func(lhs, rhs);
329
+ TensorForEach(
330
+ lhs.extent(),
331
+ func
332
+ );
333
+
334
+ return double(func);
335
+ }
336
+
337
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
338
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
339
+
340
+ /// Returns true if two tensor views are equal.
341
+ template <
342
+ typename Element, ///< Element type
343
+ typename Layout> ///< Layout function
344
+ bool TensorEquals(
345
+ TensorView<Element, Layout> const &lhs,
346
+ TensorView<Element, Layout> const &rhs) {
347
+
348
+ // Extents must be identical
349
+ if (lhs.extent() != rhs.extent()) {
350
+ return false;
351
+ }
352
+
353
+ detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);
354
+ TensorForEach(
355
+ lhs.extent(),
356
+ func
357
+ );
358
+
359
+ return bool(func);
360
+ }
361
+
362
+ /// Returns true if two tensor views are equal.
363
+ template <
364
+ typename Element, ///< Element type
365
+ typename Layout> ///< Layout function
366
+ bool TensorEquals(
367
+ TensorViewPlanarComplex<Element, Layout> const &lhs,
368
+ TensorViewPlanarComplex<Element, Layout> const &rhs) {
369
+
370
+ // Extents must be identical
371
+ if (lhs.extent() != rhs.extent()) {
372
+ return false;
373
+ }
374
+
375
+ detail::TensorEqualsFunc<Element, Layout> real_func(
376
+ {lhs.data(), lhs.layout(), lhs.extent()},
377
+ {rhs.data(), rhs.layout(), rhs.extent()}
378
+ );
379
+
380
+ TensorForEach(
381
+ lhs.extent(),
382
+ real_func
383
+ );
384
+
385
+ if (!bool(real_func)) {
386
+ return false;
387
+ }
388
+
389
+ detail::TensorEqualsFunc<Element, Layout> imag_func(
390
+ {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()},
391
+ {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}
392
+ );
393
+
394
+ TensorForEach(
395
+ lhs.extent(),
396
+ imag_func
397
+ );
398
+
399
+ return bool(imag_func);
400
+ }
401
+
402
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
403
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
404
+
405
+ /// Returns true if two tensor views are relatively equal.
406
+ template <
407
+ typename Element, ///< Element type
408
+ typename Layout> ///< Layout function
409
+ bool TensorRelativelyEquals(
410
+ TensorView<Element, Layout> const &lhs,
411
+ TensorView<Element, Layout> const &rhs,
412
+ Element epsilon,
413
+ Element nonzero_floor) {
414
+
415
+ // Extents must be identical
416
+ if (lhs.extent() != rhs.extent()) {
417
+ return false;
418
+ }
419
+
420
+ detail::TensorRelativelyEqualsFunc<Element, Layout> func(lhs, rhs, epsilon, nonzero_floor);
421
+ TensorForEach(
422
+ lhs.extent(),
423
+ func
424
+ );
425
+
426
+ return bool(func);
427
+ }
428
+
429
+ /// Returns true if two tensor views are relatively equal.
430
+ template <
431
+ typename Element, ///< Element type
432
+ typename Layout> ///< Layout function
433
+ bool TensorRelativelyEquals(
434
+ TensorViewPlanarComplex<Element, Layout> const &lhs,
435
+ TensorViewPlanarComplex<Element, Layout> const &rhs,
436
+ Element epsilon,
437
+ Element nonzero_floor) {
438
+
439
+ // Extents must be identical
440
+ if (lhs.extent() != rhs.extent()) {
441
+ return false;
442
+ }
443
+
444
+ detail::TensorRelativelyEqualsFunc<Element, Layout> real_func(
445
+ {lhs.data(), lhs.layout(), lhs.extent()},
446
+ {rhs.data(), rhs.layout(), rhs.extent()},
447
+ epsilon,
448
+ nonzero_floor
449
+ );
450
+
451
+ TensorForEach(
452
+ lhs.extent(),
453
+ real_func
454
+ );
455
+
456
+ if (!bool(real_func)) {
457
+ return false;
458
+ }
459
+
460
+ detail::TensorEqualsFunc<Element, Layout> imag_func(
461
+ {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()},
462
+ {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()},
463
+ epsilon,
464
+ nonzero_floor
465
+ );
466
+
467
+ TensorForEach(
468
+ lhs.extent(),
469
+ imag_func
470
+ );
471
+
472
+ return bool(imag_func);
473
+ }
474
+
475
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
476
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
477
+
478
+ /// Returns true if two tensor views are NOT equal.
479
+ template <
480
+ typename Element, ///< Element type
481
+ typename Layout> ///< Layout function
482
+ bool TensorNotEquals(
483
+ TensorView<Element, Layout> const &lhs,
484
+ TensorView<Element, Layout> const &rhs) {
485
+
486
+ // Extents must be identical
487
+ if (lhs.extent() != rhs.extent()) {
488
+ return true;
489
+ }
490
+
491
+ detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);
492
+ TensorForEach(
493
+ lhs.extent(),
494
+ func
495
+ );
496
+
497
+ return !bool(func);
498
+ }
499
+
500
+ /// Returns true if two tensor views are equal.
501
+ template <
502
+ typename Element, ///< Element type
503
+ typename Layout> ///< Layout function
504
+ bool TensorNotEquals(
505
+ TensorViewPlanarComplex<Element, Layout> const &lhs,
506
+ TensorViewPlanarComplex<Element, Layout> const &rhs) {
507
+
508
+ return !TensorEquals(lhs, rhs);
509
+ }
510
+
511
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
512
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
513
+
514
+ namespace detail {
515
+
516
+ template <
517
+ typename Element, ///< Element type
518
+ typename Layout> ///< Layout function
519
+ struct TensorContainsFunc {
520
+
521
+ //
522
+ // Data members
523
+ //
524
+
525
+ TensorView<Element, Layout> view;
526
+ Element value;
527
+ bool contains;
528
+ Coord<Layout::kRank> location;
529
+
530
+ //
531
+ // Methods
532
+ //
533
+
534
+ /// Ctor
535
+ TensorContainsFunc(): contains(false) { }
536
+
537
+ /// Ctor
538
+ TensorContainsFunc(
539
+ TensorView<Element, Layout> const &view_,
540
+ Element value_
541
+ ) :
542
+ view(view_), value(value_), contains(false) { }
543
+
544
+ /// Visits a coordinate
545
+ void operator()(Coord<Layout::kRank> const &coord) {
546
+
547
+ if (view.at(coord) == value) {
548
+ if (!contains) {
549
+ location = coord;
550
+ }
551
+ contains = true;
552
+ }
553
+ }
554
+
555
+ /// Returns true if equal
556
+ operator bool() const {
557
+ return contains;
558
+ }
559
+ };
560
+
561
+ } // namespace detail
562
+
563
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
564
+
565
+ /// Returns true if a value is present in a tensor
566
+ template <
567
+ typename Element, ///< Element type
568
+ typename Layout> ///< Layout function
569
+ bool TensorContains(
570
+ TensorView<Element, Layout> const & view,
571
+ Element value) {
572
+
573
+ detail::TensorContainsFunc<Element, Layout> func(
574
+ view,
575
+ value
576
+ );
577
+
578
+ TensorForEach(
579
+ view.extent(),
580
+ func
581
+ );
582
+
583
+ return bool(func);
584
+ }
585
+
586
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
587
+
588
+ /// Returns a pair containing a boolean of whether a value exists in a tensor and the location of
589
+ /// of the first occurrence. If the value is not contained in the tensor, the second element of the
590
+ /// pair is undefined.
591
+ template <
592
+ typename Element, ///< Element type
593
+ typename Layout> ///< Layout function
594
+ std::pair<bool, Coord<Layout::kRank> > TensorFind(
595
+ TensorView<Element, Layout> const & view,
596
+ Element value) {
597
+
598
+ detail::TensorContainsFunc<Element, Layout> func(
599
+ view,
600
+ value
601
+ );
602
+
603
+ TensorForEach(
604
+ view.extent(),
605
+ func
606
+ );
607
+
608
+ return std::make_pair(bool(func), func.location);
609
+ }
610
+
611
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
612
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
613
+
614
+ } // namespace host
615
+ } // namespace reference
616
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Provides several functions for filling tensors with data.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // Standard Library includes
38
+ #include <utility>
39
+ #include <cstdlib>
40
+ #include <cmath>
41
+
42
+ // Cute includes
43
+ #include "cute/tensor.hpp"
44
+
45
+ // Cutlass includes
46
+ #include "cutlass/cutlass.h"
47
+ #include "cutlass/complex.h"
48
+ #include "cutlass/quaternion.h"
49
+ #include "cutlass/array.h"
50
+ #include "cutlass/numeric_types.h"
51
+
52
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace cutlass {
55
+ namespace reference {
56
+ namespace host {
57
+
58
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ /// Returns true if two tensor views are equal.
61
+ template <
62
+ typename TensorL,
63
+ typename TensorR
64
+ >
65
+ bool TensorEquals(
66
+ TensorL lhs,
67
+ TensorR rhs) {
68
+
69
+ // Extents must be identical
70
+ if (cute::size(lhs) != cute::size(rhs)) {
71
+ return false;
72
+ }
73
+
74
+ for (int64_t idx = 0; idx < cute::size(lhs); ++idx) {
75
+ if (lhs(idx) != rhs(idx)) {
76
+ return false;
77
+ }
78
+ }
79
+
80
+ return true;
81
+ }
82
+
83
+ /// Returns true if two tensor views are NOT equal.
84
+ template <
85
+ typename TensorL,
86
+ typename TensorR
87
+ >
88
+ bool TensorNotEquals(
89
+ TensorL lhs,
90
+ TensorR rhs) {
91
+
92
+ return TensorEquals(lhs, rhs);
93
+ }
94
+
95
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
96
+
97
+ } // namespace host
98
+ } // namespace reference
99
+ } // namespace cutlass
100
+
101
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Defines host-side elementwise operations on TensorView.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // Standard Library includes
38
+ #include <utility>
39
+
40
+ // Cutlass includes
41
+ #include "cutlass/cutlass.h"
42
+ #include "tensor_foreach.h"
43
+
44
+ namespace cutlass {
45
+ namespace reference {
46
+ namespace host {
47
+
48
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace detail {
51
+
52
+ /// Helper to convert between types
53
+ template <
54
+ typename DstElement,
55
+ typename SrcElement
56
+ >
57
+ struct TrivialConvert {
58
+
59
+ TrivialConvert() { }
60
+
61
+ DstElement operator()(SrcElement src) const {
62
+ return DstElement(src);
63
+ }
64
+ };
65
+
66
+ /// Helper to conditionally copy between tensor views.
67
+ template <
68
+ typename DstElement,
69
+ typename DstLayout,
70
+ typename SrcElement,
71
+ typename SrcLayout,
72
+ typename F
73
+ >
74
+ struct TensorCopyIf {
75
+
76
+ using DstTensorView = TensorView<DstElement, DstLayout>;
77
+ using SrcTensorView = TensorView<SrcElement, SrcLayout>;
78
+
79
+ //
80
+ // Data members
81
+ //
82
+
83
+ DstTensorView dst;
84
+ SrcTensorView src;
85
+ F convert;
86
+
87
+ //
88
+ // Methods
89
+ //
90
+
91
+ TensorCopyIf() { }
92
+
93
+ TensorCopyIf(
94
+ DstTensorView const &dst_,
95
+ SrcTensorView const &src_,
96
+ F const &convert_): dst(dst_), src(src_), convert(convert_) {}
97
+
98
+ /// Copies based on destination and source bounds
99
+ void operator()(Coord<DstLayout::kRank> const &coord) {
100
+ if (dst.contains(coord) && src.contains(coord)) {
101
+ dst.at(coord) = convert(src.at(coord));
102
+ }
103
+ }
104
+ };
105
+
106
+ } // namespace detail
107
+
108
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
109
+
110
+ /// Copies elements from one tensor view into another, satisfying bounds of each tensor.
111
+ template <
112
+ typename DstElement, /// Destination tensor's element type
113
+ typename DstLayout, /// Destination tensor's layout
114
+ typename SrcElement, /// Source tensor's element type
115
+ typename SrcLayout, /// Source tensor's layout
116
+ typename F /// Transformation functor
117
+ >
118
+ void TensorCopy(
119
+ TensorView<DstElement, DstLayout> dst,
120
+ TensorView<SrcElement, SrcLayout> src,
121
+ F const &transform) {
122
+
123
+ using CopyIf = detail::TensorCopyIf<
124
+ DstElement,
125
+ DstLayout,
126
+ SrcElement,
127
+ SrcLayout,
128
+ F>;
129
+
130
+ CopyIf copy_if(dst, src, transform);
131
+
132
+ TensorForEach(dst.extent(), copy_if);
133
+ }
134
+
135
+
136
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
137
+
138
+ /// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent
139
+ /// to avoid out of bounds accesses.
140
+ template <
141
+ typename DstElement, /// Destination tensor's element type
142
+ typename DstLayout, /// Destination tensor's layout
143
+ typename SrcElement, /// Source tensor's element type
144
+ typename SrcLayout, /// Source tensor's layout
145
+ typename F /// Transformation functor
146
+ >
147
+ void TensorCopy(
148
+ TensorView<DstElement, DstLayout> dst,
149
+ TensorRef<SrcElement, SrcLayout> src,
150
+ F const &transform) {
151
+
152
+ using CopyIf = detail::TensorCopyIf<
153
+ DstElement,
154
+ DstLayout,
155
+ SrcElement,
156
+ SrcLayout,
157
+ F>;
158
+
159
+ TensorView<SrcElement, SrcLayout> src_view(src, dst.extent());
160
+
161
+ CopyIf copy_if(dst, src_view, transform);
162
+
163
+ TensorForEach(dst.extent(), copy_if);
164
+ }
165
+
166
+ /// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent
167
+ /// to avoid out of bounds accesses.
168
+ template <
169
+ typename DstElement, /// Destination tensor's element type
170
+ typename DstLayout, /// Destination tensor's layout
171
+ typename SrcElement, /// Source tensor's element type
172
+ typename SrcLayout, /// Source tensor's layout
173
+ typename F /// Transformation functor
174
+ >
175
+ void TensorCopy(
176
+ TensorRef<DstElement, DstLayout> dst,
177
+ TensorView<SrcElement, SrcLayout> src,
178
+ F const &transform) {
179
+
180
+ using CopyIf = detail::TensorCopyIf<
181
+ DstElement,
182
+ DstLayout,
183
+ SrcElement,
184
+ SrcLayout,
185
+ F>;
186
+
187
+ TensorView<DstElement, DstLayout> dst_view(dst, src.extent());
188
+
189
+ CopyIf copy_if(dst_view, src, transform);
190
+
191
+ TensorForEach(src.extent(), copy_if);
192
+ }
193
+
194
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
195
+
196
+ /// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds
197
+ /// if SrcElement can be converted to DstElement.
198
+ template <
199
+ typename DstElement, /// Destination tensor's element type
200
+ typename DstLayout, /// Destination tensor's layout
201
+ typename SrcElement, /// Source tensor's element type
202
+ typename SrcLayout /// Source tensor's layout
203
+ >
204
+ void TensorCopy(
205
+ TensorView<DstElement, DstLayout> dst,
206
+ TensorView<SrcElement, SrcLayout> src) {
207
+
208
+ detail::TrivialConvert<DstElement, SrcElement> convert;
209
+
210
+ TensorCopy(dst, src, convert);
211
+ }
212
+
213
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
214
+
215
+ /// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds
216
+ /// if SrcElement can be converted to DstElement.
217
+ template <
218
+ typename DstElement, /// Destination tensor's element type
219
+ typename DstLayout, /// Destination tensor's layout
220
+ typename SrcElement, /// Source tensor's element type
221
+ typename SrcLayout, /// Source tensor's layout
222
+ typename F /// Transformation functor
223
+ >
224
+ void TensorCopy(
225
+ TensorView<DstElement, DstLayout> dst,
226
+ TensorRef<SrcElement, SrcLayout> src) {
227
+
228
+ detail::TrivialConvert<DstElement, SrcElement> convert;
229
+
230
+ TensorCopy(dst, src, convert);
231
+ }
232
+
233
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
234
+
235
+ /// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds
236
+ /// if SrcElement can be converted to DstElement.
237
+ template <
238
+ typename DstElement, /// Destination tensor's element type
239
+ typename DstLayout, /// Destination tensor's layout
240
+ typename SrcElement, /// Source tensor's element type
241
+ typename SrcLayout /// Source tensor's layout
242
+ >
243
+ void TensorCopy(
244
+ TensorRef<DstElement, DstLayout> dst,
245
+ TensorView<SrcElement, SrcLayout> src) {
246
+
247
+ detail::TrivialConvert<DstElement, SrcElement> convert;
248
+
249
+ TensorCopy(dst, src, convert);
250
+ }
251
+
252
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
253
+
254
+ } // namespace host
255
+ } // namespace reference
256
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Defines host-side elementwise operations on TensorView.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // Cutlass includes
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/functional.h"
40
+
41
+ #include "tensor_foreach.h"
42
+
43
+ namespace cutlass {
44
+ namespace reference {
45
+ namespace host {
46
+
47
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
48
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace detail {
51
+
52
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Helper to apply a binary operator in place
55
+ template <
56
+ typename ElementA,
57
+ typename LayoutA,
58
+ typename ElementB,
59
+ typename LayoutB,
60
+ typename ElementD,
61
+ typename LayoutD,
62
+ typename BinaryFunc>
63
+ struct TensorFuncBinaryOp {
64
+
65
+ //
66
+ // Data members
67
+ //
68
+
69
+ /// View of left-hand-side tensor
70
+ TensorView<ElementD, LayoutD> view_d;
71
+ TensorRef<ElementA, LayoutA> view_a;
72
+ TensorRef<ElementB, LayoutB> view_b;
73
+ BinaryFunc func;
74
+
75
+ //
76
+ // Methods
77
+ //
78
+
79
+ /// Constructor
80
+ TensorFuncBinaryOp() { }
81
+
82
+ /// Constructor
83
+ TensorFuncBinaryOp(
84
+ TensorView<ElementD, LayoutD> const & view_d_,
85
+ TensorRef<ElementA, LayoutA> const & view_a_,
86
+ TensorRef<ElementB, LayoutB> const & view_b_,
87
+ BinaryFunc func = BinaryFunc()
88
+ ):
89
+ view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { }
90
+
91
+ /// Equality check
92
+ void operator()(Coord<LayoutD::kRank> const &coord) const {
93
+ view_d.at(coord) = func(
94
+ ElementD(view_a.at(coord)),
95
+ ElementD(view_b.at(coord))
96
+ );
97
+ }
98
+ };
99
+
100
+ } // namespace detail
101
+
102
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
103
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
104
+
105
+ /// Adds two tensors and stores in the destination tensor: d = a + b
106
+ template <
107
+ typename ElementD,
108
+ typename LayoutD,
109
+ typename ElementA,
110
+ typename LayoutA,
111
+ typename ElementB,
112
+ typename LayoutB
113
+ >
114
+ void TensorAdd(
115
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
116
+ TensorRef<ElementA, LayoutA> a, ///< A tensor reference
117
+ TensorRef<ElementB, LayoutB> b ///< B tensor reference
118
+ ) {
119
+
120
+ detail::TensorFuncBinaryOp<
121
+ ElementD,
122
+ LayoutD,
123
+ ElementA,
124
+ LayoutA,
125
+ ElementB,
126
+ LayoutB,
127
+ cutlass::plus<ElementD>
128
+ > func(d, a, b);
129
+
130
+ TensorForEach(
131
+ d.extent(),
132
+ func);
133
+ }
134
+
135
+ /// Adds a tensor in place: d = d .+ a
136
+ template <
137
+ typename ElementD,
138
+ typename LayoutD,
139
+ typename ElementA,
140
+ typename LayoutA
141
+ >
142
+ void TensorAdd(
143
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
144
+ TensorRef<ElementA, LayoutA> a ///< A tensor reference
145
+ ) {
146
+ TensorAdd(d, d, a);
147
+ }
148
+
149
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
150
+
151
+ /// Subtracts two tensors and stores in the destination tensor: d = a - b
152
+ template <
153
+ typename ElementD,
154
+ typename LayoutD,
155
+ typename ElementA,
156
+ typename LayoutA,
157
+ typename ElementB,
158
+ typename LayoutB
159
+ >
160
+ void TensorSub(
161
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
162
+ TensorRef<ElementA, LayoutA> a, ///< A tensor reference
163
+ TensorRef<ElementB, LayoutB> b ///< B tensor reference
164
+ ) {
165
+
166
+ detail::TensorFuncBinaryOp<
167
+ ElementD,
168
+ LayoutD,
169
+ ElementA,
170
+ LayoutA,
171
+ ElementB,
172
+ LayoutB,
173
+ cutlass::minus<ElementD>
174
+ > func(d, a, b);
175
+
176
+ TensorForEach(
177
+ d.extent(),
178
+ func);
179
+ }
180
+
181
+ /// Subtracts two tensors in place: d = d .- a
182
+ template <
183
+ typename ElementD,
184
+ typename LayoutD,
185
+ typename ElementA,
186
+ typename LayoutA,
187
+ typename ElementB,
188
+ typename LayoutB
189
+ >
190
+ void TensorSub(
191
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
192
+ TensorRef<ElementA, LayoutA> a ///< A tensor reference
193
+ ) {
194
+
195
+ TensorSub(d, d, a);
196
+ }
197
+
198
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
199
+
200
+ /// Multiplies two tensors and stores in the destination tensor: d = a .* b
201
+ template <
202
+ typename ElementD,
203
+ typename LayoutD,
204
+ typename ElementA,
205
+ typename LayoutA,
206
+ typename ElementB,
207
+ typename LayoutB
208
+ >
209
+ void TensorMul(
210
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
211
+ TensorRef<ElementA, LayoutA> a, ///< A tensor reference
212
+ TensorRef<ElementB, LayoutB> b ///< B tensor reference
213
+ ) {
214
+
215
+ detail::TensorFuncBinaryOp<
216
+ ElementD,
217
+ LayoutD,
218
+ ElementA,
219
+ LayoutA,
220
+ ElementB,
221
+ LayoutB,
222
+ cutlass::multiplies<ElementD>
223
+ > func(d, a, b);
224
+
225
+ TensorForEach(
226
+ d.extent(),
227
+ func);
228
+ }
229
+
230
+ /// Multiplies tensors in place: d = d .* a
231
+ template <
232
+ typename ElementD,
233
+ typename LayoutD,
234
+ typename ElementA,
235
+ typename LayoutA
236
+ >
237
+ void TensorMul(
238
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
239
+ TensorRef<ElementA, LayoutA> a ///< A tensor reference
240
+ ) {
241
+ TensorMul(d, d, a);
242
+ }
243
+
244
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
245
+
246
+ /// Divides two tensors and stores in the destination tensor: d = a ./ b
247
+ template <
248
+ typename ElementD,
249
+ typename LayoutD,
250
+ typename ElementA,
251
+ typename LayoutA,
252
+ typename ElementB,
253
+ typename LayoutB
254
+ >
255
+ void TensorDiv(
256
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
257
+ TensorRef<ElementA, LayoutA> a, ///< A tensor reference
258
+ TensorRef<ElementB, LayoutB> b ///< B tensor reference
259
+ ) {
260
+
261
+ detail::TensorFuncBinaryOp<
262
+ ElementD,
263
+ LayoutD,
264
+ ElementA,
265
+ LayoutA,
266
+ ElementB,
267
+ LayoutB,
268
+ cutlass::divides<ElementD>
269
+ > func(d, a, b);
270
+
271
+ TensorForEach(
272
+ d.extent(),
273
+ func);
274
+ }
275
+
276
+ /// Divides tensors in place: d = d ./ a
277
+ template <
278
+ typename ElementD,
279
+ typename LayoutD,
280
+ typename ElementA,
281
+ typename LayoutA
282
+ >
283
+ void TensorDiv(
284
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
285
+ TensorRef<ElementA, LayoutA> a ///< A tensor reference
286
+ ) {
287
+ TensorDiv(d, d, a);
288
+ }
289
+
290
+
291
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
292
+
293
+ /// Divides two tensors and stores in the destination tensor: d = a ./ b
294
+ template <
295
+ typename ElementD,
296
+ typename LayoutD,
297
+ typename ElementA,
298
+ typename LayoutA,
299
+ typename ElementB,
300
+ typename LayoutB
301
+ >
302
+ void TensorModulus(
303
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
304
+ TensorRef<ElementA, LayoutA> a, ///< A tensor reference
305
+ TensorRef<ElementB, LayoutB> b ///< B tensor reference
306
+ ) {
307
+
308
+ detail::TensorFuncBinaryOp<
309
+ ElementD,
310
+ LayoutD,
311
+ ElementA,
312
+ LayoutA,
313
+ ElementB,
314
+ LayoutB,
315
+ cutlass::divides<ElementD>
316
+ > func(d, a, b);
317
+
318
+ TensorForEach(
319
+ d.extent(),
320
+ func);
321
+ }
322
+
323
+ /// Divides tensors in place: d = d ./ a
324
+ template <
325
+ typename ElementD,
326
+ typename LayoutD,
327
+ typename ElementA,
328
+ typename LayoutA
329
+ >
330
+ void TensorModulus(
331
+ TensorView<ElementD, LayoutD> d, ///< destination tensor view
332
+ TensorRef<ElementA, LayoutA> a ///< A tensor reference
333
+ ) {
334
+ TensorDiv(d, d, a);
335
+ }
336
+
337
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
338
+
339
+ } // namespace host
340
+ } // namespace reference
341
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h ADDED
@@ -0,0 +1,1718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Provides several functions for filling tensors with data.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // Standard Library includes
38
+ #include <utility>
39
+ #include <cstdlib>
40
+ #include <cmath>
41
+ #include <random>
42
+ #include <stdexcept>
43
+
44
+ // Cutlass includes
45
+ #include "cutlass/cutlass.h"
46
+ #include "cutlass/complex.h"
47
+ #include "cutlass/quaternion.h"
48
+ #include "cutlass/array.h"
49
+ #include "cutlass/numeric_types.h"
50
+ #include "cutlass/subbyte_reference.h"
51
+ #include "cutlass/tensor_view.h"
52
+ #include "cutlass/tensor_view_planar_complex.h"
53
+ #include "cutlass/blas3.h"
54
+
55
+ #include "cutlass/util/distribution.h"
56
+ #include "tensor_foreach.h"
57
+
58
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace reference {
62
+ namespace host {
63
+
64
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
65
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
66
+
67
+ namespace detail {
68
+
69
+ template <
70
+ typename Element, ///< Element type
71
+ typename Layout> ///< Layout function
72
+ struct TensorFillFunc {
73
+
74
+ using TensorView = TensorView<Element, Layout>;
75
+
76
+ //
77
+ // Data members
78
+ //
79
+
80
+ TensorView view;
81
+ Element value;
82
+
83
+ //
84
+ // Methods
85
+ //
86
+
87
+ TensorFillFunc(
88
+ TensorView const &view_ = TensorView(),
89
+ Element value_ = Element(0)
90
+ ): view(view_), value(value_) { }
91
+
92
+ void operator()(Coord<Layout::kRank> const & coord) const {
93
+ view.at(coord) = value;
94
+ }
95
+ };
96
+
97
+ /// Returns a pair of values of the Gaussian distribution generated by the Box Muller method
98
+ struct BoxMullerFunc {
99
+
100
+ BoxMullerFunc() {}
101
+
102
+ void operator()(
103
+ double* rnd, ///< Size-2 vector to be filled with random values
104
+ double mean = 0, ///< Mean of the Gaussian distribution
105
+ double stddev = 1, ///< Standard deviation of the Gaussian distribution
106
+ double pi = std::acos(-1)) const {
107
+
108
+ double u1 = double(std::rand()) / double(RAND_MAX);
109
+ double u2 = double(std::rand()) / double(RAND_MAX);
110
+ rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
111
+ rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2);
112
+ rnd[0] = mean + stddev * rnd[0];
113
+ rnd[1] = mean + stddev * rnd[1];
114
+ }
115
+ };
116
+ } // namespace detail
117
+
118
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
119
+
120
+ /// Fills a tensor with a uniform value
121
+ template <
122
+ typename Element, ///< Element type
123
+ typename Layout> ///< Layout function
124
+ void TensorFill(
125
+ TensorView<Element, Layout> dst, ///< destination tensor
126
+ Element val = Element(0)) { ///< value to uniformly fill it with
127
+
128
+ detail::TensorFillFunc<Element, Layout> func(dst, val);
129
+
130
+ TensorForEach(
131
+ dst.extent(),
132
+ func
133
+ );
134
+ }
135
+
136
+ /// Fills a tensor with a uniform value
137
+ template <
138
+ typename Element, ///< Element type
139
+ typename Layout> ///< Layout function
140
+ void TensorFill(
141
+ TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
142
+ cutlass::complex<Element> val = cutlass::complex<Element>(0)) { ///< value to uniformly fill it with
143
+
144
+ TensorFill(dst.view_real(), val.real());
145
+ TensorFill(dst.view_imag(), val.imag());
146
+ }
147
+
148
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
149
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
150
+
151
+ namespace detail {
152
+
153
+ template <typename Element>
154
+ struct RandomGaussianFunc {
155
+
156
+ uint64_t seed;
157
+ double mean;
158
+ double stddev;
159
+ int int_scale;
160
+ double pi;
161
+ double pnz;
162
+ bool exclude_zero;
163
+
164
+ //
165
+ // Methods
166
+ //
167
+ RandomGaussianFunc(
168
+ uint64_t seed_ = 0,
169
+ double mean_ = 0,
170
+ double stddev_ = 1,
171
+ int int_scale_ = -1,
172
+ double pnz_ = 1.0,
173
+ bool exclude_zero_ = false
174
+ ):
175
+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) {
176
+ std::srand((unsigned)seed);
177
+ }
178
+
179
+ /// Compute random value and update RNG state
180
+ Element operator()() const {
181
+
182
+ // Box-Muller transform to generate random numbers with Normal distribution
183
+ double u1 = double(std::rand()) / double(RAND_MAX);
184
+ double u2 = double(std::rand()) / double(RAND_MAX);
185
+
186
+ // Compute Gaussian random value
187
+ double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
188
+ rnd = mean + stddev * rnd;
189
+
190
+ // Scale and convert final result
191
+ Element result;
192
+
193
+ // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian
194
+ std::random_device rnd_device;
195
+ std::mt19937 bernoulli_rnd(rnd_device());
196
+ std::bernoulli_distribution bernoulli_dist(pnz);
197
+ bool bernoulli_result = bernoulli_dist(bernoulli_rnd);
198
+
199
+ // Sample from the Gaussian distribution for a nonzero element
200
+ if (bernoulli_result) {
201
+ if (int_scale >= 0) {
202
+ rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale);
203
+ result = static_cast<Element>(rnd);
204
+ }
205
+ else {
206
+ result = static_cast<Element>(rnd);
207
+ }
208
+ }
209
+ else {
210
+ result = static_cast<Element>(0);
211
+ }
212
+
213
+ // Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros
214
+ if (exclude_zero && result == Element(0)) {
215
+ if (rnd > 0) {
216
+ rnd += 1;
217
+ } else {
218
+ rnd -= 1;
219
+ }
220
+ result = Element(rnd);
221
+ }
222
+
223
+ return result;
224
+ }
225
+ };
226
+
227
+ /// Partial specialization for initializing a complex value.
228
+ template <typename Element>
229
+ struct RandomGaussianFunc<complex<Element> > {
230
+
231
+ uint64_t seed;
232
+ double mean;
233
+ double stddev;
234
+ int int_scale;
235
+ double pi;
236
+ double pnz;
237
+ bool exclude_zero;
238
+
239
+ //
240
+ // Methods
241
+ //
242
+ RandomGaussianFunc(
243
+ uint64_t seed_ = 0,
244
+ double mean_ = 0,
245
+ double stddev_ = 1,
246
+ int int_scale_ = -1,
247
+ double pnz_ = 1.0,
248
+ bool exclude_zero_ = false
249
+ ):
250
+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) {
251
+ std::srand((unsigned)seed);
252
+ }
253
+
254
+ /// Compute random value and update RNG state
255
+ complex<Element> operator()() const {
256
+
257
+ Element reals[2];
258
+
259
+ double rnd[2];
260
+ detail::BoxMullerFunc func;
261
+ func(rnd, mean, stddev, pi);
262
+
263
+ // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian
264
+ std::random_device rnd_device;
265
+ std::mt19937 bernoulli_rnd(rnd_device());
266
+ std::bernoulli_distribution bernoulli_dist(pnz);
267
+ bool bernoulli_result = bernoulli_dist(bernoulli_rnd);
268
+
269
+ // Sample from the Gaussian distribution for a nonzero element
270
+ if (bernoulli_result) {
271
+ if (int_scale >= 0) {
272
+ rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale)));
273
+ rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale)));
274
+ reals[0] = from_real<Element>(rnd[0] / double(1 << int_scale));
275
+ reals[1] = from_real<Element>(rnd[1] / double(1 << int_scale));
276
+ }
277
+ else {
278
+ reals[0] = from_real<Element>(rnd[0]);
279
+ reals[1] = from_real<Element>(rnd[1]);
280
+ }
281
+ }
282
+ else {
283
+ reals[0] = from_real<Element>(0);
284
+ reals[1] = from_real<Element>(0);
285
+ }
286
+
287
+ // Note that this will invalidate the above else statement because it unsets zero elements
288
+ if (exclude_zero &&
289
+ reals[0] == from_real<Element>(0.0) &&
290
+ reals[1] == from_real<Element>(0.0)) {
291
+
292
+ if (rnd[0] > 0.0) {
293
+ rnd[0] += 1.0;
294
+ } else {
295
+ rnd[0] -= 1.0;
296
+ }
297
+ reals[0] = from_real<Element>(rnd[0]);
298
+ }
299
+
300
+ return complex<Element>(reals[0], reals[1]);
301
+ }
302
+ };
303
+
304
+ /// Partial specialization for initializing a complex value.
305
+ template <typename Element>
306
+ struct RandomGaussianFunc<Quaternion<Element> > {
307
+
308
+ uint64_t seed;
309
+ double mean;
310
+ double stddev;
311
+ int int_scale;
312
+ double pi;
313
+ double pnz;
314
+ bool exclude_zero;
315
+
316
+ //
317
+ // Methods
318
+ //
319
+ RandomGaussianFunc(
320
+ uint64_t seed_ = 0,
321
+ double mean_ = 0,
322
+ double stddev_ = 1,
323
+ int int_scale_ = -1,
324
+ double pnz_ = 1.0,
325
+ bool exclude_zero_ = false
326
+ ):
327
+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) {
328
+ std::srand((unsigned)seed);
329
+ }
330
+
331
+ /// Compute random value and update RNG state
332
+ Quaternion<Element> operator()() const {
333
+
334
+ Element reals[4];
335
+
336
+ double rnd1[2];
337
+ double rnd2[2];
338
+ detail::BoxMullerFunc func;
339
+ func(rnd1, mean, stddev, pi);
340
+ func(rnd2, mean, stddev, pi);
341
+
342
+ // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian
343
+ std::random_device rnd_device;
344
+ std::mt19937 bernoulli_rnd(rnd_device());
345
+ std::bernoulli_distribution bernoulli_dist(pnz);
346
+ bool bernoulli_result = bernoulli_dist(bernoulli_rnd);
347
+
348
+ // Sample from the Gaussian distribution for a nonzero element
349
+ if (bernoulli_result) {
350
+ if (int_scale >= 0) {
351
+ rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale)));
352
+ rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale)));
353
+ rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale)));
354
+ rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale)));
355
+
356
+ reals[0] = from_real<Element>(rnd1[0] / double(1 << int_scale));
357
+ reals[1] = from_real<Element>(rnd1[1] / double(1 << int_scale));
358
+ reals[2] = from_real<Element>(rnd2[0] / double(1 << int_scale));
359
+ reals[3] = from_real<Element>(rnd2[1] / double(1 << int_scale));
360
+ }
361
+ else {
362
+ reals[0] = from_real<Element>(rnd1[0]);
363
+ reals[1] = from_real<Element>(rnd1[1]);
364
+ reals[2] = from_real<Element>(rnd2[0]);
365
+ reals[3] = from_real<Element>(rnd2[1]);
366
+ }
367
+ }
368
+ else {
369
+ reals[0] = from_real<Element>(0);
370
+ reals[1] = from_real<Element>(0);
371
+ reals[2] = from_real<Element>(0);
372
+ reals[3] = from_real<Element>(0);
373
+ }
374
+
375
+ // Note that this will invalidate the above else statement because it unsets zero elements
376
+ if (exclude_zero &&
377
+ reals[0] == from_real<Element>(0) &&
378
+ reals[1] == from_real<Element>(0) &&
379
+ reals[2] == from_real<Element>(0) &&
380
+ reals[3] == from_real<Element>(0)) {
381
+
382
+ if (rnd1[0] > 0.0) {
383
+ rnd1[0] += 1.0;
384
+ } else {
385
+ rnd1[0] -= 1.0;
386
+ }
387
+ reals[0] = from_real<Element>(rnd1[0]);
388
+ }
389
+
390
+ return Quaternion<Element>(reals[0], reals[1], reals[2], reals[3]);
391
+ }
392
+ };
393
+
394
+ /// Computes a random Gaussian distribution
395
+ template <
396
+ typename Element, ///< Element type
397
+ typename Layout> ///< Layout function
398
+ struct TensorFillGaussianFunc {
399
+
400
+ using TensorView = TensorView<Element, Layout>;
401
+
402
+ //
403
+ // Data members
404
+ //
405
+
406
+ TensorView view;
407
+ RandomGaussianFunc<Element> func;
408
+
409
+ //
410
+ // Methods
411
+ //
412
+
413
+ /// Construction of Gaussian RNG functor.
414
+ TensorFillGaussianFunc(
415
+ TensorView view_ = TensorView(),
416
+ RandomGaussianFunc<Element> func_ = RandomGaussianFunc<Element>()
417
+ ):
418
+ view(view_), func(func_) {
419
+
420
+ }
421
+
422
+ /// Compute random value and update RNG state
423
+ void operator()(Coord<Layout::kRank> const &coord) const {
424
+ view.at(coord) = func();
425
+ }
426
+ };
427
+
428
+ /// Computes a random Gaussian distribution for a rank-2 tensor
429
+ template <
430
+ typename Element, ///< Element type
431
+ typename Layout> ///< Layout function
432
+ struct TensorFillSymmetricGaussianFunc {
433
+
434
+ using TensorView = TensorView<Element, Layout>;
435
+
436
+ //
437
+ // Data members
438
+ //
439
+
440
+ TensorView view;
441
+ RandomGaussianFunc<Element> func;
442
+ cutlass::FillMode fill_mode;
443
+
444
+ //
445
+ // Methods
446
+ //
447
+
448
+ /// Construction of Gaussian RNG functor.
449
+ TensorFillSymmetricGaussianFunc(
450
+ TensorView view_ = TensorView(),
451
+ RandomGaussianFunc<Element> func_ = RandomGaussianFunc<Element>(),
452
+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid
453
+ ):
454
+ view(view_), func(func_), fill_mode(fill_mode_) {
455
+
456
+ }
457
+
458
+ /// Compute random value and update RNG state
459
+ void operator()(Coord<Layout::kRank> const &coord) const {
460
+ // Fill half of matrix based on FillMode
461
+ if (Layout::kRank == 2 &&
462
+ fill_mode == cutlass::FillMode::kLower &&
463
+ coord[0] >= coord[1]) {
464
+ view.at(coord) = func();
465
+ } else if (Layout::kRank == 2 &&
466
+ fill_mode == cutlass::FillMode::kUpper &&
467
+ coord[0] <= coord[1]) {
468
+ view.at(coord) = func();
469
+ }
470
+ }
471
+ };
472
+
473
+ } // namespace detail
474
+
475
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
476
+
477
+ /// Fills a tensor with random values with a Gaussian distribution.
478
+ template <
479
+ typename Element, ///< Element type
480
+ typename Layout> ///< Layout function
481
+ void TensorFillRandomGaussian(
482
+ TensorView<Element, Layout> dst, ///< destination tensor
483
+ uint64_t seed, ///< seed for RNG
484
+ double mean = 0, ///< Gaussian distribution's mean
485
+ double stddev = 1, ///< Gaussian distribution's standard deviation
486
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
487
+ double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of
488
+ /// data.
489
+ bool exclude_zero = false) { ///< Exclude zeros from tensor init.
490
+
491
+ detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits, pnz, exclude_zero);
492
+
493
+ detail::TensorFillGaussianFunc<Element, Layout> func(
494
+ dst,
495
+ random_func
496
+ );
497
+
498
+ TensorForEach(
499
+ dst.extent(),
500
+ func
501
+ );
502
+ }
503
+
504
+ /// Fills a tensor with random values with a Gaussian distribution.
505
+ template <
506
+ typename Element, ///< Element type
507
+ typename Layout> ///< Layout function
508
+ void TensorFillRandomGaussian(
509
+ TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
510
+ uint64_t seed, ///< seed for RNG
511
+ double mean = 0, ///< Gaussian distribution's mean
512
+ double stddev = 1, ///< Gaussian distribution's standard deviation
513
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
514
+ double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of
515
+ /// data.
516
+ bool exclude_zero = false) { ///< Exclude zeros from tensor init.
517
+
518
+ TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz);
519
+ TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz);
520
+ }
521
+
522
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
523
+ /// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution.
524
+ template <
525
+ typename Element, ///< Element type
526
+ typename Layout> ///< Layout function
527
+ void TensorFillSymmetricRandomGaussian(
528
+ TensorView<Element, Layout> dst, ///< destination tensor
529
+ uint64_t seed, ///< seed for RNG
530
+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
531
+ double mean = 0, ///< Gaussian distribution's mean
532
+ double stddev = 1, ///< Gaussian distribution's standard deviation
533
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
534
+ double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of
535
+ /// data.
536
+
537
+ detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits, pnz);
538
+
539
+ detail::TensorFillSymmetricGaussianFunc<Element, Layout> func(
540
+ dst,
541
+ random_func,
542
+ fill_mode
543
+ );
544
+
545
+ TensorForEach(
546
+ dst.extent(),
547
+ func
548
+ );
549
+ }
550
+
551
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
552
+
553
+ /// Fills a tensor with random values of a Gaussian distribution.
554
+ template <
555
+ typename Element ///< Element type
556
+ >
557
+ void BlockFillRandomGaussian(
558
+ Element *ptr, ///< destination buffer
559
+ size_t capacity, ///< number of elements
560
+ uint64_t seed, ///< seed for RNG
561
+ double mean = 0, ///< Gaussian distribution's mean
562
+ double stddev = 1, ///< Gaussian distribution's standard deviation
563
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
564
+ double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of
565
+ /// data.
566
+
567
+
568
+ detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits, pnz);
569
+
570
+ for (size_t i = 0; i < capacity; ++i) {
571
+ ReferenceFactory<Element>::get(ptr, i) = random_func();
572
+ }
573
+ }
574
+
575
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
576
+
577
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
578
+
579
+ namespace detail {
580
+
581
+ template <typename Element>
582
+ struct RandomUniformFunc {
583
+
584
+ using Real = typename RealType<Element>::Type;
585
+
586
+ uint64_t seed;
587
+ double range;
588
+ double min;
589
+ int int_scale;
590
+
591
+ double pnan;
592
+ private:
593
+ using engine_type = std::mt19937;
594
+ public:
595
+ engine_type bernoulli_rnd;
596
+ std::bernoulli_distribution bernoulli_dist;
597
+
598
+ bool exclude_zero;
599
+
600
+ RandomUniformFunc(
601
+ uint64_t seed_ = 0,
602
+ double max = 1,
603
+ double min_ = 0,
604
+ int int_scale_ = -1,
605
+ double pnan_ = 0,
606
+ bool exclude_zero_ = false
607
+ ):
608
+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_)
609
+ , bernoulli_rnd{static_cast<engine_type::result_type>(seed_)}
610
+ , bernoulli_dist(pnan_)
611
+ , exclude_zero(exclude_zero_)
612
+ {
613
+ std::srand((unsigned)seed);
614
+
615
+ // Handle cases where min = 0 or max = 0 for excluding zeros
616
+ if (exclude_zero) {
617
+ min = (min == 0.0) ? min + 1: min;
618
+ range = (max == 0.0) ? range - 1: range;
619
+ }
620
+ }
621
+
622
+
623
+ /// Compute random value and update RNG state
624
+ Element operator()() {
625
+
626
+ // Sample from NaN distribution.
627
+ if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
628
+ if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
629
+ return Element(NAN);
630
+ }
631
+ }
632
+
633
+ double rnd = double(std::rand()) / double(RAND_MAX);
634
+
635
+ rnd = min + range * rnd;
636
+
637
+ // Random values are cast to integer after scaling by a power of two to facilitate error
638
+ // testing
639
+ Element result;
640
+ if (int_scale >= 0) {
641
+ rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale);
642
+ result = static_cast<Element>(Real(rnd));
643
+ }
644
+ else {
645
+ result = static_cast<Element>(Real(rnd));
646
+ }
647
+
648
+ if (exclude_zero && result == Element(0)) {
649
+ if (rnd > 0.0) {
650
+ rnd = std::min(min + range, rnd + 1.0);
651
+ } else {
652
+ rnd = std::max(min, rnd - 1.0);
653
+ }
654
+ result = static_cast<Element>(Real(rnd));
655
+ }
656
+
657
+ return result;
658
+ }
659
+ };
660
+
661
+ /// Partial specialization for initializing a complex value.
662
+ template <typename Element>
663
+ struct RandomUniformFunc<complex<Element> > {
664
+
665
+ using Real = typename RealType<Element>::Type;
666
+
667
+ uint64_t seed;
668
+ double range;
669
+ double min;
670
+ int int_scale;
671
+
672
+ double pnan;
673
+ private:
674
+ using engine_type = std::mt19937;
675
+ public:
676
+ engine_type bernoulli_rnd;
677
+ std::bernoulli_distribution bernoulli_dist;
678
+
679
+ bool exclude_zero;
680
+
681
+ //
682
+ // Methods
683
+ //
684
+
685
+ RandomUniformFunc(
686
+ uint64_t seed_ = 0,
687
+ double max = 1,
688
+ double min_ = 0,
689
+ int int_scale_ = -1,
690
+ double pnan_ = 0,
691
+ bool exclude_zero_ = false
692
+ ):
693
+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_)
694
+ , bernoulli_rnd{static_cast<engine_type::result_type>(seed_)}
695
+ , bernoulli_dist(pnan_)
696
+ , exclude_zero(exclude_zero_) {
697
+ std::srand((unsigned)seed);
698
+
699
+ // Handle cases where min = 0 or max = 0 for excluding zeros
700
+ if (exclude_zero) {
701
+ min = (min == 0.0) ? min + 1: min;
702
+ range = (max == 0.0) ? range - 1: range;
703
+ }
704
+ }
705
+
706
+
707
+ /// Compute random value and update RNG state
708
+ complex<Element> operator()() {
709
+
710
+ // Sample from NaN distribution.
711
+ if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
712
+ if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
713
+ return Element(NAN);
714
+ }
715
+ }
716
+
717
+ Element reals[2];
718
+
719
+ for (int i = 0; i < 2; ++i) {
720
+ double rnd = double(std::rand()) / double(RAND_MAX);
721
+
722
+ rnd = min + range * rnd;
723
+
724
+ // Random values are cast to integer after scaling by a power of two to facilitate error
725
+ // testing
726
+
727
+ if (int_scale >= 0) {
728
+ rnd = double(std::llround(rnd * double(1 << int_scale)));
729
+ reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
730
+ }
731
+ else {
732
+ reals[i] = from_real<Element>(Real(rnd));
733
+ }
734
+
735
+ if (exclude_zero &&
736
+ i == 0 &&
737
+ reals[0] == from_real<Element>(0.0)) {
738
+
739
+ if (rnd > 0.0) {
740
+ rnd = std::min(min + range, rnd + 1.0);
741
+ } else {
742
+ rnd = std::max(min, rnd - 1.0);
743
+ }
744
+ reals[0] = from_real<Element>(Real(rnd));
745
+ }
746
+
747
+ }
748
+
749
+ return complex<Element>(reals[0], reals[1]);
750
+ }
751
+ };
752
+
753
+ /// Partial specialization for initializing a Quaternion value.
754
+ template <typename Element>
755
+ struct RandomUniformFunc<Quaternion<Element> > {
756
+
757
+ using Real = typename RealType<Element>::Type;
758
+
759
+ uint64_t seed;
760
+ double range;
761
+ double min;
762
+ int int_scale;
763
+
764
+ double pnan;
765
+ private:
766
+ using engine_type = std::mt19937;
767
+ public:
768
+ engine_type bernoulli_rnd;
769
+ std::bernoulli_distribution bernoulli_dist;
770
+
771
+ //
772
+ // Methods
773
+ //
774
+
775
+ RandomUniformFunc(
776
+ uint64_t seed_ = 0,
777
+ double max = 1,
778
+ double min_ = 0,
779
+ int int_scale_ = -1,
780
+ double pnan_ = 0
781
+ ):
782
+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_),
783
+ bernoulli_rnd{static_cast<engine_type::result_type>(seed_)},
784
+ bernoulli_dist(pnan_)
785
+ {
786
+ std::srand((unsigned)seed);
787
+ }
788
+
789
+
790
+ /// Compute random value and update RNG state
791
+ Quaternion<Element> operator()() {
792
+
793
+ // Sample from NaN distribution.
794
+ if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
795
+ if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
796
+ return Element(NAN);
797
+ }
798
+ }
799
+
800
+ Element reals[4];
801
+
802
+ for (int i = 0; i < 4; ++i) {
803
+ double rnd = double(std::rand()) / double(RAND_MAX);
804
+
805
+ rnd = min + range * rnd;
806
+
807
+ // Random values are cast to integer after scaling by a power of two to facilitate error
808
+ // testing
809
+
810
+ if (int_scale >= 0) {
811
+ rnd = double(std::llround(rnd * double(1 << int_scale)));
812
+ reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
813
+ }
814
+ else {
815
+ reals[i] = from_real<Element>(Real(rnd));
816
+ }
817
+ }
818
+
819
+ return make_Quaternion(reals[0], reals[1], reals[2], reals[3]);
820
+ }
821
+ };
822
+
823
+ /// Computes a random uniform distribution
824
+ template <
825
+ typename Element, ///< Element type
826
+ typename Layout> ///< Layout function
827
+ struct TensorFillRandomUniformFunc {
828
+
829
+ using TensorView = TensorView<Element, Layout>;
830
+
831
+ //
832
+ // Data members
833
+ //
834
+
835
+ TensorView view;
836
+ RandomUniformFunc<Element> func;
837
+
838
+ //
839
+ // Methods
840
+ //
841
+
842
+ /// Construction of uniform RNG functor.
843
+ TensorFillRandomUniformFunc(
844
+ TensorView view_ = TensorView(),
845
+ RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>()
846
+ ):
847
+ view(view_), func(func_) {
848
+
849
+ }
850
+
851
+ /// Compute random value and update RNG state
852
+ void operator()(Coord<Layout::kRank> const &coord) {
853
+
854
+ view.at(coord) = func();
855
+ }
856
+ };
857
+
858
+ /// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution.
859
+ template <
860
+ typename Element, ///< Element type
861
+ typename Layout> ///< Layout function
862
+ struct TensorFillSymmetricRandomUniformFunc {
863
+
864
+ using TensorView = TensorView<Element, Layout>;
865
+
866
+ //
867
+ // Data members
868
+ //
869
+
870
+ TensorView view;
871
+ RandomUniformFunc<Element> func;
872
+ cutlass::FillMode fill_mode;
873
+
874
+ //
875
+ // Methods
876
+ //
877
+
878
+ /// Construction of uniform RNG functor.
879
+ TensorFillSymmetricRandomUniformFunc(
880
+ TensorView view_ = TensorView(),
881
+ RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>(),
882
+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid
883
+ ):
884
+ view(view_), func(func_), fill_mode(fill_mode_) {
885
+
886
+ }
887
+
888
+ /// Compute random value and update RNG state
889
+ void operator()(Coord<Layout::kRank> const &coord) {
890
+ // Fill half of matrix based on FillMode
891
+ if (Layout::kRank == 2 &&
892
+ fill_mode == cutlass::FillMode::kLower &&
893
+ coord[0] >= coord[1]) {
894
+ view.at(coord) = func();
895
+ } else if (Layout::kRank == 2 &&
896
+ fill_mode == cutlass::FillMode::kUpper &&
897
+ coord[0] <= coord[1]) {
898
+ view.at(coord) = func();
899
+ }
900
+ }
901
+ };
902
+
903
+ /// Computes a random Uniform distribution and pads diagonal with zeros
904
+ template <
905
+ typename Element, ///< Element type
906
+ typename Layout> ///< Layout function
907
+ struct TensorFillPadDiagonalRandomUniformFunc {
908
+
909
+ using TensorView = TensorView<Element, Layout>;
910
+
911
+ //
912
+ // Data members
913
+ //
914
+
915
+ TensorView view;
916
+ RandomUniformFunc<Element> func;
917
+ cutlass::FillMode fill_mode;
918
+ int alignment;
919
+
920
+ //
921
+ // Methods
922
+ //
923
+
924
+ /// Construction of uniform RNG functor.
925
+ TensorFillPadDiagonalRandomUniformFunc(
926
+ TensorView view_ = TensorView(),
927
+ RandomUniformFunc<Element> func_ = RandomUniformFunc<Element>(),
928
+ cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid,
929
+ int alignment_ = 1
930
+ ):
931
+ view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) {
932
+
933
+ }
934
+
935
+ /// Compute random value and update RNG state
936
+ void operator()(Coord<Layout::kRank> const &coord) {
937
+ // Fill half of matrix based on FillMode
938
+ if (Layout::kRank == 2 &&
939
+ (fill_mode == cutlass::FillMode::kLower) &&
940
+ (coord[0] >= coord[1]) ||
941
+ ((coord[1] - coord[0]) >= alignment)) {
942
+ view.at(coord) = func();
943
+ } else if (Layout::kRank == 2 &&
944
+ fill_mode == cutlass::FillMode::kUpper &&
945
+ (coord[0] <= coord[1]) ||
946
+ ((coord[0] - coord[1]) >= alignment)) {
947
+ view.at(coord) = func();
948
+ }
949
+ }
950
+ };
951
+
952
+ } // namespace detail
953
+
954
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
955
+
956
+ /// Fills a tensor with random values of a uniform random distribution.
957
+ template <
958
+ typename Element, ///< Element type
959
+ typename Layout> ///< Layout function
960
+ void TensorFillRandomUniform(
961
+ TensorView<Element, Layout> dst, ///< destination tensor
962
+ uint64_t seed, ///< seed for RNG
963
+ double max = 1, ///< upper bound of distribution
964
+ double min = 0, ///< lower bound for distribution
965
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
966
+ /// are not truncated to zero. Permits reducing precision of
967
+ /// data.
968
+ double pnan = 0, ///< Percentage of NaN elements.
969
+ bool exclude_zero = false) { ///< Exclude zero from tensor init
970
+ detail::RandomUniformFunc<Element> random_func(seed, max, min, bits, pnan, exclude_zero);
971
+
972
+ detail::TensorFillRandomUniformFunc<Element, Layout> func(
973
+ dst,
974
+ random_func
975
+ );
976
+
977
+ TensorForEach(
978
+ dst.extent(),
979
+ func
980
+ );
981
+ }
982
+
983
+ /// Fills a tensor with random values of a uniform random distribution.
984
+ template <
985
+ typename Element, ///< Element type
986
+ typename Layout> ///< Layout function
987
+ void TensorFillRandomUniform(
988
+ TensorViewPlanarComplex<Element, Layout> dst, ///< destination tensor
989
+ uint64_t seed, ///< seed for RNG
990
+ double max = 1, ///< upper bound of distribution
991
+ double min = 0, ///< lower bound for distribution
992
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
993
+ /// are not truncated to zero. Permits reducing precision of
994
+ /// data.
995
+ double pnan = 0, ///< Percentage of NaN elements.
996
+ bool exclude_zero = false) { ///< Exclude zero from tensor init
997
+
998
+ TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero);
999
+ TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero);
1000
+ }
1001
+
1002
+
1003
+ /// Fills a tensor with random values with a uniform random distribution.
1004
+ template <
1005
+ typename Element, ///< Element type
1006
+ typename Layout> ///< Layout function
1007
+ void TensorFillRandomUniform(
1008
+ TensorView<Quaternion<Element>, Layout> dst, ///< destination tensor
1009
+ uint64_t seed, ///< seed for RNG
1010
+ double max = 1, ///< upper bound of distribution
1011
+ double min = 0, ///< lower bound for distribution
1012
+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that
1013
+ /// are not truncated to zero. Permits reducing precision of
1014
+ /// data.
1015
+ detail::RandomUniformFunc<Quaternion<Element>> random_func(seed, max, min, bits);
1016
+
1017
+ detail::TensorFillRandomUniformFunc<Quaternion<Element>, Layout> func(
1018
+ dst,
1019
+ random_func
1020
+ );
1021
+
1022
+ TensorForEach(
1023
+ dst.extent(),
1024
+ func
1025
+ );
1026
+ }
1027
+
1028
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1029
+
1030
+ /// Fills a tensor with random values with a uniform random distribution.
1031
+ template <
1032
+ typename Element, ///< Element type
1033
+ typename Layout> ///< Layout function
1034
+ void TensorFillSymmetricRandomUniform(
1035
+ TensorView<Element, Layout> dst, ///< destination tensor
1036
+ uint64_t seed, ///< seed for RNG
1037
+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
1038
+ double max = 1, ///< upper bound of distribution
1039
+ double min = 0, ///< lower bound for distribution
1040
+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that
1041
+ /// are not truncated to zero. Permits reducing precision of
1042
+ /// data.
1043
+
1044
+ detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
1045
+
1046
+ detail::TensorFillSymmetricRandomUniformFunc<Element, Layout> func(
1047
+ dst,
1048
+ random_func,
1049
+ fill_mode
1050
+ );
1051
+
1052
+ TensorForEach(
1053
+ dst.extent(),
1054
+ func
1055
+ );
1056
+ }
1057
+
1058
+ /// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal
1059
+ template <
1060
+ typename Element, ///< Element type
1061
+ typename Layout> ///< Layout function
1062
+ void TensorFillPadDiagonalRandomUniform(
1063
+ TensorView<Element, Layout> dst, ///< destination tensor
1064
+ uint64_t seed, ///< seed for RNG
1065
+ cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices
1066
+ double max = 1, ///< upper bound of distribution
1067
+ double min = 0, ///< lower bound for distribution
1068
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
1069
+ /// are not truncated to zero. Permits reducing precision of
1070
+ /// data.
1071
+ int alignment = 1
1072
+ ) {
1073
+
1074
+ detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
1075
+
1076
+ detail::TensorFillPadDiagonalRandomUniformFunc<Element, Layout> func(
1077
+ dst,
1078
+ random_func,
1079
+ fill_mode,
1080
+ alignment
1081
+ );
1082
+
1083
+ TensorForEach(
1084
+ dst.extent(),
1085
+ func
1086
+ );
1087
+ }
1088
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1089
+
1090
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1091
+
1092
+ /// Fills a tensor with a uniform value
1093
+ template <
1094
+ typename Element ///< Element type
1095
+ >
1096
+ void BlockFill(
1097
+ Element *ptr,
1098
+ size_t capacity,
1099
+ Element val
1100
+ ) {
1101
+ for (size_t i = 0; i < capacity; ++i) {
1102
+ ReferenceFactory<Element>::get(ptr, i) = val;
1103
+ }
1104
+ }
1105
+
1106
+ /// Fills a tensor with random values with a uniform random distribution.
1107
+ template <
1108
+ typename Element ///< Element type
1109
+ >
1110
+ void BlockFillRandomUniform(
1111
+ Element *ptr,
1112
+ size_t capacity,
1113
+ uint64_t seed, ///< seed for RNG
1114
+ double max = 1, ///< upper bound of distribution
1115
+ double min = 0, ///< lower bound for distribution
1116
+ int bits = -1, ///< If non-negative, specifies number of fractional bits that
1117
+ /// are not truncated to zero. Permits reducing precision of
1118
+ /// data.
1119
+ double pnan = 0) { ///< Percentage of NaN elements.
1120
+ detail::RandomUniformFunc<Element> random_func(seed, max, min, bits, pnan);
1121
+
1122
+ for (size_t i = 0; i < capacity; ++i) {
1123
+ ReferenceFactory<Element>::get(ptr, i) = random_func();
1124
+ }
1125
+ }
1126
+
1127
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1128
+
1129
+ namespace detail {
1130
+
1131
+ template <
1132
+ typename Element, ///< Element type
1133
+ typename Layout> ///< Layout function
1134
+ struct TensorFillDiagonalFunc {
1135
+
1136
+ using TensorView = TensorView<Element, Layout>;
1137
+
1138
+ //
1139
+ // Data members
1140
+ //
1141
+
1142
+ TensorView view;
1143
+ Element diag;
1144
+ Element other;
1145
+
1146
+ //
1147
+ // Methods
1148
+ //
1149
+
1150
+ TensorFillDiagonalFunc(
1151
+ TensorView const &view_ = TensorView(),
1152
+ Element diag_ = Element(1),
1153
+ Element other_ = Element(0)
1154
+ ):
1155
+ view(view_), diag(diag_), other(other_) { }
1156
+
1157
+ void operator()(Coord<Layout::kRank> const & coord) const {
1158
+ bool is_diag = true;
1159
+
1160
+ CUTLASS_PRAGMA_UNROLL
1161
+ for (int i = 1; i < Layout::kRank; ++i) {
1162
+ if (coord[i] != coord[i - 1]) {
1163
+ is_diag = false;
1164
+ break;
1165
+ }
1166
+ }
1167
+
1168
+ view.at(coord) = (is_diag ? diag : other);
1169
+ }
1170
+ };
1171
+
1172
+ } // namespace detail
1173
+
1174
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1175
+
1176
+ /// Fills a tensor everywhere with a unique value for its diagonal.
1177
+ template <
1178
+ typename Element, ///< Element type
1179
+ typename Layout> ///< Layout function
1180
+ void TensorFillDiagonal(
1181
+ TensorView<Element, Layout> dst, ///< destination tensor
1182
+ Element diag = Element(1), ///< value to write in the diagonal
1183
+ Element other = Element(0)) { ///< value to write off the diagonal
1184
+
1185
+ detail::TensorFillDiagonalFunc<Element, Layout> func(
1186
+ dst,
1187
+ diag,
1188
+ other
1189
+ );
1190
+
1191
+ TensorForEach(
1192
+ dst.extent(),
1193
+ func
1194
+ );
1195
+ }
1196
+
1197
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1198
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1199
+
1200
+ /// Helper to fill a tensor's diagonal with 1 and 0 everywhere else.
1201
+ template <
1202
+ typename Element, ///< Element type
1203
+ typename Layout> ///< Layout function
1204
+ void TensorFillIdentity(
1205
+ TensorView<Element, Layout> dst) { ///< destination tensor
1206
+
1207
+ TensorFillDiagonal(dst, Element(1), Element(0));
1208
+ }
1209
+
1210
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1211
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1212
+
1213
+ /// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements.
1214
+ template <
1215
+ typename Element, ///< Element type
1216
+ typename Layout> ///< Layout function
1217
+ void TensorUpdateDiagonal(
1218
+ TensorView<Element, Layout> dst, ///< destination tensor
1219
+ Element val = Element(1)) {
1220
+
1221
+ typename Layout::Index extent = dst.extent().min();
1222
+
1223
+ for (typename Layout::Index i = 0; i < extent; ++i) {
1224
+ Coord<Layout::kRank> coord(i);
1225
+ dst.at(coord) = val;
1226
+ }
1227
+ }
1228
+
1229
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1230
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1231
+
1232
+ namespace detail {
1233
+
1234
+ template <
1235
+ typename Element, ///< Element type
1236
+ typename Layout> ///< Layout function
1237
+ struct TensorUpdateOffDiagonalFunc {
1238
+
1239
+ using TensorView = TensorView<Element, Layout>;
1240
+
1241
+ //
1242
+ // Data members
1243
+ //
1244
+
1245
+ TensorView view;
1246
+ Element other;
1247
+
1248
+ //
1249
+ // Methods
1250
+ //
1251
+
1252
+ TensorUpdateOffDiagonalFunc(
1253
+ TensorView const &view_ = TensorView(),
1254
+ Element other_ = Element(0)
1255
+ ):
1256
+ view(view_), other(other_) { }
1257
+
1258
+ void operator()(Coord<Layout::kRank> const & coord) const {
1259
+ bool is_diag = true;
1260
+
1261
+ CUTLASS_PRAGMA_UNROLL
1262
+ for (int i = 1; i < Layout::kRank; ++i) {
1263
+ if (coord[i] != coord[i - 1]) {
1264
+ is_diag = false;
1265
+ break;
1266
+ }
1267
+ }
1268
+
1269
+ if (!is_diag) {
1270
+ view.at(coord) = other;
1271
+ }
1272
+ }
1273
+ };
1274
+
1275
+ } // namespace detail
1276
+
1277
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1278
+
1279
+ /// Writes a uniform value to all elements in the tensor without modifying diagonal elements.
1280
+ template <
1281
+ typename Element, ///< Element type
1282
+ typename Layout> ///< Layout function
1283
+ void TensorUpdateOffDiagonal(
1284
+ TensorView<Element, Layout> dst, ///< destination tensor
1285
+ Element other = Element(1)) {
1286
+
1287
+ detail::TensorUpdateOffDiagonalFunc<Element, Layout> func(
1288
+ dst,
1289
+ other
1290
+ );
1291
+
1292
+ TensorForEach(
1293
+ dst.extent(),
1294
+ func
1295
+ );
1296
+ }
1297
+
1298
+
1299
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1300
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1301
+
1302
+ namespace detail {
1303
+
1304
+ template <
1305
+ typename Element, ///< Element type
1306
+ typename Layout> ///< Layout function
1307
+ struct TensorFillLinearFunc {
1308
+
1309
+ using TensorView = TensorView<Element, Layout>;
1310
+
1311
+ //
1312
+ // Data members
1313
+ //
1314
+
1315
+ TensorView view;
1316
+ Array<Element, Layout::kRank> v;
1317
+ Element s;
1318
+
1319
+ //
1320
+ // Methods
1321
+ //
1322
+
1323
+ TensorFillLinearFunc() { }
1324
+
1325
+ /// Constructs functor
1326
+ TensorFillLinearFunc(
1327
+ TensorView const &view_,
1328
+ Array<Element, Layout::kRank> const & v_,
1329
+ Element s_ = Element(0)
1330
+ ):
1331
+ view(view_), v(v_), s(s_) { }
1332
+
1333
+ /// Updates the tensor
1334
+ void operator()(Coord<Layout::kRank> const & coord) const {
1335
+
1336
+ Element sum(s);
1337
+
1338
+ CUTLASS_PRAGMA_UNROLL
1339
+ for (int i = 0; i < Layout::kRank; ++i) {
1340
+ sum += Element(coord[i]) * v[i];
1341
+ }
1342
+
1343
+ view.at(coord) = sum;
1344
+ }
1345
+ };
1346
+
1347
+ } // namespace detail
1348
+
1349
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1350
+
1351
+ /// Fills tensor with a linear combination of its coordinate and another vector
1352
+ template <
1353
+ typename Element, ///< Element type
1354
+ typename Layout> ///< Layout function
1355
+ void TensorFillLinear(
1356
+ TensorView<Element, Layout> dst, ///< destination tensor
1357
+ Array<Element, Layout::kRank> const & v,
1358
+ Element s = Element(0)) {
1359
+
1360
+ detail::TensorFillLinearFunc<Element, Layout> func(
1361
+ dst,
1362
+ v,
1363
+ s
1364
+ );
1365
+
1366
+ TensorForEach(
1367
+ dst.extent(),
1368
+ func
1369
+ );
1370
+ }
1371
+
1372
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1373
+
1374
+ /// Fills tensor with a linear combination of its coordinate and another vector
1375
+ template <
1376
+ typename Element, ///< Element type
1377
+ typename Layout> ///< Layout function
1378
+ void TensorFillSequential(
1379
+ TensorView<Element, Layout> dst, ///< destination tensor
1380
+ Element s = Element(0)) {
1381
+
1382
+ Array<Element, Layout::kRank> stride;
1383
+
1384
+ stride[0] = Element(1);
1385
+
1386
+ CUTLASS_PRAGMA_UNROLL
1387
+ for (int i = 1; i < Layout::kRank; ++i) {
1388
+ stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]);
1389
+ }
1390
+
1391
+ TensorFillLinear(dst, stride, s);
1392
+ }
1393
+
1394
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1395
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1396
+
1397
+ /// Fills a tensor with random values from a distribution.
1398
+ template <
1399
+ typename Element, ///< Element type
1400
+ typename Layout> ///< Layout function
1401
+ void TensorFillRandom(
1402
+ TensorView<Element, Layout> view, ///< destination tensor
1403
+ uint64_t seed,
1404
+ Distribution dist,
1405
+ bool exclude_zero = false ///< If true, excludes 0.
1406
+ /// Note that setting this flag will result in more 1's,
1407
+ /// as we use a simple mechanism to replace 0's by adding/subtracting 1's.
1408
+ ) {
1409
+
1410
+ using Real = typename RealType<Element>::Type;
1411
+
1412
+ if (dist.kind == Distribution::Gaussian) {
1413
+ TensorFillRandomGaussian(
1414
+ view,
1415
+ seed,
1416
+ dist.gaussian.mean,
1417
+ dist.gaussian.stddev,
1418
+ dist.int_scale,
1419
+ dist.gaussian.pnz,
1420
+ exclude_zero);
1421
+ } else if (dist.kind == Distribution::Uniform) {
1422
+ TensorFillRandomUniform(
1423
+ view,
1424
+ seed,
1425
+ dist.uniform.max,
1426
+ dist.uniform.min,
1427
+ dist.int_scale,
1428
+ dist.uniform.pnan,
1429
+ exclude_zero);
1430
+ }
1431
+ }
1432
+
1433
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1434
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1435
+
1436
+ /// Fills a block of data with sequential elements
1437
+ template <
1438
+ typename Element
1439
+ >
1440
+ void BlockFillSequential(
1441
+ Element *ptr,
1442
+ int64_t capacity,
1443
+ Element v = Element(1),
1444
+ Element s = Element(0)) {
1445
+ int i = 0;
1446
+
1447
+ while (i < capacity) {
1448
+ cutlass::ReferenceFactory<Element, (cutlass::sizeof_bits<Element>::value <
1449
+ 8)>::get(ptr, i) = s;
1450
+
1451
+ s = Element(s + v);
1452
+ ++i;
1453
+ }
1454
+ }
1455
+
1456
+ /// Fills a block of data with sequential elements
1457
+ template <
1458
+ typename Element
1459
+ >
1460
+ void BlockFillSequentialModN(
1461
+ Element *ptr,
1462
+ int64_t capacity,
1463
+ int64_t mod,
1464
+ int64_t v = int64_t(1),
1465
+ int64_t s = int64_t(0)) {
1466
+ int i = 0;
1467
+
1468
+ while (i < capacity) {
1469
+ cutlass::ReferenceFactory<Element, (cutlass::sizeof_bits<Element>::value <
1470
+ 8)>::get(ptr, i) = Element(s);
1471
+
1472
+ s = int64_t(s + v) % mod;
1473
+ ++i;
1474
+ }
1475
+ }
1476
+
1477
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1478
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1479
+
1480
+ /// Fills a block of data with sequential elements
1481
+ template <
1482
+ typename Element
1483
+ >
1484
+ void BlockFillRandom(
1485
+ Element *ptr,
1486
+ size_t capacity,
1487
+ uint64_t seed,
1488
+ Distribution dist) {
1489
+
1490
+ if (dist.kind == Distribution::Gaussian) {
1491
+ BlockFillRandomGaussian<Element>(
1492
+ ptr,
1493
+ capacity,
1494
+ seed,
1495
+ dist.gaussian.mean,
1496
+ dist.gaussian.stddev,
1497
+ dist.int_scale,
1498
+ dist.gaussian.pnz);
1499
+ }
1500
+ else if (dist.kind == Distribution::Uniform) {
1501
+ BlockFillRandomUniform<Element>(
1502
+ ptr,
1503
+ capacity,
1504
+ seed,
1505
+ dist.uniform.max,
1506
+ dist.uniform.min,
1507
+ dist.int_scale,
1508
+ dist.uniform.pnan);
1509
+ }
1510
+ }
1511
+
1512
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1513
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1514
+
1515
+ namespace detail {
1516
+
1517
+ template <typename Element>
1518
+ struct RandomSparseMetaFunc {
1519
+
1520
+ uint64_t seed;
1521
+ int range;
1522
+ int MetaSizeInBits;
1523
+
1524
+ //
1525
+ // Methods
1526
+ //
1527
+
1528
+ RandomSparseMetaFunc(
1529
+ uint64_t seed_ = 0,
1530
+ int MetaSizeInBits_ = 2
1531
+ ):
1532
+ seed(seed_), MetaSizeInBits(MetaSizeInBits_) {
1533
+ std::srand((unsigned)seed);
1534
+ if (MetaSizeInBits_ == 2) {
1535
+ range = 6;
1536
+ }
1537
+ else if (MetaSizeInBits_ == 4) {
1538
+ range = 2;
1539
+ }
1540
+ else {
1541
+ throw std::invalid_argument("Invalid MetaSizeInBits");
1542
+ }
1543
+ }
1544
+
1545
+ /// Compute random value and update RNG state
1546
+ Element operator()() const {
1547
+ Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe};
1548
+ Element TwoToOneMeta[2] = {0x4, 0xe};
1549
+
1550
+ Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta;
1551
+
1552
+ Element result = 0x0;
1553
+
1554
+ for (int i = 0; i < cutlass::sizeof_bits<Element>::value / 4; ++i) {
1555
+ int rnd = std::rand() % range;
1556
+ Element meta = MetaArray[rnd];
1557
+
1558
+ result = (Element)(result | ((Element)(meta << (i * 4))));
1559
+ }
1560
+
1561
+ return result;
1562
+ }
1563
+ };
1564
+
1565
+ /// Computes a random sparse meta
1566
+ template <
1567
+ typename Element, ///< Element type
1568
+ typename Layout> ///< Layout function
1569
+ struct TensorFillRandomSparseMetaFunc {
1570
+
1571
+ using TensorView = TensorView<Element, Layout>;
1572
+
1573
+ //
1574
+ // Data members
1575
+ //
1576
+
1577
+ TensorView view;
1578
+ RandomSparseMetaFunc<Element> func;
1579
+
1580
+ //
1581
+ // Methods
1582
+ //
1583
+
1584
+ /// Construction of Gaussian RNG functor.
1585
+ TensorFillRandomSparseMetaFunc(
1586
+ TensorView view_ = TensorView(),
1587
+ RandomSparseMetaFunc<Element> func_ = RandomSparseMetaFunc<Element>()
1588
+ ):
1589
+ view(view_), func(func_) {
1590
+
1591
+ }
1592
+
1593
+ /// Compute random value and update RNG state
1594
+ void operator()(Coord<Layout::kRank> const &coord) const {
1595
+
1596
+ view.at(coord) = func();
1597
+ }
1598
+ };
1599
+
1600
+ } // namespace detail
1601
+
1602
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1603
+
1604
+ /// Fills a tensor with random values with a uniform random distribution.
1605
+ template <
1606
+ typename Element, ///< Element type
1607
+ typename Layout> ///< Layout function
1608
+ void TensorFillRandomSparseMeta(
1609
+ TensorView<Element, Layout> dst, ///< destination tensor
1610
+ uint64_t seed, ///< seed for RNG
1611
+ int MetaSizeInBits) { ///< 2 bit or 4 bit
1612
+
1613
+ detail::RandomSparseMetaFunc<Element> random_func(seed, MetaSizeInBits);
1614
+
1615
+ detail::TensorFillRandomSparseMetaFunc<Element, Layout> func(
1616
+ dst,
1617
+ random_func
1618
+ );
1619
+
1620
+ TensorForEach(
1621
+ dst.extent(),
1622
+ func
1623
+ );
1624
+ }
1625
+
1626
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1627
+
1628
+ /// Fills a tensor with random values with a uniform random distribution.
1629
+ template <
1630
+ typename Element ///< Element type
1631
+ >
1632
+ void BlockFillRandomSparseMeta(
1633
+ Element *ptr,
1634
+ size_t capacity,
1635
+ uint64_t seed, ///< seed for RNG
1636
+ int MetaSizeInBits) { ///< 2 bit or 4bit
1637
+
1638
+ detail::RandomSparseMetaFunc<Element> random_func(seed, MetaSizeInBits);
1639
+
1640
+ for (size_t i = 0; i < capacity; ++i) {
1641
+ ptr[i] = random_func();
1642
+ }
1643
+ }
1644
+
1645
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1646
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1647
+
1648
+ /// Fills a ell block index matrix with random values with a uniform random distribution.
1649
+ template <
1650
+ typename Element, ///< Element type
1651
+ typename Layout> ///< Layout function
1652
+ void TensorFillRandomEllIdx(
1653
+ TensorView<Element, Layout> dst, ///< destination tensor
1654
+ uint64_t seed, ///< seed for RNG
1655
+ int rows, int ell_cols, int cols) { ///< dimension of the matrix
1656
+
1657
+ std::srand((unsigned)seed);
1658
+
1659
+ for (int i = 0; i < rows; ++i) {
1660
+ int col_idx = std::rand() % cols;
1661
+
1662
+ for (int j = 0; j < ell_cols; ++j) {
1663
+ dst.at({i, j}) = col_idx;
1664
+
1665
+ if (col_idx != -1) {
1666
+ if (col_idx == (cols - 1)) {
1667
+ col_idx = -1;
1668
+ } else {
1669
+ col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1;
1670
+ }
1671
+ }
1672
+ }
1673
+ }
1674
+ }
1675
+
1676
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1677
+
1678
+ /// Copies a diagonal in from host memory without modifying off-diagonal elements.
1679
+ template <
1680
+ typename Element, ///< Element type
1681
+ typename Layout> ///< Layout function
1682
+ void TensorCopyDiagonalIn(
1683
+ TensorView<Element, Layout> dst, ///< destination tensor
1684
+ Element const *ptr) { ///< dense buffer of elements
1685
+
1686
+ typename Layout::Index extent = dst.extent().min();
1687
+
1688
+ for (typename Layout::Index i = 0; i < extent; ++i) {
1689
+ Coord<Layout::kRank> coord(i);
1690
+ dst.at(coord) = ReferenceFactory<Element>::get(ptr, i);
1691
+ }
1692
+ }
1693
+
1694
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1695
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1696
+
1697
+ /// Copies the diagonal of a tensor into a dense buffer in host memory.
1698
+ template <
1699
+ typename Element, ///< Element type
1700
+ typename Layout> ///< Layout function
1701
+ void TensorCopyDiagonalOut(
1702
+ Element *ptr, ///< dense buffer of elements
1703
+ TensorView<Element, Layout> src) { ///< source tensor
1704
+
1705
+ typename Layout::Index extent = src.extent().min();
1706
+
1707
+ for (typename Layout::Index i = 0; i < extent; ++i) {
1708
+ Coord<Layout::kRank> coord(i);
1709
+ ReferenceFactory<Element>::get(ptr, i) = src.at(coord);
1710
+ }
1711
+ }
1712
+
1713
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1714
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1715
+
1716
+ } // namespace host
1717
+ } // namespace reference
1718
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Provides several functions for filling tensors with data.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // Standard Library includes
38
+ #include <utility>
39
+ #include <cstdlib>
40
+ #include <cmath>
41
+
42
+ // Cute includes
43
+ #include "cute/tensor.hpp"
44
+
45
+ // Cutlass includes
46
+ #include "cutlass/cutlass.h"
47
+ #include "cutlass/complex.h"
48
+ #include "cutlass/quaternion.h"
49
+ #include "cutlass/array.h"
50
+ #include "cutlass/numeric_types.h"
51
+
52
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace cutlass {
55
+ namespace reference {
56
+ namespace host {
57
+
58
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
59
+ //
60
+ // Uniform and procedural tensor fills
61
+ //
62
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
63
+
64
+ /// Fills a tensor with a scalar element
65
+ template <typename Tensor>
66
+ void TensorFill(Tensor dst, typename Tensor::value_type element) {
67
+
68
+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
69
+ dst(idx) = element;
70
+ }
71
+ }
72
+
73
+ /// Fills a tensor with the contents of its layout
74
+ template <typename Tensor>
75
+ void TensorFillSequential(Tensor dst) {
76
+
77
+ auto layout = dst.layout();
78
+
79
+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
80
+ dst(idx) = layout(idx);
81
+ }
82
+ }
83
+
84
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
85
+ //
86
+ // Random uniform values
87
+ //
88
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
89
+
90
+ namespace detail {
91
+
92
+ template <typename Element>
93
+ struct RandomUniformFunc {
94
+
95
+ using Real = typename RealType<Element>::Type;
96
+
97
+ uint64_t seed;
98
+ double range;
99
+ double min;
100
+ int int_scale;
101
+
102
+ //
103
+ // Methods
104
+ //
105
+
106
+ RandomUniformFunc(
107
+ uint64_t seed_ = 0,
108
+ double max = 1,
109
+ double min_ = 0,
110
+ int int_scale_ = -1
111
+ ):
112
+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
113
+ std::srand((unsigned)seed);
114
+ }
115
+
116
+
117
+ /// Compute random value and update RNG state
118
+ Element operator()() const {
119
+
120
+ double rnd = double(std::rand()) / double(RAND_MAX);
121
+
122
+ rnd = min + range * rnd;
123
+
124
+ // Random values are cast to integer after scaling by a power of two to facilitate error
125
+ // testing
126
+ Element result;
127
+
128
+ if (int_scale >= 0) {
129
+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
130
+ result = static_cast<Element>(Real(rnd));
131
+ }
132
+ else {
133
+ result = static_cast<Element>(Real(rnd));
134
+ }
135
+
136
+ return result;
137
+ }
138
+ };
139
+
140
+ /// Partial specialization for initializing a complex value.
141
+ template <typename Element>
142
+ struct RandomUniformFunc<complex<Element> > {
143
+
144
+ using Real = typename RealType<Element>::Type;
145
+
146
+ uint64_t seed;
147
+ double range;
148
+ double min;
149
+ int int_scale;
150
+
151
+ //
152
+ // Methods
153
+ //
154
+
155
+ RandomUniformFunc(
156
+ uint64_t seed_ = 0,
157
+ double max = 1,
158
+ double min_ = 0,
159
+ int int_scale_ = -1
160
+ ):
161
+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
162
+ std::srand((unsigned)seed);
163
+ }
164
+
165
+
166
+ /// Compute random value and update RNG state
167
+ complex<Element> operator()() const {
168
+
169
+ Element reals[2];
170
+
171
+ for (int i = 0; i < 2; ++i) {
172
+ double rnd = double(std::rand()) / double(RAND_MAX);
173
+
174
+ rnd = min + range * rnd;
175
+
176
+ // Random values are cast to integer after scaling by a power of two to facilitate error
177
+ // testing
178
+
179
+ if (int_scale >= 0) {
180
+ rnd = double(int(rnd * double(1 << int_scale)));
181
+ reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
182
+ }
183
+ else {
184
+ reals[i] = from_real<Element>(Real(rnd));
185
+ }
186
+ }
187
+
188
+ return complex<Element>(reals[0], reals[1]);
189
+ }
190
+ };
191
+
192
+ /// Partial specialization for initializing a Quaternion value.
193
+ template <typename Element>
194
+ struct RandomUniformFunc<Quaternion<Element> > {
195
+
196
+ using Real = typename RealType<Element>::Type;
197
+
198
+ uint64_t seed;
199
+ double range;
200
+ double min;
201
+ int int_scale;
202
+
203
+ //
204
+ // Methods
205
+ //
206
+
207
+ RandomUniformFunc(
208
+ uint64_t seed_ = 0,
209
+ double max = 1,
210
+ double min_ = 0,
211
+ int int_scale_ = -1
212
+ ):
213
+ seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) {
214
+ std::srand((unsigned)seed);
215
+ }
216
+
217
+
218
+ /// Compute random value and update RNG state
219
+ Quaternion<Element> operator()() const {
220
+
221
+ Element reals[4];
222
+
223
+ for (int i = 0; i < 4; ++i) {
224
+ double rnd = double(std::rand()) / double(RAND_MAX);
225
+
226
+ rnd = min + range * rnd;
227
+
228
+ // Random values are cast to integer after scaling by a power of two to facilitate error
229
+ // testing
230
+
231
+ if (int_scale >= 0) {
232
+ rnd = double(int(rnd * double(1 << int_scale)));
233
+ reals[i] = from_real<Element>(Real(rnd / double(1 << int_scale)));
234
+ }
235
+ else {
236
+ reals[i] = from_real<Element>(Real(rnd));
237
+ }
238
+ }
239
+
240
+ return make_Quaternion(reals[0], reals[1], reals[2], reals[3]);
241
+ }
242
+ };
243
+
244
+ } // namespace detail
245
+
246
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
247
+
248
+ /// Fills a tensor with random values with a uniform random distribution.
249
+ template <typename Tensor> ///< Tensor object
250
+ void TensorFillRandomUniform(
251
+ Tensor dst, ///< destination tensor
252
+ uint64_t seed, ///< seed for RNG
253
+ double max = 1, ///< upper bound of distribution
254
+ double min = 0, ///< lower bound for distribution
255
+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that
256
+ /// are not truncated to zero. Permits reducing precision of
257
+ /// data.
258
+
259
+ detail::RandomUniformFunc<typename Tensor::value_type> random_func(seed, max, min, bits);
260
+
261
+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
262
+ dst(idx) = random_func();
263
+ }
264
+ }
265
+
266
+ /// Fills a block with random values with a uniform random distribution.
267
+ template <
268
+ typename Element ///< Element type
269
+ >
270
+ void BlockFillRandomUniform(
271
+ Element *ptr,
272
+ size_t capacity,
273
+ uint64_t seed, ///< seed for RNG
274
+ double max = 1, ///< upper bound of distribution
275
+ double min = 0, ///< lower bound for distribution
276
+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that
277
+ /// are not truncated to zero. Permits reducing precision of
278
+ /// data.
279
+ detail::RandomUniformFunc<Element> random_func(seed, max, min, bits);
280
+
281
+ for (size_t i = 0; i < capacity; ++i) {
282
+ ptr[i] = random_func();
283
+ }
284
+ }
285
+
286
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
287
+ //
288
+ // Random Gaussian
289
+ //
290
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
291
+
292
+ namespace detail {
293
+
294
+ template <typename Element>
295
+ struct RandomGaussianFunc {
296
+
297
+ uint64_t seed;
298
+ double mean;
299
+ double stddev;
300
+ int int_scale;
301
+ double pi;
302
+
303
+ //
304
+ // Methods
305
+ //
306
+ RandomGaussianFunc(
307
+ uint64_t seed_ = 0,
308
+ double mean_ = 0,
309
+ double stddev_ = 1,
310
+ int int_scale_ = -1
311
+ ):
312
+ seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) {
313
+ std::srand((unsigned)seed);
314
+ }
315
+
316
+ /// Compute random value and update RNG state
317
+ Element operator()() const {
318
+
319
+ // Box-Muller transform to generate random numbers with Normal distribution
320
+ double u1 = double(std::rand()) / double(RAND_MAX);
321
+ double u2 = double(std::rand()) / double(RAND_MAX);
322
+
323
+ // Compute Gaussian random value
324
+ double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2);
325
+ rnd = mean + stddev * rnd;
326
+
327
+ // Scale and convert final result
328
+ Element result;
329
+
330
+ if (int_scale >= 0) {
331
+ rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale);
332
+ result = static_cast<Element>(rnd);
333
+ }
334
+ else {
335
+ result = static_cast<Element>(rnd);
336
+ }
337
+
338
+ return result;
339
+ }
340
+ };
341
+
342
+ } // namespace detail
343
+
344
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
345
+
346
+ /// Fills a tensor with random values with a Gaussian distribution.
347
+ template <
348
+ typename Tensor
349
+ >
350
+ void TensorFillRandomGaussian(
351
+ Tensor dst, ///< destination tensor
352
+ uint64_t seed, ///< seed for RNG
353
+ double mean = 0, ///< Gaussian distribution's mean
354
+ double stddev = 1, ///< Gaussian distribution's standard deviation
355
+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that
356
+ /// are not truncated to zero. Permits reducing precision of
357
+ /// data.
358
+
359
+ detail::RandomGaussianFunc<typename Tensor::value_type> random_func(seed, mean, stddev, bits);
360
+
361
+ for (int64_t idx = 0; idx < cute::size(dst); ++idx) {
362
+ dst(idx) = random_func();
363
+ }
364
+ }
365
+
366
+ /// Fills a block with random values with a Gaussian distribution.
367
+ template <
368
+ typename Element ///< Element type
369
+ >
370
+ void BlockFillRandomGaussian(
371
+ Element *ptr, ///< destination buffer
372
+ size_t capacity, ///< number of elements
373
+ uint64_t seed, ///< seed for RNG
374
+ double mean = 0, ///< Gaussian distribution's mean
375
+ double stddev = 1, ///< Gaussian distribution's standard deviation
376
+ int bits = -1) { ///< If non-negative, specifies number of fractional bits that
377
+ /// are not truncated to zero. Permits reducing precision of
378
+ /// data.
379
+
380
+ detail::RandomGaussianFunc<Element> random_func(seed, mean, stddev, bits);
381
+
382
+ for (size_t i = 0; i < capacity; ++i) {
383
+ ptr[i] = random_func();
384
+ }
385
+ }
386
+
387
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
388
+
389
+ /// Fills a block of data with sequential elements
390
+ template <
391
+ typename Element
392
+ >
393
+ void BlockFillSequential(
394
+ Element *ptr,
395
+ int64_t capacity,
396
+ Element v = Element(1),
397
+ Element s = Element(0)) {
398
+ int i = 0;
399
+
400
+ while (i < capacity) {
401
+
402
+ ptr[i] = Element(s + v);
403
+ ++i;
404
+ }
405
+ }
406
+
407
+ /// Fills a block of data with sequential elements
408
+ template <
409
+ typename Element
410
+ >
411
+ void BlockFillSequentialModN(
412
+ Element *ptr,
413
+ int64_t capacity,
414
+ int64_t mod,
415
+ int64_t v = int64_t(1),
416
+ int64_t s = int64_t(0)) {
417
+ int i = 0;
418
+
419
+ while (i < capacity) {
420
+
421
+ ptr[i] = static_cast<Element>(int32_t(int64_t(s + v) % mod));
422
+ ++i;
423
+ }
424
+ }
425
+
426
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
427
+
428
+ } // namespace host
429
+ } // namespace reference
430
+ } // namespace cutlass
431
+
432
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include <stdexcept>
34
+ #include "cutlass/cutlass.h"
35
+
36
+ namespace cutlass {
37
+ namespace reference {
38
+ namespace host {
39
+
40
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ /// Defines several helpers
43
+ namespace detail {
44
+
45
+ /// Helper to perform for-each operation
46
+ template <typename Func, int Rank, int RankRemaining>
47
+ struct TensorForEachHelper {
48
+
49
+ /// Index of the active rank
50
+ static int const kActiveRank = Rank - RankRemaining - 1;
51
+
52
+ /// Constructor for general rank
53
+ TensorForEachHelper(
54
+ Func &func,
55
+ Coord<Rank> const &extent,
56
+ Coord<Rank> &coord) {
57
+
58
+ for (int i = 0; i < extent.at(kActiveRank); ++i) {
59
+ coord[kActiveRank] = i;
60
+ TensorForEachHelper<Func, Rank, RankRemaining - 1>(func, extent, coord);
61
+ }
62
+ }
63
+ };
64
+
65
+ /// Helper to perform for-each operation
66
+ template <typename Func, int Rank>
67
+ struct TensorForEachHelper<Func, Rank, 0> {
68
+
69
+ /// Index of the active rank
70
+ static int const kActiveRank = Rank - 1;
71
+
72
+ /// Constructor for fastest changing rank
73
+ TensorForEachHelper(
74
+ Func &func,
75
+ Coord<Rank> const &extent,
76
+ Coord<Rank> &coord) {
77
+
78
+ for (int i = 0; i < extent.at(kActiveRank); ++i) {
79
+ coord[kActiveRank] = i;
80
+ func(coord);
81
+ }
82
+ }
83
+ };
84
+
85
+ } // namespace detail
86
+
87
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
88
+
89
+ /// Iterates over the index space of a tensor
90
+ template <
91
+ typename Func, ///< function applied to each point in a tensor's index space
92
+ int Rank> ///< rank of index space
93
+ void TensorForEach(Coord<Rank> extent, Func & func) {
94
+ Coord<Rank> coord;
95
+ detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, extent, coord);
96
+ }
97
+
98
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
99
+
100
+ /// Iterates over the index space of a tensor and calls a C++ lambda
101
+ template <
102
+ typename Func, ///< function applied to each point in a tensor's index space
103
+ int Rank> ///< rank of index space
104
+ void TensorForEachLambda(Coord<Rank> extent, Func func) {
105
+ Coord<Rank> coord;
106
+ detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, extent, coord);
107
+ }
108
+
109
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
110
+
111
+ template <typename Element, typename Func>
112
+ struct BlockForEach {
113
+
114
+ /// Constructor performs the operation.
115
+ BlockForEach(
116
+ Element *ptr,
117
+ size_t capacity,
118
+ typename Func::Params params = typename Func::Params()) {
119
+
120
+ Func func(params);
121
+
122
+ for (size_t index = 0; index < capacity; ++index) {
123
+ ptr[index] = func();
124
+ }
125
+ }
126
+ };
127
+
128
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
129
+
130
+ } // namespace host
131
+ } // namespace reference
132
+ } // namespace cutlass
133
+
134
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+
34
+ #include "cutlass/cutlass.h"
35
+
36
+ // The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions.
37
+
38
+ #include "cutlass/util/reference/host/tensor_reduce.h"
39
+
40
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include <cmath>
34
+
35
+ #include "cutlass/cutlass.h"
36
+ #include "cutlass/complex.h"
37
+ #include "cutlass/tensor_ref.h"
38
+
39
+ #include "cutlass/util/reference/detail/linear_to_coordinate.h"
40
+ #include "cutlass/core_io.h"
41
+
42
+ namespace cutlass {
43
+ namespace reference {
44
+ namespace host {
45
+
46
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
49
+ /// workspace
50
+ template <
51
+ typename Element,
52
+ typename Layout,
53
+ typename ComputeType,
54
+ typename ReduceOp,
55
+ typename TransformOp
56
+ >
57
+ ComputeType TensorTransformReduce(
58
+ TensorView<Element, Layout> view,
59
+ ComputeType identity,
60
+ ReduceOp reduce,
61
+ TransformOp transform
62
+ ) {
63
+
64
+ for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) {
65
+ typename Layout::TensorCoord coord;
66
+ cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view.extent());
67
+
68
+ if (view.contains(coord)) {
69
+ Element x = view.at(coord);
70
+ identity = reduce(identity, transform(x));
71
+ }
72
+ }
73
+
74
+ return identity;
75
+ }
76
+
77
+ /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
78
+ /// workspace
79
+ template <
80
+ typename Element,
81
+ typename Layout,
82
+ typename ComputeType,
83
+ typename ReduceOp,
84
+ typename TransformOp
85
+ >
86
+ ComputeType TensorTransformReduce(
87
+ TensorView<Element, Layout> view_A,
88
+ TensorView<Element, Layout> view_B,
89
+ ComputeType identity,
90
+ ReduceOp reduce,
91
+ TransformOp transform) {
92
+
93
+ if (view_A.extent() != view_B.extent()) {
94
+ throw std::runtime_error("Tensor extents must match.");
95
+ }
96
+
97
+ for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) {
98
+
99
+ typename Layout::TensorCoord coord;
100
+ cutlass::reference::detail::LinearToCoordinate<Layout::kRank>()(coord, idx, view_A.extent());
101
+
102
+ if (view_A.contains(coord)) {
103
+ Element a = view_A.at(coord);
104
+ Element b = view_B.at(coord);
105
+ identity = reduce(identity, transform(a, b));
106
+ }
107
+ }
108
+
109
+ return identity;
110
+ }
111
+
112
+ /// Helper to compute the sum of the elements of a tensor
113
+ template <
114
+ typename Element,
115
+ typename Layout,
116
+ typename ComputeType = Element
117
+ >
118
+ ComputeType TensorSum(
119
+ TensorView<Element, Layout> view,
120
+ ComputeType identity = ComputeType()
121
+ ) {
122
+
123
+ plus<ComputeType> reduce;
124
+ NumericConverter<ComputeType, Element> transform;
125
+
126
+ return TensorTransformReduce(
127
+ view, identity, reduce, transform);
128
+ }
129
+
130
+ /// Helper to compute the sum of the squares of the elements of a tensor
131
+ template <
132
+ typename Element,
133
+ typename Layout,
134
+ typename ComputeType = Element
135
+ >
136
+ ComputeType TensorSumSq(
137
+ TensorView<Element, Layout> view,
138
+ ComputeType identity = ComputeType()
139
+ ) {
140
+
141
+ plus<ComputeType> reduce;
142
+ magnitude_squared<Element, ComputeType> transform;
143
+
144
+ return TensorTransformReduce(
145
+ view, identity, reduce, transform);
146
+ }
147
+
148
+ /// Helper to compute the norm of the elements of a tensor.
149
+ template <
150
+ typename Element,
151
+ typename Layout,
152
+ typename ComputeType = double
153
+ >
154
+ ComputeType TensorNorm(
155
+ TensorView<Element, Layout> view,
156
+ ComputeType identity = ComputeType()
157
+ ) {
158
+
159
+ return std::sqrt(TensorSumSq(view, identity));
160
+ }
161
+
162
+ /// Helper to compute the sum of the squares of the differences of two tensors
163
+ template <
164
+ typename Element,
165
+ typename Layout,
166
+ typename ComputeType = double
167
+ >
168
+ ComputeType TensorSumSqDiff(
169
+ TensorView<Element, Layout> view_A,
170
+ TensorView<Element, Layout> view_B,
171
+ ComputeType identity = ComputeType()
172
+ ) {
173
+
174
+ plus<ComputeType> reduce;
175
+ magnitude_squared_difference<Element, ComputeType> transform;
176
+
177
+ return TensorTransformReduce(
178
+ view_A, view_B, identity, reduce, transform);
179
+ }
180
+
181
+
182
+ /// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
183
+ template <
184
+ typename Element,
185
+ typename Layout,
186
+ typename ComputeType = double
187
+ >
188
+ ComputeType TensorNormDiff(
189
+ TensorView<Element, Layout> view_A,
190
+ TensorView<Element, Layout> view_B,
191
+ ComputeType identity = ComputeType()
192
+ ) {
193
+
194
+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity));
195
+ }
196
+
197
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
198
+
199
+ } // namespace host
200
+ } // namespace reference
201
+ } // namespace cutlass
202
+
203
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Provides several functions for filling tensors with data.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // Standard Library includes
38
+ #include <utility>
39
+ #include <cstdlib>
40
+ #include <cmath>
41
+
42
+ // Cute includes
43
+ #include "cute/tensor.hpp"
44
+
45
+ // Cutlass includes
46
+ #include "cutlass/cutlass.h"
47
+ #include "cutlass/complex.h"
48
+ #include "cutlass/functional.h"
49
+ #include "cutlass/numeric_conversion.h"
50
+ #include "cutlass/quaternion.h"
51
+ #include "cutlass/array.h"
52
+ #include "cutlass/numeric_types.h"
53
+
54
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ namespace cutlass {
57
+ namespace reference {
58
+ namespace host {
59
+
60
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
61
+ //
62
+ // Tensor reductions
63
+ //
64
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
65
+
66
+ /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
67
+ /// workspace
68
+ template <
69
+ typename Tensor,
70
+ typename ComputeType,
71
+ typename ReduceOp,
72
+ typename TransformOp
73
+ >
74
+ ComputeType TensorTransformReduce(
75
+ Tensor view,
76
+ ComputeType identity,
77
+ ReduceOp reduce,
78
+ TransformOp transform
79
+ ) {
80
+
81
+ for (int64_t idx = 0; idx < cute::size(view); ++idx) {
82
+ identity = reduce(identity, transform(view(idx)));
83
+ }
84
+
85
+ return identity;
86
+ }
87
+
88
+ /// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side
89
+ /// workspace
90
+ template <
91
+ typename TensorA,
92
+ typename TensorB,
93
+ typename ComputeType,
94
+ typename ReduceOp,
95
+ typename TransformOp
96
+ >
97
+ ComputeType TensorTransformReduce(
98
+ TensorA view_A,
99
+ TensorB view_B,
100
+ ComputeType identity,
101
+ ReduceOp reduce,
102
+ TransformOp transform) {
103
+
104
+ if (cute::size(view_A) != cute::size(view_B)) {
105
+ throw std::runtime_error("Tensor sizes must match.");
106
+ }
107
+
108
+ for (int64_t idx = 0; idx < cute::size(view_A); ++idx) {
109
+ identity = reduce(identity, transform(view_A(idx), view_B(idx)));
110
+ }
111
+
112
+ return identity;
113
+ }
114
+
115
+ /// Helper to compute the sum of the elements of a tensor
116
+ template <
117
+ typename Tensor,
118
+ typename ComputeType = typename Tensor::value_type
119
+ >
120
+ ComputeType TensorSum(
121
+ Tensor view,
122
+ ComputeType identity = ComputeType()
123
+ ) {
124
+
125
+ plus<ComputeType> reduce;
126
+ NumericConverter<ComputeType, typename Tensor::value_type> transform;
127
+
128
+ return TensorTransformReduce(
129
+ view, identity, reduce, transform);
130
+ }
131
+
132
+ /// Helper to compute the sum of the squares of the elements of a tensor
133
+ template <
134
+ typename Tensor,
135
+ typename ComputeType = typename Tensor::value_type
136
+ >
137
+ ComputeType TensorSumSq(
138
+ Tensor view,
139
+ ComputeType identity = ComputeType()
140
+ ) {
141
+
142
+ plus<ComputeType> reduce;
143
+ magnitude_squared<typename Tensor::value_type, ComputeType> transform;
144
+
145
+ return TensorTransformReduce(
146
+ view, identity, reduce, transform);
147
+ }
148
+
149
+ /// Helper to compute the norm of the elements of a tensor.
150
+ template <
151
+ typename Tensor,
152
+ typename ComputeType = double
153
+ >
154
+ ComputeType TensorNorm(
155
+ Tensor view,
156
+ ComputeType identity = ComputeType()
157
+ ) {
158
+
159
+ return std::sqrt(TensorSumSq(view, identity));
160
+ }
161
+
162
+ /// Helper to compute the sum of the squares of the differences of two tensors
163
+ template <
164
+ typename TensorA,
165
+ typename TensorB,
166
+ typename ComputeType = double
167
+ >
168
+ ComputeType TensorSumSqDiff(
169
+ TensorA view_A,
170
+ TensorB view_B,
171
+ ComputeType identity = ComputeType()
172
+ ) {
173
+
174
+ plus<ComputeType> reduce;
175
+ magnitude_squared_difference<typename TensorA::value_type, ComputeType> transform;
176
+
177
+ return TensorTransformReduce(
178
+ view_A, view_B, identity, reduce, transform);
179
+ }
180
+
181
+
182
+ /// Helper to compute the norm of the tensor computed as the difference of two tensors in memory
183
+ template <
184
+ typename TensorA,
185
+ typename TensorB,
186
+ typename ComputeType = double
187
+ >
188
+ ComputeType TensorNormDiff(
189
+ TensorA view_A,
190
+ TensorB view_B,
191
+ ComputeType identity = ComputeType()
192
+ ) {
193
+
194
+ return std::sqrt(TensorSumSqDiff(view_A, view_B, identity));
195
+ }
196
+
197
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
198
+
199
+ } // namespace host
200
+ } // namespace reference
201
+ } // namespace cutlass
202
+
203
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for TRMM in host-side code.
33
+
34
+
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/blas3.h"
40
+ #include "cutlass/numeric_conversion.h"
41
+ #include "cutlass/tensor_view.h"
42
+ #include "cutlass/gemm/gemm.h"
43
+ #include "cutlass/arch/mma.h"
44
+ #include "cutlass/util/host_tensor.h"
45
+
46
+ #include "cutlass/util/reference/host/gemm.h"
47
+
48
+ namespace cutlass {
49
+ namespace reference {
50
+ namespace host {
51
+
52
+ /// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef
53
+ /// objects.
54
+ template <
55
+ typename ElementA,
56
+ typename LayoutA,
57
+ SideMode SideModeA,
58
+ FillMode FillModeA,
59
+ DiagType DiagTypeA,
60
+ typename ElementB,
61
+ typename LayoutB,
62
+ typename ElementC,
63
+ typename LayoutC,
64
+ typename ScalarType,
65
+ typename ComputeType,
66
+ typename InnerProductOp = multiply_add<ComputeType>,
67
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
68
+ >
69
+ void compute_trmm(
70
+ gemm::GemmCoord problem_size,
71
+ ScalarType alpha,
72
+ TensorRef<ElementA, LayoutA> tensor_a,
73
+ TensorRef<ElementB, LayoutB> tensor_b,
74
+ TensorRef<ElementC, LayoutC> tensor_d,
75
+ ComputeType initial_accum) {
76
+
77
+ static_assert(
78
+ LayoutA::kRank == 2 &&
79
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
80
+
81
+ static_assert(SideModeA != SideMode::kInvalid
82
+ , "Side Mode can either be Left or Right.");
83
+
84
+ static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper
85
+ , "Fill Mode can either be Lower or Upper.");
86
+
87
+ using CompareOp = typename TrMatrixCompareOp<FillModeA, DiagTypeA>::Type;
88
+
89
+ // Note: batch is ignored.
90
+ int const M = problem_size.m();
91
+ int const N = problem_size.n();
92
+ // Assuming correct k-dimension value is passed
93
+ int const K = problem_size.k();
94
+
95
+ // Blocking necessary to speedup reference implementation
96
+ int const Mblock = 16;
97
+ int const Nblock = 16;
98
+
99
+ ConvertOp convert_op;
100
+ InnerProductOp inner_product_op;
101
+ CompareOp compare_op;
102
+
103
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
104
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
105
+
106
+ ComputeType accum[Mblock][Nblock];
107
+
108
+ for (int j = 0; j < Nblock; j++) {
109
+ for (int i = 0; i < Mblock; i++) {
110
+ accum[i][j] = initial_accum;
111
+ }
112
+ }
113
+
114
+ for (int k_block = 0; k_block < K; ++k_block) {
115
+ for (int j = 0; j < Nblock; j++) {
116
+ for (int i = 0; i < Mblock; i++) {
117
+ int row = row_block + i;
118
+ int col = col_block + j;
119
+
120
+ if (row < M && col < N) {
121
+ ElementA a = ElementA();
122
+ ElementB b = ElementB();
123
+
124
+ if (SideModeA == SideMode::kLeft) {
125
+ a = (compare_op(row, k_block)) ?
126
+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0);
127
+ if (row == k_block && DiagTypeA == DiagType::kUnit) {
128
+ a = ElementA(1);
129
+ }
130
+ b = tensor_b.at(MatrixCoord(k_block, col));
131
+ } else if (SideModeA == SideMode::kRight) {
132
+ a = tensor_b.at(MatrixCoord(row, k_block));
133
+ b = (compare_op(k_block, col)) ?
134
+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0);
135
+ if (k_block == col && DiagTypeA == DiagType::kUnit) {
136
+ b = ElementA(1);
137
+ }
138
+ }
139
+
140
+ ComputeType compute_a(cast_if_scalar<ComputeType>(a));
141
+ ComputeType compute_b(cast_if_scalar<ComputeType>(b));
142
+
143
+ accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]);
144
+ }
145
+ }
146
+ }
147
+ }
148
+
149
+ for (int j = 0; j < Nblock; j++) {
150
+ for (int i = 0; i < Mblock; i++) {
151
+ int row = row_block + i;
152
+ int col = col_block + j;
153
+
154
+ MatrixCoord coord = MatrixCoord(row, col);
155
+
156
+ if (row < M && col < N) {
157
+ tensor_d.at(coord) = convert_op(
158
+ alpha * ScalarType(accum[i][j]));
159
+ }
160
+ }
161
+ }
162
+ }
163
+ }
164
+ }
165
+
166
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
167
+
168
+ template <
169
+ typename ElementA,
170
+ typename LayoutA,
171
+ SideMode SideModeA,
172
+ FillMode FillModeA,
173
+ DiagType DiagTypeA,
174
+ typename ElementB,
175
+ typename LayoutB,
176
+ typename ElementC,
177
+ typename LayoutC,
178
+ typename ScalarType,
179
+ typename ComputeType,
180
+ typename InnerProductOp = cutlass::arch::OpMultiplyAdd
181
+ >
182
+ struct Trmm;
183
+
184
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
185
+
186
+ /// Partial specialization for multiply-add
187
+ template <typename ElementA, typename LayoutA, SideMode SideModeA,
188
+ FillMode FillModeA, DiagType DiagTypeA,
189
+ typename ElementB, typename LayoutB,
190
+ typename ElementC, typename LayoutC,
191
+ typename ScalarType, typename ComputeType>
192
+ struct Trmm<ElementA, LayoutA, SideModeA, FillModeA, DiagTypeA, ElementB, LayoutB,
193
+ ElementC, LayoutC, ScalarType,
194
+ ComputeType, arch::OpMultiplyAdd> {
195
+
196
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
197
+ TensorRef<ElementA, LayoutA> tensor_a,
198
+ TensorRef<ElementB, LayoutB> tensor_b,
199
+ TensorRef<ElementC, LayoutC> tensor_d,
200
+ ComputeType initial_accum = ComputeType(0)) {
201
+ static_assert(
202
+ LayoutA::kRank == 2 && LayoutC::kRank == 2,
203
+ "Tensors must be of rank 2");
204
+
205
+ compute_trmm<ElementA, LayoutA, SideModeA, FillModeA, DiagTypeA, ElementB, LayoutB,
206
+ ElementC, LayoutC, ScalarType, ComputeType, multiply_add<ComputeType>>(
207
+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
208
+ }
209
+ };
210
+
211
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
212
+
213
+ } // namespace host
214
+ } // namespace reference
215
+ } // namespace cutlass
build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Reference implementation for complex-valued TRMM in host-side code.
33
+
34
+
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/blas3.h"
40
+ #include "cutlass/complex.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/tensor_view.h"
43
+ #include "cutlass/gemm/gemm.h"
44
+
45
+ #include "cutlass/util/reference/host/gemm.h"
46
+
47
+ namespace cutlass {
48
+ namespace reference {
49
+ namespace host {
50
+
51
+ /// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef
52
+ /// objects.
53
+ template <
54
+ typename ElementA,
55
+ typename LayoutA,
56
+ ComplexTransform TransformA,
57
+ SideMode SideModeA,
58
+ FillMode FillModeA,
59
+ DiagType DiagTypeA,
60
+ typename ElementB,
61
+ typename LayoutB,
62
+ ComplexTransform TransformB,
63
+ typename ElementC,
64
+ typename LayoutC,
65
+ typename ScalarType,
66
+ typename ComputeType,
67
+ typename InnerProductOp = multiply_add<ComputeType>,
68
+ typename ConvertOp = NumericConverter<ElementC, ScalarType>
69
+ >
70
+ void compute_trmm_complex(
71
+ gemm::GemmCoord problem_size,
72
+ ScalarType alpha,
73
+ TensorRef<ElementA, LayoutA> tensor_a,
74
+ TensorRef<ElementB, LayoutB> tensor_b,
75
+ TensorRef<ElementC, LayoutC> tensor_d,
76
+ ComputeType initial_accum) {
77
+
78
+ static_assert(
79
+ LayoutA::kRank == 2 &&
80
+ LayoutC::kRank == 2, "Tensors must be of rank 2");
81
+
82
+ static_assert(SideModeA != SideMode::kInvalid
83
+ , "Side Mode can either be Left or Right.");
84
+
85
+ static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper
86
+ , "Fill Mode can either be Lower or Upper.");
87
+
88
+ using CompareOp = typename TrMatrixCompareOp<FillModeA, DiagTypeA>::Type;
89
+
90
+ // Note: batch is ignored.
91
+ int const M = problem_size.m();
92
+ int const N = problem_size.n();
93
+ // Assuming correct k-dimension value is passed
94
+ int const K = problem_size.k();
95
+
96
+ // Blocking necessary to speedup reference implementation
97
+ int const Mblock = 16;
98
+ int const Nblock = 16;
99
+
100
+ ConvertOp convert_op;
101
+ InnerProductOp inner_product_op;
102
+ CompareOp compare_op;
103
+
104
+ for (int row_block = 0; row_block < M; row_block += Mblock) {
105
+ for (int col_block = 0; col_block < N; col_block += Nblock) {
106
+
107
+ ComputeType accum[Mblock][Nblock];
108
+
109
+ for (int j = 0; j < Nblock; j++) {
110
+ for (int i = 0; i < Mblock; i++) {
111
+ accum[i][j] = initial_accum;
112
+ }
113
+ }
114
+
115
+ for (int k_block = 0; k_block < K; ++k_block) {
116
+ for (int j = 0; j < Nblock; j++) {
117
+ for (int i = 0; i < Mblock; i++) {
118
+ int row = row_block + i;
119
+ int col = col_block + j;
120
+
121
+ if (row < M && col < N) {
122
+ ElementA a = ElementA();
123
+ ElementB b = ElementB();
124
+
125
+ if (SideModeA == SideMode::kLeft) {
126
+ a = (compare_op(row, k_block)) ?
127
+ (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0);
128
+ if (row == k_block && DiagTypeA == DiagType::kUnit) {
129
+ a = ElementA(1);
130
+ }
131
+ b = tensor_b.at(MatrixCoord(k_block, col));
132
+ } else if (SideModeA == SideMode::kRight) {
133
+ a = tensor_b.at(MatrixCoord(row, k_block));
134
+ b = (compare_op(k_block, col)) ?
135
+ tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0);
136
+ if (k_block == col && DiagTypeA == DiagType::kUnit) {
137
+ b = ElementA(1);
138
+ }
139
+ }
140
+
141
+ ComputeType a_ik = ComputeType(a);
142
+ ComputeType b_kj = ComputeType(b);
143
+
144
+ // Conjugate, and hence hermitian, is only allowed for the triangular matrix
145
+ if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) {
146
+ a_ik = conj(a_ik);
147
+ } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) {
148
+ b_kj = conj(b_kj);
149
+ }
150
+
151
+ accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
152
+ }
153
+ }
154
+ }
155
+ }
156
+
157
+ for (int j = 0; j < Nblock; j++) {
158
+ for (int i = 0; i < Mblock; i++) {
159
+ int row = row_block + i;
160
+ int col = col_block + j;
161
+
162
+ MatrixCoord coord = MatrixCoord(row, col);
163
+
164
+ if (row < M && col < N) {
165
+ tensor_d.at(coord) = convert_op(
166
+ alpha * ScalarType(accum[i][j]));
167
+ }
168
+ }
169
+ }
170
+ }
171
+ }
172
+ }
173
+
174
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
175
+
176
+ template <
177
+ typename ElementA,
178
+ typename LayoutA,
179
+ ComplexTransform TransformA,
180
+ SideMode SideModeA,
181
+ FillMode FillModeA,
182
+ DiagType DiagTypeA,
183
+ typename ElementB,
184
+ typename LayoutB,
185
+ ComplexTransform TransformB,
186
+ typename ElementC,
187
+ typename LayoutC,
188
+ typename ScalarType,
189
+ typename ComputeType,
190
+ typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex
191
+ >
192
+ struct TrmmComplex;
193
+
194
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
195
+
196
+ /// Partial specialization for multiply-add
197
+ template <typename ElementA, typename LayoutA, ComplexTransform TransformA,
198
+ SideMode SideModeA, FillMode FillModeA, DiagType DiagTypeA,
199
+ typename ElementB, typename LayoutB, ComplexTransform TransformB,
200
+ typename ElementC, typename LayoutC,
201
+ typename ScalarType, typename ComputeType>
202
+ struct TrmmComplex<ElementA, LayoutA, TransformA,
203
+ SideModeA, FillModeA, DiagTypeA,
204
+ ElementB, LayoutB, TransformB,
205
+ ElementC, LayoutC, ScalarType,
206
+ ComputeType, arch::OpMultiplyAddComplex> {
207
+
208
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
209
+ TensorRef<ElementA, LayoutA> tensor_a,
210
+ TensorRef<ElementB, LayoutB> tensor_b,
211
+ TensorRef<ElementC, LayoutC> tensor_d,
212
+ ComputeType initial_accum = ComputeType(0)) {
213
+ static_assert(
214
+ LayoutA::kRank == 2 && LayoutC::kRank == 2,
215
+ "Tensors must be of rank 2");
216
+
217
+ compute_trmm_complex<ElementA, LayoutA, TransformA,
218
+ SideModeA, FillModeA, DiagTypeA,
219
+ ElementB, LayoutB, TransformB,
220
+ ElementC, LayoutC,
221
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
222
+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
223
+ }
224
+ };
225
+
226
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
227
+
228
+ /// Partial specialization for gaussian multiply-add
229
+ template <typename ElementA, typename LayoutA, ComplexTransform TransformA,
230
+ SideMode SideModeA, FillMode FillModeA, DiagType DiagTypeA,
231
+ typename ElementB, typename LayoutB, ComplexTransform TransformB,
232
+ typename ElementC, typename LayoutC,
233
+ typename ScalarType, typename ComputeType>
234
+ struct TrmmComplex<ElementA, LayoutA, TransformA,
235
+ SideModeA, FillModeA, DiagTypeA,
236
+ ElementB, LayoutB, TransformB,
237
+ ElementC, LayoutC, ScalarType,
238
+ ComputeType, arch::OpMultiplyAddGaussianComplex> {
239
+
240
+ void operator()(gemm::GemmCoord problem_size, ScalarType alpha,
241
+ TensorRef<ElementA, LayoutA> tensor_a,
242
+ TensorRef<ElementB, LayoutB> tensor_b,
243
+ TensorRef<ElementC, LayoutC> tensor_d,
244
+ ComputeType initial_accum = ComputeType(0)) {
245
+ static_assert(
246
+ LayoutA::kRank == 2 && LayoutC::kRank == 2,
247
+ "Tensors must be of rank 2");
248
+
249
+ compute_trmm_complex<ElementA, LayoutA, TransformA,
250
+ SideModeA, FillModeA, DiagTypeA,
251
+ ElementB, LayoutB, TransformB,
252
+ ElementC, LayoutC,
253
+ ScalarType, ComputeType, multiply_add<ComputeType>>(
254
+ problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum);
255
+ }
256
+ };
257
+
258
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
259
+
260
+ } // namespace host
261
+ } // namespace reference
262
+ } // namespace cutlass