danieldk HF staff commited on
Commit
29e93ec
·
1 Parent(s): 8d5a3ce

Add MoE kernels from vLLM

Browse files

TODO: add MoE configs, but this requires a change to the builder.

README.md CHANGED
@@ -1,3 +1,7 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ## MoE
6
+
7
+ MoE kernels from [vLLM](https://github.com/vllm-project/).
activation/activation_kernels.cu ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <torch/all.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include <cmath>
6
+
7
+ #include "cuda_compat.h"
8
+ #include "dispatch_utils.h"
9
+
10
+ namespace vllm {
11
+
12
+ // Activation and gating kernel template.
13
+ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
14
+ __global__ void act_and_mul_kernel(
15
+ scalar_t* __restrict__ out, // [..., d]
16
+ const scalar_t* __restrict__ input, // [..., 2, d]
17
+ const int d) {
18
+ const int64_t token_idx = blockIdx.x;
19
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
20
+ const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
21
+ const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
22
+ out[token_idx * d + idx] = ACT_FN(x) * y;
23
+ }
24
+ }
25
+
26
+ template <typename T>
27
+ __device__ __forceinline__ T silu_kernel(const T& x) {
28
+ // x * sigmoid(x)
29
+ return (T)(((float)x) / (1.0f + expf((float)-x)));
30
+ }
31
+
32
+ template <typename T>
33
+ __device__ __forceinline__ T gelu_kernel(const T& x) {
34
+ // Equivalent to PyTorch GELU with 'none' approximation.
35
+ // Refer to:
36
+ // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
37
+ const float f = (float)x;
38
+ constexpr float ALPHA = M_SQRT1_2;
39
+ return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
40
+ }
41
+
42
+ template <typename T>
43
+ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
44
+ // Equivalent to PyTorch GELU with 'tanh' approximation.
45
+ // Refer to:
46
+ // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
47
+ const float f = (float)x;
48
+ constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
49
+ constexpr float KAPPA = 0.044715;
50
+ float x_cube = f * f * f;
51
+ float inner = BETA * (f + KAPPA * x_cube);
52
+ return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
53
+ }
54
+
55
+ } // namespace vllm
56
+
57
+ // Launch activation and gating kernel.
58
+ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
59
+ int d = input.size(-1) / 2; \
60
+ int64_t num_tokens = input.numel() / input.size(-1); \
61
+ dim3 grid(num_tokens); \
62
+ dim3 block(std::min(d, 1024)); \
63
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
64
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
65
+ VLLM_DISPATCH_FLOATING_TYPES( \
66
+ input.scalar_type(), "act_and_mul_kernel", [&] { \
67
+ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
68
+ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
69
+ input.data_ptr<scalar_t>(), d); \
70
+ });
71
+
72
+ void silu_and_mul(torch::Tensor& out, // [..., d]
73
+ torch::Tensor& input) // [..., 2 * d]
74
+ {
75
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
76
+ }
77
+
78
+ //void gelu_and_mul(torch::Tensor& out, // [..., d]
79
+ // torch::Tensor& input) // [..., 2 * d]
80
+ //{
81
+ // LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
82
+ //}
83
+
84
+ //void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
85
+ // torch::Tensor& input) // [..., 2 * d]
86
+ //{
87
+ // LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
88
+ //}
89
+
90
+ namespace vllm {
91
+
92
+ template <typename T>
93
+ __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
94
+ const float f = (float)x;
95
+ return (T)(f > threshold ? f : 0.0f);
96
+ }
97
+
98
+ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
99
+ __global__ void act_and_mul_kernel_with_param(
100
+ scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
101
+ const float param) {
102
+ const int64_t token_idx = blockIdx.x;
103
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
104
+ const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
105
+ const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
106
+ out[token_idx * d + idx] = ACT_FN(x, param) * y;
107
+ }
108
+ }
109
+
110
+ } // namespace vllm
111
+
112
+ #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
113
+ int d = input.size(-1) / 2; \
114
+ int64_t num_tokens = input.numel() / input.size(-1); \
115
+ dim3 grid(num_tokens); \
116
+ dim3 block(std::min(d, 1024)); \
117
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
118
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
119
+ VLLM_DISPATCH_FLOATING_TYPES( \
120
+ input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
121
+ vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
122
+ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
123
+ input.data_ptr<scalar_t>(), d, \
124
+ PARAM); \
125
+ });
126
+
127
+ //void fatrelu_and_mul(torch::Tensor& out, // [..., d],
128
+ // torch::Tensor& input, // [..., 2 * d]
129
+ // double threshold) {
130
+ // LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
131
+ //}
132
+ namespace vllm {
133
+
134
+ // Element-wise activation kernel template.
135
+ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
136
+ __global__ void activation_kernel(
137
+ scalar_t* __restrict__ out, // [..., d]
138
+ const scalar_t* __restrict__ input, // [..., d]
139
+ const int d) {
140
+ const int64_t token_idx = blockIdx.x;
141
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
142
+ const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
143
+ out[token_idx * d + idx] = ACT_FN(x);
144
+ }
145
+ }
146
+
147
+ } // namespace vllm
148
+
149
+ // Launch element-wise activation kernel.
150
+ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
151
+ int d = input.size(-1); \
152
+ int64_t num_tokens = input.numel() / d; \
153
+ dim3 grid(num_tokens); \
154
+ dim3 block(std::min(d, 1024)); \
155
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
156
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
157
+ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
158
+ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
159
+ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
160
+ input.data_ptr<scalar_t>(), d); \
161
+ });
162
+
163
+ namespace vllm {
164
+
165
+ template <typename T>
166
+ __device__ __forceinline__ T gelu_new_kernel(const T& x) {
167
+ const float x3 = (float)(x * x * x);
168
+ const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
169
+ return ((T)0.5) * x * (((T)1.0) + t);
170
+ }
171
+
172
+ template <typename T>
173
+ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
174
+ const float f = (float)x;
175
+ const T t =
176
+ (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
177
+ return ((T)0.5) * x * (((T)1.0) + t);
178
+ }
179
+
180
+ template <typename T>
181
+ __device__ __forceinline__ T gelu_quick_kernel(const T& x) {
182
+ // x * sigmoid(1.702 * x)
183
+ return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
184
+ }
185
+
186
+ } // namespace vllm
187
+
188
+ //void gelu_new(torch::Tensor& out, // [..., d]
189
+ // torch::Tensor& input) // [..., d]
190
+ //{
191
+ // LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
192
+ //}
193
+
194
+ //void gelu_fast(torch::Tensor& out, // [..., d]
195
+ // torch::Tensor& input) // [..., d]
196
+ //{
197
+ // LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
198
+ //}
199
+
200
+ //void gelu_quick(torch::Tensor& out, // [..., d]
201
+ // torch::Tensor& input) // [..., d]
202
+ //{
203
+ // LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
204
+ //}
activation/cuda_compat.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_ROCM
4
+ #include <hip/hip_runtime.h>
5
+ #endif
6
+
7
+ #ifndef USE_ROCM
8
+ #define WARP_SIZE 32
9
+ #else
10
+ #define WARP_SIZE warpSize
11
+ #endif
12
+
13
+ #ifndef USE_ROCM
14
+ #define VLLM_LDG(arg) __ldg(arg)
15
+ #else
16
+ #define VLLM_LDG(arg) *(arg)
17
+ #endif
18
+
19
+ #ifndef USE_ROCM
20
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
21
+ __shfl_xor_sync(uint32_t(-1), var, lane_mask)
22
+ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
23
+ __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
24
+ #else
25
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
26
+ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
27
+ __shfl_xor(var, lane_mask, width)
28
+ #endif
29
+
30
+ #ifndef USE_ROCM
31
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
32
+ #else
33
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
34
+ #endif
35
+
36
+ #ifndef USE_ROCM
37
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
38
+ __shfl_down_sync(uint32_t(-1), var, lane_delta)
39
+ #else
40
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
41
+ #endif
42
+
43
+ #ifndef USE_ROCM
44
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
45
+ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
46
+ #else
47
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
48
+ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
49
+ #endif
activation/dispatch_utils.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
+ */
5
+ #pragma once
6
+
7
+ #include <torch/all.h>
8
+
9
+ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
+
14
+ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
16
+
17
+ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
18
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
19
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
20
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
21
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
22
+
23
+ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
24
+ AT_DISPATCH_SWITCH(TYPE, NAME, \
25
+ VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
26
+
27
+ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
28
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
29
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
30
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
31
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
32
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
33
+
34
+ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
35
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
build.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "moe"
6
+ src = [
7
+ "core/scalar_type.hpp",
8
+ "ext-torch/registration.h",
9
+ "ext-torch/torch_binding.cpp",
10
+ "ext-torch/torch_binding.h"
11
+ ]
12
+ include = [ "." ]
13
+ pyroot = "ext-torch"
14
+
15
+ [kernel.moe]
16
+ capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
17
+ src = [
18
+ "cuda_compat.h",
19
+ "dispatch_utils.h",
20
+ "moe/moe_align_sum_kernels.cu",
21
+ "moe/topk_softmax_kernels.cu",
22
+ ]
23
+ depends = [ "torch" ]
24
+
25
+ [kernel.moe-marlin]
26
+ capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0" ]
27
+ src = [
28
+ "core/exception.hpp",
29
+ "core/scalar_type.hpp",
30
+ "marlin-moe/marlin_moe_ops.cu",
31
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.cu",
32
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu",
33
+ "marlin-moe/marlin_kernels/marlin_moe_kernel.h",
34
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.h",
35
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.h",
36
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu",
37
+ "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h",
38
+ ]
39
+ include = [ "." ]
40
+ depends = [ "torch" ]
41
+
42
+ [kernel.activation]
43
+ capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
44
+ src = [
45
+ "activation/activation_kernels.cu",
46
+ "activation/cuda_compat.h",
47
+ "activation/dispatch_utils.h",
48
+ ]
49
+ depends = [ "torch" ]
core/exception.hpp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #pragma once
2
+
3
+ #define VLLM_IMPLIES(p, q) (!(p) || (q))
core/scalar_type.hpp ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // For TORCH_CHECK
4
+ #include <torch/library.h>
5
+
6
+ namespace vllm {
7
+
8
+ //
9
+ // ScalarType can represent a wide range of floating point and integer types,
10
+ // in particular it can be used to represent sub-byte data types (something
11
+ // that torch.dtype currently does not support).
12
+ //
13
+ // The type definitions on the Python side can be found in: vllm/scalar_type.py
14
+ // these type definitions should be kept up to date with any Python API changes
15
+ // here.
16
+ //
17
+ class ScalarType {
18
+ public:
19
+ enum NanRepr : uint8_t {
20
+ NAN_NONE = 0, // nans are not supported
21
+ NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
22
+ NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
23
+
24
+ NAN_REPR_ID_MAX
25
+ };
26
+
27
+ constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
28
+ int32_t bias, bool finite_values_only = false,
29
+ NanRepr nan_repr = NAN_IEEE_754)
30
+ : exponent(exponent),
31
+ mantissa(mantissa),
32
+ signed_(signed_),
33
+ bias(bias),
34
+ finite_values_only(finite_values_only),
35
+ nan_repr(nan_repr){};
36
+
37
+ static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
38
+ return ScalarType(0, size_bits - 1, true, bias);
39
+ }
40
+
41
+ static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
42
+ return ScalarType(0, size_bits, false, bias);
43
+ }
44
+
45
+ // IEEE 754 compliant floating point type
46
+ static constexpr ScalarType float_IEEE754(uint8_t exponent,
47
+ uint8_t mantissa) {
48
+ TORCH_CHECK(mantissa > 0 && exponent > 0);
49
+ return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
50
+ }
51
+
52
+ // IEEE 754 non-compliant floating point type
53
+ static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
54
+ bool finite_values_only,
55
+ NanRepr nan_repr) {
56
+ TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
57
+ TORCH_CHECK(mantissa > 0 && exponent > 0);
58
+ TORCH_CHECK(nan_repr != NAN_IEEE_754,
59
+ "use `float_IEEE754` constructor for floating point types that "
60
+ "follow IEEE 754 conventions");
61
+ return ScalarType(exponent, mantissa, true, 0, finite_values_only,
62
+ nan_repr);
63
+ }
64
+
65
+ uint8_t const exponent; // size of the exponent field (0 for integer types)
66
+ uint8_t const mantissa; // size of the mantissa field (size of the integer
67
+ // excluding the sign bit for integer types)
68
+ bool const signed_; // flag if the type supports negative numbers (i.e. has a
69
+ // sign bit)
70
+ int32_t const bias; // stored values equal value + bias,
71
+ // used for quantized type
72
+
73
+ // Extra Floating point info
74
+ bool const finite_values_only; // i.e. no +/-inf if true
75
+ NanRepr const nan_repr; // how NaNs are represented
76
+ // (not applicable for integer types)
77
+
78
+ using Id = int64_t;
79
+
80
+ private:
81
+ // Field size in id
82
+ template <typename T_>
83
+ static constexpr size_t member_id_field_width() {
84
+ using T = std::decay_t<T_>;
85
+ return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
86
+ }
87
+
88
+ template <typename Fn, typename Init, typename Member, typename... Rest>
89
+ static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
90
+ Rest... rest) {
91
+ auto new_val = f(val, member);
92
+ if constexpr (sizeof...(rest) > 0) {
93
+ return reduce_members_helper(f, new_val, rest...);
94
+ } else {
95
+ return new_val;
96
+ };
97
+ }
98
+
99
+ template <typename Fn, typename Init>
100
+ constexpr auto reduce_members(Fn f, Init init) const {
101
+ // Should be in constructor order for `from_id`
102
+ return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
103
+ finite_values_only, nan_repr);
104
+ };
105
+
106
+ template <typename Fn, typename Init>
107
+ static constexpr auto reduce_member_types(Fn f, Init init) {
108
+ constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
109
+ return dummy_type.reduce_members(f, init);
110
+ };
111
+
112
+ static constexpr auto id_size_bits() {
113
+ return reduce_member_types(
114
+ [](int acc, auto member) -> int {
115
+ return acc + member_id_field_width<decltype(member)>();
116
+ },
117
+ 0);
118
+ }
119
+
120
+ public:
121
+ // unique id for this scalar type that can be computed at compile time for
122
+ // c++17 template specialization this is not needed once we migrate to
123
+ // c++20 and can pass literal classes as template parameters
124
+ constexpr Id id() const {
125
+ static_assert(id_size_bits() <= sizeof(Id) * 8,
126
+ "ScalarType id is too large to be stored");
127
+
128
+ auto or_and_advance = [](std::pair<Id, uint32_t> result,
129
+ auto member) -> std::pair<Id, uint32_t> {
130
+ auto [id, bit_offset] = result;
131
+ auto constexpr bits = member_id_field_width<decltype(member)>();
132
+ return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
133
+ << bit_offset,
134
+ bit_offset + bits};
135
+ };
136
+ return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
137
+ }
138
+
139
+ // create a ScalarType from an id, for c++17 template specialization,
140
+ // this is not needed once we migrate to c++20 and can pass literal
141
+ // classes as template parameters
142
+ static constexpr ScalarType from_id(Id id) {
143
+ auto extract_and_advance = [id](auto result, auto member) {
144
+ using T = decltype(member);
145
+ auto [tuple, bit_offset] = result;
146
+ auto constexpr bits = member_id_field_width<T>();
147
+ auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
148
+ ((uint64_t(1) << bits) - 1));
149
+ auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
150
+ return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
151
+ };
152
+
153
+ auto [tuple_args, _] = reduce_member_types(extract_and_advance,
154
+ std::pair<std::tuple<>, int>{});
155
+ return std::apply([](auto... args) { return ScalarType(args...); },
156
+ tuple_args);
157
+ }
158
+
159
+ constexpr int64_t size_bits() const {
160
+ return mantissa + exponent + is_signed();
161
+ }
162
+ constexpr bool is_signed() const { return signed_; }
163
+ constexpr bool is_integer() const { return exponent == 0; }
164
+ constexpr bool is_floating_point() const { return exponent > 0; }
165
+ constexpr bool is_ieee_754() const {
166
+ return is_floating_point() && finite_values_only == false &&
167
+ nan_repr == NAN_IEEE_754;
168
+ }
169
+ constexpr bool has_nans() const {
170
+ return is_floating_point() && nan_repr != NAN_NONE;
171
+ }
172
+ constexpr bool has_infs() const {
173
+ return is_floating_point() && finite_values_only == false;
174
+ }
175
+ constexpr bool has_bias() const { return bias != 0; }
176
+
177
+ private:
178
+ double _floating_point_max() const {
179
+ TORCH_CHECK(mantissa <= 52 && exponent <= 11,
180
+ "Cannot represent max/min as a double for type ", str());
181
+
182
+ uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
183
+ if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
184
+ max_mantissa -= 1;
185
+ }
186
+
187
+ uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
188
+ if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
189
+ TORCH_CHECK(exponent < 11,
190
+ "Cannot represent max/min as a double for type ", str());
191
+ max_exponent += 1;
192
+ }
193
+
194
+ // adjust the exponent to match that of a double
195
+ // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
196
+ // is the exponent bits), there is some precedent for non-standard biases,
197
+ // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
198
+ // but to avoid premature over complication we are just assuming the
199
+ // standard exponent bias until there is a need to support non-standard
200
+ // biases
201
+ uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
202
+ uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
203
+
204
+ uint64_t max_exponent_double =
205
+ max_exponent - exponent_bias + exponent_bias_double;
206
+
207
+ // shift the mantissa into the position for a double and
208
+ // the exponent
209
+ uint64_t double_raw =
210
+ (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
211
+
212
+ return *reinterpret_cast<double*>(&double_raw);
213
+ }
214
+
215
+ constexpr std::variant<int64_t, double> _raw_max() const {
216
+ if (is_floating_point()) {
217
+ return {_floating_point_max()};
218
+ } else {
219
+ TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
220
+ "Cannot represent max as a int64_t");
221
+ return {(int64_t(1) << mantissa) - 1};
222
+ }
223
+ }
224
+
225
+ constexpr std::variant<int64_t, double> _raw_min() const {
226
+ if (is_floating_point()) {
227
+ TORCH_CHECK(is_signed(),
228
+ "We currently assume all floating point types are signed");
229
+ constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
230
+
231
+ double max = _floating_point_max();
232
+ uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
233
+ uint64_t min_raw = max_raw | sign_bit_double;
234
+ return {*reinterpret_cast<double*>(&min_raw)};
235
+ } else {
236
+ TORCH_CHECK(!is_signed() || size_bits() <= 64,
237
+ "Cannot represent min as a int64_t");
238
+ if (is_signed()) {
239
+ // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
240
+ // then perform an arithmetic shift right to set all the bits above
241
+ // (size_bits() - 1) to 1
242
+ return {INT64_MIN >> (64 - size_bits())};
243
+ } else {
244
+ return {int64_t(0)};
245
+ }
246
+ }
247
+ }
248
+
249
+ public:
250
+ // Max representable value for this scalar type.
251
+ // (accounting for bias if there is one)
252
+ constexpr std::variant<int64_t, double> max() const {
253
+ return std::visit(
254
+ [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
255
+ _raw_max());
256
+ }
257
+
258
+ // Min representable value for this scalar type.
259
+ // (accounting for bias if there is one)
260
+ constexpr std::variant<int64_t, double> min() const {
261
+ return std::visit(
262
+ [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
263
+ _raw_min());
264
+ }
265
+
266
+ std::string str() const {
267
+ /* naming generally follows: https://github.com/jax-ml/ml_dtypes
268
+ * for floating point types (leading f) the scheme is:
269
+ * `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
270
+ * flags:
271
+ * - no-flags: means it follows IEEE 754 conventions
272
+ * - f: means finite values only (no infinities)
273
+ * - n: means nans are supported (non-standard encoding)
274
+ * for integer types the scheme is:
275
+ * `[u]int<size_bits>[b<bias>]`
276
+ * - if bias is not present it means its zero
277
+ */
278
+ if (is_floating_point()) {
279
+ auto ret = "float" + std::to_string(size_bits()) + "_e" +
280
+ std::to_string(exponent) + "m" + std::to_string(mantissa);
281
+ if (!is_ieee_754()) {
282
+ if (finite_values_only) {
283
+ ret += "f";
284
+ }
285
+ if (nan_repr != NAN_NONE) {
286
+ ret += "n";
287
+ }
288
+ }
289
+ return ret;
290
+ } else {
291
+ auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
292
+ if (has_bias()) {
293
+ ret += "b" + std::to_string(bias);
294
+ }
295
+ return ret;
296
+ }
297
+ }
298
+
299
+ constexpr bool operator==(ScalarType const& other) const {
300
+ return mantissa == other.mantissa && exponent == other.exponent &&
301
+ bias == other.bias && signed_ == other.signed_ &&
302
+ finite_values_only == other.finite_values_only &&
303
+ nan_repr == other.nan_repr;
304
+ }
305
+ };
306
+
307
+ using ScalarTypeId = ScalarType::Id;
308
+
309
+ // "rust style" names generally following:
310
+ // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
311
+ static inline constexpr auto kS4 = ScalarType::int_(4);
312
+ static inline constexpr auto kU4 = ScalarType::uint(4);
313
+ static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
314
+ static inline constexpr auto kS8 = ScalarType::int_(8);
315
+ static inline constexpr auto kU8 = ScalarType::uint(8);
316
+ static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
317
+
318
+ static inline constexpr auto kFE3M2f =
319
+ ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
320
+ static inline constexpr auto kFE4M3fn =
321
+ ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
322
+ static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
323
+ static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
324
+ static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
325
+
326
+ // Fixed width style names, generally following:
327
+ // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
328
+ static inline constexpr auto kInt4 = kS4;
329
+ static inline constexpr auto kUint4 = kU4;
330
+ static inline constexpr auto kUint4b8 = kU4B8;
331
+ static inline constexpr auto kInt8 = kS8;
332
+ static inline constexpr auto kUint8 = kU8;
333
+ static inline constexpr auto kUint8b128 = kU8B128;
334
+
335
+ static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
336
+ static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
337
+ static inline constexpr auto kFloat8_e5m2 = kFE5M2;
338
+ static inline constexpr auto kFloat16_e8m7 = kFE8M7;
339
+ static inline constexpr auto kFloat16_e5m10 = kFE5M10;
340
+
341
+ // colloquial names
342
+ static inline constexpr auto kHalf = kFE5M10;
343
+ static inline constexpr auto kFloat16 = kHalf;
344
+ static inline constexpr auto kBFloat16 = kFE8M7;
345
+
346
+ static inline constexpr auto kFloat16Id = kFloat16.id();
347
+ }; // namespace vllm
cuda_compat.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_ROCM
4
+ #include <hip/hip_runtime.h>
5
+ #endif
6
+
7
+ #ifndef USE_ROCM
8
+ #define WARP_SIZE 32
9
+ #else
10
+ #define WARP_SIZE warpSize
11
+ #endif
12
+
13
+ #ifndef USE_ROCM
14
+ #define VLLM_LDG(arg) __ldg(arg)
15
+ #else
16
+ #define VLLM_LDG(arg) *(arg)
17
+ #endif
18
+
19
+ #ifndef USE_ROCM
20
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
21
+ __shfl_xor_sync(uint32_t(-1), var, lane_mask)
22
+ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
23
+ __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
24
+ #else
25
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
26
+ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
27
+ __shfl_xor(var, lane_mask, width)
28
+ #endif
29
+
30
+ #ifndef USE_ROCM
31
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
32
+ #else
33
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
34
+ #endif
35
+
36
+ #ifndef USE_ROCM
37
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
38
+ __shfl_down_sync(uint32_t(-1), var, lane_delta)
39
+ #else
40
+ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
41
+ #endif
42
+
43
+ #ifndef USE_ROCM
44
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
45
+ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
46
+ #else
47
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
48
+ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
49
+ #endif
dispatch_utils.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
+ */
5
+ #pragma once
6
+
7
+ #include <torch/all.h>
8
+
9
+ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
+
14
+ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
16
+
17
+ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
18
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
19
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
20
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
21
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
22
+
23
+ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
24
+ AT_DISPATCH_SWITCH(TYPE, NAME, \
25
+ VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
26
+
27
+ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
28
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
29
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
30
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
31
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
32
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
33
+
34
+ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
35
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
ext-torch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import moe._custom_ops as ops
ext-torch/_custom_ops.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+
8
+ def register_fake(fn):
9
+ return lambda name: fn
10
+
11
+ else:
12
+ try:
13
+ from torch.library import register_fake
14
+ except ImportError:
15
+ from torch.library import impl_abstract as register_fake
16
+
17
+ try:
18
+ from ._ops import ops, add_op_namespace_prefix
19
+ except ImportError as e:
20
+ # Fallback for local development.
21
+ try:
22
+ import _moe
23
+
24
+ ops = torch._moe
25
+
26
+ def add_op_namespace_prefix(op_name: str):
27
+ return f"_quantization::{op_name}"
28
+
29
+ except ImportError:
30
+ raise e
31
+
32
+ from .scalar_type import ScalarType
33
+
34
+ def gptq_marlin_moe_repack(
35
+ b_q_weight: torch.Tensor,
36
+ perm: torch.Tensor,
37
+ size_k: int,
38
+ size_n: int,
39
+ num_bits: int,
40
+ ) -> torch.Tensor:
41
+ num_experts = b_q_weight.shape[0]
42
+ assert size_k % 16 == 0
43
+ output = torch.empty(
44
+ (num_experts, size_k // 16, size_n * (num_bits // 2)),
45
+ device=b_q_weight.device,
46
+ dtype=b_q_weight.dtype,
47
+ )
48
+ for e in range(num_experts):
49
+ output[e] = ops.gptq_marlin_repack(
50
+ b_q_weight[e], perm[e], size_k, size_n, num_bits
51
+ )
52
+ return output
53
+
54
+
55
+ def awq_marlin_moe_repack(
56
+ b_q_weight: torch.Tensor,
57
+ perm: torch.Tensor,
58
+ size_k: int,
59
+ size_n: int,
60
+ num_bits: int,
61
+ ) -> torch.Tensor:
62
+ num_experts = b_q_weight.shape[0]
63
+ assert size_k % 16 == 0
64
+ output = torch.empty(
65
+ (num_experts, size_k // 16, size_n * (num_bits // 2)),
66
+ device=b_q_weight.device,
67
+ dtype=b_q_weight.dtype,
68
+ )
69
+ for e in range(num_experts):
70
+ output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)
71
+ return output
72
+
73
+
74
+ def moe_sum(input: torch.Tensor, output: torch.Tensor):
75
+ ops.moe_sum(input, output)
76
+
77
+
78
+ def moe_align_block_size(
79
+ topk_ids: torch.Tensor,
80
+ num_experts: int,
81
+ block_size: int,
82
+ sorted_token_ids: torch.Tensor,
83
+ experts_ids: torch.Tensor,
84
+ num_tokens_post_pad: torch.Tensor,
85
+ ) -> None:
86
+ ops.moe_align_block_size(
87
+ topk_ids,
88
+ num_experts,
89
+ block_size,
90
+ sorted_token_ids,
91
+ experts_ids,
92
+ num_tokens_post_pad,
93
+ )
94
+
95
+
96
+ def topk_softmax(
97
+ topk_weights: torch.Tensor,
98
+ topk_ids: torch.Tensor,
99
+ token_expert_indicies: torch.Tensor,
100
+ gating_output: float,
101
+ ) -> None:
102
+ ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
103
+
104
+ if hasattr(ops, "marlin_gemm_moe"):
105
+
106
+ @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
107
+ def marlin_gemm_moe_fake(
108
+ a: torch.Tensor,
109
+ b_q_weights: torch.Tensor,
110
+ sorted_ids: torch.Tensor,
111
+ topk_weights: torch.Tensor,
112
+ topk_ids: torch.Tensor,
113
+ b_scales: torch.Tensor,
114
+ b_zero_points: torch.Tensor,
115
+ g_idx: torch.Tensor,
116
+ perm: torch.Tensor,
117
+ workspace: torch.Tensor,
118
+ b_q_type: ScalarType,
119
+ size_m: torch.SymInt,
120
+ size_n: torch.SymInt,
121
+ size_k: torch.SymInt,
122
+ is_k_full: bool,
123
+ num_experts: int,
124
+ topk: int,
125
+ moe_block_size: int,
126
+ replicate_input: bool,
127
+ apply_weights: bool,
128
+ ) -> torch.Tensor:
129
+ return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
130
+
131
+
132
+
133
+ def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
134
+ ops.silu_and_mul(out, x)
135
+ return out
ext-torch/fp8.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Tuple, Optional, Union
4
+
5
+
6
+ def is_hip() -> bool:
7
+ return torch.version.hip is not None
8
+
9
+
10
+ def scaled_fp8_quant(
11
+ input: torch.Tensor,
12
+ scale: Optional[torch.Tensor] = None,
13
+ num_token_padding: Optional[int] = None,
14
+ scale_ub: Optional[torch.Tensor] = None,
15
+ use_per_token_if_dynamic: bool = False,
16
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
17
+ """
18
+ Quantize input tensor to FP8 and return quantized tensor and scale.
19
+
20
+ This function supports both static and dynamic quantization: If you
21
+ provide the scale, it will use static scaling and if you omit it,
22
+ the scale will be determined dynamically. The function also allows
23
+ optional padding of the output tensors for downstream kernels that
24
+ will benefit from padding.
25
+
26
+ Args:
27
+ input: The input tensor to be quantized to FP8
28
+ scale: Optional scaling factor for the FP8 quantization
29
+ scale_ub: Optional upper bound for scaling factor in dynamic
30
+ per token case
31
+ num_token_padding: If specified, pad the first dimension
32
+ of the output to at least this value.
33
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
34
+ in the dynamic quantization case.
35
+
36
+ Returns:
37
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
38
+ scaling factor.
39
+ """
40
+ # This code assumes batch_dim and num_tokens are flattened
41
+ assert input.ndim == 2
42
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
43
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
44
+ out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
45
+ if num_token_padding:
46
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
47
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
48
+
49
+ if scale is None:
50
+ if use_per_token_if_dynamic:
51
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
+ torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
+ output, input, scale, scale_ub
54
+ )
55
+ else:
56
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
+ torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
+ else:
59
+ # num_token_padding not implemented for this case
60
+ assert scale.numel() == 1 or num_token_padding is None
61
+ torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
+
63
+ return output, scale
ext-torch/fused_marlin_moe.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fused MoE utilities for GPTQ."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Optional
5
+
6
+ import torch
7
+
8
+ from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config
9
+ from .scalar_type import scalar_types
10
+ import moe._custom_ops as ops
11
+
12
+
13
+ def get_scalar_type(num_bits: int, has_zp: bool):
14
+ if has_zp:
15
+ assert num_bits == 4
16
+ return scalar_types.uint4
17
+ else:
18
+ return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
19
+
20
+
21
+ def single_marlin_moe(
22
+ hidden_states: torch.Tensor,
23
+ w: torch.Tensor,
24
+ scales: torch.Tensor,
25
+ gating_output: torch.Tensor,
26
+ topk: int,
27
+ renormalize: bool,
28
+ g_idx: Optional[torch.Tensor] = None,
29
+ sort_indices: Optional[torch.Tensor] = None,
30
+ w_zeros: Optional[torch.Tensor] = None,
31
+ override_config: Optional[Dict[str, Any]] = None,
32
+ num_bits: int = 8,
33
+ is_k_full: bool = True,
34
+ ) -> torch.Tensor:
35
+ """
36
+ This function computes the multiplication of hidden_states with expert
37
+ weights used in Marlin MoE, using weights w and top-k gating mechanism.
38
+ Its purpose is testing and debugging the fused MoE kernel.
39
+
40
+ Parameters:
41
+ - hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
42
+ - w (torch.Tensor): The set of expert weights.
43
+ - scales (torch.Tensor): The quantization scales.
44
+ - gating_output (torch.Tensor): The output of the gating operation
45
+ (before softmax).
46
+ - g_idx (Optional[torch.Tensor]): Optional act_order indices.
47
+ - sort_indices (Optional[torch.Tensor]): Optional act_order input
48
+ permutation.
49
+ - topk (int): The number of top-k experts to select.
50
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
51
+ - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
52
+ - override_config (Optional[Dict[str, Any]]): Optional override
53
+ for the kernel configuration.
54
+ - num_bits (bool): The number of bits in expert weights quantization.
55
+
56
+ Returns:
57
+ - torch.Tensor: The output tensor after applying the MoE layer.
58
+ """
59
+ # Check constraints.
60
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
61
+ assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
62
+ assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
63
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
64
+ assert w.is_contiguous(), "Expert weights must be contiguous"
65
+ assert hidden_states.dtype == torch.float16
66
+ assert num_bits in [4, 8]
67
+
68
+ M, K = hidden_states.shape
69
+ E = w.shape[0]
70
+ N = w.shape[2] // (num_bits // 2)
71
+
72
+ topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize)
73
+
74
+ # This might not be an optimal config for a single MMM
75
+ get_config_func = functools.partial(
76
+ try_get_optimal_moe_config,
77
+ w.shape,
78
+ w.shape,
79
+ topk_ids.shape[1],
80
+ None,
81
+ override_config=override_config,
82
+ is_marlin=True,
83
+ )
84
+ config = get_config_func(M)
85
+
86
+ block_size_m = config["BLOCK_SIZE_M"]
87
+
88
+ sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
89
+
90
+ max_workspace_size = (N // 64) * 16
91
+ workspace = torch.zeros(
92
+ max_workspace_size,
93
+ dtype=torch.int,
94
+ device=hidden_states.device,
95
+ requires_grad=False,
96
+ )
97
+
98
+ has_zero_point = w_zeros is not None
99
+ if w_zeros is None:
100
+ w_zeros = torch.empty(
101
+ (0, 0),
102
+ dtype=hidden_states.dtype,
103
+ device=hidden_states.device,
104
+ requires_grad=False,
105
+ )
106
+
107
+ if g_idx is None:
108
+ g_idx = torch.empty(
109
+ (0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
110
+ )
111
+
112
+ if sort_indices is None:
113
+ sort_indices = torch.empty(
114
+ (0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
115
+ )
116
+
117
+ scalar_type = get_scalar_type(num_bits, has_zero_point)
118
+
119
+ intermediate_cache = ops.ops.marlin_gemm_moe(
120
+ hidden_states,
121
+ w,
122
+ sorted_token_ids,
123
+ topk_weights,
124
+ topk_ids,
125
+ scales,
126
+ w_zeros,
127
+ g_idx,
128
+ sort_indices,
129
+ workspace,
130
+ scalar_type.id,
131
+ M,
132
+ N,
133
+ K,
134
+ is_k_full,
135
+ E,
136
+ topk,
137
+ block_size_m,
138
+ True,
139
+ False,
140
+ )
141
+
142
+ return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
143
+
144
+
145
+ def fused_marlin_moe(
146
+ hidden_states: torch.Tensor,
147
+ w1: torch.Tensor,
148
+ w2: torch.Tensor,
149
+ w1_scale: torch.Tensor,
150
+ w2_scale: torch.Tensor,
151
+ gating_output: torch.Tensor,
152
+ topk_weights: torch.Tensor,
153
+ topk_ids: torch.Tensor,
154
+ g_idx1: Optional[torch.Tensor] = None,
155
+ g_idx2: Optional[torch.Tensor] = None,
156
+ sort_indices1: Optional[torch.Tensor] = None,
157
+ sort_indices2: Optional[torch.Tensor] = None,
158
+ w1_zeros: Optional[torch.Tensor] = None,
159
+ w2_zeros: Optional[torch.Tensor] = None,
160
+ override_config: Optional[Dict[str, Any]] = None,
161
+ num_bits: int = 8,
162
+ is_k_full: bool = True,
163
+ ) -> torch.Tensor:
164
+ """
165
+ This function computes a Mixture of Experts (MoE) layer using two sets of
166
+ weights, w1 and w2, and top-k gating mechanism.
167
+
168
+ Parameters:
169
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
170
+ - w1 (torch.Tensor): The first set of expert weights.
171
+ - w2 (torch.Tensor): The second set of expert weights.
172
+ - w1_scale (torch.Tensor): Scale to be used for w1.
173
+ - w2_scale (torch.Tensor): Scale to be used for w2.
174
+ - gating_output (torch.Tensor): The output of the gating operation
175
+ (before softmax).
176
+ - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
177
+ - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
178
+ - sort_indices1 (Optional[torch.Tensor]): The first act_order input
179
+ permutation.
180
+ - sort_indices2 (Optional[torch.Tensor]): The second act_order input
181
+ permutation.
182
+ - topk_weights (torch.Tensor): Top-k weights.
183
+ - topk_ids (torch.Tensor): Indices of topk-k elements.
184
+ - override_config (Optional[Dict[str, Any]]): Optional override
185
+ for the kernel configuration.
186
+ - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
187
+ - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
188
+ - num_bits (bool): The number of bits in expert weights quantization.
189
+
190
+ Returns:
191
+ - torch.Tensor: The output tensor after applying the MoE layer.
192
+ """
193
+ # Check constraints.
194
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
195
+ assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
196
+ assert hidden_states.shape[1] == w2.shape[2] // (
197
+ num_bits // 2
198
+ ), "Hidden size mismatch w2"
199
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
200
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
201
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
202
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
203
+ assert hidden_states.dtype == torch.float16
204
+ assert num_bits in [4, 8]
205
+
206
+ has_no_act_order = (
207
+ g_idx1 is None
208
+ and g_idx2 is None
209
+ and sort_indices1 is None
210
+ and sort_indices2 is None
211
+ )
212
+ has_all_act_order = (
213
+ g_idx1 is not None
214
+ and g_idx2 is not None
215
+ and sort_indices1 is not None
216
+ and sort_indices2 is not None
217
+ )
218
+ assert has_no_act_order or has_all_act_order, (
219
+ "g_idx and sorted_indices " "must be all not None or must be all None"
220
+ )
221
+
222
+ has_no_zp = w1_zeros is None and w2_zeros is None
223
+ has_all_zp = w1_zeros is not None and w2_zeros is not None
224
+ assert has_no_zp or has_all_zp, (
225
+ "zero points must be both not None or " "must be both None"
226
+ )
227
+
228
+ M, K = hidden_states.shape
229
+ E = w1.shape[0]
230
+ N = w2.shape[1] * 16
231
+ topk = topk_ids.shape[1]
232
+
233
+ get_config_func = functools.partial(
234
+ try_get_optimal_moe_config,
235
+ w1.shape,
236
+ w2.shape,
237
+ topk_ids.shape[1],
238
+ None,
239
+ override_config=override_config,
240
+ is_marlin=True,
241
+ )
242
+ config = get_config_func(M)
243
+
244
+ block_size_m = config["BLOCK_SIZE_M"]
245
+
246
+ sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
247
+
248
+ max_workspace_size = (max(2 * N, K) // 64) * 16
249
+ workspace = torch.zeros(
250
+ max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False
251
+ )
252
+
253
+ if has_no_zp:
254
+ w1_zeros = torch.empty(
255
+ (0, 0),
256
+ dtype=hidden_states.dtype,
257
+ device=hidden_states.device,
258
+ requires_grad=False,
259
+ )
260
+ w2_zeros = torch.empty(
261
+ (0, 0),
262
+ dtype=hidden_states.dtype,
263
+ device=hidden_states.device,
264
+ requires_grad=False,
265
+ )
266
+
267
+ if has_no_act_order:
268
+ g_idx1 = torch.empty(
269
+ (0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
270
+ )
271
+ g_idx2 = torch.empty(
272
+ (0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
273
+ )
274
+ sort_indices1 = torch.empty(
275
+ (0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
276
+ )
277
+ sort_indices2 = torch.empty(
278
+ (0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
279
+ )
280
+
281
+ scalar_type1 = get_scalar_type(num_bits, has_all_zp)
282
+ scalar_type2 = get_scalar_type(num_bits, has_all_zp)
283
+
284
+ intermediate_cache2 = torch.empty(
285
+ (M * topk_ids.shape[1], N),
286
+ device=hidden_states.device,
287
+ dtype=hidden_states.dtype,
288
+ )
289
+
290
+ intermediate_cache1 = ops.ops.marlin_gemm_moe(
291
+ hidden_states,
292
+ w1,
293
+ sorted_token_ids,
294
+ topk_weights,
295
+ topk_ids,
296
+ w1_scale,
297
+ w1_zeros,
298
+ g_idx1,
299
+ sort_indices1,
300
+ workspace,
301
+ scalar_type1.id,
302
+ M,
303
+ 2 * N,
304
+ K,
305
+ is_k_full,
306
+ E,
307
+ topk,
308
+ block_size_m,
309
+ True,
310
+ False,
311
+ )
312
+
313
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
314
+
315
+ intermediate_cache3 = ops.ops.marlin_gemm_moe(
316
+ intermediate_cache2,
317
+ w2,
318
+ sorted_token_ids,
319
+ topk_weights,
320
+ topk_ids,
321
+ w2_scale,
322
+ w2_zeros,
323
+ g_idx2,
324
+ sort_indices2,
325
+ workspace,
326
+ scalar_type2.id,
327
+ M,
328
+ K,
329
+ N,
330
+ is_k_full,
331
+ E,
332
+ topk,
333
+ block_size_m,
334
+ False,
335
+ True,
336
+ )
337
+
338
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
ext-torch/fused_moe.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fused MoE kernel."""
2
+
3
+ import functools
4
+ import json
5
+ import os
6
+ from typing import Any, Callable, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from .platforms import current_platform
13
+ from .fp8 import scaled_fp8_quant
14
+ import moe._custom_ops as ops
15
+
16
+ VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
+
18
+
19
+ @triton.jit
20
+ def fused_moe_kernel(
21
+ # Pointers to matrices
22
+ a_ptr,
23
+ b_ptr,
24
+ c_ptr,
25
+ a_scale_ptr,
26
+ b_scale_ptr,
27
+ topk_weights_ptr,
28
+ sorted_token_ids_ptr,
29
+ expert_ids_ptr,
30
+ num_tokens_post_padded_ptr,
31
+ # Matrix dimensions
32
+ N,
33
+ K,
34
+ EM,
35
+ num_valid_tokens,
36
+ # The stride variables represent how much to increase the ptr by when
37
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
38
+ # how much to increase `a_ptr` by to get the element one row down
39
+ # (A has M rows).
40
+ stride_am,
41
+ stride_ak,
42
+ stride_be,
43
+ stride_bk,
44
+ stride_bn,
45
+ stride_cm,
46
+ stride_cn,
47
+ stride_bse,
48
+ stride_bsn,
49
+ # Meta-parameters
50
+ BLOCK_SIZE_M: tl.constexpr,
51
+ BLOCK_SIZE_N: tl.constexpr,
52
+ BLOCK_SIZE_K: tl.constexpr,
53
+ GROUP_SIZE_M: tl.constexpr,
54
+ MUL_ROUTED_WEIGHT: tl.constexpr,
55
+ top_k: tl.constexpr,
56
+ compute_type: tl.constexpr,
57
+ use_fp8_w8a8: tl.constexpr,
58
+ use_int8_w8a16: tl.constexpr,
59
+ ):
60
+ """
61
+ Implements the fused computation for a Mixture of Experts (MOE) using
62
+ token and expert matrices.
63
+
64
+ Key Parameters:
65
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
66
+ be any shape representing batches and K is the feature dimension of
67
+ each token.
68
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
69
+ the number of experts, K is the input feature dimension, and N is
70
+ the output feature dimension.
71
+ - C: The output cache tensor with shape (M, topk, N), where M is the
72
+ total number of tokens post padding, topk is the number of times
73
+ each token is repeated, and N is the output feature dimension.
74
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
75
+ repeated topk times and arranged by the expert index they are
76
+ assigned to.
77
+ - expert_ids: A tensor containing the indices of the expert for each
78
+ block. It determines which expert matrix from B should be used for
79
+ each block in A.
80
+ This kernel performs the multiplication of a token by its corresponding
81
+ expert matrix as determined by `expert_ids`. The sorting of
82
+ `sorted_token_ids` by expert index and padding ensures divisibility by
83
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
84
+ multiplication across different blocks processed by the same expert.
85
+ """
86
+ # -----------------------------------------------------------
87
+ # Map program ids `pid` to the block of C it should compute.
88
+ # This is done in a grouped ordering to promote L2 data reuse.
89
+ pid = tl.program_id(axis=0)
90
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
91
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
92
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
93
+ group_id = pid // num_pid_in_group
94
+ first_pid_m = group_id * GROUP_SIZE_M
95
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
96
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
97
+ pid_n = (pid % num_pid_in_group) // group_size_m
98
+
99
+ # ----------------------------------------------------------
100
+ # Create pointers for the first blocks of A and B.
101
+ # We will advance this pointer as we move in the K direction
102
+ # and accumulate
103
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
104
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
105
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
+ return
108
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
+ token_mask = offs_token < num_valid_tokens
111
+
112
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
114
+ a_ptrs = a_ptr + (
115
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
+ )
117
+
118
+ off_experts = tl.load(expert_ids_ptr + pid_m)
119
+ b_ptrs = (
120
+ b_ptr
121
+ + off_experts * stride_be
122
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
123
+ )
124
+ if use_int8_w8a16:
125
+ b_scale_ptrs = (
126
+ b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
127
+ )
128
+ b_scale = tl.load(b_scale_ptrs)
129
+
130
+ if use_fp8_w8a8:
131
+ a_scale = tl.load(a_scale_ptr)
132
+ b_scale = tl.load(b_scale_ptr + off_experts)
133
+
134
+ # -----------------------------------------------------------
135
+ # Iterate to compute a block of the C matrix.
136
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
137
+ # of fp32 values for higher accuracy.
138
+ # `accumulator` will be converted back to fp16 after the loop.
139
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
140
+
141
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
142
+ # Load the next block of A and B, generate a mask by checking the
143
+ # K dimension.
144
+ a = tl.load(
145
+ a_ptrs,
146
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
147
+ other=0.0,
148
+ )
149
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
150
+ # We accumulate along the K dimension.
151
+ if use_int8_w8a16:
152
+ accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
+ elif use_fp8_w8a8:
154
+ accumulator = tl.dot(a, b, acc=accumulator)
155
+ else:
156
+ accumulator += tl.dot(a, b)
157
+ # Advance the ptrs to the next K block.
158
+ a_ptrs += BLOCK_SIZE_K * stride_ak
159
+ b_ptrs += BLOCK_SIZE_K * stride_bk
160
+
161
+ if MUL_ROUTED_WEIGHT:
162
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
163
+ accumulator = accumulator * moe_weight[:, None]
164
+ if use_int8_w8a16:
165
+ accumulator = (accumulator * b_scale).to(compute_type)
166
+ elif use_fp8_w8a8:
167
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
168
+ else:
169
+ accumulator = accumulator.to(compute_type)
170
+ # -----------------------------------------------------------
171
+ # Write back the block of the output
172
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
173
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
174
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
175
+ tl.store(c_ptrs, accumulator, mask=c_mask)
176
+
177
+
178
+ def moe_align_block_size(
179
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
180
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
+ """
182
+ Aligns the token distribution across experts to be compatible with block
183
+ size for matrix multiplication.
184
+
185
+ Parameters:
186
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
187
+ top-k expert indices for each token.
188
+ - block_size: The block size used in block matrix multiplication.
189
+ - num_experts: The total number of experts.
190
+
191
+ Returns:
192
+ - sorted_token_ids: A tensor containing the sorted token indices according
193
+ to their allocated expert.
194
+ - expert_ids: A tensor indicating the assigned expert index for each block.
195
+ - num_tokens_post_padded: The total number of tokens after padding,
196
+ ensuring divisibility by block_size.
197
+
198
+ This function pads the number of tokens that each expert needs to process
199
+ so that it is divisible by block_size.
200
+ Padding ensures that during block matrix multiplication, the dimensions
201
+ align correctly.
202
+
203
+ Example:
204
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
205
+ block_size = 4, and num_experts = 4:
206
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
207
+ with each expert needing to process 3 tokens.
208
+ - As block_size is 4, we pad 1 token for each expert.
209
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
210
+ - Then append padding tokens [12, 12, 12, 12] for each block.
211
+ - After sorting by expert index, we obtain token_ids
212
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
213
+ Tokens 12 are non-existent (padding) and are ignored in
214
+ the subsequent matrix multiplication.
215
+ - The padding ensures that the total number of tokens is now divisible
216
+ by block_size for proper block matrix operations.
217
+ """
218
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
219
+ sorted_ids = torch.empty(
220
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
221
+ )
222
+ sorted_ids.fill_(topk_ids.numel())
223
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
224
+ expert_ids = torch.empty(
225
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
+ )
227
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
+ ops.moe_align_block_size(
229
+ topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
+ )
231
+ return sorted_ids, expert_ids, num_tokens_post_pad
232
+
233
+
234
+ def invoke_fused_moe_kernel(
235
+ A: torch.Tensor,
236
+ B: torch.Tensor,
237
+ C: torch.Tensor,
238
+ A_scale: Optional[torch.Tensor],
239
+ B_scale: Optional[torch.Tensor],
240
+ topk_weights: torch.Tensor,
241
+ topk_ids: torch.Tensor,
242
+ sorted_token_ids: torch.Tensor,
243
+ expert_ids: torch.Tensor,
244
+ num_tokens_post_padded: torch.Tensor,
245
+ mul_routed_weight: bool,
246
+ top_k: int,
247
+ config: Dict[str, Any],
248
+ compute_type: tl.dtype,
249
+ use_fp8_w8a8: bool,
250
+ use_int8_w8a16: bool,
251
+ ) -> None:
252
+ assert topk_weights.stride(1) == 1
253
+ assert sorted_token_ids.stride(0) == 1
254
+
255
+ if use_fp8_w8a8:
256
+ A, A_scale = scaled_fp8_quant(A, A_scale)
257
+ assert B_scale is not None
258
+ elif use_int8_w8a16:
259
+ assert B_scale is not None
260
+ else:
261
+ assert A_scale is None
262
+ assert B_scale is None
263
+
264
+ grid = lambda META: (
265
+ triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
+ * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
+ )
268
+
269
+ fused_moe_kernel[grid](
270
+ A,
271
+ B,
272
+ C,
273
+ A_scale,
274
+ B_scale,
275
+ topk_weights,
276
+ sorted_token_ids,
277
+ expert_ids,
278
+ num_tokens_post_padded,
279
+ B.shape[1],
280
+ B.shape[2],
281
+ sorted_token_ids.shape[0],
282
+ topk_ids.numel(),
283
+ A.stride(0),
284
+ A.stride(1),
285
+ B.stride(0),
286
+ B.stride(2),
287
+ B.stride(1),
288
+ C.stride(1),
289
+ C.stride(2),
290
+ B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
+ B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
293
+ top_k=top_k,
294
+ compute_type=compute_type,
295
+ use_fp8_w8a8=use_fp8_w8a8,
296
+ use_int8_w8a16=use_int8_w8a16,
297
+ **config,
298
+ )
299
+
300
+
301
+ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
302
+ device_name = current_platform.get_device_name().replace(" ", "_")
303
+ dtype_selector = "" if not dtype else f",dtype={dtype}"
304
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
305
+
306
+
307
+ @functools.lru_cache
308
+ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
309
+ """
310
+ Return optimized configurations for the fused MoE kernel.
311
+
312
+ The return value will be a dictionary that maps an irregular grid of
313
+ batch sizes to configurations of the fused_moe kernel. To evaluate the
314
+ kernel on a given batch size bs, the closest batch size in the grid should
315
+ be picked and the associated configuration chosen to invoke the kernel.
316
+ """
317
+
318
+ # First look up if an optimized configuration is available in the configs
319
+ # directory
320
+ json_file_name = get_config_file_name(E, N, dtype)
321
+
322
+ config_file_path = os.path.join(
323
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
+ )
325
+ if os.path.exists(config_file_path):
326
+ with open(config_file_path) as f:
327
+ # If a configuration has been found, return it
328
+ return {int(key): val for key, val in json.load(f).items()}
329
+
330
+ # If no optimized configuration is available, we will use the default
331
+ # configuration
332
+ return None
333
+
334
+
335
+ def get_default_config(
336
+ M: int,
337
+ E: int,
338
+ N: int,
339
+ K: int,
340
+ topk: int,
341
+ dtype: Optional[str],
342
+ is_marlin: bool,
343
+ ) -> Dict[str, int]:
344
+ config = {
345
+ "BLOCK_SIZE_M": 64,
346
+ "BLOCK_SIZE_N": 64,
347
+ "BLOCK_SIZE_K": 32,
348
+ "GROUP_SIZE_M": 8,
349
+ }
350
+ # A heuristic: fused marlin works faster with this config for small M
351
+ if M <= E or (is_marlin and M <= 32):
352
+ config = {
353
+ "BLOCK_SIZE_M": 16,
354
+ "BLOCK_SIZE_N": 32,
355
+ "BLOCK_SIZE_K": 64,
356
+ "GROUP_SIZE_M": 1,
357
+ }
358
+ return config
359
+
360
+
361
+ def try_get_optimal_moe_config(
362
+ w1_shape: Tuple[int, ...],
363
+ w2_shape: Tuple[int, ...],
364
+ top_k: int,
365
+ dtype: Optional[str],
366
+ M: int,
367
+ override_config: Optional[Dict[str, Any]] = None,
368
+ is_marlin: bool = False,
369
+ ):
370
+ if override_config:
371
+ config = override_config
372
+ else:
373
+ # First try to load optimal config from the file
374
+ E, _, N = w2_shape
375
+ configs = get_moe_configs(E, N, dtype)
376
+
377
+ if configs:
378
+ # If an optimal configuration map has been found, look up the
379
+ # optimal config
380
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
+ else:
382
+ # Else use the default config
383
+ config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
384
+ return config
385
+
386
+
387
+ def fused_topk(
388
+ hidden_states: torch.Tensor,
389
+ gating_output: torch.Tensor,
390
+ topk: int,
391
+ renormalize: bool,
392
+ ):
393
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
394
+
395
+ M, _ = hidden_states.shape
396
+
397
+ topk_weights = torch.empty(
398
+ M, topk, dtype=torch.float32, device=hidden_states.device
399
+ )
400
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
401
+ token_expert_indicies = torch.empty(
402
+ M, topk, dtype=torch.int32, device=hidden_states.device
403
+ )
404
+
405
+ ops.topk_softmax(
406
+ topk_weights,
407
+ topk_ids,
408
+ token_expert_indicies,
409
+ gating_output.float(), # TODO(woosuk): Optimize this.
410
+ )
411
+ del token_expert_indicies # Not used. Will be used in the future.
412
+
413
+ if renormalize:
414
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
415
+
416
+ return topk_weights, topk_ids
417
+
418
+
419
+ # This is used by the Deepseek-V2 model
420
+ def grouped_topk(
421
+ hidden_states: torch.Tensor,
422
+ gating_output: torch.Tensor,
423
+ topk: int,
424
+ renormalize: bool,
425
+ num_expert_group: int = 0,
426
+ topk_group: int = 0,
427
+ ):
428
+
429
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
+
431
+ scores = torch.softmax(gating_output, dim=-1)
432
+ num_token = scores.shape[0]
433
+ group_scores = (
434
+ scores.view(num_token, num_expert_group, -1).max(dim=-1).values
435
+ ) # [n, n_group]
436
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
437
+ 1
438
+ ] # [n, top_k_group]
439
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
440
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
441
+ score_mask = (
442
+ group_mask.unsqueeze(-1)
443
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
444
+ .reshape(num_token, -1)
445
+ ) # [n, e]
446
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
448
+
449
+ if renormalize:
450
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
451
+
452
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
453
+
454
+
455
+ def get_config_dtype_str(
456
+ dtype: torch.dtype,
457
+ use_int8_w8a16: Optional[bool] = False,
458
+ use_fp8_w8a8: Optional[bool] = False,
459
+ ):
460
+ if use_fp8_w8a8:
461
+ return "fp8_w8a8"
462
+ elif use_int8_w8a16:
463
+ return "int8_w8a16"
464
+ elif dtype == torch.float:
465
+ # avoiding cases where kernel fails when float32 MoE
466
+ # use fp16/bfloat16 configs
467
+ return "float32"
468
+ return None
469
+
470
+
471
+ def fused_experts(
472
+ hidden_states: torch.Tensor,
473
+ w1: torch.Tensor,
474
+ w2: torch.Tensor,
475
+ topk_weights: torch.Tensor,
476
+ topk_ids: torch.Tensor,
477
+ inplace: bool = False,
478
+ override_config: Optional[Dict[str, Any]] = None,
479
+ use_fp8_w8a8: bool = False,
480
+ use_int8_w8a16: bool = False,
481
+ w1_scale: Optional[torch.Tensor] = None,
482
+ w2_scale: Optional[torch.Tensor] = None,
483
+ a1_scale: Optional[torch.Tensor] = None,
484
+ a2_scale: Optional[torch.Tensor] = None,
485
+ ):
486
+ # Check constraints.
487
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
488
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
491
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
492
+ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
493
+
494
+ num_tokens, _ = hidden_states.shape
495
+ E, N, _ = w1.shape
496
+ # We execute the fused_moe kernel in chunks to circumvent this issue:
497
+ # https://github.com/vllm-project/vllm/issues/5938
498
+ CHUNK_SIZE = VLLM_FUSED_MOE_CHUNK_SIZE
499
+ M = min(num_tokens, CHUNK_SIZE)
500
+ config_dtype = get_config_dtype_str(
501
+ use_fp8_w8a8=use_fp8_w8a8,
502
+ use_int8_w8a16=use_int8_w8a16,
503
+ dtype=hidden_states.dtype,
504
+ )
505
+
506
+ get_config_func = functools.partial(
507
+ try_get_optimal_moe_config,
508
+ w1.shape,
509
+ w2.shape,
510
+ topk_ids.shape[1],
511
+ config_dtype,
512
+ override_config=override_config,
513
+ )
514
+
515
+ config = get_config_func(M)
516
+
517
+ intermediate_cache1 = torch.empty(
518
+ (M, topk_ids.shape[1], N),
519
+ device=hidden_states.device,
520
+ dtype=hidden_states.dtype,
521
+ )
522
+ intermediate_cache2 = torch.empty(
523
+ (M * topk_ids.shape[1], N // 2),
524
+ device=hidden_states.device,
525
+ dtype=hidden_states.dtype,
526
+ )
527
+ intermediate_cache3 = torch.empty(
528
+ (M, topk_ids.shape[1], w2.shape[1]),
529
+ device=hidden_states.device,
530
+ dtype=hidden_states.dtype,
531
+ )
532
+
533
+ compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
534
+
535
+ if inplace:
536
+ out_hidden_states = hidden_states
537
+ else:
538
+ out_hidden_states = torch.empty_like(hidden_states)
539
+
540
+ for chunk in range((num_tokens // CHUNK_SIZE) + 1):
541
+ begin_chunk_idx, end_chunk_idx = (
542
+ chunk * CHUNK_SIZE,
543
+ min((chunk + 1) * CHUNK_SIZE, num_tokens),
544
+ )
545
+ curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
546
+ tokens_in_chunk, _ = curr_hidden_states.shape
547
+
548
+ if tokens_in_chunk == 0:
549
+ break
550
+
551
+ if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
552
+ # Adjust the intermediate cache size and config for the last
553
+ # chunk. Note that in most cases we only have one chunk
554
+ # so the cache size and config are already set correctly and
555
+ # do not need to be adjusted.
556
+ intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
557
+ intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
558
+ intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
559
+ config = get_config_func(tokens_in_chunk)
560
+
561
+ curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
562
+ curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
563
+
564
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
565
+ curr_topk_ids, config["BLOCK_SIZE_M"], E
566
+ )
567
+
568
+ invoke_fused_moe_kernel(
569
+ curr_hidden_states,
570
+ w1,
571
+ intermediate_cache1,
572
+ a1_scale,
573
+ w1_scale,
574
+ curr_topk_weights,
575
+ curr_topk_ids,
576
+ sorted_token_ids,
577
+ expert_ids,
578
+ num_tokens_post_padded,
579
+ False,
580
+ topk_ids.shape[1],
581
+ config,
582
+ compute_type=compute_type,
583
+ use_fp8_w8a8=use_fp8_w8a8,
584
+ use_int8_w8a16=use_int8_w8a16,
585
+ )
586
+
587
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
588
+
589
+ invoke_fused_moe_kernel(
590
+ intermediate_cache2,
591
+ w2,
592
+ intermediate_cache3,
593
+ a2_scale,
594
+ w2_scale,
595
+ curr_topk_weights,
596
+ curr_topk_ids,
597
+ sorted_token_ids,
598
+ expert_ids,
599
+ num_tokens_post_padded,
600
+ True,
601
+ 1,
602
+ config,
603
+ compute_type=compute_type,
604
+ use_fp8_w8a8=use_fp8_w8a8,
605
+ use_int8_w8a16=use_int8_w8a16,
606
+ )
607
+
608
+ ops.moe_sum(
609
+ intermediate_cache3.view(*intermediate_cache3.shape),
610
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
611
+ )
612
+ return out_hidden_states
613
+
614
+
615
+ def fused_moe(
616
+ hidden_states: torch.Tensor,
617
+ w1: torch.Tensor,
618
+ w2: torch.Tensor,
619
+ gating_output: torch.Tensor,
620
+ topk: int,
621
+ renormalize: bool,
622
+ inplace: bool = False,
623
+ override_config: Optional[Dict[str, Any]] = None,
624
+ use_grouped_topk: bool = False,
625
+ num_expert_group: Optional[int] = None,
626
+ topk_group: Optional[int] = None,
627
+ custom_routing_function: Optional[Callable] = None,
628
+ use_fp8_w8a8: bool = False,
629
+ use_int8_w8a16: bool = False,
630
+ w1_scale: Optional[torch.Tensor] = None,
631
+ w2_scale: Optional[torch.Tensor] = None,
632
+ a1_scale: Optional[torch.Tensor] = None,
633
+ a2_scale: Optional[torch.Tensor] = None,
634
+ ) -> torch.Tensor:
635
+ """
636
+ This function computes a Mixture of Experts (MoE) layer using two sets of
637
+ weights, w1 and w2, and top-k gating mechanism.
638
+
639
+ Parameters:
640
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
641
+ - w1 (torch.Tensor): The first set of expert weights.
642
+ - w2 (torch.Tensor): The second set of expert weights.
643
+ - gating_output (torch.Tensor): The output of the gating operation
644
+ (before softmax).
645
+ - topk (int): The number of top-k experts to select.
646
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
+ - inplace (bool): If True, perform the operation in-place.
648
+ Defaults to False.
649
+ - override_config (Optional[Dict[str, Any]]): Optional override
650
+ for the kernel configuration.
651
+ - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
+ - topk_group: Optional[int]: additional parameter for grouped_topk
653
+ - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
+ note: Deepseekv2 model uses grouped_topk
655
+ - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
+ products for w1 and w2. Defaults to False.
657
+ - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
+ products for w1 and w2. Defaults to False.
659
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
+ w1.
661
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
+ w2.
663
+
664
+ Returns:
665
+ - torch.Tensor: The output tensor after applying the MoE layer.
666
+ """
667
+ # Check constraints.
668
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
669
+
670
+ if use_grouped_topk:
671
+ assert num_expert_group is not None and topk_group is not None
672
+ topk_weights, topk_ids = grouped_topk(
673
+ hidden_states,
674
+ gating_output,
675
+ topk,
676
+ renormalize,
677
+ num_expert_group,
678
+ topk_group,
679
+ )
680
+ elif custom_routing_function is None:
681
+ topk_weights, topk_ids = fused_topk(
682
+ hidden_states, gating_output, topk, renormalize
683
+ )
684
+ else:
685
+ topk_weights, topk_ids = custom_routing_function(
686
+ hidden_states, gating_output, topk, renormalize
687
+ )
688
+
689
+ return fused_experts(
690
+ hidden_states,
691
+ w1,
692
+ w2,
693
+ topk_weights,
694
+ topk_ids,
695
+ inplace=inplace,
696
+ override_config=override_config,
697
+ use_fp8_w8a8=use_fp8_w8a8,
698
+ use_int8_w8a16=use_int8_w8a16,
699
+ w1_scale=w1_scale,
700
+ w2_scale=w2_scale,
701
+ a1_scale=a1_scale,
702
+ a2_scale=a2_scale,
703
+ )
ext-torch/platforms.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, ParamSpec, TypeVar
2
+ import os
3
+ from functools import lru_cache, wraps
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+ class CudaPlatform:
10
+ @classmethod
11
+ @lru_cache(maxsize=8)
12
+ def get_device_name(cls, device_id: int = 0) -> str:
13
+ return torch.cuda.get_device_name(0)
14
+
15
+ class RocmPlatform:
16
+ @classmethod
17
+ @lru_cache(maxsize=8)
18
+ def get_device_name(cls, device_id: int = 0) -> str:
19
+ return torch.cuda.get_device_name(device_id)
20
+
21
+
22
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
ext-torch/registration.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <Python.h>
4
+
5
+ #define _CONCAT(A, B) A##B
6
+ #define CONCAT(A, B) _CONCAT(A, B)
7
+
8
+ #define _STRINGIFY(A) #A
9
+ #define STRINGIFY(A) _STRINGIFY(A)
10
+
11
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
+ // could be a macro instead of a literal token.
13
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
+
15
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
+ // could be a macro instead of a literal token.
17
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
+
20
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
+ // via python's import statement.
22
+ #define REGISTER_EXTENSION(NAME) \
23
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
+ return PyModule_Create(&module); \
27
+ }
ext-torch/scalar_type.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Optional, Union
6
+
7
+
8
+ # Mirrors enum in `core/scalar_type.hpp`
9
+ class NanRepr(Enum):
10
+ NONE = 0 # nans are not supported
11
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
12
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
13
+
14
+
15
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
16
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
17
+ # in sync until the inductor fully supports custom C++ classes.
18
+ @dataclass(frozen=True)
19
+ class ScalarType:
20
+ """
21
+ ScalarType can represent a wide range of floating point and integer
22
+ types, in particular it can be used to represent sub-byte data types
23
+ (something that torch.dtype currently does not support). It is also
24
+ capable of representing types with a bias, i.e.:
25
+ `stored_value = value + bias`,
26
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
27
+ of 8). The implementation for this class can be found in
28
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
29
+ with that file.
30
+ """
31
+
32
+ exponent: int
33
+ """
34
+ Number of bits in the exponent if this is a floating point type
35
+ (zero if this an integer type)
36
+ """
37
+
38
+ mantissa: int
39
+ """
40
+ Number of bits in the mantissa if this is a floating point type,
41
+ or the number bits representing an integer excluding the sign bit if
42
+ this an integer type.
43
+ """
44
+
45
+ signed: bool
46
+ "If the type is signed (i.e. has a sign bit)"
47
+
48
+ bias: int
49
+ """
50
+ bias used to encode the values in this scalar type
51
+ (value = stored_value - bias, default 0) for example if we store the
52
+ type as an unsigned integer with a bias of 128 then the value 0 will be
53
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
54
+ """
55
+
56
+ _finite_values_only: bool = False
57
+ """
58
+ Private: if infs are supported, used `has_infs()` instead.
59
+ """
60
+
61
+ nan_repr: NanRepr = NanRepr.IEEE_754
62
+ """
63
+ How NaNs are represent in this scalar type, returns NanRepr value.
64
+ (not applicable for integer types)
65
+ """
66
+
67
+ def _floating_point_max_int(self) -> int:
68
+ assert (
69
+ self.mantissa <= 52 and self.exponent <= 11
70
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
71
+
72
+ max_mantissa = (1 << self.mantissa) - 1
73
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
74
+ max_mantissa = max_mantissa - 1
75
+
76
+ max_exponent = (1 << self.exponent) - 2
77
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
78
+ or self.nan_repr == NanRepr.NONE):
79
+ assert (
80
+ self.exponent < 11
81
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
82
+ max_exponent = max_exponent + 1
83
+
84
+ # adjust the exponent to match that of a double
85
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
86
+ # e is the exponent bits), there is some precedent for non-standard
87
+ # biases, example `float8_e4m3b11fnuz` here:
88
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
89
+ # complication we are just assuming the standard exponent bias until
90
+ # there is a need to support non-standard biases
91
+ exponent_bias = (1 << (self.exponent - 1)) - 1
92
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
93
+
94
+ max_exponent_double = (max_exponent - exponent_bias +
95
+ exponent_bias_double)
96
+
97
+ # shift the mantissa and exponent into the proper positions for an
98
+ # IEEE double and bitwise-or them together.
99
+ return (max_mantissa <<
100
+ (52 - self.mantissa)) | (max_exponent_double << 52)
101
+
102
+ def _floating_point_max(self) -> float:
103
+ double_raw = self._floating_point_max_int()
104
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
105
+
106
+ def _raw_max(self) -> Union[int, float]:
107
+ if self.is_floating_point():
108
+ return self._floating_point_max()
109
+ else:
110
+ assert (self.size_bits < 64 or self.size_bits == 64
111
+ and self.is_signed()), "Cannot represent max as an int"
112
+ return (1 << self.mantissa) - 1
113
+
114
+ def _raw_min(self) -> Union[int, float]:
115
+ if self.is_floating_point():
116
+ assert self.is_signed(
117
+ ), "We currently assume all floating point types are signed"
118
+ sign_bit_double = 1 << 63
119
+
120
+ max_raw = self._floating_point_max_int()
121
+ min_raw = max_raw | sign_bit_double
122
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
+ else:
124
+ assert (not self.is_signed() or
125
+ self.size_bits <= 64), "Cannot represent min as a int64_t"
126
+
127
+ if self.is_signed():
128
+ return -(1 << (self.size_bits - 1))
129
+ else:
130
+ return 0
131
+
132
+ @functools.cached_property
133
+ def id(self) -> int:
134
+ """
135
+ Convert the ScalarType to an int which can be passed to pytorch custom
136
+ ops. This layout of the int must be kept in sync with the C++
137
+ ScalarType's from_id method.
138
+ """
139
+ val = 0
140
+ offset = 0
141
+
142
+ def or_and_advance(member, bit_width):
143
+ nonlocal val
144
+ nonlocal offset
145
+ bit_mask = (1 << bit_width) - 1
146
+ val = val | (int(member) & bit_mask) << offset
147
+ offset = offset + bit_width
148
+
149
+ or_and_advance(self.exponent, 8)
150
+ or_and_advance(self.mantissa, 8)
151
+ or_and_advance(self.signed, 1)
152
+ or_and_advance(self.bias, 32)
153
+ or_and_advance(self._finite_values_only, 1)
154
+ or_and_advance(self.nan_repr.value, 8)
155
+
156
+ assert offset <= 64, \
157
+ f"ScalarType fields too big {offset} to fit into an int64"
158
+
159
+ return val
160
+
161
+ @property
162
+ def size_bits(self) -> int:
163
+ return self.exponent + self.mantissa + int(self.signed)
164
+
165
+ def min(self) -> Union[int, float]:
166
+ """
167
+ Min representable value for this scalar type.
168
+ (accounting for bias if there is one)
169
+ """
170
+ return self._raw_min() - self.bias
171
+
172
+ def max(self) -> Union[int, float]:
173
+ """
174
+ Max representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_max() - self.bias
178
+
179
+ def is_signed(self) -> bool:
180
+ """
181
+ If the type is signed (i.e. has a sign bit), same as `signed`
182
+ added for consistency with:
183
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
184
+ """
185
+ return self.signed
186
+
187
+ def is_floating_point(self) -> bool:
188
+ "If the type is a floating point type"
189
+ return self.exponent != 0
190
+
191
+ def is_integer(self) -> bool:
192
+ "If the type is an integer type"
193
+ return self.exponent == 0
194
+
195
+ def has_bias(self) -> bool:
196
+ "If the type has a non-zero bias"
197
+ return self.bias != 0
198
+
199
+ def has_infs(self) -> bool:
200
+ "If the type is floating point and supports infinity"
201
+ return not self._finite_values_only
202
+
203
+ def has_nans(self) -> bool:
204
+ return self.nan_repr != NanRepr.NONE.value
205
+
206
+ def is_ieee_754(self) -> bool:
207
+ """
208
+ If the type is a floating point type that follows IEEE 754
209
+ conventions
210
+ """
211
+ return self.nan_repr == NanRepr.IEEE_754.value and \
212
+ not self._finite_values_only
213
+
214
+ def __str__(self) -> str:
215
+ """
216
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
217
+ for floating point types (leading f) the scheme is:
218
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
219
+ flags:
220
+ - no-flags: means it follows IEEE 754 conventions
221
+ - f: means finite values only (no infinities)
222
+ - n: means nans are supported (non-standard encoding)
223
+ for integer types the scheme is:
224
+ `[u]int<size_bits>[b<bias>]`
225
+ - if bias is not present it means its zero
226
+ """
227
+ if self.is_floating_point():
228
+ ret = "float" + str(self.size_bits) + "_e" + str(
229
+ self.exponent) + "m" + str(self.mantissa)
230
+
231
+ if not self.is_ieee_754():
232
+ if self._finite_values_only:
233
+ ret = ret + "f"
234
+ if self.nan_repr != NanRepr.NONE:
235
+ ret = ret + "n"
236
+
237
+ return ret
238
+ else:
239
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
240
+ if self.has_bias():
241
+ ret = ret + "b" + str(self.bias)
242
+ return ret
243
+
244
+ def __repr__(self) -> str:
245
+ return "ScalarType." + self.__str__()
246
+
247
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
248
+ # opcheck to work.
249
+ def __len__(self) -> int:
250
+ raise TypeError
251
+
252
+ #
253
+ # Convenience Constructors
254
+ #
255
+
256
+ @classmethod
257
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
259
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
260
+ ret.id # noqa B018: make sure the id is cached
261
+ return ret
262
+
263
+ @classmethod
264
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ """Create a unsigned integer scalar type."""
266
+ ret = cls(0, size_bits, False, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
272
+ """
273
+ Create a standard floating point type
274
+ (i.e. follows IEEE 754 conventions).
275
+ """
276
+ assert (mantissa > 0 and exponent > 0)
277
+ ret = cls(exponent, mantissa, True, 0)
278
+ ret.id # noqa B018: make sure the id is cached
279
+ return ret
280
+
281
+ @classmethod
282
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
283
+ nan_repr: NanRepr) -> 'ScalarType':
284
+ """
285
+ Create a non-standard floating point type
286
+ (i.e. does not follow IEEE 754 conventions).
287
+ """
288
+ assert (mantissa > 0 and exponent > 0)
289
+ assert (nan_repr != NanRepr.IEEE_754), (
290
+ "use `float_IEEE754` constructor for floating point types that "
291
+ "follow IEEE 754 conventions")
292
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
293
+ ret.id # noqa B018: make sure the id is cached
294
+ return ret
295
+
296
+
297
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
+ # for floating point types (leading f) the scheme is:
299
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
300
+ # flags:
301
+ # - no-flags: means it follows IEEE 754 conventions
302
+ # - f: means finite values only (no infinities)
303
+ # - n: means nans are supported (non-standard encoding)
304
+ # for integer types the scheme is:
305
+ # `[u]int<size_bits>[b<bias>]`
306
+ # - if bias is not present it means its zero
307
+
308
+
309
+ class scalar_types:
310
+ int4 = ScalarType.int_(4, None)
311
+ uint4 = ScalarType.uint(4, None)
312
+ int8 = ScalarType.int_(8, None)
313
+ uint8 = ScalarType.uint(8, None)
314
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
315
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
316
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
317
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
318
+
319
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
+
322
+ # "gptq" types
323
+ uint2b2 = ScalarType.uint(2, 2)
324
+ uint3b4 = ScalarType.uint(3, 4)
325
+ uint4b8 = ScalarType.uint(4, 8)
326
+ uint8b128 = ScalarType.uint(8, 128)
327
+
328
+ # colloquial names
329
+ bfloat16 = float16_e8m7
330
+ float16 = float16_e5m10
ext-torch/torch_binding.cpp ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ // Activation used in fused MoE layers.
8
+ ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
9
+ ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
10
+
11
+ // Apply topk softmax to the gating outputs.
12
+ ops.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
13
+ "token_expert_indices, Tensor gating_output) -> ()");
14
+ ops.impl("topk_softmax", torch::kCUDA, &topk_softmax);
15
+
16
+ // Calculate the result of moe by summing up the partial results
17
+ // from all selected experts.
18
+ ops.def("moe_sum(Tensor! input, Tensor output) -> ()");
19
+ ops.impl("moe_sum", torch::kCUDA, &moe_sum);
20
+
21
+ // Aligning the number of tokens to be processed by each expert such
22
+ // that it is divisible by the block size.
23
+ ops.def("moe_align_block_size(Tensor topk_ids, int num_experts,"
24
+ " int block_size, Tensor! sorted_token_ids,"
25
+ " Tensor! experts_ids,"
26
+ " Tensor! num_tokens_post_pad) -> ()");
27
+ ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
28
+
29
+ #ifndef USE_ROCM
30
+ ops.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
31
+ "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
32
+ "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
33
+ "int b_q_type, SymInt size_m, "
34
+ "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
35
+ "topk, "
36
+ "int moe_block_size, bool replicate_input, bool apply_weights)"
37
+ " -> Tensor");
38
+ #endif
39
+ }
40
+
41
+ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
42
+ ops.impl("marlin_gemm_moe", &marlin_gemm_moe);
43
+ }
44
+
45
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
ext-torch/torch_binding.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ #include <core/scalar_type.hpp>
6
+
7
+ void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
8
+
9
+ void topk_softmax(torch::Tensor &topk_weights, torch::Tensor &topk_indices,
10
+ torch::Tensor &token_expert_indices,
11
+ torch::Tensor &gating_output);
12
+
13
+ void moe_sum(torch::Tensor &input, torch::Tensor &output);
14
+
15
+ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
16
+ int64_t block_size, torch::Tensor sorted_token_ids,
17
+ torch::Tensor experts_ids,
18
+ torch::Tensor num_tokens_post_pad);
19
+
20
+ #ifndef USE_ROCM
21
+ torch::Tensor marlin_gemm_moe(
22
+ const torch::Tensor &a, const torch::Tensor &b_q_weights,
23
+ const torch::Tensor &sorted_ids, const torch::Tensor &topk_weights,
24
+ const torch::Tensor &topk_ids, const torch::Tensor &b_scales,
25
+ torch::Tensor &b_zeros, const torch::Tensor &g_idx,
26
+ const torch::Tensor &perm, torch::Tensor &workspace,
27
+ vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
28
+ int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
29
+ int64_t moe_block_size, bool replicate_input, bool apply_weights);
30
+ #endif
ext-torch/utils/__init__.py ADDED
File without changes
ext-torch/utils/marlin_utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ from moe.scalar_type import ScalarType, scalar_types
7
+
8
+ from .quant_utils import pack_cols, unpack_cols
9
+
10
+ GPTQ_MARLIN_TILE = 16
11
+ GPTQ_MARLIN_MIN_THREAD_N = 64
12
+ GPTQ_MARLIN_MIN_THREAD_K = 128
13
+ GPTQ_MARLIN_MAX_PARALLEL = 16
14
+
15
+ GPTQ_MARLIN_24_TILE = 16
16
+ GPTQ_MARLIN_24_MIN_THREAD_N = 128
17
+ GPTQ_MARLIN_24_MIN_THREAD_K = 128
18
+ GPTQ_MARLIN_24_MAX_PARALLEL = 64
19
+
20
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
21
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
22
+
23
+ MARLIN_QQQ_TILE = 16
24
+ MARLIN_QQQ_MIN_THREAD_N = 64
25
+ MARLIN_QQQ_MIN_THREAD_K = 128
26
+ MARLIN_QQQ_MAX_PARALLEL = 16
27
+
28
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
29
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
30
+ MARLIN_QQQ_SUPPORTED_SYM = [True]
31
+
32
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
33
+
34
+ # In case there is a performance issue with Marlin, the variable below can be
35
+ # changed to False, which allows Marlin to perform global reductions in fp16
36
+ # precision (instead of fp32), and therefore, save on some memory movements.
37
+ USE_FP32_REDUCE_DEFAULT = True
38
+
39
+
40
+ # For binary size and compile time, we don't support the same types for with and
41
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
42
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
43
+ def query_marlin_supported_quant_types(
44
+ has_zp: bool, device_capability: Optional[int] = None
45
+ ):
46
+ if device_capability is None:
47
+ capability_tuple = torch.cuda.get_device_capability()
48
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
49
+
50
+ if device_capability < 80:
51
+ return []
52
+
53
+ if has_zp:
54
+ # AWQ style, unsigned + runtime zero-point
55
+ return [scalar_types.uint4, scalar_types.uint8]
56
+ else:
57
+ # GPTQ style, unsigned + symmetric bias
58
+ # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
59
+ # to add `scalar_types.float8_e4m3fn` here
60
+ return [scalar_types.uint4b8, scalar_types.uint8b128]
61
+
62
+
63
+ def _check_marlin_supported(
64
+ quant_type: ScalarType,
65
+ group_size: Optional[int],
66
+ has_zp: bool,
67
+ device_capability: Optional[int] = None,
68
+ ) -> Tuple[bool, Optional[str]]:
69
+
70
+ if device_capability is None:
71
+ capability_tuple = torch.cuda.get_device_capability()
72
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
73
+
74
+ supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
75
+
76
+ if quant_type not in supported_types:
77
+ return (
78
+ False,
79
+ f"Marlin does not support weight_bits = {quant_type}. "
80
+ f"Only types = {supported_types} "
81
+ f"are supported (for group_size = {group_size}, "
82
+ f"device_capability = {device_capability}, zp = {has_zp}).",
83
+ )
84
+ if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
85
+ return (
86
+ False,
87
+ f"Marlin does not support group_size = {group_size}. "
88
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
89
+ "are supported.",
90
+ )
91
+
92
+ return True, None
93
+
94
+
95
+ def check_marlin_supported(
96
+ quant_type: ScalarType,
97
+ group_size: int,
98
+ has_zp: bool = False,
99
+ device_capability: Optional[int] = None,
100
+ ) -> bool:
101
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
102
+ return cond
103
+
104
+
105
+ def verify_marlin_supported(
106
+ quant_type: ScalarType, group_size: int, has_zp: bool = False
107
+ ) -> None:
108
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
109
+ if not cond:
110
+ assert err_msg is not None
111
+ raise ValueError(err_msg)
112
+
113
+
114
+ def verify_marlin_supports_shape(
115
+ output_size_per_partition: int,
116
+ input_size_per_partition: int,
117
+ input_size: int,
118
+ group_size: int,
119
+ ) -> None:
120
+
121
+ # Validate output_size_per_partition
122
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
123
+ raise ValueError(
124
+ f"Weight output_size_per_partition = "
125
+ f"{output_size_per_partition} is not divisible by "
126
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
127
+ "Consider reducing tensor_parallel_size or running "
128
+ "with --quantization gptq."
129
+ )
130
+
131
+ # Validate input_size_per_partition
132
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
133
+ raise ValueError(
134
+ f"Weight input_size_per_partition = "
135
+ f"{input_size_per_partition} is not divisible "
136
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
137
+ "Consider reducing tensor_parallel_size or running "
138
+ "with --quantization gptq."
139
+ )
140
+
141
+ if group_size < input_size and input_size_per_partition % group_size != 0:
142
+ raise ValueError(
143
+ f"Weight input_size_per_partition = {input_size_per_partition}"
144
+ f" is not divisible by group_size = {group_size}."
145
+ "Consider reducing tensor_parallel_size or running "
146
+ "with --quantization gptq."
147
+ )
148
+
149
+
150
+ def check_marlin_supports_shape(
151
+ output_size_per_partition: int,
152
+ input_size_per_partition: int,
153
+ input_size: int,
154
+ group_size: int,
155
+ ) -> Tuple[bool, Optional[str]]:
156
+ try:
157
+ verify_marlin_supports_shape(
158
+ output_size_per_partition, input_size_per_partition, input_size, group_size
159
+ )
160
+ except ValueError as e:
161
+ return False, e.__str__()
162
+ return True, None
163
+
164
+
165
+ def marlin_make_workspace(
166
+ output_size_per_partition: int, device: torch.device
167
+ ) -> torch.Tensor:
168
+ max_workspace_size = (
169
+ output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
170
+ ) * GPTQ_MARLIN_MAX_PARALLEL
171
+
172
+ return torch.zeros(
173
+ max_workspace_size, dtype=torch.int, device=device, requires_grad=False
174
+ )
175
+
176
+
177
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
178
+ return (not act_order) or (act_order and not is_row_parallel)
179
+
180
+
181
+ def marlin_repeat_scales_on_all_ranks(
182
+ act_order: bool, group_size: int, is_row_parallel: bool
183
+ ) -> bool:
184
+ # Need to repeat scales on every rank if act_ordering or
185
+ # channelwise and RowParallelLinear
186
+ is_channelwise = group_size == -1
187
+ return act_order or (is_channelwise and is_row_parallel)
188
+
189
+
190
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
191
+ return torch.nn.Parameter(
192
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
193
+ )
194
+
195
+
196
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
197
+ return torch.nn.Parameter(
198
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
199
+ )
200
+
201
+
202
+ def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
203
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
204
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
205
+
206
+
207
+ def get_scale_perms():
208
+ scale_perm: List[int] = []
209
+ for i in range(8):
210
+ scale_perm.extend([i + 8 * j for j in range(8)])
211
+ scale_perm_single: List[int] = []
212
+ for i in range(4):
213
+ scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
214
+ return scale_perm, scale_perm_single
215
+
216
+
217
+ def marlin_permute_scales(
218
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
219
+ ) -> torch.Tensor:
220
+
221
+ scale_perm, scale_perm_single = get_scale_perms()
222
+ if group_size < size_k and group_size != -1:
223
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
224
+ else:
225
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
226
+ s = s.reshape((-1, size_n)).contiguous()
227
+
228
+ return s
229
+
230
+
231
+ def marlin_moe_permute_scales(
232
+ s: torch.Tensor,
233
+ size_k: int,
234
+ size_n: int,
235
+ group_size: int,
236
+ ):
237
+ num_experts = s.shape[0]
238
+ output = torch.empty(
239
+ (num_experts, s.shape[1], s.shape[2]),
240
+ device=s.device,
241
+ dtype=s.dtype,
242
+ )
243
+
244
+ for e in range(num_experts):
245
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
246
+ return output
247
+
248
+
249
+ def marlin_zero_points(
250
+ zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
251
+ ) -> torch.Tensor:
252
+ # Permute zero-points in a similar way to scales, but do not use the
253
+ # "single" permutation, since zero-points are applied on every MMA
254
+ scale_perm, _ = get_scale_perms()
255
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
256
+
257
+ # Interleave column dim (for the dequantize code) and pack it to int32
258
+ if num_bits == 4:
259
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
260
+ elif num_bits == 8:
261
+ interleave = numpy.array([0, 2, 1, 3])
262
+ else:
263
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
264
+
265
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
266
+ zp = zp.reshape((-1, size_n)).contiguous()
267
+ zp = pack_cols(zp, num_bits, size_k, size_n)
268
+
269
+ return zp
270
+
271
+
272
+ def awq_to_marlin_zero_points(
273
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
274
+ ) -> torch.Tensor:
275
+ # AWQ zero-points are quantized and packed on the column dim.
276
+ # In addition, the values are permuted based on dequantizer.
277
+ # Here we undo both of these, and then apply marlin permutation
278
+ # and pack it back.
279
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
280
+
281
+ # Undo interleaving (use argsort(..) to get inverse perm)
282
+ if num_bits == 4:
283
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
284
+ elif num_bits == 8:
285
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
286
+ else:
287
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
288
+
289
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
290
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
291
+
292
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
293
+ return marlin_zp
294
+
295
+
296
+ def moe_awq_to_marlin_zero_points(
297
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
298
+ ):
299
+ num_experts = q_zp_packed.shape[0]
300
+ output = torch.empty(
301
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
302
+ device=q_zp_packed.device,
303
+ dtype=q_zp_packed.dtype,
304
+ )
305
+ for e in range(num_experts):
306
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
307
+ return output
ext-torch/utils/marlin_utils_test.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from moe.scalar_type import ScalarType
9
+
10
+ from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
11
+ from .quant_utils import (
12
+ get_pack_factor,
13
+ gptq_quantize_weights,
14
+ quantize_weights,
15
+ sort_weights,
16
+ )
17
+
18
+
19
+ class MarlinWorkspace:
20
+
21
+ def __init__(self, out_features, min_thread_n, max_parallel):
22
+ assert (
23
+ out_features % min_thread_n == 0
24
+ ), "out_features = {} is undivisible by min_thread_n = {}".format(
25
+ out_features, min_thread_n
26
+ )
27
+
28
+ max_workspace_size = (out_features // min_thread_n) * max_parallel
29
+
30
+ self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
31
+
32
+
33
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
34
+ assert q_w.shape == (size_k, size_n)
35
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
36
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
37
+
38
+ # Permute weights to 16x64 marlin tiles
39
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
40
+ q_w = q_w.permute((0, 2, 1, 3))
41
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
42
+
43
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
44
+
45
+ return q_w
46
+
47
+
48
+ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
49
+ # Permute
50
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
51
+
52
+ # Pack
53
+ pack_factor = get_pack_factor(num_bits)
54
+ orig_device = q_w.device
55
+
56
+ q_w = q_w.cpu().numpy().astype(np.uint32)
57
+
58
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
59
+ for i in range(pack_factor):
60
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
61
+
62
+ q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
63
+
64
+ return q_packed
65
+
66
+
67
+ def get_weight_perm(num_bits: int):
68
+ perm_list: List[int] = []
69
+ for i in range(32):
70
+ perm1: List[int] = []
71
+ col = i // 4
72
+ for block in [0, 1]:
73
+ for row in [
74
+ 2 * (i % 4),
75
+ 2 * (i % 4) + 1,
76
+ 2 * (i % 4 + 4),
77
+ 2 * (i % 4 + 4) + 1,
78
+ ]:
79
+ perm1.append(16 * row + col + 8 * block)
80
+ for j in range(4):
81
+ perm_list.extend([p + 256 * j for p in perm1])
82
+
83
+ perm = np.array(perm_list)
84
+
85
+ if num_bits == 4:
86
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
87
+ elif num_bits == 8:
88
+ interleave = np.array([0, 2, 1, 3])
89
+ else:
90
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
91
+
92
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
93
+ perm = torch.from_numpy(perm)
94
+ return perm
95
+
96
+
97
+ def marlin_quantize(
98
+ w: torch.Tensor,
99
+ quant_type: ScalarType,
100
+ group_size: int,
101
+ act_order: bool,
102
+ test_perm: Optional[torch.Tensor] = None,
103
+ ):
104
+ size_k, size_n = w.shape
105
+ num_bits = quant_type.size_bits
106
+
107
+ # Normalize group_size
108
+ if group_size == -1:
109
+ group_size = size_k
110
+ assert group_size <= size_k
111
+
112
+ # Quantize (and apply act_order if provided)
113
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
114
+ w, quant_type, group_size, act_order, test_perm
115
+ )
116
+
117
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
118
+ # increasing
119
+ sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
120
+ if act_order:
121
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
122
+
123
+ # Reformat to marlin
124
+ weight_perm = get_weight_perm(num_bits)
125
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
126
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
127
+
128
+ # Create result
129
+ res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
130
+ for i in range(len(res_list)):
131
+ res_list[i] = res_list[i].to(w.device)
132
+
133
+ return res_list
134
+
135
+
136
+ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
137
+ size_k, size_n = w.shape
138
+
139
+ # Normalize group_size
140
+ if group_size == -1:
141
+ group_size = size_k
142
+ assert group_size <= size_k
143
+
144
+ # Detect num groups
145
+ assert size_k % group_size == 0
146
+ num_groups = size_k // group_size
147
+
148
+ # Quantize with zp
149
+ w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
150
+
151
+ # Reformat to marlin
152
+ weight_perm = get_weight_perm(quant_type.size_bits)
153
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
154
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
155
+ marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
156
+
157
+ # Create result
158
+ res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
159
+ for i in range(len(res_list)):
160
+ res_list[i] = res_list[i].to(w.device)
161
+
162
+ return res_list
ext-torch/utils/quant_utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is used for /tests and /benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy
6
+ import torch
7
+
8
+ from moe.scalar_type import ScalarType, scalar_types
9
+
10
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
12
+
13
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
14
+
15
+ # Note: this is a hack. We should update each model to register the
16
+ # stacked params and get it from there instead in a future PR.
17
+ # fused_name: List[shard_name]
18
+ FUSED_LAYER_NAME_MAPPING = {
19
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
20
+ "gate_up_proj": ["gate_proj", "up_proj"],
21
+ }
22
+
23
+
24
+ def pack_quantized_values_into_int32(
25
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
26
+ ):
27
+ # move dim to pack to the end
28
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
29
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
30
+ w_q_perm = w_q.permute(perm)
31
+
32
+ pack_factor = 32 // wtype.size_bits
33
+ mask = (1 << wtype.size_bits) - 1
34
+
35
+ new_shape_perm = list(w_q_perm.shape)
36
+ assert w_q_perm.shape[-1] % pack_factor == 0
37
+ new_shape_perm[-1] //= pack_factor
38
+
39
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
40
+ for i in range(pack_factor):
41
+ res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
42
+
43
+ return res.permute(inv_perm)
44
+
45
+
46
+ def unpack_quantized_values_into_int32(
47
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
48
+ ):
49
+ # move dim to pack to the end
50
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
51
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
52
+ w_q_perm = w_q.permute(perm)
53
+
54
+ pack_factor = 32 // wtype.size_bits
55
+ mask = (1 << wtype.size_bits) - 1
56
+
57
+ new_shape_perm = list(w_q_perm.shape)
58
+ new_shape_perm[-1] *= pack_factor
59
+
60
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
61
+ for i in range(pack_factor):
62
+ res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
63
+
64
+ return res.permute(inv_perm)
65
+
66
+
67
+ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
68
+ # prefix: model.layers.0.self_attn.q_proj
69
+ # proj_name: q_proj
70
+ proj_name = prefix.split(".")[-1]
71
+ if proj_name in FUSED_LAYER_NAME_MAPPING:
72
+ shard_prefixes = [
73
+ prefix.replace(proj_name, shard_proj_name)
74
+ for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
75
+ ]
76
+
77
+ is_skipped = None
78
+ for shard_prefix in shard_prefixes:
79
+ is_shard_skipped = shard_prefix in ignored_layers
80
+
81
+ if is_skipped is None:
82
+ is_skipped = is_shard_skipped
83
+ elif is_shard_skipped != is_skipped:
84
+ raise ValueError(
85
+ f"Detected some but not all shards of {prefix} "
86
+ "are quantized. All shards of fused layers "
87
+ "to have the same precision."
88
+ )
89
+ else:
90
+ is_skipped = prefix in ignored_layers
91
+
92
+ assert is_skipped is not None
93
+ return is_skipped
94
+
95
+
96
+ def get_pack_factor(num_bits):
97
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
98
+ return 32 // num_bits
99
+
100
+
101
+ def permute_rows(
102
+ q_w: torch.Tensor,
103
+ w_ref: torch.Tensor,
104
+ group_size: int,
105
+ test_perm: Optional[torch.Tensor] = None,
106
+ ):
107
+ assert q_w.shape == w_ref.shape
108
+
109
+ orig_device = q_w.device
110
+ k_size, _ = q_w.shape
111
+
112
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
113
+ for i in range(k_size):
114
+ g_idx[i] = i // group_size
115
+
116
+ # Simulate act_order by doing a random permutation on K
117
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
118
+
119
+ g_idx = g_idx[rand_perm].contiguous()
120
+ q_w = q_w[rand_perm, :].contiguous()
121
+ w_ref = w_ref[rand_perm, :].contiguous()
122
+
123
+ return (
124
+ w_ref.to(device=orig_device),
125
+ q_w.to(device=orig_device),
126
+ g_idx.to(device=orig_device),
127
+ rand_perm.to(device=orig_device),
128
+ )
129
+
130
+
131
+ def quantize_weights(
132
+ w: torch.Tensor,
133
+ quant_type: ScalarType,
134
+ group_size: Optional[int],
135
+ zero_points: bool = False,
136
+ ref_zero_points_after_scales: bool = False,
137
+ ):
138
+ assert (
139
+ quant_type.is_integer()
140
+ ), "Floating point quantization may work but has not been tested"
141
+ assert not zero_points or group_size is not None, (
142
+ "to have group zero points, group_size must be provided "
143
+ "(-1 group_size is channelwise)"
144
+ )
145
+
146
+ orig_device = w.device
147
+ orig_type = w.dtype
148
+ size_k, size_n = w.shape
149
+
150
+ assert w.is_floating_point(), "w must be float"
151
+
152
+ if group_size == -1:
153
+ group_size = size_k
154
+
155
+ # Reshape to [groupsize, -1]
156
+ if group_size is not None and group_size < size_k:
157
+ w = w.reshape((-1, group_size, size_n))
158
+ w = w.permute(1, 0, 2)
159
+ w = w.reshape((group_size, -1))
160
+
161
+ # Compute scale for each group
162
+ max_val = torch.max(w, 0, keepdim=True).values
163
+ min_val = torch.min(w, 0, keepdim=True).values
164
+
165
+ max_q_val = quant_type.max()
166
+ min_q_val = quant_type.min()
167
+
168
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
169
+ maybe_w_zp = None
170
+ if group_size is not None:
171
+ if zero_points:
172
+ assert not quant_type.is_signed() and quant_type.max() > 0
173
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
174
+ maybe_w_zp = (
175
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
176
+ )
177
+ else:
178
+ # If the bias is such that there are no possible negative/positive
179
+ # values, set the max value to inf to avoid divide by 0
180
+ w_s = torch.max(
181
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
182
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
183
+ )
184
+
185
+ # Quantize
186
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
187
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
188
+
189
+ # Compute ref (dequantized)
190
+ # For some kernels (namely Machete) the zero-points are applied after the
191
+ # scales are applied, for this case computing the reference in similar way
192
+ # allows us to use tighter error tolerances in our unit tests.
193
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
194
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
195
+ else:
196
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
197
+
198
+ if quant_type.has_bias():
199
+ w_q += quant_type.bias
200
+
201
+ # Restore original shapes
202
+ if group_size is not None and group_size < size_k:
203
+
204
+ def reshape_w(w):
205
+ w = w.reshape((group_size, -1, size_n))
206
+ w = w.permute(1, 0, 2)
207
+ w = w.reshape((size_k, size_n)).contiguous()
208
+ return w
209
+
210
+ w_q = reshape_w(w_q)
211
+ w_ref = reshape_w(w_ref)
212
+ w_s = w_s.reshape((-1, size_n)).contiguous()
213
+
214
+ if maybe_w_zp is not None:
215
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
216
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
217
+
218
+ return (
219
+ w_ref.to(device=orig_device),
220
+ w_q.to(device=orig_device),
221
+ w_s if group_size is not None else None,
222
+ maybe_w_zp,
223
+ )
224
+
225
+
226
+ def gptq_quantize_weights(
227
+ w: torch.Tensor,
228
+ quant_type: ScalarType,
229
+ group_size: int,
230
+ act_order: bool,
231
+ test_perm: Optional[torch.Tensor] = None,
232
+ ):
233
+ size_k, _ = w.shape
234
+
235
+ assert w.is_floating_point(), "w must be float"
236
+ assert (
237
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
238
+ ), f"Unsupported gptq type = {quant_type}"
239
+ assert group_size in SUPPORTED_GROUP_SIZES + [
240
+ size_k
241
+ ], f"Unsupported groupsize = {group_size}"
242
+
243
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
244
+
245
+ # Apply act_order
246
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
247
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
248
+ if act_order:
249
+ assert (
250
+ group_size < size_k
251
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
252
+ group_size, size_k
253
+ )
254
+
255
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
256
+
257
+ return w_ref, w_q, w_s, g_idx, rand_perm
258
+
259
+
260
+ # QQQ employs different quant schemes for per-group and
261
+ # per-channel quantization.
262
+ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
263
+ orig_device = w.device
264
+ size_k, size_n = w.shape
265
+
266
+ assert w.is_floating_point(), "w must be float"
267
+ assert (
268
+ num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
269
+ ), f"Unsupported num_bits = {num_bits}"
270
+ assert group_size in SUPPORTED_GROUP_SIZES + [
271
+ size_k
272
+ ], f"Unsupported groupsize = {group_size}"
273
+
274
+ if group_size == -1:
275
+ group_size = size_k
276
+ assert group_size <= size_k
277
+
278
+ if group_size < size_k:
279
+ # Reshape to [groupsize, -1]
280
+ w = w.reshape((-1, group_size, size_n))
281
+ w = w.permute(1, 0, 2)
282
+ w = w.reshape((group_size, -1))
283
+
284
+ max_q_val = 2**num_bits - 1
285
+ half_q_val = (max_q_val + 1) // 2
286
+
287
+ # Compute scale for each group
288
+ s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
289
+ s_group *= 2 / max_q_val # 2 => symmetric
290
+
291
+ # Quantize
292
+ q_w = torch.round(w / s_group).int()
293
+ q_w += half_q_val
294
+ q_w = torch.clamp(q_w, 0, max_q_val)
295
+ # Compute ref (dequantized)
296
+ w_ref = (q_w - half_q_val).half() * s_group
297
+
298
+ # Restore original shapes
299
+ def reshape_w(w):
300
+ w = w.reshape((group_size, -1, size_n))
301
+ w = w.permute(1, 0, 2)
302
+ w = w.reshape((size_k, size_n)).contiguous()
303
+ return w
304
+
305
+ q_w = reshape_w(q_w)
306
+ w_ref = reshape_w(w_ref)
307
+
308
+ # Compute int8 quantization scale for each channel
309
+ s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
310
+ s_channel /= 127.0
311
+ t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
312
+ w_ref = t_int8.half() * s_channel
313
+ s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
314
+
315
+ # Fuse scales
316
+ s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
317
+ dtype=torch.half
318
+ )
319
+ else:
320
+ max_q_val = 2 ** (num_bits - 1) - 1
321
+
322
+ # Compute scale for each channel
323
+ s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
324
+ s_channel /= max_q_val
325
+
326
+ # Quantize
327
+ q_w = torch.round(w / s_channel).int()
328
+ q_w = torch.clamp(q_w, -max_q_val, max_q_val)
329
+ # Compute ref (dequantized)
330
+ w_ref = q_w.half() * s_channel
331
+
332
+ s_group = torch.tensor([], dtype=torch.half)
333
+ # div 2 ** (8 - self.bits)) to offset right shift in unpacking
334
+ s_channel /= 2 ** (8 - num_bits)
335
+ s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
336
+
337
+ return (
338
+ w_ref.to(device=orig_device),
339
+ q_w.to(device=orig_device),
340
+ s_group.to(device=orig_device),
341
+ s_channel.to(device=orig_device),
342
+ )
343
+
344
+
345
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
346
+ orig_device = q_w.device
347
+
348
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
349
+
350
+ g_idx = g_idx[sort_indices].contiguous()
351
+ q_w = q_w[sort_indices, :].contiguous()
352
+
353
+ return (
354
+ q_w.to(device=orig_device),
355
+ g_idx.to(device=orig_device),
356
+ sort_indices.to(device=orig_device),
357
+ )
358
+
359
+
360
+ def pack_rows(
361
+ q_w: torch.Tensor,
362
+ num_bits: int,
363
+ size_k: int,
364
+ size_n: int,
365
+ ):
366
+ assert q_w.shape == (size_k, size_n)
367
+
368
+ pack_factor = get_pack_factor(num_bits)
369
+ assert size_k % pack_factor == 0
370
+
371
+ orig_device = q_w.device
372
+
373
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
374
+
375
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
376
+
377
+ for i in range(pack_factor):
378
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
379
+
380
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
381
+ return q_res
382
+
383
+
384
+ def pack_cols(
385
+ q_w: torch.Tensor,
386
+ num_bits: int,
387
+ size_k: int,
388
+ size_n: int,
389
+ ):
390
+ assert q_w.shape == (size_k, size_n)
391
+
392
+ pack_factor = get_pack_factor(num_bits)
393
+ assert size_n % pack_factor == 0
394
+
395
+ orig_device = q_w.device
396
+
397
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
398
+
399
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
400
+
401
+ for i in range(pack_factor):
402
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
403
+
404
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
405
+ q_res = q_res.contiguous()
406
+
407
+ return q_res
408
+
409
+
410
+ def unpack_cols(
411
+ packed_q_w: torch.Tensor,
412
+ num_bits: int,
413
+ size_k: int,
414
+ size_n: int,
415
+ ):
416
+ pack_factor = get_pack_factor(num_bits)
417
+ assert size_n % pack_factor == 0
418
+ assert packed_q_w.shape == (
419
+ size_k,
420
+ size_n // pack_factor,
421
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
422
+ packed_q_w.shape, size_k, size_n, pack_factor
423
+ )
424
+
425
+ orig_device = packed_q_w.device
426
+
427
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
428
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
429
+
430
+ mask = (1 << num_bits) - 1
431
+ for i in range(pack_factor):
432
+ vals = packed_q_w_cpu & mask
433
+ packed_q_w_cpu >>= num_bits
434
+ q_res[:, i::pack_factor] = vals
435
+
436
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
437
+ q_res = q_res.contiguous()
438
+
439
+ return q_res
440
+
441
+
442
+ def gptq_pack(
443
+ q_w: torch.Tensor,
444
+ num_bits: int,
445
+ size_k: int,
446
+ size_n: int,
447
+ ):
448
+ return pack_rows(q_w, num_bits, size_k, size_n)
449
+
450
+
451
+ def awq_pack(
452
+ q_w: torch.Tensor,
453
+ num_bits: int,
454
+ size_k: int,
455
+ size_n: int,
456
+ ):
457
+ assert q_w.shape == (size_k, size_n)
458
+
459
+ # Interleave column dim (for the dequantize code) and pack it to int32
460
+ if num_bits == 4:
461
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
462
+ elif num_bits == 8:
463
+ interleave = numpy.array([0, 2, 1, 3])
464
+ else:
465
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
466
+
467
+ q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
468
+ q_w = q_w.reshape((-1, size_n)).contiguous()
469
+
470
+ return pack_cols(q_w, num_bits, size_k, size_n)
flake.nix ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for activation kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "git+ssh://git@github.com/huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
14
+ }
marlin-moe/marlin_kernels/marlin_moe_kernel.h ADDED
@@ -0,0 +1,1616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <cuda.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_runtime.h>
10
+
11
+ #include <iostream>
12
+
13
+ #include "core/scalar_type.hpp"
14
+
15
+ namespace marlin_moe {
16
+
17
+ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
18
+
19
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
20
+
21
+ // Instances of `Vec` are used to organize groups of >>registers<<, as needed
22
+ // for instance as inputs to tensor core operations. Consequently, all
23
+ // corresponding index accesses must be compile-time constants, which is why we
24
+ // extensively use `#pragma unroll` throughout the kernel code to guarantee
25
+ // this.
26
+ template <typename T, int n>
27
+ struct Vec {
28
+ T elems[n];
29
+ __device__ T& operator[](int i) { return elems[i]; }
30
+ };
31
+
32
+ using I4 = Vec<int, 4>;
33
+
34
+ // Matrix fragments for tensor core instructions; their precise layout is
35
+ // documented here:
36
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
37
+ using FragA = Vec<half2, 4>;
38
+ using FragB = Vec<half2, 2>;
39
+ using FragC = Vec<float, 4>;
40
+ using FragS = Vec<half2, 1>; // quantization scales
41
+ using FragZP = Vec<half2, 4>;
42
+
43
+ // Predicated asynchronous global->shared copy; used for inputs A where we apply
44
+ // predication to handle batchsizes that are not multiples of 16.
45
+ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
46
+ bool pred = true) {
47
+ const int BYTES = 16;
48
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
49
+ asm volatile(
50
+ "{\n"
51
+ " .reg .pred p;\n"
52
+ " setp.ne.b32 p, %0, 0;\n"
53
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
54
+ "}\n" ::"r"((int)pred),
55
+ "r"(smem), "l"(glob_ptr), "n"(BYTES));
56
+ }
57
+
58
+ // Asynchronous global->shared copy
59
+ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
60
+ const int BYTES = 16;
61
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
62
+ asm volatile(
63
+ "{\n"
64
+ " cp.async.cg.shared.global [%0], [%1], %2;\n"
65
+ "}\n" ::"r"(smem),
66
+ "l"(glob_ptr), "n"(BYTES));
67
+ }
68
+
69
+ // Async copy fence.
70
+ __device__ inline void cp_async_fence() {
71
+ asm volatile("cp.async.commit_group;\n" ::);
72
+ }
73
+
74
+ // Wait until at most `n` async copy stages are still pending.
75
+ template <int n>
76
+ __device__ inline void cp_async_wait() {
77
+ asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
78
+ }
79
+
80
+ // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
81
+ // output/accumulation.
82
+ __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
83
+ FragC& frag_c) {
84
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
85
+ const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
86
+ float* c = reinterpret_cast<float*>(&frag_c);
87
+ asm volatile(
88
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
89
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
90
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
91
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
92
+ "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
93
+ }
94
+
95
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
96
+ // memory, directly in tensor core layout.
97
+ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
98
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
99
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
100
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
101
+ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
102
+ : "r"(smem));
103
+ }
104
+
105
+ // Lookup-table based 3-input logical operation; explicitly used for
106
+ // dequantization as the compiler does not seem to automatically recognize it in
107
+ // all cases.
108
+ template <int lut>
109
+ __device__ inline int lop3(int a, int b, int c) {
110
+ int res;
111
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
112
+ : "=r"(res)
113
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
114
+ return res;
115
+ }
116
+
117
+ // Constructs destination register by taking bytes from 2 sources (based on
118
+ // mask)
119
+ template <int start_byte, int mask>
120
+ __device__ inline uint32_t prmt(uint32_t a) {
121
+ uint32_t res;
122
+ asm volatile("prmt.b32 %0, %1, %2, %3;\n"
123
+ : "=r"(res)
124
+ : "r"(a), "n"(start_byte), "n"(mask));
125
+ return res;
126
+ }
127
+
128
+ template <vllm::ScalarTypeId w_type_id>
129
+ __device__ inline FragB dequant(int q);
130
+
131
+ // Efficiently dequantize 4bit values packed in an int32 value into a full
132
+ // B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
133
+ // with some small changes:
134
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
135
+ template <>
136
+ __device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
137
+ const int LO = 0x000f000f;
138
+ const int HI = 0x00f000f0;
139
+ const int EX = 0x64006400;
140
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
141
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
142
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
143
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
144
+ // directly into `SUB` and `ADD`.
145
+ const int SUB = 0x64086408;
146
+ const int MUL = 0x2c002c00;
147
+ const int ADD = 0xd480d480;
148
+ FragB frag_b;
149
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
150
+ *reinterpret_cast<const half2*>(&SUB));
151
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
152
+ *reinterpret_cast<const half2*>(&MUL),
153
+ *reinterpret_cast<const half2*>(&ADD));
154
+ return frag_b;
155
+ }
156
+
157
+ // Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16
158
+ // Reference:
159
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
160
+ template <>
161
+ __device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
162
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
163
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
164
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
165
+
166
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
167
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
168
+
169
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
170
+
171
+ FragB frag_b;
172
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
173
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
174
+ frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
175
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
176
+ return frag_b;
177
+ }
178
+
179
+ template <>
180
+ __device__ inline FragB dequant<vllm::kU4.id()>(int q) {
181
+ const int LO = 0x000f000f;
182
+ const int HI = 0x00f000f0;
183
+ const int EX = 0x64006400;
184
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
185
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
186
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
187
+
188
+ const int SUB = 0x64006400;
189
+ const int MUL = 0x2c002c00;
190
+ const int ADD = 0xd400d400;
191
+ FragB frag_b;
192
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
193
+ *reinterpret_cast<const half2*>(&SUB));
194
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
195
+ *reinterpret_cast<const half2*>(&MUL),
196
+ *reinterpret_cast<const half2*>(&ADD));
197
+ return frag_b;
198
+ }
199
+
200
+ template <>
201
+ __device__ inline FragB dequant<vllm::kU8.id()>(int q) {
202
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
203
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
204
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
205
+
206
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
207
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
208
+
209
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
210
+
211
+ FragB frag_b;
212
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
213
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
214
+ frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
215
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
216
+ return frag_b;
217
+ }
218
+
219
+ // Multiply dequantized values by the corresponding quantization scale; used
220
+ // only for grouped quantization.
221
+ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
222
+ half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
223
+ frag_b[0] = __hmul2(frag_b[0], s);
224
+ frag_b[1] = __hmul2(frag_b[1], s);
225
+ }
226
+
227
+ __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
228
+ half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]);
229
+ frag_b[0] = __hsub2(frag_b[0], zp);
230
+ frag_b[1] = __hsub2(frag_b[1], zp);
231
+ }
232
+
233
+ // Same as above, but for act_order (each K is multiplied individually)
234
+ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
235
+ FragS& frag_s_3, FragS& frag_s_4, int i) {
236
+ __half2 s_val_1_2;
237
+ s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i];
238
+ s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i];
239
+
240
+ __half2 s_val_3_4;
241
+ s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i];
242
+ s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i];
243
+
244
+ frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
245
+ frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
246
+ }
247
+
248
+ // Given 2 floats multiply by 2 scales (halves)
249
+ __device__ inline void scale_float(float* c, FragS& s) {
250
+ __half* s_ptr = reinterpret_cast<__half*>(&s);
251
+ c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
252
+ c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
253
+ }
254
+
255
+ // Wait until barrier reaches `count`, then lock for current threadblock.
256
+ __device__ inline void barrier_acquire(int* lock, int count) {
257
+ if (threadIdx.x == 0) {
258
+ int state = -1;
259
+ do
260
+ // Guarantee that subsequent writes by this threadblock will be visible
261
+ // globally.
262
+ asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
263
+ : "=r"(state)
264
+ : "l"(lock));
265
+ while (state != count);
266
+ }
267
+ __syncthreads();
268
+ }
269
+
270
+ // Release barrier and increment visitation count.
271
+ __device__ inline void barrier_release(int* lock, bool reset = false) {
272
+ __syncthreads();
273
+ if (threadIdx.x == 0) {
274
+ if (reset) {
275
+ lock[0] = 0;
276
+ return;
277
+ }
278
+ int val = 1;
279
+ // Make sure that all writes since acquiring this barrier are visible
280
+ // globally, while releasing the barrier.
281
+ asm volatile("fence.acq_rel.gpu;\n");
282
+ asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
283
+ :
284
+ : "l"(lock), "r"(val));
285
+ }
286
+ }
287
+
288
+ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
289
+ const int threads, // number of threads in a threadblock
290
+ const int thread_m_blocks, // number of 16x16 blocks in the m
291
+ // dimension (batchsize) of the
292
+ // threadblock
293
+ const int thread_n_blocks, // same for n dimension (output)
294
+ const int thread_k_blocks, // same for k dimension (reduction)
295
+ const int stages, // number of stages for the async global->shared
296
+ // fetch pipeline
297
+ const bool has_act_order, // whether act_order is enabled
298
+ const bool has_zp, // whether zero-points are enabled
299
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
300
+ // with a separate quantization scale
301
+ >
302
+ __device__ void MarlinMoESingle(
303
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
304
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
305
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
306
+ const int* __restrict__ sorted_ids, // int32 sorted ids of experts
307
+ const float* __restrict__ topk_weights, // float topk weights
308
+ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
309
+ // (k/groupsize)xn
310
+ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
311
+ // (k/groupsize)x(n/pack_factor)
312
+ const int* __restrict__ g_idx, // int32 group indices of shape k
313
+ const int* __restrict__ expert_offsets,
314
+ int num_groups, // number of scale groups per output channel
315
+ int expert_idx, // idx of current expert
316
+ int num_experts, // number of experts
317
+ int topk, // topk parameter of moe
318
+ int prob_m, // batch dimension m
319
+ int prob_n, // output dimension n
320
+ int prob_k, // reduction dimension k
321
+ int tot_m, // total number of rows in A and C
322
+ int* locks, // extra global storage for barrier synchronization
323
+ bool replicate_input, // do we use the same input for each expert?
324
+ bool apply_weights, // apply weights to output
325
+ int current_m_block // current m block to start kernel computation from
326
+ ) {
327
+ static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
328
+ constexpr int pack_factor = 32 / w_type.size_bits();
329
+
330
+ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
331
+ // better partitioning with less reductions
332
+ int parallel = 1;
333
+ if (prob_m > 16 * thread_m_blocks) {
334
+ parallel = prob_m / (16 * thread_m_blocks);
335
+ prob_m = 16 * thread_m_blocks;
336
+ }
337
+
338
+ int k_tiles = prob_k / 16 / thread_k_blocks;
339
+ int n_tiles = prob_n / 16 / thread_n_blocks;
340
+ int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
341
+
342
+ if constexpr (!has_act_order && group_blocks != -1) {
343
+ if (group_blocks >= thread_k_blocks) {
344
+ // Ensure that the number of tiles in each stripe is a multiple of the
345
+ // groupsize; this avoids an annoying special case where a stripe starts
346
+ // in the middle of group.
347
+ iters = (group_blocks / thread_k_blocks) *
348
+ ceildiv(iters, (group_blocks / thread_k_blocks));
349
+ }
350
+ }
351
+
352
+ int slice_row = (iters * blockIdx.x) % k_tiles;
353
+ int slice_col_par = (iters * blockIdx.x) / k_tiles;
354
+ int slice_col = slice_col_par;
355
+ int slice_iters; // number of threadblock tiles in the current slice
356
+ int slice_count =
357
+ 0; // total number of active threadblocks in the current slice
358
+ int slice_idx; // index of threadblock in current slice; numbered bottom to
359
+ // top
360
+
361
+ // We can easily implement parallel problem execution by just remapping
362
+ // indices and advancing global pointers
363
+ if (slice_col_par >= n_tiles) {
364
+ locks += (slice_col_par / n_tiles) * n_tiles;
365
+ slice_col = slice_col_par % n_tiles;
366
+ sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
367
+ }
368
+
369
+ // Compute all information about the current slice which is required for
370
+ // synchronization.
371
+ auto init_slice = [&]() {
372
+ slice_iters =
373
+ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
374
+ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
375
+ if (slice_iters == 0) return;
376
+ if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
377
+ slice_count = 1;
378
+ slice_idx = 0;
379
+ int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
380
+ if (col_first <= k_tiles * (slice_col_par + 1)) {
381
+ int col_off = col_first - k_tiles * slice_col_par;
382
+ slice_count = ceildiv(k_tiles - col_off, iters);
383
+ if (col_off > 0) slice_count++;
384
+ int delta_first = iters * blockIdx.x - col_first;
385
+ if (delta_first < 0 || (col_off == 0 && delta_first == 0))
386
+ slice_idx = slice_count - 1;
387
+ else {
388
+ slice_idx = slice_count - 1 - delta_first / iters;
389
+ if (col_off > 0) slice_idx--;
390
+ }
391
+ }
392
+ if (slice_col == n_tiles) {
393
+ sorted_ids += 16 * thread_m_blocks;
394
+ locks += n_tiles;
395
+ slice_col = 0;
396
+ }
397
+ };
398
+ init_slice();
399
+
400
+ // A sizes/strides
401
+
402
+ // stride of the A matrix in global memory
403
+ int a_gl_stride = prob_k / 8;
404
+ // stride of an A matrix tile in shared memory
405
+ constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
406
+ // delta between subsequent A tiles in global memory
407
+ constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
408
+ // between subsequent accesses within a tile
409
+ int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
410
+ // between shared memory writes
411
+ constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
412
+ // between shared memory tile reads
413
+ constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
414
+ // within a shared memory tile
415
+ constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
416
+ // overall size of a tile
417
+ constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
418
+ // number of shared write iterations for a tile
419
+ constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
420
+
421
+ // B sizes/strides
422
+ int b_gl_stride = 16 * prob_n / (pack_factor * 4);
423
+ constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
424
+ constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
425
+ constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
426
+
427
+ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
428
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
429
+ constexpr int b_sh_wr_delta = threads * b_thread_vecs;
430
+ constexpr int b_sh_rd_delta = threads * b_thread_vecs;
431
+ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
432
+ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
433
+
434
+ // Scale sizes/strides without act_order
435
+ int s_gl_stride = prob_n / 8;
436
+ constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
437
+ constexpr int s_tb_groups =
438
+ !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
439
+ ? thread_k_blocks / group_blocks
440
+ : 1;
441
+ constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
442
+ int s_gl_rd_delta = s_gl_stride;
443
+ // Scale size/strides with act_order
444
+ constexpr int tb_k = 16 * thread_k_blocks;
445
+ constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
446
+ // constexpr int act_s_row_stride = 1;
447
+ // int act_s_col_stride = act_s_row_stride * num_groups;
448
+ int act_s_col_stride = 1;
449
+ int act_s_col_warp_stride = act_s_col_stride * 8;
450
+ int tb_n_warps = thread_n_blocks / 4;
451
+ int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
452
+
453
+ // Zero-points sizes/strides
454
+ int zp_gl_stride = (prob_n / pack_factor) / 4;
455
+ constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
456
+ constexpr int zp_tb_groups = s_tb_groups;
457
+ constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
458
+ int zp_gl_rd_delta = zp_gl_stride;
459
+
460
+ // Global A read index of current thread.
461
+ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
462
+ (threadIdx.x % a_gl_rd_delta_o);
463
+ a_gl_rd += a_gl_rd_delta_o * slice_row;
464
+ // Shared write index of current thread.
465
+ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
466
+ (threadIdx.x % a_gl_rd_delta_o);
467
+ // Shared read index.
468
+ int a_sh_rd =
469
+ a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
470
+ a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
471
+
472
+ int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
473
+ (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
474
+ b_gl_rd += b_sh_stride * slice_col;
475
+ b_gl_rd += b_gl_rd_delta_o * slice_row;
476
+ int b_sh_wr = threadIdx.x * b_thread_vecs;
477
+ int b_sh_rd = threadIdx.x * b_thread_vecs;
478
+
479
+ // For act_order
480
+ constexpr int k_iter_size = tb_k / b_sh_wr_iters;
481
+ int slice_k_start = tb_k * slice_row;
482
+ int slice_k_finish = slice_k_start + tb_k * slice_iters;
483
+ int slice_k_start_shared_fetch = slice_k_start;
484
+ int slice_n_offset = act_s_col_tb_stride * slice_col;
485
+
486
+ // No act_order
487
+ int s_gl_rd;
488
+ if constexpr (!has_act_order) {
489
+ if constexpr (group_blocks == -1) {
490
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
491
+ } else {
492
+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
493
+ s_sh_stride * slice_col + threadIdx.x;
494
+ }
495
+ }
496
+ int s_sh_wr = threadIdx.x;
497
+ bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
498
+
499
+ // Zero-points
500
+ int zp_gl_rd;
501
+ if constexpr (has_zp) {
502
+ if constexpr (group_blocks == -1) {
503
+ zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
504
+ } else {
505
+ zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
506
+ zp_sh_stride * slice_col + threadIdx.x;
507
+ }
508
+ }
509
+ int zp_sh_wr = threadIdx.x;
510
+ bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
511
+
512
+ // We use a different scale layout for grouped and column-wise quantization as
513
+ // we scale a `half2` tile in column-major layout in the former and in
514
+ // row-major in the latter case.
515
+ int s_sh_rd;
516
+ if constexpr (group_blocks != -1)
517
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
518
+ (threadIdx.x % 32) / 4;
519
+ else
520
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
521
+ (threadIdx.x % 32) % 4;
522
+
523
+ // Zero-points have the same read layout as the scales
524
+ // (without column-wise case)
525
+ constexpr int num_col_threads = 8;
526
+ constexpr int num_row_threads = 4;
527
+ constexpr int num_ints_per_thread = 8 / pack_factor;
528
+ int zp_sh_rd;
529
+ if constexpr (has_zp) {
530
+ zp_sh_rd = num_ints_per_thread * num_col_threads *
531
+ ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
532
+ num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
533
+ }
534
+
535
+ int sh_first_group_id = -1;
536
+ int sh_num_groups = -1;
537
+ constexpr int sh_max_num_groups = 32;
538
+
539
+ extern __shared__ int4 sh[];
540
+ // Shared memory storage for global fetch pipelines.
541
+ int4* sh_a = sh;
542
+ int4* sh_b = sh_a + (stages * a_sh_stage);
543
+ int4* sh_g_idx = sh_b + (stages * b_sh_stage);
544
+ int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
545
+ int4* sh_s = sh_zp + (stages * zp_sh_stage);
546
+
547
+ // Precompute which thread should not read memory in which iterations; this is
548
+ // needed if there are more threads than required for a certain tilesize or
549
+ // when the batchsize is not a multiple of 16.
550
+ bool a_sh_wr_pred[a_sh_wr_iters];
551
+ #pragma unroll
552
+ for (int i = 0; i < a_sh_wr_iters; i++) {
553
+ int a_idx = a_sh_wr_delta * i + a_sh_wr;
554
+ int row = a_idx / a_gl_rd_delta_o;
555
+ if (row >= prob_m) {
556
+ a_sh_wr_pred[i] = false;
557
+ } else {
558
+ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
559
+ }
560
+ }
561
+
562
+ // To ensure that writing and reading A tiles to/from shared memory, the
563
+ // latter in fragment format, is fully bank conflict free, we need to use a
564
+ // rather fancy XOR-based layout. The key here is that neither reads nor
565
+ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
566
+ // same shared memory banks. Further, it seems (based on NSight-Compute) that
567
+ // each warp must also write a consecutive memory segment?
568
+ auto transform_a = [&](int i) {
569
+ int row = i / a_gl_rd_delta_o;
570
+ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
571
+ };
572
+ // Since the computation of this remapping is non-trivial and, due to our main
573
+ // loop unrolls, all shared memory accesses are static, we simply precompute
574
+ // both transformed reads and writes.
575
+ int a_sh_wr_trans[a_sh_wr_iters];
576
+ #pragma unroll
577
+ for (int i = 0; i < a_sh_wr_iters; i++)
578
+ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
579
+ int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
580
+ #pragma unroll
581
+ for (int i = 0; i < b_sh_wr_iters; i++) {
582
+ #pragma unroll
583
+ for (int j = 0; j < thread_m_blocks; j++)
584
+ a_sh_rd_trans[i][j] =
585
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
586
+ }
587
+
588
+ // Since B-accesses have non-constant stride they have to be computed at
589
+ // runtime; we break dependencies between subsequent accesses with a tile by
590
+ // maintining multiple pointers (we have enough registers), a tiny
591
+ // optimization.
592
+ const int4* B_ptr[b_sh_wr_iters];
593
+ #pragma unroll
594
+ for (int i = 0; i < b_sh_wr_iters; i++)
595
+ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
596
+
597
+ // Register storage for double buffer of shared memory reads.
598
+ FragA frag_a[2][thread_m_blocks];
599
+ I4 frag_b_quant[2][b_thread_vecs];
600
+ FragC frag_c[thread_m_blocks][4][2];
601
+ FragS frag_s[2][4]; // No act-order
602
+ FragS act_frag_s[2][4][4]; // For act-order
603
+ int frag_qzp[2][num_ints_per_thread]; // Zero-points
604
+ FragZP frag_zp; // Zero-points in fp16
605
+
606
+ // Zero accumulators.
607
+ auto zero_accums = [&]() {
608
+ #pragma unroll
609
+ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
610
+ reinterpret_cast<float*>(frag_c)[i] = 0;
611
+ };
612
+
613
+ auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
614
+ int last_group_id) {
615
+ sh_first_group_id = first_group_id;
616
+ sh_num_groups = last_group_id - first_group_id + 1;
617
+
618
+ if (sh_num_groups < sh_max_num_groups) {
619
+ sh_num_groups = sh_max_num_groups;
620
+ }
621
+
622
+ if (sh_first_group_id + sh_num_groups > num_groups) {
623
+ sh_num_groups = num_groups - sh_first_group_id;
624
+ }
625
+
626
+ int row_offset = first_group_id * s_gl_stride;
627
+
628
+ if (is_async) {
629
+ for (int i = 0; i < sh_num_groups; i++) {
630
+ if (threadIdx.x < s_sh_stride) {
631
+ cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
632
+ &scales_ptr[row_offset + (i * s_gl_stride) +
633
+ slice_n_offset + threadIdx.x]);
634
+ }
635
+ }
636
+ } else {
637
+ for (int i = 0; i < sh_num_groups; i++) {
638
+ if (threadIdx.x < s_sh_stride) {
639
+ sh_s[(i * s_sh_stride) + threadIdx.x] =
640
+ scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
641
+ threadIdx.x];
642
+ }
643
+ }
644
+ }
645
+ };
646
+ // Asynchronously fetch the next A, B and s tile from global to the next
647
+ // shared memory pipeline location.
648
+ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
649
+ if (pred) {
650
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
651
+ #pragma unroll
652
+ for (int i = 0; i < a_sh_wr_iters; i++) {
653
+ int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
654
+ int row = a_idx / a_gl_stride;
655
+ int sorted_row =
656
+ replicate_input ? sorted_ids[row] / topk : sorted_ids[row];
657
+ int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
658
+ if (sorted_row < tot_m * (replicate_input ? 1 : topk) &&
659
+ new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) {
660
+ cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx],
661
+ a_sh_wr_pred[i]);
662
+ }
663
+ }
664
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
665
+ #pragma unroll
666
+ for (int i = 0; i < b_sh_wr_iters; i++) {
667
+ #pragma unroll
668
+ for (int j = 0; j < b_thread_vecs; j++) {
669
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
670
+ }
671
+ B_ptr[i] += b_gl_rd_delta_o;
672
+ }
673
+
674
+ if constexpr (has_act_order) {
675
+ // Fetch g_idx thread-block portion
676
+ int full_pipe = a_off;
677
+ int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
678
+ if (cur_k < prob_k && cur_k < slice_k_finish) {
679
+ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
680
+
681
+ int4 const* cur_g_idx_stage_ptr =
682
+ reinterpret_cast<int4 const*>(&g_idx[cur_k]);
683
+
684
+ if (threadIdx.x < g_idx_stage) {
685
+ cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
686
+ &cur_g_idx_stage_ptr[threadIdx.x]);
687
+ }
688
+ }
689
+ } else {
690
+ if constexpr (group_blocks != -1) {
691
+ int4* sh_s_stage = sh_s + s_sh_stage * pipe;
692
+
693
+ if constexpr (group_blocks >= thread_k_blocks) {
694
+ // Only fetch scales if this tile starts a new group
695
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
696
+ if (s_sh_wr_pred) {
697
+ cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
698
+ }
699
+ s_gl_rd += s_gl_rd_delta;
700
+ }
701
+ } else {
702
+ for (int i = 0; i < s_tb_groups; i++) {
703
+ if (s_sh_wr_pred) {
704
+ cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
705
+ &scales_ptr[s_gl_rd]);
706
+ }
707
+ s_gl_rd += s_gl_rd_delta;
708
+ }
709
+ }
710
+ }
711
+
712
+ if constexpr (has_zp && group_blocks != -1) {
713
+ int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
714
+
715
+ if constexpr (group_blocks >= thread_k_blocks) {
716
+ // Only fetch zero-points if this tile starts a new group
717
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
718
+ if (zp_sh_wr_pred) {
719
+ cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
720
+ }
721
+ zp_gl_rd += zp_gl_rd_delta;
722
+ }
723
+ } else {
724
+ for (int i = 0; i < zp_tb_groups; i++) {
725
+ if (zp_sh_wr_pred) {
726
+ cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
727
+ &zp_ptr[zp_gl_rd]);
728
+ }
729
+ zp_gl_rd += zp_gl_rd_delta;
730
+ }
731
+ }
732
+ }
733
+ }
734
+ }
735
+ // Insert a fence even when we are winding down the pipeline to ensure that
736
+ // waiting is also correct at this point.
737
+ cp_async_fence();
738
+ };
739
+
740
+ auto fetch_zp_to_shared = [&]() {
741
+ if (zp_sh_wr_pred) {
742
+ cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
743
+ }
744
+ };
745
+
746
+ // Wait until the next thread tile has been loaded to shared memory.
747
+ auto wait_for_stage = [&]() {
748
+ // We only have `stages - 2` active fetches since we are double buffering
749
+ // and can only issue the next fetch when it is guaranteed that the previous
750
+ // shared memory load is fully complete (as it may otherwise be
751
+ // overwritten).
752
+ cp_async_wait<stages - 2>();
753
+ __syncthreads();
754
+ };
755
+
756
+ // Load the next sub-tile from the current location in the shared memory pipe
757
+ // into the current register buffer.
758
+ auto fetch_to_registers = [&](int k, int pipe) {
759
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
760
+ #pragma unroll
761
+ for (int i = 0; i < thread_m_blocks; i++)
762
+ ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
763
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
764
+
765
+ #pragma unroll
766
+ for (int i = 0; i < b_thread_vecs; i++) {
767
+ frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
768
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
769
+ }
770
+ };
771
+
772
+ bool is_same_group[stages];
773
+ int same_group_id[stages];
774
+
775
+ auto init_same_group = [&](int pipe) {
776
+ if constexpr (!has_act_order) {
777
+ is_same_group[pipe] = false;
778
+ same_group_id[pipe] = 0;
779
+ return;
780
+ }
781
+
782
+ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
783
+ int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
784
+
785
+ int group_id_1 = sh_g_idx_int_ptr[0];
786
+ int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
787
+
788
+ is_same_group[pipe] = group_id_1 == group_id_2;
789
+ same_group_id[pipe] = group_id_1;
790
+ };
791
+
792
+ auto fetch_scales_to_registers = [&](int k, int full_pipe) {
793
+ int pipe = full_pipe % stages;
794
+
795
+ if constexpr (!has_act_order) {
796
+ // No act-order case
797
+ if constexpr (group_blocks != -1) {
798
+ if constexpr (group_blocks >= thread_k_blocks) {
799
+ int4* sh_s_stage =
800
+ sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
801
+ (pipe / (group_blocks / thread_k_blocks)));
802
+ reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
803
+ } else {
804
+ int warp_id = threadIdx.x / 32;
805
+ int n_warps = thread_n_blocks / 4;
806
+
807
+ int warp_row = warp_id / n_warps;
808
+
809
+ int cur_k = warp_row * 16;
810
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
811
+
812
+ int k_blocks = cur_k / 16;
813
+ int cur_group_id = k_blocks / group_blocks;
814
+
815
+ int4* sh_s_stage = sh_s + s_sh_stage * pipe;
816
+
817
+ reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
818
+ sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
819
+ }
820
+ }
821
+
822
+ return;
823
+ }
824
+
825
+ // Act-order case
826
+
827
+ // Determine K of the "current" thread-block
828
+ int cur_k = slice_k_start + tb_k * full_pipe;
829
+ if (cur_k >= prob_k || cur_k >= slice_k_finish) {
830
+ return;
831
+ }
832
+
833
+ // Reset (to current thread-block) since we read g_idx portion from the
834
+ // shared memory
835
+ cur_k = 0;
836
+
837
+ // Progress to current iteration
838
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
839
+
840
+ // Determine "position" inside the thread-block (based on warp and
841
+ // thread-id)
842
+ int warp_id = threadIdx.x / 32;
843
+ int n_warps =
844
+ thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
845
+
846
+ int warp_row = warp_id / n_warps;
847
+ int warp_col = warp_id % n_warps;
848
+
849
+ cur_k += warp_row * 16;
850
+
851
+ int th_id = threadIdx.x % 32;
852
+ cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
853
+
854
+ int s_col_shift =
855
+ /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
856
+ (th_id / 4) * act_s_col_stride;
857
+
858
+ if (is_same_group[pipe]) {
859
+ if (k % 2 == 0) {
860
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
861
+ sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
862
+ s_col_shift];
863
+ } else {
864
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
865
+ *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
866
+ }
867
+
868
+ for (int i = 1; i < 4; i++) {
869
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
870
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
871
+ }
872
+ return;
873
+ }
874
+
875
+ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
876
+ int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
877
+
878
+ constexpr int k_frag_offsets[4] = {0, 1, 8,
879
+ 9}; // Tensor core offsets per thread
880
+
881
+ #pragma unroll
882
+ for (int i = 0; i < 4; i++) {
883
+ int actual_k = cur_k + k_frag_offsets[i];
884
+
885
+ int group_id = sh_g_idx_int_ptr[actual_k];
886
+ int rel_group_id = group_id - sh_first_group_id;
887
+
888
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
889
+ sh_s[rel_group_id * s_sh_stride + s_col_shift];
890
+ }
891
+ };
892
+
893
+ auto fetch_zp_to_registers = [&](int k, int full_pipe) {
894
+ // This code does not handle group_blocks == 0,
895
+ // which signifies act_order.
896
+ // has_zp implies AWQ, which doesn't have act_order,
897
+ static_assert(!has_zp || group_blocks != 0);
898
+
899
+ if constexpr (has_zp) {
900
+ int pipe = full_pipe % stages;
901
+
902
+ if constexpr (group_blocks == -1) {
903
+ for (int i = 0; i < num_ints_per_thread; i++) {
904
+ frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
905
+ }
906
+
907
+ } else if constexpr (group_blocks >= thread_k_blocks) {
908
+ int4* sh_zp_stage =
909
+ sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
910
+ (pipe / (group_blocks / thread_k_blocks)));
911
+ for (int i = 0; i < num_ints_per_thread; i++) {
912
+ frag_qzp[k % 2][i] =
913
+ (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
914
+ }
915
+ } else {
916
+ int warp_id = threadIdx.x / 32;
917
+ int n_warps = thread_n_blocks / 4;
918
+
919
+ int warp_row = warp_id / n_warps;
920
+
921
+ int cur_k = warp_row * 16;
922
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
923
+
924
+ int k_blocks = cur_k / 16;
925
+ int cur_group_id = 0;
926
+
927
+ // Suppress bogus and persistent divide-by-zero warning
928
+ #pragma nv_diagnostic push
929
+ #pragma nv_diag_suppress divide_by_zero
930
+ cur_group_id = k_blocks / group_blocks;
931
+ #pragma nv_diagnostic pop
932
+
933
+ int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
934
+
935
+ sh_zp_stage += cur_group_id * zp_sh_stride;
936
+
937
+ for (int i = 0; i < num_ints_per_thread; i++) {
938
+ frag_qzp[k % 2][i] =
939
+ (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
940
+ }
941
+ }
942
+ }
943
+ };
944
+
945
+ // Execute the actual tensor core matmul of a sub-tile.
946
+ auto matmul = [&](int k) {
947
+ if constexpr (has_zp) {
948
+ FragB frag_zp_0;
949
+ FragB frag_zp_1;
950
+ int zp_quant_0, zp_quant_1;
951
+
952
+ if constexpr (w_type.size_bits() == 4) {
953
+ zp_quant_0 = frag_qzp[k % 2][0];
954
+ zp_quant_1 = zp_quant_0 >> 8;
955
+ } else {
956
+ static_assert(w_type.size_bits() == 8);
957
+ zp_quant_0 = frag_qzp[k % 2][0];
958
+ zp_quant_1 = frag_qzp[k % 2][1];
959
+ }
960
+
961
+ frag_zp_0 = dequant<w_type_id>(zp_quant_0);
962
+ frag_zp_1 = dequant<w_type_id>(zp_quant_1);
963
+
964
+ frag_zp[0] = frag_zp_0[0];
965
+ frag_zp[1] = frag_zp_0[1];
966
+ frag_zp[2] = frag_zp_1[0];
967
+ frag_zp[3] = frag_zp_1[1];
968
+ }
969
+
970
+ // We have the m dimension as the inner loop in order to encourage overlapping
971
+ // dequantization and matmul operations.
972
+ #pragma unroll
973
+ for (int j = 0; j < 4; j++) {
974
+ int b_quant_0, b_quant_1;
975
+ if constexpr (w_type.size_bits() == 4) {
976
+ b_quant_0 = frag_b_quant[k % 2][0][j];
977
+ b_quant_1 = b_quant_0 >> 8;
978
+ } else {
979
+ static_assert(w_type.size_bits() == 8);
980
+ int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
981
+ b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
982
+ b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
983
+ }
984
+
985
+ FragB frag_b0 = dequant<w_type_id>(b_quant_0);
986
+ FragB frag_b1 = dequant<w_type_id>(b_quant_1);
987
+ // Apply zero-point to frag_b0
988
+ if constexpr (has_zp) {
989
+ sub_zp(frag_b0, frag_zp[j], 0);
990
+ }
991
+
992
+ // Apply scale to frag_b0
993
+ if constexpr (has_act_order) {
994
+ scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
995
+ act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
996
+ } else {
997
+ if constexpr (group_blocks != -1) {
998
+ scale(frag_b0, frag_s[k % 2][j], 0);
999
+ }
1000
+ }
1001
+
1002
+ // Apply zero-point to frag_b1
1003
+ if constexpr (has_zp) {
1004
+ sub_zp(frag_b1, frag_zp[j], 1);
1005
+ }
1006
+
1007
+ // Apply scale to frag_b1
1008
+ if constexpr (has_act_order) {
1009
+ scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
1010
+ act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
1011
+
1012
+ } else {
1013
+ if constexpr (group_blocks != -1) {
1014
+ scale(frag_b1, frag_s[k % 2][j], 1);
1015
+ }
1016
+ }
1017
+
1018
+ #pragma unroll
1019
+ for (int i = 0; i < thread_m_blocks; i++) {
1020
+ mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
1021
+ mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
1022
+ }
1023
+ }
1024
+ };
1025
+
1026
+ // Since we slice across the k dimension of a tile in order to increase the
1027
+ // number of warps while keeping the n dimension of a tile reasonable, we have
1028
+ // multiple warps that accumulate their partial sums of the same output
1029
+ // location; which we have to reduce over in the end. We do in shared memory.
1030
+ auto thread_block_reduce = [&]() {
1031
+ constexpr int red_off = threads / b_sh_stride_threads / 2;
1032
+ if (red_off >= 1) {
1033
+ int red_idx = threadIdx.x / b_sh_stride_threads;
1034
+ constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
1035
+ constexpr int red_sh_delta = b_sh_stride_threads;
1036
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
1037
+ (threadIdx.x % b_sh_stride_threads);
1038
+
1039
+ // Parallel logarithmic shared memory reduction. We make sure to avoid any
1040
+ // unnecessary read or write iterations, e.g., for two warps we write only
1041
+ // once by warp 1 and read only once by warp 0.
1042
+
1043
+ #pragma unroll
1044
+ for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
1045
+ #pragma unroll
1046
+ for (int i = red_off; i > 0; i /= 2) {
1047
+ if (i <= red_idx && red_idx < 2 * i) {
1048
+ #pragma unroll
1049
+ for (int j = 0; j < 4 * 2; j++) {
1050
+ int red_sh_wr =
1051
+ red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
1052
+ if (i < red_off) {
1053
+ float* c_rd =
1054
+ reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
1055
+ float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
1056
+ #pragma unroll
1057
+ for (int k = 0; k < 4; k++)
1058
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
1059
+ c_rd[k] + c_wr[k];
1060
+ }
1061
+ sh[red_sh_wr] =
1062
+ reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
1063
+ }
1064
+ }
1065
+ __syncthreads();
1066
+ }
1067
+ if (red_idx == 0) {
1068
+ #pragma unroll
1069
+ for (int i = 0; i < 4 * 2; i++) {
1070
+ float* c_rd =
1071
+ reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
1072
+ #pragma unroll
1073
+ for (int j = 0; j < 4; j++)
1074
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
1075
+ c_rd[j];
1076
+ }
1077
+ }
1078
+ __syncthreads();
1079
+ }
1080
+ }
1081
+ };
1082
+
1083
+ // Since multiple threadblocks may process parts of the same column slice, we
1084
+ // finally have to globally reduce over the results. As the striped
1085
+ // partitioning minimizes the number of such reductions and our outputs are
1086
+ // usually rather small, we perform this reduction serially in L2 cache.
1087
+ auto global_reduce = [&](bool first = false, bool last = false) {
1088
+ // We are very careful here to reduce directly in the output buffer to
1089
+ // maximize L2 cache utilization in this step. To do this, we write out
1090
+ // results in FP16 (but still reduce with FP32 compute).
1091
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
1092
+ if (threadIdx.x < active_threads) {
1093
+ int c_gl_stride = prob_n / 8;
1094
+ int c_gl_wr_delta_o = 8 * c_gl_stride;
1095
+ int c_gl_wr_delta_i = 4 * (active_threads / 32);
1096
+ int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
1097
+ 4 * (threadIdx.x / 32) + threadIdx.x % 4;
1098
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
1099
+ constexpr int c_sh_wr_delta = active_threads;
1100
+ int c_sh_wr = threadIdx.x;
1101
+
1102
+ int row = (threadIdx.x % 32) / 4;
1103
+
1104
+ if (!first) {
1105
+ // Interestingly, doing direct global accesses here really seems to mess up
1106
+ // the compiler and lead to slowdowns, hence we also use async-copies even
1107
+ // though these fetches are not actually asynchronous.
1108
+ #pragma unroll
1109
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
1110
+ int c_idx =
1111
+ c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
1112
+ int sorted_row = sorted_ids[c_idx / c_gl_stride];
1113
+ int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
1114
+ cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx],
1115
+ sorted_row < tot_m * topk &&
1116
+ (8 * (i / 2) + row < prob_m &&
1117
+ (i < (thread_m_blocks - 1) * 4 ||
1118
+ sorted_ids[8 * (i / 2) + row] < tot_m * topk)));
1119
+ }
1120
+ cp_async_fence();
1121
+ cp_async_wait<0>();
1122
+ }
1123
+
1124
+ #pragma unroll
1125
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
1126
+ if (8 * (i / 2) + row < prob_m &&
1127
+ (i < (thread_m_blocks - 1) * 4 ||
1128
+ sorted_ids[8 * (i / 2) + row] < tot_m * topk)) {
1129
+ if (!first) {
1130
+ int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
1131
+ #pragma unroll
1132
+ for (int j = 0; j < 2 * 4; j++) {
1133
+ reinterpret_cast<float*>(
1134
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
1135
+ __half2float(reinterpret_cast<__half*>(&c_red)[j]);
1136
+ }
1137
+ }
1138
+ if (!last) {
1139
+ int4 c;
1140
+ #pragma unroll
1141
+ for (int j = 0; j < 2 * 4; j++) {
1142
+ reinterpret_cast<__half*>(&c)[j] =
1143
+ __float2half(reinterpret_cast<float*>(
1144
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
1145
+ }
1146
+ int c_idx =
1147
+ c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
1148
+ int row = sorted_ids[c_idx / c_gl_stride];
1149
+ if (row < tot_m * topk) {
1150
+ int new_idx = row * c_gl_stride + c_idx % c_gl_stride;
1151
+ C[new_idx] = c;
1152
+ }
1153
+ }
1154
+ }
1155
+ }
1156
+ }
1157
+ };
1158
+
1159
+ // Write out the reduce final result in the correct layout. We only actually
1160
+ // reshuffle matrix fragments in this step, the reduction above is performed
1161
+ // in fragment layout.
1162
+ auto write_result = [&]() {
1163
+ int c_gl_stride = prob_n / 8;
1164
+ constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
1165
+ int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
1166
+ constexpr int c_sh_rd_delta =
1167
+ c_sh_stride * (threads / (2 * thread_n_blocks));
1168
+
1169
+ int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
1170
+ (threadIdx.x % (2 * thread_n_blocks));
1171
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
1172
+ int c_sh_wr =
1173
+ (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
1174
+ c_sh_wr += 32 * (threadIdx.x / 32);
1175
+ int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
1176
+ (threadIdx.x % (2 * thread_n_blocks));
1177
+
1178
+ int c_gl_wr_end = c_gl_stride * prob_m;
1179
+
1180
+ // We first reorder in shared memory to guarantee the most efficient final
1181
+ // global write patterns
1182
+ auto write = [&](int idx, float c0, float c1, FragS& s) {
1183
+ half2 res = __halves2half2(__float2half(c0), __float2half(c1));
1184
+
1185
+ // For per-column quantization we finally apply the scale here (only for
1186
+ // 4-bit)
1187
+ if constexpr (!has_act_order && group_blocks == -1 &&
1188
+ w_type.size_bits() == 4) {
1189
+ res = __hmul2(res, s[0]);
1190
+ }
1191
+
1192
+ ((half2*)sh)[idx] = res;
1193
+ };
1194
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1195
+ #pragma unroll
1196
+ for (int i = 0; i < thread_m_blocks; i++) {
1197
+ #pragma unroll
1198
+ for (int j = 0; j < 4; j++) {
1199
+ int wr = c_sh_wr + 8 * j;
1200
+ write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
1201
+ frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
1202
+ write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
1203
+ frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
1204
+ write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
1205
+ frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
1206
+ write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
1207
+ frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
1208
+ }
1209
+ c_sh_wr += 16 * (4 * c_sh_stride);
1210
+ }
1211
+ }
1212
+ __syncthreads();
1213
+
1214
+ #pragma unroll
1215
+ for (int i = 0;
1216
+ i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
1217
+ i++) {
1218
+ if (c_gl_wr < c_gl_wr_end) {
1219
+ int row = sorted_ids[c_gl_wr / c_gl_stride];
1220
+ if (row < tot_m * topk) {
1221
+ int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
1222
+ if (!apply_weights) {
1223
+ C[off] = sh[c_sh_rd];
1224
+ } else {
1225
+ __half* ctrg = reinterpret_cast<__half*>(&C[off]);
1226
+ __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
1227
+ for (int j = 0; j < 8; ++j) {
1228
+ ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j]));
1229
+ }
1230
+ }
1231
+ c_gl_wr += c_gl_wr_delta;
1232
+ c_sh_rd += c_sh_rd_delta;
1233
+ }
1234
+ }
1235
+ }
1236
+ };
1237
+
1238
+ // Start global fetch and register load pipelines.
1239
+ auto start_pipes = [&]() {
1240
+
1241
+ #pragma unroll
1242
+ for (int i = 0; i < stages - 1; i++) {
1243
+ if (has_act_order && i == 0) {
1244
+ int last_g_idx = slice_k_start + stages * tb_k * 2;
1245
+ if (last_g_idx >= prob_k) {
1246
+ last_g_idx = prob_k - 1;
1247
+ }
1248
+ fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
1249
+ }
1250
+
1251
+ if constexpr (has_zp && group_blocks == -1) {
1252
+ if (i == 0) {
1253
+ fetch_zp_to_shared();
1254
+ }
1255
+ }
1256
+ fetch_to_shared(i, i, i < slice_iters);
1257
+ }
1258
+
1259
+ zero_accums();
1260
+ wait_for_stage();
1261
+ init_same_group(0);
1262
+ fetch_to_registers(0, 0);
1263
+ fetch_scales_to_registers(0, 0);
1264
+ fetch_zp_to_registers(0, 0);
1265
+ a_gl_rd += a_gl_rd_delta_o * (stages - 1);
1266
+ slice_k_start_shared_fetch += tb_k * (stages - 1);
1267
+ };
1268
+ if (slice_iters) {
1269
+ start_pipes();
1270
+ }
1271
+
1272
+ // Main loop.
1273
+ while (slice_iters) {
1274
+ // We unroll over both the global fetch and the register load pipeline to
1275
+ // ensure all shared memory accesses are static. Note that both pipelines
1276
+ // have even length meaning that the next iteration will always start at
1277
+ // index 0.
1278
+ #pragma unroll
1279
+ for (int pipe = 0; pipe < stages;) {
1280
+ #pragma unroll
1281
+ for (int k = 0; k < b_sh_wr_iters; k++) {
1282
+ fetch_to_registers(k + 1, pipe % stages);
1283
+ fetch_scales_to_registers(k + 1, pipe);
1284
+ fetch_zp_to_registers(k + 1, pipe);
1285
+ if (k == b_sh_wr_iters - 2) {
1286
+ fetch_to_shared((pipe + stages - 1) % stages, pipe,
1287
+ slice_iters >= stages);
1288
+ pipe++;
1289
+ wait_for_stage();
1290
+ init_same_group(pipe % stages);
1291
+ }
1292
+ matmul(k);
1293
+ }
1294
+ slice_iters--;
1295
+ if (slice_iters == 0) {
1296
+ break;
1297
+ }
1298
+ }
1299
+
1300
+ a_gl_rd += a_gl_rd_delta_o * stages;
1301
+ slice_k_start += tb_k * stages;
1302
+ slice_k_start_shared_fetch += tb_k * stages;
1303
+
1304
+ if constexpr (has_act_order) {
1305
+ int first_group_id = g_idx[slice_k_start];
1306
+ int last_g_idx = slice_k_start + stages * tb_k * 2;
1307
+ if (last_g_idx >= prob_k) {
1308
+ last_g_idx = prob_k - 1;
1309
+ }
1310
+ int last_group_id = g_idx[last_g_idx];
1311
+ if (last_group_id >= sh_first_group_id + sh_num_groups) {
1312
+ fetch_scales_to_shared(false, first_group_id, last_group_id);
1313
+ __syncthreads();
1314
+ }
1315
+ }
1316
+
1317
+ // Process results and, if necessary, proceed to the next column slice.
1318
+ // While this pattern may not be the most readable, other ways of writing
1319
+ // the loop seemed to noticeably worse performance after compilation.
1320
+ if (slice_iters == 0) {
1321
+ cp_async_wait<0>();
1322
+ bool last = slice_idx == slice_count - 1;
1323
+ if constexpr (!has_act_order && group_blocks == -1) {
1324
+ if constexpr (w_type.size_bits() == 8) {
1325
+ if (s_sh_wr_pred) {
1326
+ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
1327
+ }
1328
+ cp_async_fence();
1329
+ } else {
1330
+ // For 4-bit per-column scales, we only fetch them here in the
1331
+ // final step before write-out
1332
+ if (last) {
1333
+ if (s_sh_wr_pred) {
1334
+ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
1335
+ }
1336
+ cp_async_fence();
1337
+ }
1338
+ }
1339
+ }
1340
+
1341
+ thread_block_reduce();
1342
+ if constexpr (!has_act_order && group_blocks == -1) {
1343
+ if constexpr (w_type.size_bits() == 8) {
1344
+ cp_async_wait<0>();
1345
+ __syncthreads();
1346
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1347
+ reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
1348
+ reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
1349
+ }
1350
+
1351
+ } else {
1352
+ if (last) {
1353
+ cp_async_wait<0>();
1354
+ __syncthreads();
1355
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1356
+ reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
1357
+ reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
1358
+ }
1359
+ }
1360
+ }
1361
+ }
1362
+
1363
+ // For 8-bit channelwise, we apply the scale before the global reduction
1364
+ // that converts the fp32 results to fp16 (so that we avoid possible
1365
+ // overflow in fp16)
1366
+ if constexpr (!has_act_order && group_blocks == -1 &&
1367
+ w_type.size_bits() == 8) {
1368
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1369
+ #pragma unroll
1370
+ for (int i = 0; i < thread_m_blocks; i++) {
1371
+ #pragma unroll
1372
+ for (int j = 0; j < 4; j++) {
1373
+ scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
1374
+ frag_s[j / 2][2 * (j % 2) + 0]);
1375
+ scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
1376
+ frag_s[j / 2][2 * (j % 2) + 0]);
1377
+
1378
+ scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
1379
+ frag_s[j / 2][2 * (j % 2) + 1]);
1380
+ scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
1381
+ frag_s[j / 2][2 * (j % 2) + 1]);
1382
+ }
1383
+ }
1384
+ }
1385
+ }
1386
+
1387
+ if (slice_count > 1) { // only globally reduce if there is more than one
1388
+ // block in a slice
1389
+ barrier_acquire(&locks[slice_col], slice_idx);
1390
+ global_reduce(slice_idx == 0, last);
1391
+ barrier_release(&locks[slice_col], last);
1392
+ }
1393
+ if (last) // only the last block in a slice actually writes the result
1394
+ write_result();
1395
+ slice_row = 0;
1396
+ slice_col_par++;
1397
+ slice_col++;
1398
+ init_slice();
1399
+ if (slice_iters) {
1400
+ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
1401
+ (threadIdx.x % a_gl_rd_delta_o);
1402
+ #pragma unroll
1403
+ for (int i = 0; i < b_sh_wr_iters; i++)
1404
+ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
1405
+ if (slice_col == 0) {
1406
+ #pragma unroll
1407
+ for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
1408
+ }
1409
+
1410
+ // Update slice k/n for scales loading
1411
+ if constexpr (has_act_order) {
1412
+ slice_k_start = tb_k * slice_row;
1413
+ slice_k_finish = slice_k_start + tb_k * slice_iters;
1414
+ slice_k_start_shared_fetch = slice_k_start;
1415
+ slice_n_offset = act_s_col_tb_stride * slice_col;
1416
+
1417
+ } else {
1418
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
1419
+ zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
1420
+ }
1421
+
1422
+ start_pipes();
1423
+ }
1424
+ }
1425
+ }
1426
+ }
1427
+
1428
+ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
1429
+ const int threads, // number of threads in a threadblock
1430
+ const int thread_n_blocks, // same for n dimension (output)
1431
+ const int thread_k_blocks, // same for k dimension (reduction)
1432
+ const int stages, // number of stages for the async global->shared
1433
+ // fetch pipeline
1434
+ const bool has_act_order, // whether act_order is enabled
1435
+ const bool has_zp, // whether zero-points are enabled
1436
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
1437
+ // with a separate quantization scale
1438
+ >
1439
+ __global__ void MarlinMoE(
1440
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
1441
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
1442
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
1443
+ const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts
1444
+ const float* __restrict__ topk_weights, // float topk weights
1445
+ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
1446
+ // (k/groupsize)xn
1447
+ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
1448
+ // (k/groupsize)x(n/pack_factor)
1449
+ const int* __restrict__ g_idx, // int32 group indices of shape k
1450
+ const int* __restrict__ expert_offsets,
1451
+ int num_groups, // number of scale groups per output channel
1452
+ int expert_idx, // idx of current expert
1453
+ int num_experts, // number of experts
1454
+ int topk, // topk parameter of moe
1455
+ int prob_m, // batch dimension m
1456
+ int prob_n, // output dimension n
1457
+ int prob_k, // reduction dimension k
1458
+ int tot_m, // total number of rows in A and C
1459
+ int* locks, // extra global storage for barrier synchronization
1460
+ bool replicate_input, // do we use the same input for each expert?
1461
+ bool apply_weights, // apply weights to output
1462
+ int current_m_block, // current m block to start kernel computation from
1463
+ int max_par, // maximum parallelism
1464
+ int cfg_max_m_blocks // upper bound on m blocks
1465
+ ) {
1466
+ int m_block_ctr = current_m_block;
1467
+
1468
+ const int* sorted_ids_expert =
1469
+ sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par;
1470
+ int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx];
1471
+ if (tot_its == 0) {
1472
+ return;
1473
+ }
1474
+ int tot_m_blocks = ceildiv(tot_its, 16);
1475
+ int pad = 16 * tot_m_blocks - tot_its;
1476
+
1477
+ if (m_block_ctr >= tot_m_blocks) {
1478
+ return;
1479
+ }
1480
+
1481
+ int max_block = tot_m_blocks - m_block_ctr;
1482
+ prob_m = tot_its - 16 * m_block_ctr;
1483
+
1484
+ int par = 1;
1485
+ if (max_block > cfg_max_m_blocks) {
1486
+ // Note that parallel > 1 currently only works for inputs without any
1487
+ // padding
1488
+ par = (16 * max_block - pad) / (16 * cfg_max_m_blocks);
1489
+ if (par > max_par) par = max_par;
1490
+ prob_m = (16 * cfg_max_m_blocks) * par;
1491
+ m_block_ctr += cfg_max_m_blocks * (par - 1);
1492
+ max_block = cfg_max_m_blocks;
1493
+ }
1494
+
1495
+ if (max_block == 1) {
1496
+ MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
1497
+ stages, has_act_order, has_zp, group_blocks>(
1498
+ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
1499
+ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
1500
+ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
1501
+ current_m_block);
1502
+ } else if (max_block == 2) {
1503
+ MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
1504
+ stages, has_act_order, has_zp, group_blocks>(
1505
+ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
1506
+ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
1507
+ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
1508
+ current_m_block);
1509
+ } else if (max_block == 3) {
1510
+ MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
1511
+ stages, has_act_order, has_zp, group_blocks>(
1512
+ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
1513
+ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
1514
+ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
1515
+ current_m_block);
1516
+ } else {
1517
+ MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
1518
+ stages, has_act_order, has_zp, group_blocks>(
1519
+ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
1520
+ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
1521
+ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
1522
+ current_m_block);
1523
+ }
1524
+ }
1525
+
1526
+ #else
1527
+
1528
+ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
1529
+ const int threads, // number of threads in a threadblock
1530
+ const int thread_n_blocks, // same for n dimension (output)
1531
+ const int thread_k_blocks, // same for k dimension (reduction)
1532
+ const int stages, // number of stages for the async global->shared
1533
+ // fetch pipeline
1534
+ const bool has_act_order, // whether act_order is enabled
1535
+ const bool has_zp, // whether zero-points are enabled
1536
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
1537
+ // with a separate quantization scale
1538
+ >
1539
+ __global__ void MarlinMoE(
1540
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
1541
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
1542
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
1543
+ const int* __restrict__ sorted_ids, // int32 sorted ids of experts
1544
+ const float* __restrict__ topk_weights, // float topk weights
1545
+ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
1546
+ // (k/groupsize)xn
1547
+ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
1548
+ // (k/groupsize)x(n/pack_factor)
1549
+ const int* __restrict__ g_idx, // int32 group indices of shape k
1550
+ const int* __restrict__ expert_offsets,
1551
+ int num_groups, // number of scale groups per output channel
1552
+ int expert_idx, // idx of current expert
1553
+ int num_experts, // number of experts
1554
+ int topk, // topk parameter of moe
1555
+ int prob_m, // batch dimension m
1556
+ int prob_n, // output dimension n
1557
+ int prob_k, // reduction dimension k
1558
+ int tot_m, // total number of rows in A and C
1559
+ int* locks, // extra global storage for barrier synchronization
1560
+ bool replicate_input, // do we use the same input for each expert?
1561
+ bool apply_weights, // apply weights to output
1562
+ int current_m_block, // current m block to start kernel computation from
1563
+ int max_par, // maximum parallelism
1564
+ int cfg_max_m_blocks // upper bound on m blocks
1565
+ ) {
1566
+ // Marlin is not implemented yet for SM < 8.0
1567
+ assert(false);
1568
+ return;
1569
+ }
1570
+
1571
+ #endif
1572
+
1573
+ // 8 warps are a good choice since every SM has 4 schedulers and having more
1574
+ // than 1 warp per schedule allows some more latency hiding. At the same time,
1575
+ // we want relatively few warps to have many registers per warp and small tiles.
1576
+ const int USER_THREADS =
1577
+ 256; // Note: This is only used with user-provided thread_k/n
1578
+ const int STAGES = 4; // 4 pipeline stages fit into shared memory
1579
+
1580
+ static constexpr int min_thread_n = 64;
1581
+ static constexpr int min_thread_k = 64;
1582
+
1583
+ #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
1584
+ HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
1585
+ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
1586
+ thread_k_blocks == THREAD_K_BLOCKS && \
1587
+ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
1588
+ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
1589
+ cudaFuncSetAttribute( \
1590
+ MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
1591
+ STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
1592
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
1593
+ MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
1594
+ STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
1595
+ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
1596
+ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
1597
+ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
1598
+ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
1599
+ replicate_input, apply_weights, m_block, max_par, \
1600
+ cfg_max_m_blocks); \
1601
+ }
1602
+
1603
+ #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1604
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
1605
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
1606
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
1607
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
1608
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
1609
+
1610
+ #define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1611
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
1612
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
1613
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
1614
+ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
1615
+
1616
+ } // namespace marlin_moe
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.cu ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "marlin_moe_kernel_ku4.h"
2
+
3
+ namespace marlin_moe {
4
+
5
+ // We return bool so we can create these different kernel calls as a sequence
6
+ // of if-elseif's.
7
+ bool call_marlin_moe_kernel_ku4(
8
+ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
9
+ bool has_act_order, int group_blocks, int num_threads, int blocks,
10
+ int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
11
+ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
12
+ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
13
+ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
14
+ int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
15
+ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
16
+ int m_block, int max_par, int cfg_max_m_blocks) {
17
+ bool has_zp = true;
18
+
19
+ if (false) {
20
+ }
21
+ AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
22
+ AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
23
+ AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
24
+ AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
25
+ else {
26
+ return false;
27
+ }
28
+ return true;
29
+ }
30
+
31
+ } // namespace marlin_moe
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "marlin_moe_kernel.h"
4
+
5
+ namespace marlin_moe {
6
+
7
+ // We return bool so we can create these different kernel calls as a sequence
8
+ // of if-elseif's.
9
+ bool call_marlin_moe_kernel_ku4(
10
+ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
11
+ bool has_act_order, int group_blocks, int num_threads, int blocks,
12
+ int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
13
+ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
14
+ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
15
+ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
16
+ int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
17
+ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
18
+ int m_block, int max_par, int cfg_max_m_blocks);
19
+
20
+ } // namespace marlin_moe
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "marlin_moe_kernel_ku4b8.h"
2
+
3
+ namespace marlin_moe {
4
+
5
+ // We return bool so we can create these different kernel calls as a sequence
6
+ // of if-elseif's.
7
+ bool call_marlin_moe_kernel_ku4b8(
8
+ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
9
+ bool has_act_order, int group_blocks, int num_threads, int blocks,
10
+ int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
11
+ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
12
+ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
13
+ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
14
+ int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
15
+ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
16
+ int m_block, int max_par, int cfg_max_m_blocks) {
17
+ bool has_zp = false;
18
+
19
+ if (false) {
20
+ }
21
+ GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
22
+ GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
23
+ GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
24
+ GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
25
+ else {
26
+ return false;
27
+ }
28
+ return true;
29
+ }
30
+
31
+ } // namespace marlin_moe
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "marlin_moe_kernel.h"
4
+
5
+ namespace marlin_moe {
6
+
7
+ // We return bool so we can create these different kernel calls as a sequence
8
+ // of if-elseif's.
9
+ bool call_marlin_moe_kernel_ku4b8(
10
+ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
11
+ bool has_act_order, int group_blocks, int num_threads, int blocks,
12
+ int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
13
+ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
14
+ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
15
+ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
16
+ int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
17
+ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
18
+ int m_block, int max_par, int cfg_max_m_blocks);
19
+
20
+ } // namespace marlin_moe
marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "marlin_moe_kernel_ku8b128.h"
2
+
3
+ namespace marlin_moe {
4
+
5
+ // We return bool so we can create these different kernel calls as a sequence
6
+ // of if-elseif's.
7
+ bool call_marlin_moe_kernel_ku8b128(
8
+ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
9
+ bool has_act_order, int group_blocks, int num_threads, int blocks,
10
+ int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
11
+ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
12
+ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
13
+ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
14
+ int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
15
+ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
16
+ int m_block, int max_par, int cfg_max_m_blocks) {
17
+ bool has_zp = false;
18
+
19
+ if (false) {
20
+ }
21
+ GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
22
+ GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
23
+ GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
24
+ GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
25
+ else {
26
+ return false;
27
+ }
28
+ return true;
29
+ }
30
+
31
+ } // namespace marlin_moe
marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "marlin_moe_kernel.h"
4
+
5
+ namespace marlin_moe {
6
+
7
+ bool call_marlin_moe_kernel_ku8b128(
8
+ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
9
+ bool has_act_order, int group_blocks, int num_threads, int blocks,
10
+ int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
11
+ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
12
+ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
13
+ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
14
+ int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
15
+ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
16
+ int m_block, int max_par, int cfg_max_m_blocks);
17
+
18
+ }
marlin-moe/marlin_moe_ops.cu ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Modified by Neural Magic
3
+ * Copyright (C) Marlin.2024 Elias Frantar
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include <torch/all.h>
19
+
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <c10/cuda/CUDAGuard.h>
22
+ #include <cuda.h>
23
+ #include <cuda_fp16.h>
24
+ #include <cuda_runtime.h>
25
+
26
+ #include <iostream>
27
+
28
+ #include "core/exception.hpp"
29
+ #include "core/scalar_type.hpp"
30
+ #include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
31
+ #include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
32
+ #include "marlin_kernels/marlin_moe_kernel_ku4.h"
33
+
34
+ template <typename T>
35
+ inline std::string str(T x) {
36
+ return std::to_string(x);
37
+ }
38
+
39
+ namespace marlin_moe {
40
+
41
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
42
+
43
+ // For a given "a" of size [M,K] performs a permutation of the K columns based
44
+ // on the given "perm" indices.
45
+ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
46
+ int const* __restrict__ perm_int_ptr,
47
+ int4* __restrict__ out_int4_ptr, int size_m,
48
+ int size_k, int block_rows) {
49
+ int start_row = block_rows * blockIdx.x;
50
+ int finish_row = start_row + block_rows;
51
+ if (finish_row > size_m) {
52
+ finish_row = size_m;
53
+ }
54
+ int cur_block_rows = finish_row - start_row;
55
+
56
+ int row_stride = size_k * sizeof(half) / 16;
57
+
58
+ auto permute_row = [&](int row) {
59
+ int iters = size_k / blockDim.x;
60
+ int rest = size_k % blockDim.x;
61
+
62
+ int offset = row * row_stride;
63
+
64
+ half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
65
+ half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
66
+
67
+ int base_k = 0;
68
+
69
+ for (int i = 0; i < iters; i++) {
70
+ int cur_k = base_k + threadIdx.x;
71
+ int src_pos = perm_int_ptr[cur_k];
72
+
73
+ out_half[cur_k] = a_row_half[src_pos];
74
+
75
+ base_k += blockDim.x;
76
+ }
77
+
78
+ if (rest) {
79
+ if (threadIdx.x < rest) {
80
+ int cur_k = base_k + threadIdx.x;
81
+ int src_pos = perm_int_ptr[cur_k];
82
+
83
+ out_half[cur_k] = a_row_half[src_pos];
84
+ }
85
+ }
86
+ };
87
+
88
+ for (int i = 0; i < cur_block_rows; i++) {
89
+ int cur_row = start_row + i;
90
+ if (cur_row < size_m) {
91
+ permute_row(cur_row);
92
+ }
93
+ }
94
+ }
95
+
96
+ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
97
+ int* __restrict__ expert_offsets,
98
+ int topk_length, int block_size) {
99
+ int expert_id = threadIdx.x;
100
+ int num_experts = blockDim.x;
101
+
102
+ int occurrences = 0;
103
+ for (int i = 0; i < topk_length; ++i) {
104
+ occurrences += (topk_ids[i] == expert_id);
105
+ }
106
+ expert_offsets[expert_id + 1] = occurrences;
107
+ __syncthreads();
108
+
109
+ if (threadIdx.x == 0) {
110
+ int tot_offset = 0;
111
+ expert_offsets[0] = 0;
112
+ for (int i = 0; i < num_experts; ++i) {
113
+ tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
114
+ expert_offsets[i + 1] = tot_offset;
115
+ }
116
+ }
117
+ __syncthreads();
118
+ }
119
+
120
+ #else
121
+
122
+ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
123
+ int const* __restrict__ perm_int_ptr,
124
+ int4* __restrict__ out_int4_ptr, int size_m,
125
+ int size_k, int block_rows) {
126
+ // Marlin is not implemented yet for SM < 8.0
127
+ assert(false);
128
+ return;
129
+ }
130
+
131
+ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
132
+ int* __restrict__ expert_offsets,
133
+ int topk_length, int block_size) {
134
+ // Marlin is not implemented yet for SM < 8.0
135
+ assert(false);
136
+ return;
137
+ }
138
+
139
+ #endif
140
+
141
+ typedef struct {
142
+ int thread_k;
143
+ int thread_n;
144
+ int num_threads;
145
+ } thread_config_t;
146
+
147
+ typedef struct {
148
+ int max_m_blocks;
149
+ thread_config_t tb_cfg;
150
+ } exec_config_t;
151
+
152
+ thread_config_t small_batch_thread_configs[] = {
153
+ // Ordered by priority
154
+
155
+ // thread_k, thread_n, num_threads
156
+ {128, 128, 256}, // Default
157
+ {128, 64, 128}, // Reduce N 2X, same K
158
+ {64, 256, 256}, // Reduce K 2X, increase N 2X
159
+ {64, 128, 128}, // Reduce K 2X, same N
160
+ {64, 64, 128}, // Reduce both 2X
161
+ };
162
+
163
+ thread_config_t large_batch_thread_configs[] = {
164
+ // Ordered by priority
165
+
166
+ // thread_k, thread_n, num_threads
167
+ {64, 256, 256}, // Default
168
+ {128, 128, 256}, // Reduce N 2X, increase K 2X
169
+ {64, 128, 128}, // Reduce N 2X, same K
170
+ {128, 64, 128}, // Reduce N 4X, increase K 2X
171
+ {64, 64, 128}, // Reduce N 4X, same K
172
+ };
173
+
174
+ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
175
+ int prob_n, int prob_k, int num_bits, int group_size,
176
+ bool has_act_order, bool is_k_full) {
177
+ bool cache_scales_chunk = has_act_order && !is_k_full;
178
+
179
+ int tb_n = th_config.thread_n;
180
+ int tb_k = th_config.thread_k;
181
+
182
+ // Get max scale groups per thread-block
183
+ int tb_groups;
184
+ if (group_size == -1) {
185
+ tb_groups = 1;
186
+ } else if (group_size == 0) {
187
+ tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
188
+ } else {
189
+ tb_groups = ceildiv(tb_k, group_size);
190
+ }
191
+
192
+ if (cache_scales_chunk) {
193
+ int load_groups =
194
+ tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
195
+ load_groups = max(load_groups, 32); // We load at least 32 scale groups
196
+ return load_groups * tb_n * 4;
197
+
198
+ } else {
199
+ int tb_scales = tb_groups * tb_n * 2;
200
+
201
+ return tb_scales * STAGES;
202
+ }
203
+ }
204
+
205
+ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
206
+ int prob_m, int prob_n, int prob_k, int num_bits,
207
+ int scales_cache_size, int max_shared_mem) {
208
+ int pack_factor = 32 / num_bits;
209
+
210
+ // Get B size
211
+ int tb_k = th_config.thread_k;
212
+ int tb_n = th_config.thread_n;
213
+
214
+ int b_size = (tb_k * tb_n / pack_factor) * 4;
215
+
216
+ // Get A size
217
+ int m_blocks = ceildiv(prob_m, 16);
218
+ int tb_max_m = 16;
219
+
220
+ while (true) {
221
+ if (m_blocks >= max_m_blocks) {
222
+ tb_max_m *= max_m_blocks;
223
+ break;
224
+ }
225
+
226
+ max_m_blocks--;
227
+ if (max_m_blocks == 0) {
228
+ TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
229
+ }
230
+ }
231
+
232
+ int a_size = (tb_max_m * tb_k) * 2;
233
+
234
+ float pipe_size = (a_size + b_size) * STAGES;
235
+
236
+ TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
237
+
238
+ return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
239
+ }
240
+
241
+ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
242
+ int prob_m, int prob_n, int prob_k, int num_bits,
243
+ int group_size, bool has_act_order, bool is_k_full,
244
+ int max_shared_mem) {
245
+ // Sanity
246
+ if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
247
+ th_config.num_threads == -1) {
248
+ return false;
249
+ }
250
+
251
+ // Verify K/N are divisible by thread K/N
252
+ if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
253
+ return false;
254
+ }
255
+
256
+ // thread_k can be only 128 or 64 (because it must be less than groupsize
257
+ // which is 128)
258
+ if (th_config.thread_k != 128 && th_config.thread_k != 64) {
259
+ return false;
260
+ }
261
+
262
+ // Verify min for thread K/N
263
+ if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
264
+ return false;
265
+ }
266
+
267
+ // num_threads must be at least 128 (= 4 warps)
268
+ if (th_config.num_threads < 128) {
269
+ return false;
270
+ }
271
+
272
+ // Determine cache for scales
273
+ int scales_cache_size =
274
+ get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
275
+ group_size, has_act_order, is_k_full);
276
+
277
+ // Check that pipeline fits into cache
278
+ if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
279
+ num_bits, scales_cache_size, max_shared_mem)) {
280
+ return false;
281
+ }
282
+
283
+ return true;
284
+ }
285
+
286
+ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
287
+ int num_bits, int group_size,
288
+ bool has_act_order, bool is_k_full,
289
+ int max_shared_mem) {
290
+ int max_m_blocks = 4;
291
+ while (max_m_blocks > 0) {
292
+ if (prob_m <= 16) {
293
+ for (auto th_config : small_batch_thread_configs) {
294
+ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
295
+ num_bits, group_size, has_act_order, is_k_full,
296
+ max_shared_mem)) {
297
+ return exec_config_t{max_m_blocks, th_config};
298
+ }
299
+ }
300
+ } else {
301
+ for (auto th_config : large_batch_thread_configs) {
302
+ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
303
+ num_bits, group_size, has_act_order, is_k_full,
304
+ max_shared_mem)) {
305
+ return exec_config_t{max_m_blocks, th_config};
306
+ }
307
+ }
308
+ }
309
+
310
+ max_m_blocks--; // Process less M blocks per invocation to reduce cache
311
+ // usage
312
+ }
313
+
314
+ return exec_config_t{0, {-1, -1, -1}};
315
+ }
316
+
317
+ #define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
318
+ else if (KERNEL_FUNCTION( \
319
+ q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
320
+ group_blocks, num_threads, blocks, max_shared_mem, stream, \
321
+ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
322
+ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
323
+ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
324
+ replicate_input, apply_weights, m_block, max_par, \
325
+ exec_cfg.max_m_blocks)) { \
326
+ }
327
+
328
+ void marlin_mm_moe(const void* A, const void* B, void* C,
329
+ const void* sorted_ids, const void* topk_weights,
330
+ const void* topk_ids, const void* s, void* zp,
331
+ const void* g_idx, const void* perm, void* a_tmp,
332
+ void* expert_offsets, int prob_m, int prob_n, int prob_k,
333
+ void* workspace, vllm::ScalarType const& q_type,
334
+ bool has_act_order, bool is_k_full, bool has_zp,
335
+ int num_groups, int group_size, int num_experts, int topk,
336
+ int moe_block_size, int dev, cudaStream_t stream,
337
+ int thread_k, int thread_n, int sms, int max_par,
338
+ bool replicate_input, bool apply_weights) {
339
+ TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
340
+ ", ", prob_n, ", ", prob_k, "]");
341
+
342
+ if (sms == -1) {
343
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
344
+ }
345
+
346
+ int max_shared_mem = 0;
347
+ cudaDeviceGetAttribute(&max_shared_mem,
348
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
349
+ TORCH_CHECK(max_shared_mem > 0);
350
+
351
+ int num_bits = q_type.size_bits();
352
+
353
+ // Set thread config
354
+ exec_config_t exec_cfg;
355
+ if (thread_k != -1 && thread_n != -1) {
356
+ // User-defined config
357
+ exec_cfg =
358
+ exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
359
+ } else {
360
+ // Auto config
361
+ exec_cfg =
362
+ determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
363
+ has_act_order, is_k_full, max_shared_mem);
364
+ }
365
+
366
+ TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
367
+ is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
368
+ prob_m, prob_n, prob_k, num_bits, group_size,
369
+ has_act_order, is_k_full, max_shared_mem),
370
+ "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
371
+ ", thread_k = ", exec_cfg.tb_cfg.thread_k,
372
+ ", thread_n = ", exec_cfg.tb_cfg.thread_n,
373
+ ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
374
+ prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
375
+ ", group_size = ", group_size,
376
+ ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
377
+ ", max_shared_mem = ", max_shared_mem);
378
+
379
+ int num_threads = exec_cfg.tb_cfg.num_threads;
380
+ thread_k = exec_cfg.tb_cfg.thread_k;
381
+ thread_n = exec_cfg.tb_cfg.thread_n;
382
+
383
+ int thread_k_blocks = thread_k / 16;
384
+ int thread_n_blocks = thread_n / 16;
385
+
386
+ int blocks = sms;
387
+
388
+ TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
389
+ " is not divisible by thread_n = ", thread_n);
390
+ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
391
+ " is not divisible by thread_k = ", thread_k);
392
+
393
+ int group_blocks = 0;
394
+ if (has_act_order) {
395
+ if (is_k_full) {
396
+ TORCH_CHECK(group_size != -1);
397
+ group_blocks = group_size / 16;
398
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
399
+ " is not divisible by group_blocks = ", group_blocks);
400
+ } else {
401
+ TORCH_CHECK(group_size == 0);
402
+ group_blocks = 0;
403
+ }
404
+
405
+ } else {
406
+ if (group_size == -1) {
407
+ group_blocks = -1;
408
+ } else {
409
+ group_blocks = group_size / 16;
410
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
411
+ " is not divisible by group_blocks = ", group_blocks);
412
+ }
413
+ }
414
+
415
+ int tot_m = prob_m;
416
+
417
+ const int* topk_ids_ptr = (const int*)topk_ids;
418
+ int* expert_offsets_ptr = (int*)expert_offsets;
419
+ compute_expert_offsets<<<1, num_experts, 0, stream>>>(
420
+ topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
421
+
422
+ bool do_permute_a = has_act_order;
423
+
424
+ // If we have a full K, then we can run the non-act-order version of Marlin
425
+ // (since the weight rows are reordered by increasing group ids, and by
426
+ // having a full K, we have full original groups)
427
+ if (is_k_full) {
428
+ has_act_order = false;
429
+ }
430
+
431
+ int pack_factor = 32 / q_type.size_bits();
432
+
433
+ for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
434
+ const int4* A_ptr = (const int4*)A;
435
+ int4* a_tmp_ptr = (int4*)a_tmp;
436
+ const int4* B_ptr =
437
+ (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
438
+ int4* C_ptr = (int4*)C;
439
+ const float* topk_weights_ptr = (const float*)topk_weights;
440
+ const int* sorted_ids_ptr = (const int*)sorted_ids;
441
+ const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
442
+ const int4* zp_ptr =
443
+ (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
444
+ const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
445
+ const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
446
+ int* locks = (int*)workspace;
447
+
448
+ if (do_permute_a) {
449
+ // Permute A columns
450
+ int topk_rows = replicate_input ? tot_m : tot_m * topk;
451
+ int block_rows = ceildiv(topk_rows, blocks);
452
+ permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
453
+ A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
454
+ A_ptr = a_tmp_ptr;
455
+ }
456
+
457
+ int tot_m_blocks = ceildiv(tot_m, 16);
458
+ for (int m_block = 0; m_block < tot_m_blocks;
459
+ m_block += 4 * exec_cfg.max_m_blocks) {
460
+ if (false) {
461
+ }
462
+ CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
463
+ CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
464
+ CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
465
+ else {
466
+ TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
467
+ str(prob_n) + ", " + str(prob_k) + "]" +
468
+ ", has_act_order = " + str(has_act_order) +
469
+ ", num_groups = " + str(num_groups) +
470
+ ", group_size = " + str(group_size) +
471
+ ", thread_n_blocks = " + str(thread_n_blocks) +
472
+ ", thread_k_blocks = " + str(thread_k_blocks));
473
+ }
474
+ }
475
+ }
476
+ }
477
+
478
+ } // namespace marlin_moe
479
+
480
+ torch::Tensor marlin_gemm_moe(
481
+ const torch::Tensor& a, const torch::Tensor& b_q_weights,
482
+ const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
483
+ const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
484
+ torch::Tensor& b_zeros, const torch::Tensor& g_idx,
485
+ const torch::Tensor& perm, torch::Tensor& workspace,
486
+ vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
487
+ int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
488
+ int64_t moe_block_size, bool replicate_input, bool apply_weights) {
489
+ vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
490
+ bool has_zp = b_zeros.size(1) != 0;
491
+ if (has_zp) {
492
+ TORCH_CHECK(
493
+ b_q_type == vllm::kU4,
494
+ "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
495
+ } else {
496
+ TORCH_CHECK(
497
+ b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
498
+ "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
499
+ }
500
+
501
+ int pack_factor = 32 / b_q_type.size_bits();
502
+
503
+ int max_par = 4;
504
+
505
+ int dev = a.get_device();
506
+
507
+ auto options_dtype =
508
+ torch::TensorOptions().dtype(a.dtype()).device(a.device());
509
+ auto options_int =
510
+ torch::TensorOptions().dtype(torch::kInt).device(a.device());
511
+ torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
512
+ torch::Tensor a_tmp =
513
+ replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
514
+ : torch::zeros({size_m, topk, size_k}, options_dtype);
515
+ torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
516
+
517
+ // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
518
+ // auto -1)
519
+ int thread_k = -1;
520
+ // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
521
+ // auto -1)
522
+ int thread_n = -1;
523
+ // sms: number of SMs to use for the kernel (can usually be left as auto -1)
524
+ int sms = -1;
525
+
526
+ // Detect groupsize and act_order
527
+ int num_groups = -1;
528
+ int group_size = -1;
529
+ bool has_act_order = g_idx.size(1) != 0;
530
+
531
+ int b_rank = b_scales.sizes().size();
532
+ TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
533
+ TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
534
+ " is not size_n = ", size_n);
535
+ num_groups = b_scales.size(1);
536
+
537
+ TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
538
+ "if is_k_full is false, has_act_order must be true");
539
+
540
+ if (has_act_order) {
541
+ if (is_k_full) {
542
+ TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
543
+ TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
544
+ ", is not divisible by num_groups = ", num_groups);
545
+ group_size = size_k / num_groups;
546
+ } else {
547
+ group_size = 0;
548
+ }
549
+
550
+ } else {
551
+ if (num_groups > 1) {
552
+ TORCH_CHECK(
553
+ size_k % num_groups == 0, "size_k = ", size_k,
554
+ ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
555
+ group_size = size_k / num_groups;
556
+ } else {
557
+ group_size = -1;
558
+ }
559
+ }
560
+
561
+ // Verify b_zeros
562
+ if (has_zp) {
563
+ int rank = b_zeros.sizes().size();
564
+ TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
565
+ TORCH_CHECK(b_zeros.size(1) == num_groups,
566
+ "b_zeros dim 1 = ", b_zeros.size(1),
567
+ " is not num_groups = ", num_groups);
568
+ TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
569
+ "b_zeros dim 2 = ", b_zeros.size(2),
570
+ " is not size_n / pack_factor = ", size_n / pack_factor);
571
+ }
572
+
573
+ marlin_moe::marlin_mm_moe(
574
+ a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
575
+ topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
576
+ b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
577
+ expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
578
+ b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
579
+ num_experts, topk, moe_block_size, dev,
580
+ at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
581
+ replicate_input, apply_weights);
582
+ return c;
583
+ }
584
+
moe/moe_align_sum_kernels.cu ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include <ATen/ATen.h>
6
+ #include <THC/THCAtomics.cuh>
7
+
8
+ #include "../cuda_compat.h"
9
+ #include "../dispatch_utils.h"
10
+
11
+ #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
12
+
13
+ namespace vllm {
14
+ namespace moe {
15
+
16
+ namespace {
17
+ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
18
+ int32_t col) {
19
+ // don't worry about overflow because num_experts is relatively small
20
+ return row * total_col + col;
21
+ }
22
+ } // namespace
23
+
24
+ template <typename scalar_t>
25
+ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
26
+ int32_t* sorted_token_ids,
27
+ int32_t* expert_ids,
28
+ int32_t* total_tokens_post_pad,
29
+ int32_t num_experts,
30
+ int32_t block_size, size_t numel) {
31
+ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
32
+ const size_t start_idx = threadIdx.x * tokens_per_thread;
33
+
34
+ extern __shared__ int32_t shared_mem[];
35
+
36
+ int32_t* tokens_cnts =
37
+ shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
38
+ int32_t* cumsum =
39
+ shared_mem +
40
+ (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
41
+
42
+ for (int i = 0; i < num_experts; ++i) {
43
+ tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
44
+ }
45
+
46
+ /**
47
+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
48
+ * which counts how many tokens in the token shard of thread_index are
49
+ * assigned to expert expert_index.
50
+ */
51
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
52
+ ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
53
+ }
54
+
55
+ __syncthreads();
56
+
57
+ // For each expert we accumulate the token counts from the different threads.
58
+ if (threadIdx.x < num_experts) {
59
+ tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
60
+ for (int i = 1; i <= blockDim.x; ++i) {
61
+ tokens_cnts[index(num_experts, i, threadIdx.x)] +=
62
+ tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
63
+ }
64
+ }
65
+
66
+ __syncthreads();
67
+
68
+ // We accumulate the token counts of all experts in thread 0.
69
+ if (threadIdx.x == 0) {
70
+ cumsum[0] = 0;
71
+ for (int i = 1; i <= num_experts; ++i) {
72
+ cumsum[i] = cumsum[i - 1] +
73
+ CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
74
+ block_size) *
75
+ block_size;
76
+ }
77
+ *total_tokens_post_pad = cumsum[num_experts];
78
+ }
79
+
80
+ __syncthreads();
81
+
82
+ /**
83
+ * For each expert, each thread processes the tokens of the corresponding
84
+ * blocks and stores the corresponding expert_id for each block.
85
+ */
86
+ if (threadIdx.x < num_experts) {
87
+ for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
88
+ i += block_size) {
89
+ expert_ids[i / block_size] = threadIdx.x;
90
+ }
91
+ }
92
+
93
+ /**
94
+ * Each thread processes a token shard, calculating the index of each token
95
+ * after sorting by expert number. Given the example topk_ids =
96
+ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
97
+ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
98
+ * padding value(preset in python).
99
+ */
100
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
101
+ int32_t expert_id = topk_ids[i];
102
+ /** The cumsum[expert_id] stores the starting index of the tokens that the
103
+ * expert with expert_id needs to process, and
104
+ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
105
+ * processed by the expert with expert_id within the current thread's token
106
+ * shard.
107
+ */
108
+ int32_t rank_post_pad =
109
+ tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
110
+ cumsum[expert_id];
111
+ sorted_token_ids[rank_post_pad] = i;
112
+ ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
113
+ }
114
+ }
115
+
116
+ template <typename scalar_t, int TOPK>
117
+ __global__ void moe_sum_kernel(
118
+ scalar_t* __restrict__ out, // [..., d]
119
+ const scalar_t* __restrict__ input, // [..., topk, d]
120
+ const int d) {
121
+ const int64_t token_idx = blockIdx.x;
122
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
123
+ scalar_t x = 0.0;
124
+ #pragma unroll
125
+ for (int k = 0; k < TOPK; ++k) {
126
+ x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
127
+ }
128
+ out[token_idx * d + idx] = x;
129
+ }
130
+ }
131
+
132
+ } // namespace moe
133
+ } // namespace vllm
134
+
135
+ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
136
+ int64_t block_size, torch::Tensor sorted_token_ids,
137
+ torch::Tensor experts_ids,
138
+ torch::Tensor num_tokens_post_pad) {
139
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
140
+ VLLM_DISPATCH_INTEGRAL_TYPES(
141
+ topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
142
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
143
+ // tensors
144
+ const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
145
+ const int32_t shared_mem =
146
+ ((num_thread + 1) * num_experts + (num_experts + 1)) *
147
+ sizeof(int32_t);
148
+
149
+ // set dynamic shared mem
150
+ auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
151
+ AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
152
+ (void*)kernel, shared_mem));
153
+ kernel<<<1, num_thread, shared_mem, stream>>>(
154
+ topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
155
+ experts_ids.data_ptr<int32_t>(),
156
+ num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
157
+ topk_ids.numel());
158
+ });
159
+ }
160
+
161
+ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
162
+ torch::Tensor& output) // [num_tokens, hidden_size]
163
+ {
164
+ const int hidden_size = input.size(-1);
165
+ const int num_tokens = output.numel() / hidden_size;
166
+ const int topk = input.size(1);
167
+
168
+ dim3 grid(num_tokens);
169
+ dim3 block(std::min(hidden_size, 1024));
170
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
171
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
172
+
173
+ switch (topk) {
174
+ case 2:
175
+ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
176
+ vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
177
+ output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
178
+ hidden_size);
179
+ });
180
+ break;
181
+
182
+ case 3:
183
+ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
184
+ vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
185
+ output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
186
+ hidden_size);
187
+ });
188
+ break;
189
+
190
+ case 4:
191
+ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
192
+ vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
193
+ output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
194
+ hidden_size);
195
+ });
196
+ break;
197
+
198
+ default:
199
+ at::sum_out(output, input, 1);
200
+ break;
201
+ }
202
+ }
moe/topk_softmax_kernels.cu ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
3
+ * Copyright (c) 2024, The vLLM team.
4
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
5
+ * SPDX-License-Identifier: Apache-2.0
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #include <torch/all.h>
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <c10/cuda/CUDAGuard.h>
22
+ #include "../cuda_compat.h"
23
+
24
+ #ifndef USE_ROCM
25
+ #include <cub/util_type.cuh>
26
+ #include <cub/cub.cuh>
27
+ #else
28
+ #include <hipcub/util_type.hpp>
29
+ #include <hipcub/hipcub.hpp>
30
+ #endif
31
+
32
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
33
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
34
+
35
+ namespace vllm {
36
+ namespace moe {
37
+
38
+ /// Aligned array type
39
+ template <
40
+ typename T,
41
+ /// Number of elements in the array
42
+ int N,
43
+ /// Alignment requirement in bytes
44
+ int Alignment = sizeof(T) * N
45
+ >
46
+ class alignas(Alignment) AlignedArray {
47
+ float data[N];
48
+ };
49
+
50
+ // ====================== Softmax things ===============================
51
+ // We have our own implementation of softmax here so we can support transposing the output
52
+ // in the softmax kernel when we extend this module to support expert-choice routing.
53
+ template <int TPB>
54
+ __launch_bounds__(TPB) __global__
55
+ void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
56
+ {
57
+ using BlockReduce = cub::BlockReduce<float, TPB>;
58
+ __shared__ typename BlockReduce::TempStorage tmpStorage;
59
+
60
+ __shared__ float normalizing_factor;
61
+ __shared__ float float_max;
62
+
63
+ const int thread_row_offset = blockIdx.x * num_cols;
64
+
65
+ cub::Sum sum;
66
+ float threadData(-FLT_MAX);
67
+
68
+ // Don't touch finished rows.
69
+ if ((finished != nullptr) && finished[blockIdx.x])
70
+ {
71
+ return;
72
+ }
73
+
74
+ for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
75
+ {
76
+ const int idx = thread_row_offset + ii;
77
+ threadData = max(static_cast<float>(input[idx]), threadData);
78
+ }
79
+
80
+ const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
81
+ if (threadIdx.x == 0)
82
+ {
83
+ float_max = maxElem;
84
+ }
85
+ __syncthreads();
86
+
87
+ threadData = 0;
88
+
89
+ for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
90
+ {
91
+ const int idx = thread_row_offset + ii;
92
+ threadData += exp((static_cast<float>(input[idx]) - float_max));
93
+ }
94
+
95
+ const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
96
+
97
+ if (threadIdx.x == 0)
98
+ {
99
+ normalizing_factor = 1.f / Z;
100
+ }
101
+ __syncthreads();
102
+
103
+ for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
104
+ {
105
+ const int idx = thread_row_offset + ii;
106
+ const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
107
+ output[idx] = val;
108
+ }
109
+ }
110
+
111
+ template <int TPB>
112
+ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
113
+ int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
114
+ {
115
+
116
+ using cub_kvp = cub::KeyValuePair<int, float>;
117
+ using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
118
+ __shared__ typename BlockReduce::TempStorage tmpStorage;
119
+
120
+ cub_kvp thread_kvp;
121
+ cub::ArgMax arg_max;
122
+
123
+ const int num_rows = gridDim.x;
124
+ const int block_row = blockIdx.x;
125
+
126
+ const bool row_is_active = finished ? !finished[block_row] : true;
127
+ const int thread_read_offset = blockIdx.x * num_experts;
128
+ for (int k_idx = 0; k_idx < k; ++k_idx)
129
+ {
130
+ thread_kvp.key = 0;
131
+ thread_kvp.value = -1.f; // This is OK because inputs are probabilities
132
+
133
+ cub_kvp inp_kvp;
134
+ for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
135
+ {
136
+ const int idx = thread_read_offset + expert;
137
+ inp_kvp.key = expert;
138
+ inp_kvp.value = inputs_after_softmax[idx];
139
+
140
+ for (int prior_k = 0; prior_k < k_idx; ++prior_k)
141
+ {
142
+ const int prior_winning_expert = indices[k * block_row + prior_k];
143
+
144
+ if (prior_winning_expert == expert)
145
+ {
146
+ inp_kvp = thread_kvp;
147
+ }
148
+ }
149
+
150
+ thread_kvp = arg_max(inp_kvp, thread_kvp);
151
+ }
152
+
153
+ const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
154
+ if (threadIdx.x == 0)
155
+ {
156
+ // Ignore experts the node isn't responsible for with expert parallelism
157
+ const int expert = result_kvp.key;
158
+ const bool node_uses_expert = expert >= start_expert && expert < end_expert;
159
+ const bool should_process_row = row_is_active && node_uses_expert;
160
+
161
+ const int idx = k * block_row + k_idx;
162
+ output[idx] = result_kvp.value;
163
+ indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
164
+ assert(indices[idx] >= 0);
165
+ source_rows[idx] = k_idx * num_rows + block_row;
166
+ }
167
+ __syncthreads();
168
+ }
169
+ }
170
+
171
+ // ====================== TopK softmax things ===============================
172
+
173
+ /*
174
+ A Top-K gating softmax written to exploit when the number of experts in the MoE layers
175
+ are a small power of 2. This allows us to cleanly share the rows among the threads in
176
+ a single warp and eliminate communication between warps (so no need to use shared mem).
177
+
178
+ It fuses the softmax, max and argmax into a single kernel.
179
+
180
+ Limitations:
181
+ 1) This implementation is intended for when the number of experts is a small power of 2.
182
+ 2) This implementation assumes k is small, but will work for any k.
183
+ */
184
+
185
+ template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
186
+ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
187
+ void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
188
+ int* source_rows, const int k, const int start_expert, const int end_expert)
189
+ {
190
+ // We begin by enforcing compile time assertions and setting up compile time constants.
191
+ static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
192
+ static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
193
+ static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
194
+ static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
195
+
196
+ // Number of bytes each thread pulls in per load
197
+ static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
198
+ static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
199
+ static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
200
+ static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
201
+
202
+ // Restrictions based on previous section.
203
+ static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
204
+ static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
205
+ static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
206
+ static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
207
+
208
+ // We have NUM_EXPERTS elements per row. We specialize for small #experts
209
+ static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
210
+ static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
211
+ static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
212
+
213
+ // Restrictions for previous section.
214
+ static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
215
+
216
+ // ===================== From this point, we finally start computing run-time variables. ========================
217
+
218
+ // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
219
+ // This, each block processes a chunk of rows. We start by computing the start row for each block.
220
+ const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
221
+
222
+ // Now, using the base row per thread block, we compute the base row per warp.
223
+ const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
224
+
225
+ // The threads in a warp are split into sub-groups that will work on a row.
226
+ // We compute row offset for each thread sub-group
227
+ const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
228
+ const int thread_row = warp_base_row + thread_row_in_warp;
229
+
230
+ // Threads with indices out of bounds should early exit here.
231
+ if (thread_row >= num_rows)
232
+ {
233
+ return;
234
+ }
235
+ const bool row_is_active = finished ? !finished[thread_row] : true;
236
+
237
+ // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
238
+ // row it will read.
239
+ const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
240
+
241
+ // Now, we compute the group each thread belong to in order to determine the first column to start loads.
242
+ const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
243
+ const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
244
+ const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
245
+
246
+ // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
247
+ // this can support all powers of 2 up to 16.
248
+ // NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
249
+ // We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
250
+ using AccessType = AlignedArray<float, ELTS_PER_LDG>;
251
+
252
+ // Finally, we pull in the data from global mem
253
+ float row_chunk[VPT];
254
+ AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
255
+ const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
256
+ #pragma unroll
257
+ for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
258
+ {
259
+ row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
260
+ }
261
+
262
+ // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
263
+ // convert to float afterwards for the exp + sum reduction.
264
+ float thread_max = row_chunk[0];
265
+ #pragma unroll
266
+ for (int ii = 1; ii < VPT; ++ii)
267
+ {
268
+ thread_max = max(thread_max, row_chunk[ii]);
269
+ }
270
+
271
+ // Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
272
+ #pragma unroll
273
+ for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
274
+ {
275
+ thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
276
+ }
277
+
278
+ // From this point, thread max in all the threads have the max within the row.
279
+ // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
280
+ float row_sum = 0;
281
+ #pragma unroll
282
+ for (int ii = 0; ii < VPT; ++ii)
283
+ {
284
+ row_chunk[ii] = expf(row_chunk[ii] - thread_max);
285
+ row_sum += row_chunk[ii];
286
+ }
287
+
288
+ // Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
289
+ #pragma unroll
290
+ for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
291
+ {
292
+ row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
293
+ }
294
+
295
+ // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
296
+ // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
297
+ // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
298
+ // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
299
+ // argmax after computing the softmax.
300
+ const float reciprocal_row_sum = 1.f / row_sum;
301
+
302
+ #pragma unroll
303
+ for (int ii = 0; ii < VPT; ++ii)
304
+ {
305
+ row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
306
+ }
307
+
308
+ // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
309
+ // with the max index.
310
+ int start_col = first_elt_read_by_thread;
311
+ static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
312
+
313
+ for (int k_idx = 0; k_idx < k; ++k_idx)
314
+ {
315
+ // First, each thread does the local argmax
316
+ float max_val = row_chunk[0];
317
+ int expert = start_col;
318
+ #pragma unroll
319
+ for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
320
+ {
321
+ #pragma unroll
322
+ for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
323
+ {
324
+ float val = row_chunk[ldg * ELTS_PER_LDG + ii];
325
+
326
+ // No check on the experts here since columns with the smallest index are processed first and only
327
+ // updated if > (not >=)
328
+ if (val > max_val)
329
+ {
330
+ max_val = val;
331
+ expert = col + ii;
332
+ }
333
+ }
334
+ }
335
+
336
+ // Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
337
+ // This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
338
+ // then blank out their max with -inf and the warp can run more iterations...
339
+ #pragma unroll
340
+ for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
341
+ {
342
+ float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
343
+ int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
344
+
345
+ // We want lower indices to "win" in every thread so we break ties this way
346
+ if (other_max > max_val || (other_max == max_val && other_expert < expert))
347
+ {
348
+ max_val = other_max;
349
+ expert = other_expert;
350
+ }
351
+ }
352
+
353
+ // Write the max for this k iteration to global memory.
354
+ if (thread_group_idx == 0)
355
+ {
356
+ // Add a guard to ignore experts not included by this node
357
+ const bool node_uses_expert = expert >= start_expert && expert < end_expert;
358
+ const bool should_process_row = row_is_active && node_uses_expert;
359
+
360
+ // The lead thread from each sub-group will write out the final results to global memory. (This will be a
361
+ // single) thread per row of the input/output matrices.
362
+ const int idx = k * thread_row + k_idx;
363
+ output[idx] = max_val;
364
+ indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
365
+ source_rows[idx] = k_idx * num_rows + thread_row;
366
+ }
367
+
368
+ // Finally, we clear the value in the thread with the current max if there is another iteration to run.
369
+ if (k_idx + 1 < k)
370
+ {
371
+ const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
372
+ const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
373
+
374
+ // Only the thread in the group which produced the max will reset the "winning" value to -inf.
375
+ if (thread_group_idx == thread_to_clear_in_group)
376
+ {
377
+ const int offset_for_expert = expert % ELTS_PER_LDG;
378
+ // Safe to set to any negative value since row_chunk values must be between 0 and 1.
379
+ row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
380
+ }
381
+ }
382
+ }
383
+ }
384
+
385
+ namespace detail
386
+ {
387
+ // Constructs some constants needed to partition the work across threads at compile time.
388
+ template <int EXPERTS, int BYTES_PER_LDG>
389
+ struct TopkConstants
390
+ {
391
+ static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
392
+ static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
393
+ static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
394
+ static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
395
+ static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
396
+ static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
397
+ };
398
+ } // namespace detail
399
+
400
+ template <int EXPERTS, int WARPS_PER_TB>
401
+ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
402
+ int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
403
+ {
404
+ static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
405
+
406
+ static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
407
+ using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
408
+ static constexpr int VPT = Constants::VPT;
409
+ static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
410
+ const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
411
+ const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
412
+
413
+ dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
414
+ topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
415
+ input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
416
+ }
417
+
418
+ #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
419
+ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
420
+ gating_output, nullptr, topk_weights, topk_indicies, \
421
+ token_expert_indices, num_tokens, topk, 0, num_experts, \
422
+ stream);
423
+
424
+ void topkGatingSoftmaxKernelLauncher(
425
+ const float* gating_output,
426
+ float* topk_weights,
427
+ int* topk_indicies,
428
+ int* token_expert_indices,
429
+ float* softmax_workspace,
430
+ const int num_tokens,
431
+ const int num_experts,
432
+ const int topk,
433
+ cudaStream_t stream) {
434
+ static constexpr int WARPS_PER_TB = 4;
435
+ switch (num_experts) {
436
+ case 1:
437
+ LAUNCH_SOFTMAX(1, WARPS_PER_TB);
438
+ break;
439
+ case 2:
440
+ LAUNCH_SOFTMAX(2, WARPS_PER_TB);
441
+ break;
442
+ case 4:
443
+ LAUNCH_SOFTMAX(4, WARPS_PER_TB);
444
+ break;
445
+ case 8:
446
+ LAUNCH_SOFTMAX(8, WARPS_PER_TB);
447
+ break;
448
+ case 16:
449
+ LAUNCH_SOFTMAX(16, WARPS_PER_TB);
450
+ break;
451
+ case 32:
452
+ LAUNCH_SOFTMAX(32, WARPS_PER_TB);
453
+ break;
454
+ case 64:
455
+ LAUNCH_SOFTMAX(64, WARPS_PER_TB);
456
+ break;
457
+ case 128:
458
+ LAUNCH_SOFTMAX(128, WARPS_PER_TB);
459
+ break;
460
+ case 256:
461
+ LAUNCH_SOFTMAX(256, WARPS_PER_TB);
462
+ break;
463
+ default: {
464
+ TORCH_CHECK(softmax_workspace != nullptr,
465
+ "softmax_workspace must be provided for num_experts that are not a power of 2.");
466
+ static constexpr int TPB = 256;
467
+ moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
468
+ gating_output, nullptr, softmax_workspace, num_experts);
469
+ moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
470
+ softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
471
+ num_experts, topk, 0, num_experts);
472
+ }
473
+ }
474
+ }
475
+
476
+ } // namespace moe
477
+ } // namespace vllm
478
+
479
+ void topk_softmax(
480
+ torch::Tensor& topk_weights, // [num_tokens, topk]
481
+ torch::Tensor& topk_indices, // [num_tokens, topk]
482
+ torch::Tensor& token_expert_indices, // [num_tokens, topk]
483
+ torch::Tensor& gating_output) // [num_tokens, num_experts]
484
+ {
485
+ const int num_experts = gating_output.size(-1);
486
+ const int num_tokens = gating_output.numel() / num_experts;
487
+ const int topk = topk_weights.size(-1);
488
+
489
+ const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
490
+ const bool needs_workspace = !is_pow_2 || num_experts > 256;
491
+ const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
492
+
493
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
494
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
495
+ torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
496
+ vllm::moe::topkGatingSoftmaxKernelLauncher(
497
+ gating_output.data_ptr<float>(),
498
+ topk_weights.data_ptr<float>(),
499
+ topk_indices.data_ptr<int>(),
500
+ token_expert_indices.data_ptr<int>(),
501
+ softmax_workspace.data_ptr<float>(),
502
+ num_tokens,
503
+ num_experts,
504
+ topk,
505
+ stream);
506
+ }
test/__init__.py ADDED
File without changes
test/kernels/__init__.py ADDED
File without changes
test/kernels/test_moe.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the MOE layers.
2
+
3
+ Run `pytest tests/kernels/test_moe.py`.
4
+ """
5
+
6
+ from typing import List
7
+
8
+ import pytest
9
+ import torch
10
+
11
+ from moe._ops import ops
12
+ from moe.fused_moe import fused_moe, fused_topk, moe_align_block_size
13
+ from moe.fused_marlin_moe import fused_marlin_moe
14
+ from moe.scalar_type import scalar_types
15
+ from moe.utils.marlin_utils_test import marlin_quantize
16
+
17
+ from .utils import compute_max_diff, opcheck
18
+
19
+
20
+ def stack_and_dev(tensors: List[torch.Tensor]):
21
+ dev = tensors[0].device
22
+ return torch.stack(tensors, dim=0).to(dev)
23
+
24
+
25
+ NUM_EXPERTS = [8, 64]
26
+ TOP_KS = [2, 6]
27
+
28
+
29
+ @pytest.mark.parametrize("m", [1, 33, 64, 222])
30
+ @pytest.mark.parametrize("n", [128, 2048])
31
+ @pytest.mark.parametrize("k", [128, 1024])
32
+ @pytest.mark.parametrize("e", NUM_EXPERTS)
33
+ @pytest.mark.parametrize("topk", TOP_KS)
34
+ @pytest.mark.parametrize("group_size", [-1, 32, 128])
35
+ @pytest.mark.parametrize("act_order", [True, False])
36
+ @pytest.mark.parametrize("num_bits", [4, 8])
37
+ @pytest.mark.parametrize("is_k_full", [True, False])
38
+ # @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
39
+ def test_fused_marlin_moe(
40
+ m: int,
41
+ n: int,
42
+ k: int,
43
+ e: int,
44
+ topk: int,
45
+ group_size: int,
46
+ act_order: bool,
47
+ num_bits: int,
48
+ is_k_full: bool,
49
+ ):
50
+ torch.manual_seed(7)
51
+
52
+ # Filter act_order
53
+ if act_order:
54
+ if group_size == -1:
55
+ return
56
+ if group_size in (k, n):
57
+ return
58
+ else:
59
+ if not is_k_full:
60
+ return
61
+
62
+ quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
63
+ dtype = torch.float16
64
+ a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
65
+ w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
66
+ w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
67
+
68
+ w_ref1_l = []
69
+ qweight1_l = []
70
+ scales1_l = []
71
+ g_idx1_l = []
72
+ sort_indices1_l = []
73
+
74
+ for i in range(w1.shape[0]):
75
+ test_perm = torch.randperm(k)
76
+ w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
77
+ w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
78
+ )
79
+ w_ref1_l.append(w_ref1)
80
+ qweight1_l.append(qweight1)
81
+ scales1_l.append(scales1)
82
+ g_idx1_l.append(g_idx1)
83
+ sort_indices1_l.append(sort_indices1)
84
+
85
+ w_ref1 = stack_and_dev(w_ref1_l)
86
+ qweight1 = stack_and_dev(qweight1_l).contiguous()
87
+ scales1 = stack_and_dev(scales1_l)
88
+ g_idx1 = stack_and_dev(g_idx1_l)
89
+ sort_indices1 = stack_and_dev(sort_indices1_l)
90
+
91
+ w_ref2_l = []
92
+ qweight2_l = []
93
+ scales2_l = []
94
+ g_idx2_l = []
95
+ sort_indices2_l = []
96
+
97
+ for i in range(w2.shape[0]):
98
+ test_perm = torch.randperm(n)
99
+ w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
100
+ w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
101
+ )
102
+ w_ref2_l.append(w_ref2)
103
+ qweight2_l.append(qweight2)
104
+ scales2_l.append(scales2)
105
+ g_idx2_l.append(g_idx2)
106
+ sort_indices2_l.append(sort_indices2)
107
+
108
+ w_ref2 = stack_and_dev(w_ref2_l)
109
+ qweight2 = stack_and_dev(qweight2_l).contiguous()
110
+ scales2 = stack_and_dev(scales2_l)
111
+ g_idx2 = stack_and_dev(g_idx2_l)
112
+ sort_indices2 = stack_and_dev(sort_indices2_l)
113
+
114
+ score = torch.randn((m, e), device="cuda", dtype=dtype)
115
+
116
+ topk_weights, topk_ids = fused_topk(a, score, topk, False)
117
+
118
+ triton_output = fused_moe(
119
+ a,
120
+ w_ref1.transpose(1, 2).contiguous(),
121
+ w_ref2.transpose(1, 2).contiguous(),
122
+ score,
123
+ topk,
124
+ renormalize=False,
125
+ )
126
+ marlin_output = fused_marlin_moe(
127
+ a,
128
+ qweight1,
129
+ qweight2,
130
+ scales1,
131
+ scales2,
132
+ score,
133
+ topk_weights,
134
+ topk_ids,
135
+ g_idx1=g_idx1,
136
+ g_idx2=g_idx2,
137
+ sort_indices1=sort_indices1,
138
+ sort_indices2=sort_indices2,
139
+ num_bits=num_bits,
140
+ is_k_full=is_k_full,
141
+ )
142
+
143
+ assert compute_max_diff(marlin_output, triton_output) < 4e-2
144
+
145
+ token_expert_indicies = torch.empty(m, topk, dtype=torch.int32, device=a.device)
146
+
147
+ opcheck(
148
+ ops.topk_softmax,
149
+ (
150
+ topk_weights,
151
+ topk_ids,
152
+ token_expert_indicies,
153
+ score.float(),
154
+ ),
155
+ )
156
+
157
+ block_size_m = 4
158
+
159
+ sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, e)
160
+
161
+ max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
162
+ workspace = torch.zeros(
163
+ max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False
164
+ )
165
+
166
+ zp = torch.empty((0, 0), dtype=dtype, device="cuda", requires_grad=False)
167
+ opcheck(
168
+ ops.marlin_gemm_moe,
169
+ (
170
+ a,
171
+ qweight1,
172
+ sorted_token_ids,
173
+ topk_weights,
174
+ topk_ids,
175
+ scales1,
176
+ zp,
177
+ g_idx1,
178
+ sort_indices1,
179
+ workspace,
180
+ quant_type.id,
181
+ m,
182
+ 2 * n,
183
+ k,
184
+ True,
185
+ e,
186
+ topk,
187
+ block_size_m,
188
+ True,
189
+ False,
190
+ ),
191
+ )
192
+
193
+
194
+ def test_moe_align_block_size_opcheck():
195
+ num_experts = 4
196
+ block_size = 4
197
+ topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
198
+
199
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
200
+ sorted_ids = torch.empty(
201
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
202
+ )
203
+ sorted_ids.fill_(topk_ids.numel())
204
+ max_num_m_blocks = max_num_tokens_padded // block_size
205
+ expert_ids = torch.empty(
206
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
207
+ )
208
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
209
+
210
+ opcheck(
211
+ ops.moe_align_block_size,
212
+ (
213
+ topk_ids,
214
+ num_experts,
215
+ block_size,
216
+ sorted_ids,
217
+ expert_ids,
218
+ num_tokens_post_pad,
219
+ ),
220
+ )
test/kernels/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kernel test utils"""
2
+
3
+ import itertools
4
+ import random
5
+ import unittest
6
+ from numbers import Number
7
+ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
8
+
9
+ import pytest
10
+ import torch
11
+ from torch._prims_common import TensorLikeType
12
+
13
+ # For now, disable "test_aot_dispatch_dynamic" since there are some
14
+ # bugs related to this test in PyTorch 2.4.
15
+ DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
16
+ "test_schema",
17
+ "test_autograd_registration",
18
+ "test_faketensor",
19
+ )
20
+
21
+ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
22
+ "test_schema",
23
+ "test_autograd_registration",
24
+ "test_faketensor",
25
+ "test_aot_dispatch_dynamic",
26
+ )
27
+
28
+
29
+ # Copied/modified from torch._refs.__init__.py
30
+ def fp8_allclose(
31
+ a: TensorLikeType,
32
+ b: TensorLikeType,
33
+ rtol: float = 1e-05,
34
+ atol: float = 1e-08,
35
+ equal_nan: bool = False,
36
+ ) -> bool:
37
+ """
38
+ Reference implementation of torch.allclose
39
+ """
40
+ torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
41
+
42
+ return bool(
43
+ torch.all(
44
+ torch.isclose(
45
+ a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
46
+ )
47
+ ).item()
48
+ )
49
+
50
+
51
+ def compute_max_diff(output, output_ref):
52
+ return torch.mean(torch.abs(output - output_ref)) / torch.mean(
53
+ torch.abs(output_ref))
54
+
55
+
56
+ # A special version of op check that has a restricted default set of test_utils
57
+ # and a patched version of allclose that supports fp8 types.
58
+ def opcheck(
59
+ op: Union[
60
+ torch._ops.OpOverload,
61
+ torch._ops.OpOverloadPacket,
62
+ torch._library.custom_ops.CustomOpDef,
63
+ ],
64
+ args: Tuple[Any, ...],
65
+ kwargs: Optional[Dict[str, Any]] = None,
66
+ *,
67
+ test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
68
+ raise_exception: bool = True,
69
+ cond: bool = True
70
+ ) -> Dict[str, str]:
71
+ with unittest.mock.patch("torch.allclose", new=fp8_allclose):
72
+ return (
73
+ torch.library.opcheck(
74
+ op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
75
+ )
76
+ if cond
77
+ else {}
78
+ )