Add MoE kernels from vLLM
Browse filesTODO: add MoE configs, but this requires a change to the builder.
- README.md +7 -3
- activation/activation_kernels.cu +204 -0
- activation/cuda_compat.h +49 -0
- activation/dispatch_utils.h +35 -0
- build.toml +49 -0
- core/exception.hpp +3 -0
- core/scalar_type.hpp +347 -0
- cuda_compat.h +49 -0
- dispatch_utils.h +35 -0
- ext-torch/__init__.py +1 -0
- ext-torch/_custom_ops.py +135 -0
- ext-torch/fp8.py +63 -0
- ext-torch/fused_marlin_moe.py +338 -0
- ext-torch/fused_moe.py +703 -0
- ext-torch/platforms.py +22 -0
- ext-torch/registration.h +27 -0
- ext-torch/scalar_type.py +330 -0
- ext-torch/torch_binding.cpp +45 -0
- ext-torch/torch_binding.h +30 -0
- ext-torch/utils/__init__.py +0 -0
- ext-torch/utils/marlin_utils.py +307 -0
- ext-torch/utils/marlin_utils_test.py +162 -0
- ext-torch/utils/quant_utils.py +470 -0
- flake.nix +14 -0
- marlin-moe/marlin_kernels/marlin_moe_kernel.h +1616 -0
- marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.cu +31 -0
- marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.h +20 -0
- marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +31 -0
- marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +20 -0
- marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +31 -0
- marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +18 -0
- marlin-moe/marlin_moe_ops.cu +584 -0
- moe/moe_align_sum_kernels.cu +202 -0
- moe/topk_softmax_kernels.cu +506 -0
- test/__init__.py +0 -0
- test/kernels/__init__.py +0 -0
- test/kernels/test_moe.py +220 -0
- test/kernels/utils.py +78 -0
README.md
CHANGED
@@ -1,3 +1,7 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
## MoE
|
6 |
+
|
7 |
+
MoE kernels from [vLLM](https://github.com/vllm-project/).
|
activation/activation_kernels.cu
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <torch/all.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include <cmath>
|
6 |
+
|
7 |
+
#include "cuda_compat.h"
|
8 |
+
#include "dispatch_utils.h"
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
|
12 |
+
// Activation and gating kernel template.
|
13 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
14 |
+
__global__ void act_and_mul_kernel(
|
15 |
+
scalar_t* __restrict__ out, // [..., d]
|
16 |
+
const scalar_t* __restrict__ input, // [..., 2, d]
|
17 |
+
const int d) {
|
18 |
+
const int64_t token_idx = blockIdx.x;
|
19 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
20 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
21 |
+
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
22 |
+
out[token_idx * d + idx] = ACT_FN(x) * y;
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
26 |
+
template <typename T>
|
27 |
+
__device__ __forceinline__ T silu_kernel(const T& x) {
|
28 |
+
// x * sigmoid(x)
|
29 |
+
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
30 |
+
}
|
31 |
+
|
32 |
+
template <typename T>
|
33 |
+
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
34 |
+
// Equivalent to PyTorch GELU with 'none' approximation.
|
35 |
+
// Refer to:
|
36 |
+
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
37 |
+
const float f = (float)x;
|
38 |
+
constexpr float ALPHA = M_SQRT1_2;
|
39 |
+
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
40 |
+
}
|
41 |
+
|
42 |
+
template <typename T>
|
43 |
+
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
44 |
+
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
45 |
+
// Refer to:
|
46 |
+
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
47 |
+
const float f = (float)x;
|
48 |
+
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
49 |
+
constexpr float KAPPA = 0.044715;
|
50 |
+
float x_cube = f * f * f;
|
51 |
+
float inner = BETA * (f + KAPPA * x_cube);
|
52 |
+
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
|
53 |
+
}
|
54 |
+
|
55 |
+
} // namespace vllm
|
56 |
+
|
57 |
+
// Launch activation and gating kernel.
|
58 |
+
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
59 |
+
int d = input.size(-1) / 2; \
|
60 |
+
int64_t num_tokens = input.numel() / input.size(-1); \
|
61 |
+
dim3 grid(num_tokens); \
|
62 |
+
dim3 block(std::min(d, 1024)); \
|
63 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
64 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
65 |
+
VLLM_DISPATCH_FLOATING_TYPES( \
|
66 |
+
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
67 |
+
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
68 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
69 |
+
input.data_ptr<scalar_t>(), d); \
|
70 |
+
});
|
71 |
+
|
72 |
+
void silu_and_mul(torch::Tensor& out, // [..., d]
|
73 |
+
torch::Tensor& input) // [..., 2 * d]
|
74 |
+
{
|
75 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
76 |
+
}
|
77 |
+
|
78 |
+
//void gelu_and_mul(torch::Tensor& out, // [..., d]
|
79 |
+
// torch::Tensor& input) // [..., 2 * d]
|
80 |
+
//{
|
81 |
+
// LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
82 |
+
//}
|
83 |
+
|
84 |
+
//void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
85 |
+
// torch::Tensor& input) // [..., 2 * d]
|
86 |
+
//{
|
87 |
+
// LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
88 |
+
//}
|
89 |
+
|
90 |
+
namespace vllm {
|
91 |
+
|
92 |
+
template <typename T>
|
93 |
+
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
|
94 |
+
const float f = (float)x;
|
95 |
+
return (T)(f > threshold ? f : 0.0f);
|
96 |
+
}
|
97 |
+
|
98 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
|
99 |
+
__global__ void act_and_mul_kernel_with_param(
|
100 |
+
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
|
101 |
+
const float param) {
|
102 |
+
const int64_t token_idx = blockIdx.x;
|
103 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
104 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
105 |
+
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
106 |
+
out[token_idx * d + idx] = ACT_FN(x, param) * y;
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
} // namespace vllm
|
111 |
+
|
112 |
+
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
|
113 |
+
int d = input.size(-1) / 2; \
|
114 |
+
int64_t num_tokens = input.numel() / input.size(-1); \
|
115 |
+
dim3 grid(num_tokens); \
|
116 |
+
dim3 block(std::min(d, 1024)); \
|
117 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
118 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
119 |
+
VLLM_DISPATCH_FLOATING_TYPES( \
|
120 |
+
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
|
121 |
+
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
|
122 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
123 |
+
input.data_ptr<scalar_t>(), d, \
|
124 |
+
PARAM); \
|
125 |
+
});
|
126 |
+
|
127 |
+
//void fatrelu_and_mul(torch::Tensor& out, // [..., d],
|
128 |
+
// torch::Tensor& input, // [..., 2 * d]
|
129 |
+
// double threshold) {
|
130 |
+
// LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
|
131 |
+
//}
|
132 |
+
namespace vllm {
|
133 |
+
|
134 |
+
// Element-wise activation kernel template.
|
135 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
136 |
+
__global__ void activation_kernel(
|
137 |
+
scalar_t* __restrict__ out, // [..., d]
|
138 |
+
const scalar_t* __restrict__ input, // [..., d]
|
139 |
+
const int d) {
|
140 |
+
const int64_t token_idx = blockIdx.x;
|
141 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
142 |
+
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
143 |
+
out[token_idx * d + idx] = ACT_FN(x);
|
144 |
+
}
|
145 |
+
}
|
146 |
+
|
147 |
+
} // namespace vllm
|
148 |
+
|
149 |
+
// Launch element-wise activation kernel.
|
150 |
+
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
151 |
+
int d = input.size(-1); \
|
152 |
+
int64_t num_tokens = input.numel() / d; \
|
153 |
+
dim3 grid(num_tokens); \
|
154 |
+
dim3 block(std::min(d, 1024)); \
|
155 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
156 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
157 |
+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
|
158 |
+
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
|
159 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
160 |
+
input.data_ptr<scalar_t>(), d); \
|
161 |
+
});
|
162 |
+
|
163 |
+
namespace vllm {
|
164 |
+
|
165 |
+
template <typename T>
|
166 |
+
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
167 |
+
const float x3 = (float)(x * x * x);
|
168 |
+
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
|
169 |
+
return ((T)0.5) * x * (((T)1.0) + t);
|
170 |
+
}
|
171 |
+
|
172 |
+
template <typename T>
|
173 |
+
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
174 |
+
const float f = (float)x;
|
175 |
+
const T t =
|
176 |
+
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
|
177 |
+
return ((T)0.5) * x * (((T)1.0) + t);
|
178 |
+
}
|
179 |
+
|
180 |
+
template <typename T>
|
181 |
+
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
|
182 |
+
// x * sigmoid(1.702 * x)
|
183 |
+
return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
|
184 |
+
}
|
185 |
+
|
186 |
+
} // namespace vllm
|
187 |
+
|
188 |
+
//void gelu_new(torch::Tensor& out, // [..., d]
|
189 |
+
// torch::Tensor& input) // [..., d]
|
190 |
+
//{
|
191 |
+
// LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
192 |
+
//}
|
193 |
+
|
194 |
+
//void gelu_fast(torch::Tensor& out, // [..., d]
|
195 |
+
// torch::Tensor& input) // [..., d]
|
196 |
+
//{
|
197 |
+
// LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
198 |
+
//}
|
199 |
+
|
200 |
+
//void gelu_quick(torch::Tensor& out, // [..., d]
|
201 |
+
// torch::Tensor& input) // [..., d]
|
202 |
+
//{
|
203 |
+
// LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
|
204 |
+
//}
|
activation/cuda_compat.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#ifdef USE_ROCM
|
4 |
+
#include <hip/hip_runtime.h>
|
5 |
+
#endif
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#define WARP_SIZE 32
|
9 |
+
#else
|
10 |
+
#define WARP_SIZE warpSize
|
11 |
+
#endif
|
12 |
+
|
13 |
+
#ifndef USE_ROCM
|
14 |
+
#define VLLM_LDG(arg) __ldg(arg)
|
15 |
+
#else
|
16 |
+
#define VLLM_LDG(arg) *(arg)
|
17 |
+
#endif
|
18 |
+
|
19 |
+
#ifndef USE_ROCM
|
20 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
21 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
22 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
23 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
24 |
+
#else
|
25 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
26 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
27 |
+
__shfl_xor(var, lane_mask, width)
|
28 |
+
#endif
|
29 |
+
|
30 |
+
#ifndef USE_ROCM
|
31 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
32 |
+
#else
|
33 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
34 |
+
#endif
|
35 |
+
|
36 |
+
#ifndef USE_ROCM
|
37 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
38 |
+
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
39 |
+
#else
|
40 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
41 |
+
#endif
|
42 |
+
|
43 |
+
#ifndef USE_ROCM
|
44 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
45 |
+
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
46 |
+
#else
|
47 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
48 |
+
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
49 |
+
#endif
|
activation/dispatch_utils.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
4 |
+
*/
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <torch/all.h>
|
8 |
+
|
9 |
+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
+
|
14 |
+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
16 |
+
|
17 |
+
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
18 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
19 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
20 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
21 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
22 |
+
|
23 |
+
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
24 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
25 |
+
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
26 |
+
|
27 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
28 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
29 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
30 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
31 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
32 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
33 |
+
|
34 |
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
35 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
build.toml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "moe"
|
6 |
+
src = [
|
7 |
+
"core/scalar_type.hpp",
|
8 |
+
"ext-torch/registration.h",
|
9 |
+
"ext-torch/torch_binding.cpp",
|
10 |
+
"ext-torch/torch_binding.h"
|
11 |
+
]
|
12 |
+
include = [ "." ]
|
13 |
+
pyroot = "ext-torch"
|
14 |
+
|
15 |
+
[kernel.moe]
|
16 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
17 |
+
src = [
|
18 |
+
"cuda_compat.h",
|
19 |
+
"dispatch_utils.h",
|
20 |
+
"moe/moe_align_sum_kernels.cu",
|
21 |
+
"moe/topk_softmax_kernels.cu",
|
22 |
+
]
|
23 |
+
depends = [ "torch" ]
|
24 |
+
|
25 |
+
[kernel.moe-marlin]
|
26 |
+
capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
27 |
+
src = [
|
28 |
+
"core/exception.hpp",
|
29 |
+
"core/scalar_type.hpp",
|
30 |
+
"marlin-moe/marlin_moe_ops.cu",
|
31 |
+
"marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.cu",
|
32 |
+
"marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu",
|
33 |
+
"marlin-moe/marlin_kernels/marlin_moe_kernel.h",
|
34 |
+
"marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.h",
|
35 |
+
"marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.h",
|
36 |
+
"marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu",
|
37 |
+
"marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h",
|
38 |
+
]
|
39 |
+
include = [ "." ]
|
40 |
+
depends = [ "torch" ]
|
41 |
+
|
42 |
+
[kernel.activation]
|
43 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
44 |
+
src = [
|
45 |
+
"activation/activation_kernels.cu",
|
46 |
+
"activation/cuda_compat.h",
|
47 |
+
"activation/dispatch_utils.h",
|
48 |
+
]
|
49 |
+
depends = [ "torch" ]
|
core/exception.hpp
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#define VLLM_IMPLIES(p, q) (!(p) || (q))
|
core/scalar_type.hpp
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
// For TORCH_CHECK
|
4 |
+
#include <torch/library.h>
|
5 |
+
|
6 |
+
namespace vllm {
|
7 |
+
|
8 |
+
//
|
9 |
+
// ScalarType can represent a wide range of floating point and integer types,
|
10 |
+
// in particular it can be used to represent sub-byte data types (something
|
11 |
+
// that torch.dtype currently does not support).
|
12 |
+
//
|
13 |
+
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
14 |
+
// these type definitions should be kept up to date with any Python API changes
|
15 |
+
// here.
|
16 |
+
//
|
17 |
+
class ScalarType {
|
18 |
+
public:
|
19 |
+
enum NanRepr : uint8_t {
|
20 |
+
NAN_NONE = 0, // nans are not supported
|
21 |
+
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
22 |
+
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
23 |
+
|
24 |
+
NAN_REPR_ID_MAX
|
25 |
+
};
|
26 |
+
|
27 |
+
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
28 |
+
int32_t bias, bool finite_values_only = false,
|
29 |
+
NanRepr nan_repr = NAN_IEEE_754)
|
30 |
+
: exponent(exponent),
|
31 |
+
mantissa(mantissa),
|
32 |
+
signed_(signed_),
|
33 |
+
bias(bias),
|
34 |
+
finite_values_only(finite_values_only),
|
35 |
+
nan_repr(nan_repr){};
|
36 |
+
|
37 |
+
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
38 |
+
return ScalarType(0, size_bits - 1, true, bias);
|
39 |
+
}
|
40 |
+
|
41 |
+
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
42 |
+
return ScalarType(0, size_bits, false, bias);
|
43 |
+
}
|
44 |
+
|
45 |
+
// IEEE 754 compliant floating point type
|
46 |
+
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
47 |
+
uint8_t mantissa) {
|
48 |
+
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
49 |
+
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
50 |
+
}
|
51 |
+
|
52 |
+
// IEEE 754 non-compliant floating point type
|
53 |
+
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
54 |
+
bool finite_values_only,
|
55 |
+
NanRepr nan_repr) {
|
56 |
+
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
57 |
+
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
58 |
+
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
59 |
+
"use `float_IEEE754` constructor for floating point types that "
|
60 |
+
"follow IEEE 754 conventions");
|
61 |
+
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
62 |
+
nan_repr);
|
63 |
+
}
|
64 |
+
|
65 |
+
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
66 |
+
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
67 |
+
// excluding the sign bit for integer types)
|
68 |
+
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
69 |
+
// sign bit)
|
70 |
+
int32_t const bias; // stored values equal value + bias,
|
71 |
+
// used for quantized type
|
72 |
+
|
73 |
+
// Extra Floating point info
|
74 |
+
bool const finite_values_only; // i.e. no +/-inf if true
|
75 |
+
NanRepr const nan_repr; // how NaNs are represented
|
76 |
+
// (not applicable for integer types)
|
77 |
+
|
78 |
+
using Id = int64_t;
|
79 |
+
|
80 |
+
private:
|
81 |
+
// Field size in id
|
82 |
+
template <typename T_>
|
83 |
+
static constexpr size_t member_id_field_width() {
|
84 |
+
using T = std::decay_t<T_>;
|
85 |
+
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
86 |
+
}
|
87 |
+
|
88 |
+
template <typename Fn, typename Init, typename Member, typename... Rest>
|
89 |
+
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
90 |
+
Rest... rest) {
|
91 |
+
auto new_val = f(val, member);
|
92 |
+
if constexpr (sizeof...(rest) > 0) {
|
93 |
+
return reduce_members_helper(f, new_val, rest...);
|
94 |
+
} else {
|
95 |
+
return new_val;
|
96 |
+
};
|
97 |
+
}
|
98 |
+
|
99 |
+
template <typename Fn, typename Init>
|
100 |
+
constexpr auto reduce_members(Fn f, Init init) const {
|
101 |
+
// Should be in constructor order for `from_id`
|
102 |
+
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
103 |
+
finite_values_only, nan_repr);
|
104 |
+
};
|
105 |
+
|
106 |
+
template <typename Fn, typename Init>
|
107 |
+
static constexpr auto reduce_member_types(Fn f, Init init) {
|
108 |
+
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
109 |
+
return dummy_type.reduce_members(f, init);
|
110 |
+
};
|
111 |
+
|
112 |
+
static constexpr auto id_size_bits() {
|
113 |
+
return reduce_member_types(
|
114 |
+
[](int acc, auto member) -> int {
|
115 |
+
return acc + member_id_field_width<decltype(member)>();
|
116 |
+
},
|
117 |
+
0);
|
118 |
+
}
|
119 |
+
|
120 |
+
public:
|
121 |
+
// unique id for this scalar type that can be computed at compile time for
|
122 |
+
// c++17 template specialization this is not needed once we migrate to
|
123 |
+
// c++20 and can pass literal classes as template parameters
|
124 |
+
constexpr Id id() const {
|
125 |
+
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
126 |
+
"ScalarType id is too large to be stored");
|
127 |
+
|
128 |
+
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
129 |
+
auto member) -> std::pair<Id, uint32_t> {
|
130 |
+
auto [id, bit_offset] = result;
|
131 |
+
auto constexpr bits = member_id_field_width<decltype(member)>();
|
132 |
+
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
133 |
+
<< bit_offset,
|
134 |
+
bit_offset + bits};
|
135 |
+
};
|
136 |
+
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
137 |
+
}
|
138 |
+
|
139 |
+
// create a ScalarType from an id, for c++17 template specialization,
|
140 |
+
// this is not needed once we migrate to c++20 and can pass literal
|
141 |
+
// classes as template parameters
|
142 |
+
static constexpr ScalarType from_id(Id id) {
|
143 |
+
auto extract_and_advance = [id](auto result, auto member) {
|
144 |
+
using T = decltype(member);
|
145 |
+
auto [tuple, bit_offset] = result;
|
146 |
+
auto constexpr bits = member_id_field_width<T>();
|
147 |
+
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
148 |
+
((uint64_t(1) << bits) - 1));
|
149 |
+
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
150 |
+
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
151 |
+
};
|
152 |
+
|
153 |
+
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
154 |
+
std::pair<std::tuple<>, int>{});
|
155 |
+
return std::apply([](auto... args) { return ScalarType(args...); },
|
156 |
+
tuple_args);
|
157 |
+
}
|
158 |
+
|
159 |
+
constexpr int64_t size_bits() const {
|
160 |
+
return mantissa + exponent + is_signed();
|
161 |
+
}
|
162 |
+
constexpr bool is_signed() const { return signed_; }
|
163 |
+
constexpr bool is_integer() const { return exponent == 0; }
|
164 |
+
constexpr bool is_floating_point() const { return exponent > 0; }
|
165 |
+
constexpr bool is_ieee_754() const {
|
166 |
+
return is_floating_point() && finite_values_only == false &&
|
167 |
+
nan_repr == NAN_IEEE_754;
|
168 |
+
}
|
169 |
+
constexpr bool has_nans() const {
|
170 |
+
return is_floating_point() && nan_repr != NAN_NONE;
|
171 |
+
}
|
172 |
+
constexpr bool has_infs() const {
|
173 |
+
return is_floating_point() && finite_values_only == false;
|
174 |
+
}
|
175 |
+
constexpr bool has_bias() const { return bias != 0; }
|
176 |
+
|
177 |
+
private:
|
178 |
+
double _floating_point_max() const {
|
179 |
+
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
|
180 |
+
"Cannot represent max/min as a double for type ", str());
|
181 |
+
|
182 |
+
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
183 |
+
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
184 |
+
max_mantissa -= 1;
|
185 |
+
}
|
186 |
+
|
187 |
+
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
188 |
+
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
189 |
+
TORCH_CHECK(exponent < 11,
|
190 |
+
"Cannot represent max/min as a double for type ", str());
|
191 |
+
max_exponent += 1;
|
192 |
+
}
|
193 |
+
|
194 |
+
// adjust the exponent to match that of a double
|
195 |
+
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
196 |
+
// is the exponent bits), there is some precedent for non-standard biases,
|
197 |
+
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
198 |
+
// but to avoid premature over complication we are just assuming the
|
199 |
+
// standard exponent bias until there is a need to support non-standard
|
200 |
+
// biases
|
201 |
+
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
202 |
+
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
203 |
+
|
204 |
+
uint64_t max_exponent_double =
|
205 |
+
max_exponent - exponent_bias + exponent_bias_double;
|
206 |
+
|
207 |
+
// shift the mantissa into the position for a double and
|
208 |
+
// the exponent
|
209 |
+
uint64_t double_raw =
|
210 |
+
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
211 |
+
|
212 |
+
return *reinterpret_cast<double*>(&double_raw);
|
213 |
+
}
|
214 |
+
|
215 |
+
constexpr std::variant<int64_t, double> _raw_max() const {
|
216 |
+
if (is_floating_point()) {
|
217 |
+
return {_floating_point_max()};
|
218 |
+
} else {
|
219 |
+
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
220 |
+
"Cannot represent max as a int64_t");
|
221 |
+
return {(int64_t(1) << mantissa) - 1};
|
222 |
+
}
|
223 |
+
}
|
224 |
+
|
225 |
+
constexpr std::variant<int64_t, double> _raw_min() const {
|
226 |
+
if (is_floating_point()) {
|
227 |
+
TORCH_CHECK(is_signed(),
|
228 |
+
"We currently assume all floating point types are signed");
|
229 |
+
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
230 |
+
|
231 |
+
double max = _floating_point_max();
|
232 |
+
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
233 |
+
uint64_t min_raw = max_raw | sign_bit_double;
|
234 |
+
return {*reinterpret_cast<double*>(&min_raw)};
|
235 |
+
} else {
|
236 |
+
TORCH_CHECK(!is_signed() || size_bits() <= 64,
|
237 |
+
"Cannot represent min as a int64_t");
|
238 |
+
if (is_signed()) {
|
239 |
+
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
240 |
+
// then perform an arithmetic shift right to set all the bits above
|
241 |
+
// (size_bits() - 1) to 1
|
242 |
+
return {INT64_MIN >> (64 - size_bits())};
|
243 |
+
} else {
|
244 |
+
return {int64_t(0)};
|
245 |
+
}
|
246 |
+
}
|
247 |
+
}
|
248 |
+
|
249 |
+
public:
|
250 |
+
// Max representable value for this scalar type.
|
251 |
+
// (accounting for bias if there is one)
|
252 |
+
constexpr std::variant<int64_t, double> max() const {
|
253 |
+
return std::visit(
|
254 |
+
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
255 |
+
_raw_max());
|
256 |
+
}
|
257 |
+
|
258 |
+
// Min representable value for this scalar type.
|
259 |
+
// (accounting for bias if there is one)
|
260 |
+
constexpr std::variant<int64_t, double> min() const {
|
261 |
+
return std::visit(
|
262 |
+
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
263 |
+
_raw_min());
|
264 |
+
}
|
265 |
+
|
266 |
+
std::string str() const {
|
267 |
+
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
268 |
+
* for floating point types (leading f) the scheme is:
|
269 |
+
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
270 |
+
* flags:
|
271 |
+
* - no-flags: means it follows IEEE 754 conventions
|
272 |
+
* - f: means finite values only (no infinities)
|
273 |
+
* - n: means nans are supported (non-standard encoding)
|
274 |
+
* for integer types the scheme is:
|
275 |
+
* `[u]int<size_bits>[b<bias>]`
|
276 |
+
* - if bias is not present it means its zero
|
277 |
+
*/
|
278 |
+
if (is_floating_point()) {
|
279 |
+
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
280 |
+
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
281 |
+
if (!is_ieee_754()) {
|
282 |
+
if (finite_values_only) {
|
283 |
+
ret += "f";
|
284 |
+
}
|
285 |
+
if (nan_repr != NAN_NONE) {
|
286 |
+
ret += "n";
|
287 |
+
}
|
288 |
+
}
|
289 |
+
return ret;
|
290 |
+
} else {
|
291 |
+
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
292 |
+
if (has_bias()) {
|
293 |
+
ret += "b" + std::to_string(bias);
|
294 |
+
}
|
295 |
+
return ret;
|
296 |
+
}
|
297 |
+
}
|
298 |
+
|
299 |
+
constexpr bool operator==(ScalarType const& other) const {
|
300 |
+
return mantissa == other.mantissa && exponent == other.exponent &&
|
301 |
+
bias == other.bias && signed_ == other.signed_ &&
|
302 |
+
finite_values_only == other.finite_values_only &&
|
303 |
+
nan_repr == other.nan_repr;
|
304 |
+
}
|
305 |
+
};
|
306 |
+
|
307 |
+
using ScalarTypeId = ScalarType::Id;
|
308 |
+
|
309 |
+
// "rust style" names generally following:
|
310 |
+
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
311 |
+
static inline constexpr auto kS4 = ScalarType::int_(4);
|
312 |
+
static inline constexpr auto kU4 = ScalarType::uint(4);
|
313 |
+
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
314 |
+
static inline constexpr auto kS8 = ScalarType::int_(8);
|
315 |
+
static inline constexpr auto kU8 = ScalarType::uint(8);
|
316 |
+
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
317 |
+
|
318 |
+
static inline constexpr auto kFE3M2f =
|
319 |
+
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
320 |
+
static inline constexpr auto kFE4M3fn =
|
321 |
+
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
322 |
+
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
323 |
+
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
324 |
+
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
325 |
+
|
326 |
+
// Fixed width style names, generally following:
|
327 |
+
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
328 |
+
static inline constexpr auto kInt4 = kS4;
|
329 |
+
static inline constexpr auto kUint4 = kU4;
|
330 |
+
static inline constexpr auto kUint4b8 = kU4B8;
|
331 |
+
static inline constexpr auto kInt8 = kS8;
|
332 |
+
static inline constexpr auto kUint8 = kU8;
|
333 |
+
static inline constexpr auto kUint8b128 = kU8B128;
|
334 |
+
|
335 |
+
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
336 |
+
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
337 |
+
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
338 |
+
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
339 |
+
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
340 |
+
|
341 |
+
// colloquial names
|
342 |
+
static inline constexpr auto kHalf = kFE5M10;
|
343 |
+
static inline constexpr auto kFloat16 = kHalf;
|
344 |
+
static inline constexpr auto kBFloat16 = kFE8M7;
|
345 |
+
|
346 |
+
static inline constexpr auto kFloat16Id = kFloat16.id();
|
347 |
+
}; // namespace vllm
|
cuda_compat.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#ifdef USE_ROCM
|
4 |
+
#include <hip/hip_runtime.h>
|
5 |
+
#endif
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#define WARP_SIZE 32
|
9 |
+
#else
|
10 |
+
#define WARP_SIZE warpSize
|
11 |
+
#endif
|
12 |
+
|
13 |
+
#ifndef USE_ROCM
|
14 |
+
#define VLLM_LDG(arg) __ldg(arg)
|
15 |
+
#else
|
16 |
+
#define VLLM_LDG(arg) *(arg)
|
17 |
+
#endif
|
18 |
+
|
19 |
+
#ifndef USE_ROCM
|
20 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
21 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
22 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
23 |
+
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
24 |
+
#else
|
25 |
+
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
26 |
+
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
27 |
+
__shfl_xor(var, lane_mask, width)
|
28 |
+
#endif
|
29 |
+
|
30 |
+
#ifndef USE_ROCM
|
31 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
32 |
+
#else
|
33 |
+
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
34 |
+
#endif
|
35 |
+
|
36 |
+
#ifndef USE_ROCM
|
37 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
38 |
+
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
39 |
+
#else
|
40 |
+
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
41 |
+
#endif
|
42 |
+
|
43 |
+
#ifndef USE_ROCM
|
44 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
45 |
+
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
46 |
+
#else
|
47 |
+
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
48 |
+
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
49 |
+
#endif
|
dispatch_utils.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
4 |
+
*/
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <torch/all.h>
|
8 |
+
|
9 |
+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
+
|
14 |
+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
16 |
+
|
17 |
+
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
18 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
19 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
20 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
21 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
22 |
+
|
23 |
+
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
24 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
25 |
+
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
26 |
+
|
27 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
28 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
29 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
30 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
31 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
32 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
33 |
+
|
34 |
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
35 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
ext-torch/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
import moe._custom_ops as ops
|
ext-torch/_custom_ops.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# neuron has torch version that doesn't even have impl_abstract
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
|
8 |
+
def register_fake(fn):
|
9 |
+
return lambda name: fn
|
10 |
+
|
11 |
+
else:
|
12 |
+
try:
|
13 |
+
from torch.library import register_fake
|
14 |
+
except ImportError:
|
15 |
+
from torch.library import impl_abstract as register_fake
|
16 |
+
|
17 |
+
try:
|
18 |
+
from ._ops import ops, add_op_namespace_prefix
|
19 |
+
except ImportError as e:
|
20 |
+
# Fallback for local development.
|
21 |
+
try:
|
22 |
+
import _moe
|
23 |
+
|
24 |
+
ops = torch._moe
|
25 |
+
|
26 |
+
def add_op_namespace_prefix(op_name: str):
|
27 |
+
return f"_quantization::{op_name}"
|
28 |
+
|
29 |
+
except ImportError:
|
30 |
+
raise e
|
31 |
+
|
32 |
+
from .scalar_type import ScalarType
|
33 |
+
|
34 |
+
def gptq_marlin_moe_repack(
|
35 |
+
b_q_weight: torch.Tensor,
|
36 |
+
perm: torch.Tensor,
|
37 |
+
size_k: int,
|
38 |
+
size_n: int,
|
39 |
+
num_bits: int,
|
40 |
+
) -> torch.Tensor:
|
41 |
+
num_experts = b_q_weight.shape[0]
|
42 |
+
assert size_k % 16 == 0
|
43 |
+
output = torch.empty(
|
44 |
+
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
45 |
+
device=b_q_weight.device,
|
46 |
+
dtype=b_q_weight.dtype,
|
47 |
+
)
|
48 |
+
for e in range(num_experts):
|
49 |
+
output[e] = ops.gptq_marlin_repack(
|
50 |
+
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
51 |
+
)
|
52 |
+
return output
|
53 |
+
|
54 |
+
|
55 |
+
def awq_marlin_moe_repack(
|
56 |
+
b_q_weight: torch.Tensor,
|
57 |
+
perm: torch.Tensor,
|
58 |
+
size_k: int,
|
59 |
+
size_n: int,
|
60 |
+
num_bits: int,
|
61 |
+
) -> torch.Tensor:
|
62 |
+
num_experts = b_q_weight.shape[0]
|
63 |
+
assert size_k % 16 == 0
|
64 |
+
output = torch.empty(
|
65 |
+
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
66 |
+
device=b_q_weight.device,
|
67 |
+
dtype=b_q_weight.dtype,
|
68 |
+
)
|
69 |
+
for e in range(num_experts):
|
70 |
+
output[e] = ops.awq_marlin_repack(b_q_weight[e], size_k, size_n, num_bits)
|
71 |
+
return output
|
72 |
+
|
73 |
+
|
74 |
+
def moe_sum(input: torch.Tensor, output: torch.Tensor):
|
75 |
+
ops.moe_sum(input, output)
|
76 |
+
|
77 |
+
|
78 |
+
def moe_align_block_size(
|
79 |
+
topk_ids: torch.Tensor,
|
80 |
+
num_experts: int,
|
81 |
+
block_size: int,
|
82 |
+
sorted_token_ids: torch.Tensor,
|
83 |
+
experts_ids: torch.Tensor,
|
84 |
+
num_tokens_post_pad: torch.Tensor,
|
85 |
+
) -> None:
|
86 |
+
ops.moe_align_block_size(
|
87 |
+
topk_ids,
|
88 |
+
num_experts,
|
89 |
+
block_size,
|
90 |
+
sorted_token_ids,
|
91 |
+
experts_ids,
|
92 |
+
num_tokens_post_pad,
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def topk_softmax(
|
97 |
+
topk_weights: torch.Tensor,
|
98 |
+
topk_ids: torch.Tensor,
|
99 |
+
token_expert_indicies: torch.Tensor,
|
100 |
+
gating_output: float,
|
101 |
+
) -> None:
|
102 |
+
ops.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
|
103 |
+
|
104 |
+
if hasattr(ops, "marlin_gemm_moe"):
|
105 |
+
|
106 |
+
@register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
|
107 |
+
def marlin_gemm_moe_fake(
|
108 |
+
a: torch.Tensor,
|
109 |
+
b_q_weights: torch.Tensor,
|
110 |
+
sorted_ids: torch.Tensor,
|
111 |
+
topk_weights: torch.Tensor,
|
112 |
+
topk_ids: torch.Tensor,
|
113 |
+
b_scales: torch.Tensor,
|
114 |
+
b_zero_points: torch.Tensor,
|
115 |
+
g_idx: torch.Tensor,
|
116 |
+
perm: torch.Tensor,
|
117 |
+
workspace: torch.Tensor,
|
118 |
+
b_q_type: ScalarType,
|
119 |
+
size_m: torch.SymInt,
|
120 |
+
size_n: torch.SymInt,
|
121 |
+
size_k: torch.SymInt,
|
122 |
+
is_k_full: bool,
|
123 |
+
num_experts: int,
|
124 |
+
topk: int,
|
125 |
+
moe_block_size: int,
|
126 |
+
replicate_input: bool,
|
127 |
+
apply_weights: bool,
|
128 |
+
) -> torch.Tensor:
|
129 |
+
return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
134 |
+
ops.silu_and_mul(out, x)
|
135 |
+
return out
|
ext-torch/fp8.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from typing import Tuple, Optional, Union
|
4 |
+
|
5 |
+
|
6 |
+
def is_hip() -> bool:
|
7 |
+
return torch.version.hip is not None
|
8 |
+
|
9 |
+
|
10 |
+
def scaled_fp8_quant(
|
11 |
+
input: torch.Tensor,
|
12 |
+
scale: Optional[torch.Tensor] = None,
|
13 |
+
num_token_padding: Optional[int] = None,
|
14 |
+
scale_ub: Optional[torch.Tensor] = None,
|
15 |
+
use_per_token_if_dynamic: bool = False,
|
16 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
17 |
+
"""
|
18 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
19 |
+
|
20 |
+
This function supports both static and dynamic quantization: If you
|
21 |
+
provide the scale, it will use static scaling and if you omit it,
|
22 |
+
the scale will be determined dynamically. The function also allows
|
23 |
+
optional padding of the output tensors for downstream kernels that
|
24 |
+
will benefit from padding.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
input: The input tensor to be quantized to FP8
|
28 |
+
scale: Optional scaling factor for the FP8 quantization
|
29 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
30 |
+
per token case
|
31 |
+
num_token_padding: If specified, pad the first dimension
|
32 |
+
of the output to at least this value.
|
33 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
34 |
+
in the dynamic quantization case.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
38 |
+
scaling factor.
|
39 |
+
"""
|
40 |
+
# This code assumes batch_dim and num_tokens are flattened
|
41 |
+
assert input.ndim == 2
|
42 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
43 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
44 |
+
out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
45 |
+
if num_token_padding:
|
46 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
47 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
48 |
+
|
49 |
+
if scale is None:
|
50 |
+
if use_per_token_if_dynamic:
|
51 |
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
52 |
+
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
53 |
+
output, input, scale, scale_ub
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
57 |
+
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
58 |
+
else:
|
59 |
+
# num_token_padding not implemented for this case
|
60 |
+
assert scale.numel() == 1 or num_token_padding is None
|
61 |
+
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
62 |
+
|
63 |
+
return output, scale
|
ext-torch/fused_marlin_moe.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fused MoE utilities for GPTQ."""
|
2 |
+
|
3 |
+
import functools
|
4 |
+
from typing import Any, Dict, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config
|
9 |
+
from .scalar_type import scalar_types
|
10 |
+
import moe._custom_ops as ops
|
11 |
+
|
12 |
+
|
13 |
+
def get_scalar_type(num_bits: int, has_zp: bool):
|
14 |
+
if has_zp:
|
15 |
+
assert num_bits == 4
|
16 |
+
return scalar_types.uint4
|
17 |
+
else:
|
18 |
+
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
19 |
+
|
20 |
+
|
21 |
+
def single_marlin_moe(
|
22 |
+
hidden_states: torch.Tensor,
|
23 |
+
w: torch.Tensor,
|
24 |
+
scales: torch.Tensor,
|
25 |
+
gating_output: torch.Tensor,
|
26 |
+
topk: int,
|
27 |
+
renormalize: bool,
|
28 |
+
g_idx: Optional[torch.Tensor] = None,
|
29 |
+
sort_indices: Optional[torch.Tensor] = None,
|
30 |
+
w_zeros: Optional[torch.Tensor] = None,
|
31 |
+
override_config: Optional[Dict[str, Any]] = None,
|
32 |
+
num_bits: int = 8,
|
33 |
+
is_k_full: bool = True,
|
34 |
+
) -> torch.Tensor:
|
35 |
+
"""
|
36 |
+
This function computes the multiplication of hidden_states with expert
|
37 |
+
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
38 |
+
Its purpose is testing and debugging the fused MoE kernel.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
|
42 |
+
- w (torch.Tensor): The set of expert weights.
|
43 |
+
- scales (torch.Tensor): The quantization scales.
|
44 |
+
- gating_output (torch.Tensor): The output of the gating operation
|
45 |
+
(before softmax).
|
46 |
+
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
|
47 |
+
- sort_indices (Optional[torch.Tensor]): Optional act_order input
|
48 |
+
permutation.
|
49 |
+
- topk (int): The number of top-k experts to select.
|
50 |
+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
51 |
+
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
|
52 |
+
- override_config (Optional[Dict[str, Any]]): Optional override
|
53 |
+
for the kernel configuration.
|
54 |
+
- num_bits (bool): The number of bits in expert weights quantization.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
58 |
+
"""
|
59 |
+
# Check constraints.
|
60 |
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
61 |
+
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
|
62 |
+
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
|
63 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
64 |
+
assert w.is_contiguous(), "Expert weights must be contiguous"
|
65 |
+
assert hidden_states.dtype == torch.float16
|
66 |
+
assert num_bits in [4, 8]
|
67 |
+
|
68 |
+
M, K = hidden_states.shape
|
69 |
+
E = w.shape[0]
|
70 |
+
N = w.shape[2] // (num_bits // 2)
|
71 |
+
|
72 |
+
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize)
|
73 |
+
|
74 |
+
# This might not be an optimal config for a single MMM
|
75 |
+
get_config_func = functools.partial(
|
76 |
+
try_get_optimal_moe_config,
|
77 |
+
w.shape,
|
78 |
+
w.shape,
|
79 |
+
topk_ids.shape[1],
|
80 |
+
None,
|
81 |
+
override_config=override_config,
|
82 |
+
is_marlin=True,
|
83 |
+
)
|
84 |
+
config = get_config_func(M)
|
85 |
+
|
86 |
+
block_size_m = config["BLOCK_SIZE_M"]
|
87 |
+
|
88 |
+
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
89 |
+
|
90 |
+
max_workspace_size = (N // 64) * 16
|
91 |
+
workspace = torch.zeros(
|
92 |
+
max_workspace_size,
|
93 |
+
dtype=torch.int,
|
94 |
+
device=hidden_states.device,
|
95 |
+
requires_grad=False,
|
96 |
+
)
|
97 |
+
|
98 |
+
has_zero_point = w_zeros is not None
|
99 |
+
if w_zeros is None:
|
100 |
+
w_zeros = torch.empty(
|
101 |
+
(0, 0),
|
102 |
+
dtype=hidden_states.dtype,
|
103 |
+
device=hidden_states.device,
|
104 |
+
requires_grad=False,
|
105 |
+
)
|
106 |
+
|
107 |
+
if g_idx is None:
|
108 |
+
g_idx = torch.empty(
|
109 |
+
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
|
110 |
+
)
|
111 |
+
|
112 |
+
if sort_indices is None:
|
113 |
+
sort_indices = torch.empty(
|
114 |
+
(0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
|
115 |
+
)
|
116 |
+
|
117 |
+
scalar_type = get_scalar_type(num_bits, has_zero_point)
|
118 |
+
|
119 |
+
intermediate_cache = ops.ops.marlin_gemm_moe(
|
120 |
+
hidden_states,
|
121 |
+
w,
|
122 |
+
sorted_token_ids,
|
123 |
+
topk_weights,
|
124 |
+
topk_ids,
|
125 |
+
scales,
|
126 |
+
w_zeros,
|
127 |
+
g_idx,
|
128 |
+
sort_indices,
|
129 |
+
workspace,
|
130 |
+
scalar_type.id,
|
131 |
+
M,
|
132 |
+
N,
|
133 |
+
K,
|
134 |
+
is_k_full,
|
135 |
+
E,
|
136 |
+
topk,
|
137 |
+
block_size_m,
|
138 |
+
True,
|
139 |
+
False,
|
140 |
+
)
|
141 |
+
|
142 |
+
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
143 |
+
|
144 |
+
|
145 |
+
def fused_marlin_moe(
|
146 |
+
hidden_states: torch.Tensor,
|
147 |
+
w1: torch.Tensor,
|
148 |
+
w2: torch.Tensor,
|
149 |
+
w1_scale: torch.Tensor,
|
150 |
+
w2_scale: torch.Tensor,
|
151 |
+
gating_output: torch.Tensor,
|
152 |
+
topk_weights: torch.Tensor,
|
153 |
+
topk_ids: torch.Tensor,
|
154 |
+
g_idx1: Optional[torch.Tensor] = None,
|
155 |
+
g_idx2: Optional[torch.Tensor] = None,
|
156 |
+
sort_indices1: Optional[torch.Tensor] = None,
|
157 |
+
sort_indices2: Optional[torch.Tensor] = None,
|
158 |
+
w1_zeros: Optional[torch.Tensor] = None,
|
159 |
+
w2_zeros: Optional[torch.Tensor] = None,
|
160 |
+
override_config: Optional[Dict[str, Any]] = None,
|
161 |
+
num_bits: int = 8,
|
162 |
+
is_k_full: bool = True,
|
163 |
+
) -> torch.Tensor:
|
164 |
+
"""
|
165 |
+
This function computes a Mixture of Experts (MoE) layer using two sets of
|
166 |
+
weights, w1 and w2, and top-k gating mechanism.
|
167 |
+
|
168 |
+
Parameters:
|
169 |
+
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
170 |
+
- w1 (torch.Tensor): The first set of expert weights.
|
171 |
+
- w2 (torch.Tensor): The second set of expert weights.
|
172 |
+
- w1_scale (torch.Tensor): Scale to be used for w1.
|
173 |
+
- w2_scale (torch.Tensor): Scale to be used for w2.
|
174 |
+
- gating_output (torch.Tensor): The output of the gating operation
|
175 |
+
(before softmax).
|
176 |
+
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
177 |
+
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
178 |
+
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
179 |
+
permutation.
|
180 |
+
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
181 |
+
permutation.
|
182 |
+
- topk_weights (torch.Tensor): Top-k weights.
|
183 |
+
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
184 |
+
- override_config (Optional[Dict[str, Any]]): Optional override
|
185 |
+
for the kernel configuration.
|
186 |
+
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
187 |
+
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
188 |
+
- num_bits (bool): The number of bits in expert weights quantization.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
192 |
+
"""
|
193 |
+
# Check constraints.
|
194 |
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
195 |
+
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
196 |
+
assert hidden_states.shape[1] == w2.shape[2] // (
|
197 |
+
num_bits // 2
|
198 |
+
), "Hidden size mismatch w2"
|
199 |
+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
200 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
201 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
202 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
203 |
+
assert hidden_states.dtype == torch.float16
|
204 |
+
assert num_bits in [4, 8]
|
205 |
+
|
206 |
+
has_no_act_order = (
|
207 |
+
g_idx1 is None
|
208 |
+
and g_idx2 is None
|
209 |
+
and sort_indices1 is None
|
210 |
+
and sort_indices2 is None
|
211 |
+
)
|
212 |
+
has_all_act_order = (
|
213 |
+
g_idx1 is not None
|
214 |
+
and g_idx2 is not None
|
215 |
+
and sort_indices1 is not None
|
216 |
+
and sort_indices2 is not None
|
217 |
+
)
|
218 |
+
assert has_no_act_order or has_all_act_order, (
|
219 |
+
"g_idx and sorted_indices " "must be all not None or must be all None"
|
220 |
+
)
|
221 |
+
|
222 |
+
has_no_zp = w1_zeros is None and w2_zeros is None
|
223 |
+
has_all_zp = w1_zeros is not None and w2_zeros is not None
|
224 |
+
assert has_no_zp or has_all_zp, (
|
225 |
+
"zero points must be both not None or " "must be both None"
|
226 |
+
)
|
227 |
+
|
228 |
+
M, K = hidden_states.shape
|
229 |
+
E = w1.shape[0]
|
230 |
+
N = w2.shape[1] * 16
|
231 |
+
topk = topk_ids.shape[1]
|
232 |
+
|
233 |
+
get_config_func = functools.partial(
|
234 |
+
try_get_optimal_moe_config,
|
235 |
+
w1.shape,
|
236 |
+
w2.shape,
|
237 |
+
topk_ids.shape[1],
|
238 |
+
None,
|
239 |
+
override_config=override_config,
|
240 |
+
is_marlin=True,
|
241 |
+
)
|
242 |
+
config = get_config_func(M)
|
243 |
+
|
244 |
+
block_size_m = config["BLOCK_SIZE_M"]
|
245 |
+
|
246 |
+
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
247 |
+
|
248 |
+
max_workspace_size = (max(2 * N, K) // 64) * 16
|
249 |
+
workspace = torch.zeros(
|
250 |
+
max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False
|
251 |
+
)
|
252 |
+
|
253 |
+
if has_no_zp:
|
254 |
+
w1_zeros = torch.empty(
|
255 |
+
(0, 0),
|
256 |
+
dtype=hidden_states.dtype,
|
257 |
+
device=hidden_states.device,
|
258 |
+
requires_grad=False,
|
259 |
+
)
|
260 |
+
w2_zeros = torch.empty(
|
261 |
+
(0, 0),
|
262 |
+
dtype=hidden_states.dtype,
|
263 |
+
device=hidden_states.device,
|
264 |
+
requires_grad=False,
|
265 |
+
)
|
266 |
+
|
267 |
+
if has_no_act_order:
|
268 |
+
g_idx1 = torch.empty(
|
269 |
+
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
|
270 |
+
)
|
271 |
+
g_idx2 = torch.empty(
|
272 |
+
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
|
273 |
+
)
|
274 |
+
sort_indices1 = torch.empty(
|
275 |
+
(0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
|
276 |
+
)
|
277 |
+
sort_indices2 = torch.empty(
|
278 |
+
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
|
279 |
+
)
|
280 |
+
|
281 |
+
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
|
282 |
+
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
|
283 |
+
|
284 |
+
intermediate_cache2 = torch.empty(
|
285 |
+
(M * topk_ids.shape[1], N),
|
286 |
+
device=hidden_states.device,
|
287 |
+
dtype=hidden_states.dtype,
|
288 |
+
)
|
289 |
+
|
290 |
+
intermediate_cache1 = ops.ops.marlin_gemm_moe(
|
291 |
+
hidden_states,
|
292 |
+
w1,
|
293 |
+
sorted_token_ids,
|
294 |
+
topk_weights,
|
295 |
+
topk_ids,
|
296 |
+
w1_scale,
|
297 |
+
w1_zeros,
|
298 |
+
g_idx1,
|
299 |
+
sort_indices1,
|
300 |
+
workspace,
|
301 |
+
scalar_type1.id,
|
302 |
+
M,
|
303 |
+
2 * N,
|
304 |
+
K,
|
305 |
+
is_k_full,
|
306 |
+
E,
|
307 |
+
topk,
|
308 |
+
block_size_m,
|
309 |
+
True,
|
310 |
+
False,
|
311 |
+
)
|
312 |
+
|
313 |
+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
314 |
+
|
315 |
+
intermediate_cache3 = ops.ops.marlin_gemm_moe(
|
316 |
+
intermediate_cache2,
|
317 |
+
w2,
|
318 |
+
sorted_token_ids,
|
319 |
+
topk_weights,
|
320 |
+
topk_ids,
|
321 |
+
w2_scale,
|
322 |
+
w2_zeros,
|
323 |
+
g_idx2,
|
324 |
+
sort_indices2,
|
325 |
+
workspace,
|
326 |
+
scalar_type2.id,
|
327 |
+
M,
|
328 |
+
K,
|
329 |
+
N,
|
330 |
+
is_k_full,
|
331 |
+
E,
|
332 |
+
topk,
|
333 |
+
block_size_m,
|
334 |
+
False,
|
335 |
+
True,
|
336 |
+
)
|
337 |
+
|
338 |
+
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
ext-torch/fused_moe.py
ADDED
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fused MoE kernel."""
|
2 |
+
|
3 |
+
import functools
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import triton
|
10 |
+
import triton.language as tl
|
11 |
+
|
12 |
+
from .platforms import current_platform
|
13 |
+
from .fp8 import scaled_fp8_quant
|
14 |
+
import moe._custom_ops as ops
|
15 |
+
|
16 |
+
VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
|
17 |
+
|
18 |
+
|
19 |
+
@triton.jit
|
20 |
+
def fused_moe_kernel(
|
21 |
+
# Pointers to matrices
|
22 |
+
a_ptr,
|
23 |
+
b_ptr,
|
24 |
+
c_ptr,
|
25 |
+
a_scale_ptr,
|
26 |
+
b_scale_ptr,
|
27 |
+
topk_weights_ptr,
|
28 |
+
sorted_token_ids_ptr,
|
29 |
+
expert_ids_ptr,
|
30 |
+
num_tokens_post_padded_ptr,
|
31 |
+
# Matrix dimensions
|
32 |
+
N,
|
33 |
+
K,
|
34 |
+
EM,
|
35 |
+
num_valid_tokens,
|
36 |
+
# The stride variables represent how much to increase the ptr by when
|
37 |
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
38 |
+
# how much to increase `a_ptr` by to get the element one row down
|
39 |
+
# (A has M rows).
|
40 |
+
stride_am,
|
41 |
+
stride_ak,
|
42 |
+
stride_be,
|
43 |
+
stride_bk,
|
44 |
+
stride_bn,
|
45 |
+
stride_cm,
|
46 |
+
stride_cn,
|
47 |
+
stride_bse,
|
48 |
+
stride_bsn,
|
49 |
+
# Meta-parameters
|
50 |
+
BLOCK_SIZE_M: tl.constexpr,
|
51 |
+
BLOCK_SIZE_N: tl.constexpr,
|
52 |
+
BLOCK_SIZE_K: tl.constexpr,
|
53 |
+
GROUP_SIZE_M: tl.constexpr,
|
54 |
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
55 |
+
top_k: tl.constexpr,
|
56 |
+
compute_type: tl.constexpr,
|
57 |
+
use_fp8_w8a8: tl.constexpr,
|
58 |
+
use_int8_w8a16: tl.constexpr,
|
59 |
+
):
|
60 |
+
"""
|
61 |
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
62 |
+
token and expert matrices.
|
63 |
+
|
64 |
+
Key Parameters:
|
65 |
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
66 |
+
be any shape representing batches and K is the feature dimension of
|
67 |
+
each token.
|
68 |
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
69 |
+
the number of experts, K is the input feature dimension, and N is
|
70 |
+
the output feature dimension.
|
71 |
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
72 |
+
total number of tokens post padding, topk is the number of times
|
73 |
+
each token is repeated, and N is the output feature dimension.
|
74 |
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
75 |
+
repeated topk times and arranged by the expert index they are
|
76 |
+
assigned to.
|
77 |
+
- expert_ids: A tensor containing the indices of the expert for each
|
78 |
+
block. It determines which expert matrix from B should be used for
|
79 |
+
each block in A.
|
80 |
+
This kernel performs the multiplication of a token by its corresponding
|
81 |
+
expert matrix as determined by `expert_ids`. The sorting of
|
82 |
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
83 |
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
84 |
+
multiplication across different blocks processed by the same expert.
|
85 |
+
"""
|
86 |
+
# -----------------------------------------------------------
|
87 |
+
# Map program ids `pid` to the block of C it should compute.
|
88 |
+
# This is done in a grouped ordering to promote L2 data reuse.
|
89 |
+
pid = tl.program_id(axis=0)
|
90 |
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
91 |
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
92 |
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
93 |
+
group_id = pid // num_pid_in_group
|
94 |
+
first_pid_m = group_id * GROUP_SIZE_M
|
95 |
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
96 |
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
97 |
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
98 |
+
|
99 |
+
# ----------------------------------------------------------
|
100 |
+
# Create pointers for the first blocks of A and B.
|
101 |
+
# We will advance this pointer as we move in the K direction
|
102 |
+
# and accumulate
|
103 |
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
104 |
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
105 |
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
106 |
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
107 |
+
return
|
108 |
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
109 |
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
110 |
+
token_mask = offs_token < num_valid_tokens
|
111 |
+
|
112 |
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
113 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
114 |
+
a_ptrs = a_ptr + (
|
115 |
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
116 |
+
)
|
117 |
+
|
118 |
+
off_experts = tl.load(expert_ids_ptr + pid_m)
|
119 |
+
b_ptrs = (
|
120 |
+
b_ptr
|
121 |
+
+ off_experts * stride_be
|
122 |
+
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
123 |
+
)
|
124 |
+
if use_int8_w8a16:
|
125 |
+
b_scale_ptrs = (
|
126 |
+
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
127 |
+
)
|
128 |
+
b_scale = tl.load(b_scale_ptrs)
|
129 |
+
|
130 |
+
if use_fp8_w8a8:
|
131 |
+
a_scale = tl.load(a_scale_ptr)
|
132 |
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
133 |
+
|
134 |
+
# -----------------------------------------------------------
|
135 |
+
# Iterate to compute a block of the C matrix.
|
136 |
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
137 |
+
# of fp32 values for higher accuracy.
|
138 |
+
# `accumulator` will be converted back to fp16 after the loop.
|
139 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
140 |
+
|
141 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
142 |
+
# Load the next block of A and B, generate a mask by checking the
|
143 |
+
# K dimension.
|
144 |
+
a = tl.load(
|
145 |
+
a_ptrs,
|
146 |
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
147 |
+
other=0.0,
|
148 |
+
)
|
149 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
150 |
+
# We accumulate along the K dimension.
|
151 |
+
if use_int8_w8a16:
|
152 |
+
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
153 |
+
elif use_fp8_w8a8:
|
154 |
+
accumulator = tl.dot(a, b, acc=accumulator)
|
155 |
+
else:
|
156 |
+
accumulator += tl.dot(a, b)
|
157 |
+
# Advance the ptrs to the next K block.
|
158 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
159 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
160 |
+
|
161 |
+
if MUL_ROUTED_WEIGHT:
|
162 |
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
163 |
+
accumulator = accumulator * moe_weight[:, None]
|
164 |
+
if use_int8_w8a16:
|
165 |
+
accumulator = (accumulator * b_scale).to(compute_type)
|
166 |
+
elif use_fp8_w8a8:
|
167 |
+
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
168 |
+
else:
|
169 |
+
accumulator = accumulator.to(compute_type)
|
170 |
+
# -----------------------------------------------------------
|
171 |
+
# Write back the block of the output
|
172 |
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
173 |
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
174 |
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
175 |
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
176 |
+
|
177 |
+
|
178 |
+
def moe_align_block_size(
|
179 |
+
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
180 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
181 |
+
"""
|
182 |
+
Aligns the token distribution across experts to be compatible with block
|
183 |
+
size for matrix multiplication.
|
184 |
+
|
185 |
+
Parameters:
|
186 |
+
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
187 |
+
top-k expert indices for each token.
|
188 |
+
- block_size: The block size used in block matrix multiplication.
|
189 |
+
- num_experts: The total number of experts.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
- sorted_token_ids: A tensor containing the sorted token indices according
|
193 |
+
to their allocated expert.
|
194 |
+
- expert_ids: A tensor indicating the assigned expert index for each block.
|
195 |
+
- num_tokens_post_padded: The total number of tokens after padding,
|
196 |
+
ensuring divisibility by block_size.
|
197 |
+
|
198 |
+
This function pads the number of tokens that each expert needs to process
|
199 |
+
so that it is divisible by block_size.
|
200 |
+
Padding ensures that during block matrix multiplication, the dimensions
|
201 |
+
align correctly.
|
202 |
+
|
203 |
+
Example:
|
204 |
+
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
205 |
+
block_size = 4, and num_experts = 4:
|
206 |
+
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
207 |
+
with each expert needing to process 3 tokens.
|
208 |
+
- As block_size is 4, we pad 1 token for each expert.
|
209 |
+
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
210 |
+
- Then append padding tokens [12, 12, 12, 12] for each block.
|
211 |
+
- After sorting by expert index, we obtain token_ids
|
212 |
+
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
213 |
+
Tokens 12 are non-existent (padding) and are ignored in
|
214 |
+
the subsequent matrix multiplication.
|
215 |
+
- The padding ensures that the total number of tokens is now divisible
|
216 |
+
by block_size for proper block matrix operations.
|
217 |
+
"""
|
218 |
+
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
219 |
+
sorted_ids = torch.empty(
|
220 |
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
221 |
+
)
|
222 |
+
sorted_ids.fill_(topk_ids.numel())
|
223 |
+
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
224 |
+
expert_ids = torch.empty(
|
225 |
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
226 |
+
)
|
227 |
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
228 |
+
ops.moe_align_block_size(
|
229 |
+
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
230 |
+
)
|
231 |
+
return sorted_ids, expert_ids, num_tokens_post_pad
|
232 |
+
|
233 |
+
|
234 |
+
def invoke_fused_moe_kernel(
|
235 |
+
A: torch.Tensor,
|
236 |
+
B: torch.Tensor,
|
237 |
+
C: torch.Tensor,
|
238 |
+
A_scale: Optional[torch.Tensor],
|
239 |
+
B_scale: Optional[torch.Tensor],
|
240 |
+
topk_weights: torch.Tensor,
|
241 |
+
topk_ids: torch.Tensor,
|
242 |
+
sorted_token_ids: torch.Tensor,
|
243 |
+
expert_ids: torch.Tensor,
|
244 |
+
num_tokens_post_padded: torch.Tensor,
|
245 |
+
mul_routed_weight: bool,
|
246 |
+
top_k: int,
|
247 |
+
config: Dict[str, Any],
|
248 |
+
compute_type: tl.dtype,
|
249 |
+
use_fp8_w8a8: bool,
|
250 |
+
use_int8_w8a16: bool,
|
251 |
+
) -> None:
|
252 |
+
assert topk_weights.stride(1) == 1
|
253 |
+
assert sorted_token_ids.stride(0) == 1
|
254 |
+
|
255 |
+
if use_fp8_w8a8:
|
256 |
+
A, A_scale = scaled_fp8_quant(A, A_scale)
|
257 |
+
assert B_scale is not None
|
258 |
+
elif use_int8_w8a16:
|
259 |
+
assert B_scale is not None
|
260 |
+
else:
|
261 |
+
assert A_scale is None
|
262 |
+
assert B_scale is None
|
263 |
+
|
264 |
+
grid = lambda META: (
|
265 |
+
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
266 |
+
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
267 |
+
)
|
268 |
+
|
269 |
+
fused_moe_kernel[grid](
|
270 |
+
A,
|
271 |
+
B,
|
272 |
+
C,
|
273 |
+
A_scale,
|
274 |
+
B_scale,
|
275 |
+
topk_weights,
|
276 |
+
sorted_token_ids,
|
277 |
+
expert_ids,
|
278 |
+
num_tokens_post_padded,
|
279 |
+
B.shape[1],
|
280 |
+
B.shape[2],
|
281 |
+
sorted_token_ids.shape[0],
|
282 |
+
topk_ids.numel(),
|
283 |
+
A.stride(0),
|
284 |
+
A.stride(1),
|
285 |
+
B.stride(0),
|
286 |
+
B.stride(2),
|
287 |
+
B.stride(1),
|
288 |
+
C.stride(1),
|
289 |
+
C.stride(2),
|
290 |
+
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
|
291 |
+
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
|
292 |
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
293 |
+
top_k=top_k,
|
294 |
+
compute_type=compute_type,
|
295 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
296 |
+
use_int8_w8a16=use_int8_w8a16,
|
297 |
+
**config,
|
298 |
+
)
|
299 |
+
|
300 |
+
|
301 |
+
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
302 |
+
device_name = current_platform.get_device_name().replace(" ", "_")
|
303 |
+
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
304 |
+
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
305 |
+
|
306 |
+
|
307 |
+
@functools.lru_cache
|
308 |
+
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
309 |
+
"""
|
310 |
+
Return optimized configurations for the fused MoE kernel.
|
311 |
+
|
312 |
+
The return value will be a dictionary that maps an irregular grid of
|
313 |
+
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
314 |
+
kernel on a given batch size bs, the closest batch size in the grid should
|
315 |
+
be picked and the associated configuration chosen to invoke the kernel.
|
316 |
+
"""
|
317 |
+
|
318 |
+
# First look up if an optimized configuration is available in the configs
|
319 |
+
# directory
|
320 |
+
json_file_name = get_config_file_name(E, N, dtype)
|
321 |
+
|
322 |
+
config_file_path = os.path.join(
|
323 |
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
324 |
+
)
|
325 |
+
if os.path.exists(config_file_path):
|
326 |
+
with open(config_file_path) as f:
|
327 |
+
# If a configuration has been found, return it
|
328 |
+
return {int(key): val for key, val in json.load(f).items()}
|
329 |
+
|
330 |
+
# If no optimized configuration is available, we will use the default
|
331 |
+
# configuration
|
332 |
+
return None
|
333 |
+
|
334 |
+
|
335 |
+
def get_default_config(
|
336 |
+
M: int,
|
337 |
+
E: int,
|
338 |
+
N: int,
|
339 |
+
K: int,
|
340 |
+
topk: int,
|
341 |
+
dtype: Optional[str],
|
342 |
+
is_marlin: bool,
|
343 |
+
) -> Dict[str, int]:
|
344 |
+
config = {
|
345 |
+
"BLOCK_SIZE_M": 64,
|
346 |
+
"BLOCK_SIZE_N": 64,
|
347 |
+
"BLOCK_SIZE_K": 32,
|
348 |
+
"GROUP_SIZE_M": 8,
|
349 |
+
}
|
350 |
+
# A heuristic: fused marlin works faster with this config for small M
|
351 |
+
if M <= E or (is_marlin and M <= 32):
|
352 |
+
config = {
|
353 |
+
"BLOCK_SIZE_M": 16,
|
354 |
+
"BLOCK_SIZE_N": 32,
|
355 |
+
"BLOCK_SIZE_K": 64,
|
356 |
+
"GROUP_SIZE_M": 1,
|
357 |
+
}
|
358 |
+
return config
|
359 |
+
|
360 |
+
|
361 |
+
def try_get_optimal_moe_config(
|
362 |
+
w1_shape: Tuple[int, ...],
|
363 |
+
w2_shape: Tuple[int, ...],
|
364 |
+
top_k: int,
|
365 |
+
dtype: Optional[str],
|
366 |
+
M: int,
|
367 |
+
override_config: Optional[Dict[str, Any]] = None,
|
368 |
+
is_marlin: bool = False,
|
369 |
+
):
|
370 |
+
if override_config:
|
371 |
+
config = override_config
|
372 |
+
else:
|
373 |
+
# First try to load optimal config from the file
|
374 |
+
E, _, N = w2_shape
|
375 |
+
configs = get_moe_configs(E, N, dtype)
|
376 |
+
|
377 |
+
if configs:
|
378 |
+
# If an optimal configuration map has been found, look up the
|
379 |
+
# optimal config
|
380 |
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
381 |
+
else:
|
382 |
+
# Else use the default config
|
383 |
+
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
|
384 |
+
return config
|
385 |
+
|
386 |
+
|
387 |
+
def fused_topk(
|
388 |
+
hidden_states: torch.Tensor,
|
389 |
+
gating_output: torch.Tensor,
|
390 |
+
topk: int,
|
391 |
+
renormalize: bool,
|
392 |
+
):
|
393 |
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
394 |
+
|
395 |
+
M, _ = hidden_states.shape
|
396 |
+
|
397 |
+
topk_weights = torch.empty(
|
398 |
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
399 |
+
)
|
400 |
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
401 |
+
token_expert_indicies = torch.empty(
|
402 |
+
M, topk, dtype=torch.int32, device=hidden_states.device
|
403 |
+
)
|
404 |
+
|
405 |
+
ops.topk_softmax(
|
406 |
+
topk_weights,
|
407 |
+
topk_ids,
|
408 |
+
token_expert_indicies,
|
409 |
+
gating_output.float(), # TODO(woosuk): Optimize this.
|
410 |
+
)
|
411 |
+
del token_expert_indicies # Not used. Will be used in the future.
|
412 |
+
|
413 |
+
if renormalize:
|
414 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
415 |
+
|
416 |
+
return topk_weights, topk_ids
|
417 |
+
|
418 |
+
|
419 |
+
# This is used by the Deepseek-V2 model
|
420 |
+
def grouped_topk(
|
421 |
+
hidden_states: torch.Tensor,
|
422 |
+
gating_output: torch.Tensor,
|
423 |
+
topk: int,
|
424 |
+
renormalize: bool,
|
425 |
+
num_expert_group: int = 0,
|
426 |
+
topk_group: int = 0,
|
427 |
+
):
|
428 |
+
|
429 |
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
430 |
+
|
431 |
+
scores = torch.softmax(gating_output, dim=-1)
|
432 |
+
num_token = scores.shape[0]
|
433 |
+
group_scores = (
|
434 |
+
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
435 |
+
) # [n, n_group]
|
436 |
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
437 |
+
1
|
438 |
+
] # [n, top_k_group]
|
439 |
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
440 |
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
441 |
+
score_mask = (
|
442 |
+
group_mask.unsqueeze(-1)
|
443 |
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
444 |
+
.reshape(num_token, -1)
|
445 |
+
) # [n, e]
|
446 |
+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
447 |
+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
448 |
+
|
449 |
+
if renormalize:
|
450 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
451 |
+
|
452 |
+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
453 |
+
|
454 |
+
|
455 |
+
def get_config_dtype_str(
|
456 |
+
dtype: torch.dtype,
|
457 |
+
use_int8_w8a16: Optional[bool] = False,
|
458 |
+
use_fp8_w8a8: Optional[bool] = False,
|
459 |
+
):
|
460 |
+
if use_fp8_w8a8:
|
461 |
+
return "fp8_w8a8"
|
462 |
+
elif use_int8_w8a16:
|
463 |
+
return "int8_w8a16"
|
464 |
+
elif dtype == torch.float:
|
465 |
+
# avoiding cases where kernel fails when float32 MoE
|
466 |
+
# use fp16/bfloat16 configs
|
467 |
+
return "float32"
|
468 |
+
return None
|
469 |
+
|
470 |
+
|
471 |
+
def fused_experts(
|
472 |
+
hidden_states: torch.Tensor,
|
473 |
+
w1: torch.Tensor,
|
474 |
+
w2: torch.Tensor,
|
475 |
+
topk_weights: torch.Tensor,
|
476 |
+
topk_ids: torch.Tensor,
|
477 |
+
inplace: bool = False,
|
478 |
+
override_config: Optional[Dict[str, Any]] = None,
|
479 |
+
use_fp8_w8a8: bool = False,
|
480 |
+
use_int8_w8a16: bool = False,
|
481 |
+
w1_scale: Optional[torch.Tensor] = None,
|
482 |
+
w2_scale: Optional[torch.Tensor] = None,
|
483 |
+
a1_scale: Optional[torch.Tensor] = None,
|
484 |
+
a2_scale: Optional[torch.Tensor] = None,
|
485 |
+
):
|
486 |
+
# Check constraints.
|
487 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
488 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
489 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
490 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
491 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
492 |
+
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
493 |
+
|
494 |
+
num_tokens, _ = hidden_states.shape
|
495 |
+
E, N, _ = w1.shape
|
496 |
+
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
497 |
+
# https://github.com/vllm-project/vllm/issues/5938
|
498 |
+
CHUNK_SIZE = VLLM_FUSED_MOE_CHUNK_SIZE
|
499 |
+
M = min(num_tokens, CHUNK_SIZE)
|
500 |
+
config_dtype = get_config_dtype_str(
|
501 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
502 |
+
use_int8_w8a16=use_int8_w8a16,
|
503 |
+
dtype=hidden_states.dtype,
|
504 |
+
)
|
505 |
+
|
506 |
+
get_config_func = functools.partial(
|
507 |
+
try_get_optimal_moe_config,
|
508 |
+
w1.shape,
|
509 |
+
w2.shape,
|
510 |
+
topk_ids.shape[1],
|
511 |
+
config_dtype,
|
512 |
+
override_config=override_config,
|
513 |
+
)
|
514 |
+
|
515 |
+
config = get_config_func(M)
|
516 |
+
|
517 |
+
intermediate_cache1 = torch.empty(
|
518 |
+
(M, topk_ids.shape[1], N),
|
519 |
+
device=hidden_states.device,
|
520 |
+
dtype=hidden_states.dtype,
|
521 |
+
)
|
522 |
+
intermediate_cache2 = torch.empty(
|
523 |
+
(M * topk_ids.shape[1], N // 2),
|
524 |
+
device=hidden_states.device,
|
525 |
+
dtype=hidden_states.dtype,
|
526 |
+
)
|
527 |
+
intermediate_cache3 = torch.empty(
|
528 |
+
(M, topk_ids.shape[1], w2.shape[1]),
|
529 |
+
device=hidden_states.device,
|
530 |
+
dtype=hidden_states.dtype,
|
531 |
+
)
|
532 |
+
|
533 |
+
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
534 |
+
|
535 |
+
if inplace:
|
536 |
+
out_hidden_states = hidden_states
|
537 |
+
else:
|
538 |
+
out_hidden_states = torch.empty_like(hidden_states)
|
539 |
+
|
540 |
+
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
541 |
+
begin_chunk_idx, end_chunk_idx = (
|
542 |
+
chunk * CHUNK_SIZE,
|
543 |
+
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
544 |
+
)
|
545 |
+
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
546 |
+
tokens_in_chunk, _ = curr_hidden_states.shape
|
547 |
+
|
548 |
+
if tokens_in_chunk == 0:
|
549 |
+
break
|
550 |
+
|
551 |
+
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
552 |
+
# Adjust the intermediate cache size and config for the last
|
553 |
+
# chunk. Note that in most cases we only have one chunk
|
554 |
+
# so the cache size and config are already set correctly and
|
555 |
+
# do not need to be adjusted.
|
556 |
+
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
557 |
+
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
558 |
+
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
559 |
+
config = get_config_func(tokens_in_chunk)
|
560 |
+
|
561 |
+
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
562 |
+
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
563 |
+
|
564 |
+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
565 |
+
curr_topk_ids, config["BLOCK_SIZE_M"], E
|
566 |
+
)
|
567 |
+
|
568 |
+
invoke_fused_moe_kernel(
|
569 |
+
curr_hidden_states,
|
570 |
+
w1,
|
571 |
+
intermediate_cache1,
|
572 |
+
a1_scale,
|
573 |
+
w1_scale,
|
574 |
+
curr_topk_weights,
|
575 |
+
curr_topk_ids,
|
576 |
+
sorted_token_ids,
|
577 |
+
expert_ids,
|
578 |
+
num_tokens_post_padded,
|
579 |
+
False,
|
580 |
+
topk_ids.shape[1],
|
581 |
+
config,
|
582 |
+
compute_type=compute_type,
|
583 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
584 |
+
use_int8_w8a16=use_int8_w8a16,
|
585 |
+
)
|
586 |
+
|
587 |
+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
588 |
+
|
589 |
+
invoke_fused_moe_kernel(
|
590 |
+
intermediate_cache2,
|
591 |
+
w2,
|
592 |
+
intermediate_cache3,
|
593 |
+
a2_scale,
|
594 |
+
w2_scale,
|
595 |
+
curr_topk_weights,
|
596 |
+
curr_topk_ids,
|
597 |
+
sorted_token_ids,
|
598 |
+
expert_ids,
|
599 |
+
num_tokens_post_padded,
|
600 |
+
True,
|
601 |
+
1,
|
602 |
+
config,
|
603 |
+
compute_type=compute_type,
|
604 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
605 |
+
use_int8_w8a16=use_int8_w8a16,
|
606 |
+
)
|
607 |
+
|
608 |
+
ops.moe_sum(
|
609 |
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
610 |
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
611 |
+
)
|
612 |
+
return out_hidden_states
|
613 |
+
|
614 |
+
|
615 |
+
def fused_moe(
|
616 |
+
hidden_states: torch.Tensor,
|
617 |
+
w1: torch.Tensor,
|
618 |
+
w2: torch.Tensor,
|
619 |
+
gating_output: torch.Tensor,
|
620 |
+
topk: int,
|
621 |
+
renormalize: bool,
|
622 |
+
inplace: bool = False,
|
623 |
+
override_config: Optional[Dict[str, Any]] = None,
|
624 |
+
use_grouped_topk: bool = False,
|
625 |
+
num_expert_group: Optional[int] = None,
|
626 |
+
topk_group: Optional[int] = None,
|
627 |
+
custom_routing_function: Optional[Callable] = None,
|
628 |
+
use_fp8_w8a8: bool = False,
|
629 |
+
use_int8_w8a16: bool = False,
|
630 |
+
w1_scale: Optional[torch.Tensor] = None,
|
631 |
+
w2_scale: Optional[torch.Tensor] = None,
|
632 |
+
a1_scale: Optional[torch.Tensor] = None,
|
633 |
+
a2_scale: Optional[torch.Tensor] = None,
|
634 |
+
) -> torch.Tensor:
|
635 |
+
"""
|
636 |
+
This function computes a Mixture of Experts (MoE) layer using two sets of
|
637 |
+
weights, w1 and w2, and top-k gating mechanism.
|
638 |
+
|
639 |
+
Parameters:
|
640 |
+
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
641 |
+
- w1 (torch.Tensor): The first set of expert weights.
|
642 |
+
- w2 (torch.Tensor): The second set of expert weights.
|
643 |
+
- gating_output (torch.Tensor): The output of the gating operation
|
644 |
+
(before softmax).
|
645 |
+
- topk (int): The number of top-k experts to select.
|
646 |
+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
647 |
+
- inplace (bool): If True, perform the operation in-place.
|
648 |
+
Defaults to False.
|
649 |
+
- override_config (Optional[Dict[str, Any]]): Optional override
|
650 |
+
for the kernel configuration.
|
651 |
+
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
652 |
+
- topk_group: Optional[int]: additional parameter for grouped_topk
|
653 |
+
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
654 |
+
note: Deepseekv2 model uses grouped_topk
|
655 |
+
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
656 |
+
products for w1 and w2. Defaults to False.
|
657 |
+
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
658 |
+
products for w1 and w2. Defaults to False.
|
659 |
+
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
660 |
+
w1.
|
661 |
+
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
662 |
+
w2.
|
663 |
+
|
664 |
+
Returns:
|
665 |
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
666 |
+
"""
|
667 |
+
# Check constraints.
|
668 |
+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
669 |
+
|
670 |
+
if use_grouped_topk:
|
671 |
+
assert num_expert_group is not None and topk_group is not None
|
672 |
+
topk_weights, topk_ids = grouped_topk(
|
673 |
+
hidden_states,
|
674 |
+
gating_output,
|
675 |
+
topk,
|
676 |
+
renormalize,
|
677 |
+
num_expert_group,
|
678 |
+
topk_group,
|
679 |
+
)
|
680 |
+
elif custom_routing_function is None:
|
681 |
+
topk_weights, topk_ids = fused_topk(
|
682 |
+
hidden_states, gating_output, topk, renormalize
|
683 |
+
)
|
684 |
+
else:
|
685 |
+
topk_weights, topk_ids = custom_routing_function(
|
686 |
+
hidden_states, gating_output, topk, renormalize
|
687 |
+
)
|
688 |
+
|
689 |
+
return fused_experts(
|
690 |
+
hidden_states,
|
691 |
+
w1,
|
692 |
+
w2,
|
693 |
+
topk_weights,
|
694 |
+
topk_ids,
|
695 |
+
inplace=inplace,
|
696 |
+
override_config=override_config,
|
697 |
+
use_fp8_w8a8=use_fp8_w8a8,
|
698 |
+
use_int8_w8a16=use_int8_w8a16,
|
699 |
+
w1_scale=w1_scale,
|
700 |
+
w2_scale=w2_scale,
|
701 |
+
a1_scale=a1_scale,
|
702 |
+
a2_scale=a2_scale,
|
703 |
+
)
|
ext-torch/platforms.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, ParamSpec, TypeVar
|
2 |
+
import os
|
3 |
+
from functools import lru_cache, wraps
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
IS_ROCM = torch.version.hip is not None
|
8 |
+
|
9 |
+
class CudaPlatform:
|
10 |
+
@classmethod
|
11 |
+
@lru_cache(maxsize=8)
|
12 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
13 |
+
return torch.cuda.get_device_name(0)
|
14 |
+
|
15 |
+
class RocmPlatform:
|
16 |
+
@classmethod
|
17 |
+
@lru_cache(maxsize=8)
|
18 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
19 |
+
return torch.cuda.get_device_name(device_id)
|
20 |
+
|
21 |
+
|
22 |
+
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
|
ext-torch/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
ext-torch/scalar_type.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import struct
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
# Mirrors enum in `core/scalar_type.hpp`
|
9 |
+
class NanRepr(Enum):
|
10 |
+
NONE = 0 # nans are not supported
|
11 |
+
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
12 |
+
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
13 |
+
|
14 |
+
|
15 |
+
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
16 |
+
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
17 |
+
# in sync until the inductor fully supports custom C++ classes.
|
18 |
+
@dataclass(frozen=True)
|
19 |
+
class ScalarType:
|
20 |
+
"""
|
21 |
+
ScalarType can represent a wide range of floating point and integer
|
22 |
+
types, in particular it can be used to represent sub-byte data types
|
23 |
+
(something that torch.dtype currently does not support). It is also
|
24 |
+
capable of representing types with a bias, i.e.:
|
25 |
+
`stored_value = value + bias`,
|
26 |
+
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
27 |
+
of 8). The implementation for this class can be found in
|
28 |
+
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
29 |
+
with that file.
|
30 |
+
"""
|
31 |
+
|
32 |
+
exponent: int
|
33 |
+
"""
|
34 |
+
Number of bits in the exponent if this is a floating point type
|
35 |
+
(zero if this an integer type)
|
36 |
+
"""
|
37 |
+
|
38 |
+
mantissa: int
|
39 |
+
"""
|
40 |
+
Number of bits in the mantissa if this is a floating point type,
|
41 |
+
or the number bits representing an integer excluding the sign bit if
|
42 |
+
this an integer type.
|
43 |
+
"""
|
44 |
+
|
45 |
+
signed: bool
|
46 |
+
"If the type is signed (i.e. has a sign bit)"
|
47 |
+
|
48 |
+
bias: int
|
49 |
+
"""
|
50 |
+
bias used to encode the values in this scalar type
|
51 |
+
(value = stored_value - bias, default 0) for example if we store the
|
52 |
+
type as an unsigned integer with a bias of 128 then the value 0 will be
|
53 |
+
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
54 |
+
"""
|
55 |
+
|
56 |
+
_finite_values_only: bool = False
|
57 |
+
"""
|
58 |
+
Private: if infs are supported, used `has_infs()` instead.
|
59 |
+
"""
|
60 |
+
|
61 |
+
nan_repr: NanRepr = NanRepr.IEEE_754
|
62 |
+
"""
|
63 |
+
How NaNs are represent in this scalar type, returns NanRepr value.
|
64 |
+
(not applicable for integer types)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def _floating_point_max_int(self) -> int:
|
68 |
+
assert (
|
69 |
+
self.mantissa <= 52 and self.exponent <= 11
|
70 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
71 |
+
|
72 |
+
max_mantissa = (1 << self.mantissa) - 1
|
73 |
+
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
74 |
+
max_mantissa = max_mantissa - 1
|
75 |
+
|
76 |
+
max_exponent = (1 << self.exponent) - 2
|
77 |
+
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
78 |
+
or self.nan_repr == NanRepr.NONE):
|
79 |
+
assert (
|
80 |
+
self.exponent < 11
|
81 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
82 |
+
max_exponent = max_exponent + 1
|
83 |
+
|
84 |
+
# adjust the exponent to match that of a double
|
85 |
+
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
86 |
+
# e is the exponent bits), there is some precedent for non-standard
|
87 |
+
# biases, example `float8_e4m3b11fnuz` here:
|
88 |
+
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
89 |
+
# complication we are just assuming the standard exponent bias until
|
90 |
+
# there is a need to support non-standard biases
|
91 |
+
exponent_bias = (1 << (self.exponent - 1)) - 1
|
92 |
+
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
93 |
+
|
94 |
+
max_exponent_double = (max_exponent - exponent_bias +
|
95 |
+
exponent_bias_double)
|
96 |
+
|
97 |
+
# shift the mantissa and exponent into the proper positions for an
|
98 |
+
# IEEE double and bitwise-or them together.
|
99 |
+
return (max_mantissa <<
|
100 |
+
(52 - self.mantissa)) | (max_exponent_double << 52)
|
101 |
+
|
102 |
+
def _floating_point_max(self) -> float:
|
103 |
+
double_raw = self._floating_point_max_int()
|
104 |
+
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
105 |
+
|
106 |
+
def _raw_max(self) -> Union[int, float]:
|
107 |
+
if self.is_floating_point():
|
108 |
+
return self._floating_point_max()
|
109 |
+
else:
|
110 |
+
assert (self.size_bits < 64 or self.size_bits == 64
|
111 |
+
and self.is_signed()), "Cannot represent max as an int"
|
112 |
+
return (1 << self.mantissa) - 1
|
113 |
+
|
114 |
+
def _raw_min(self) -> Union[int, float]:
|
115 |
+
if self.is_floating_point():
|
116 |
+
assert self.is_signed(
|
117 |
+
), "We currently assume all floating point types are signed"
|
118 |
+
sign_bit_double = 1 << 63
|
119 |
+
|
120 |
+
max_raw = self._floating_point_max_int()
|
121 |
+
min_raw = max_raw | sign_bit_double
|
122 |
+
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
123 |
+
else:
|
124 |
+
assert (not self.is_signed() or
|
125 |
+
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
126 |
+
|
127 |
+
if self.is_signed():
|
128 |
+
return -(1 << (self.size_bits - 1))
|
129 |
+
else:
|
130 |
+
return 0
|
131 |
+
|
132 |
+
@functools.cached_property
|
133 |
+
def id(self) -> int:
|
134 |
+
"""
|
135 |
+
Convert the ScalarType to an int which can be passed to pytorch custom
|
136 |
+
ops. This layout of the int must be kept in sync with the C++
|
137 |
+
ScalarType's from_id method.
|
138 |
+
"""
|
139 |
+
val = 0
|
140 |
+
offset = 0
|
141 |
+
|
142 |
+
def or_and_advance(member, bit_width):
|
143 |
+
nonlocal val
|
144 |
+
nonlocal offset
|
145 |
+
bit_mask = (1 << bit_width) - 1
|
146 |
+
val = val | (int(member) & bit_mask) << offset
|
147 |
+
offset = offset + bit_width
|
148 |
+
|
149 |
+
or_and_advance(self.exponent, 8)
|
150 |
+
or_and_advance(self.mantissa, 8)
|
151 |
+
or_and_advance(self.signed, 1)
|
152 |
+
or_and_advance(self.bias, 32)
|
153 |
+
or_and_advance(self._finite_values_only, 1)
|
154 |
+
or_and_advance(self.nan_repr.value, 8)
|
155 |
+
|
156 |
+
assert offset <= 64, \
|
157 |
+
f"ScalarType fields too big {offset} to fit into an int64"
|
158 |
+
|
159 |
+
return val
|
160 |
+
|
161 |
+
@property
|
162 |
+
def size_bits(self) -> int:
|
163 |
+
return self.exponent + self.mantissa + int(self.signed)
|
164 |
+
|
165 |
+
def min(self) -> Union[int, float]:
|
166 |
+
"""
|
167 |
+
Min representable value for this scalar type.
|
168 |
+
(accounting for bias if there is one)
|
169 |
+
"""
|
170 |
+
return self._raw_min() - self.bias
|
171 |
+
|
172 |
+
def max(self) -> Union[int, float]:
|
173 |
+
"""
|
174 |
+
Max representable value for this scalar type.
|
175 |
+
(accounting for bias if there is one)
|
176 |
+
"""
|
177 |
+
return self._raw_max() - self.bias
|
178 |
+
|
179 |
+
def is_signed(self) -> bool:
|
180 |
+
"""
|
181 |
+
If the type is signed (i.e. has a sign bit), same as `signed`
|
182 |
+
added for consistency with:
|
183 |
+
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
184 |
+
"""
|
185 |
+
return self.signed
|
186 |
+
|
187 |
+
def is_floating_point(self) -> bool:
|
188 |
+
"If the type is a floating point type"
|
189 |
+
return self.exponent != 0
|
190 |
+
|
191 |
+
def is_integer(self) -> bool:
|
192 |
+
"If the type is an integer type"
|
193 |
+
return self.exponent == 0
|
194 |
+
|
195 |
+
def has_bias(self) -> bool:
|
196 |
+
"If the type has a non-zero bias"
|
197 |
+
return self.bias != 0
|
198 |
+
|
199 |
+
def has_infs(self) -> bool:
|
200 |
+
"If the type is floating point and supports infinity"
|
201 |
+
return not self._finite_values_only
|
202 |
+
|
203 |
+
def has_nans(self) -> bool:
|
204 |
+
return self.nan_repr != NanRepr.NONE.value
|
205 |
+
|
206 |
+
def is_ieee_754(self) -> bool:
|
207 |
+
"""
|
208 |
+
If the type is a floating point type that follows IEEE 754
|
209 |
+
conventions
|
210 |
+
"""
|
211 |
+
return self.nan_repr == NanRepr.IEEE_754.value and \
|
212 |
+
not self._finite_values_only
|
213 |
+
|
214 |
+
def __str__(self) -> str:
|
215 |
+
"""
|
216 |
+
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
217 |
+
for floating point types (leading f) the scheme is:
|
218 |
+
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
219 |
+
flags:
|
220 |
+
- no-flags: means it follows IEEE 754 conventions
|
221 |
+
- f: means finite values only (no infinities)
|
222 |
+
- n: means nans are supported (non-standard encoding)
|
223 |
+
for integer types the scheme is:
|
224 |
+
`[u]int<size_bits>[b<bias>]`
|
225 |
+
- if bias is not present it means its zero
|
226 |
+
"""
|
227 |
+
if self.is_floating_point():
|
228 |
+
ret = "float" + str(self.size_bits) + "_e" + str(
|
229 |
+
self.exponent) + "m" + str(self.mantissa)
|
230 |
+
|
231 |
+
if not self.is_ieee_754():
|
232 |
+
if self._finite_values_only:
|
233 |
+
ret = ret + "f"
|
234 |
+
if self.nan_repr != NanRepr.NONE:
|
235 |
+
ret = ret + "n"
|
236 |
+
|
237 |
+
return ret
|
238 |
+
else:
|
239 |
+
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
240 |
+
if self.has_bias():
|
241 |
+
ret = ret + "b" + str(self.bias)
|
242 |
+
return ret
|
243 |
+
|
244 |
+
def __repr__(self) -> str:
|
245 |
+
return "ScalarType." + self.__str__()
|
246 |
+
|
247 |
+
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
248 |
+
# opcheck to work.
|
249 |
+
def __len__(self) -> int:
|
250 |
+
raise TypeError
|
251 |
+
|
252 |
+
#
|
253 |
+
# Convenience Constructors
|
254 |
+
#
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
258 |
+
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
259 |
+
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
260 |
+
ret.id # noqa B018: make sure the id is cached
|
261 |
+
return ret
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
265 |
+
"""Create a unsigned integer scalar type."""
|
266 |
+
ret = cls(0, size_bits, False, bias if bias else 0)
|
267 |
+
ret.id # noqa B018: make sure the id is cached
|
268 |
+
return ret
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
272 |
+
"""
|
273 |
+
Create a standard floating point type
|
274 |
+
(i.e. follows IEEE 754 conventions).
|
275 |
+
"""
|
276 |
+
assert (mantissa > 0 and exponent > 0)
|
277 |
+
ret = cls(exponent, mantissa, True, 0)
|
278 |
+
ret.id # noqa B018: make sure the id is cached
|
279 |
+
return ret
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
283 |
+
nan_repr: NanRepr) -> 'ScalarType':
|
284 |
+
"""
|
285 |
+
Create a non-standard floating point type
|
286 |
+
(i.e. does not follow IEEE 754 conventions).
|
287 |
+
"""
|
288 |
+
assert (mantissa > 0 and exponent > 0)
|
289 |
+
assert (nan_repr != NanRepr.IEEE_754), (
|
290 |
+
"use `float_IEEE754` constructor for floating point types that "
|
291 |
+
"follow IEEE 754 conventions")
|
292 |
+
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
293 |
+
ret.id # noqa B018: make sure the id is cached
|
294 |
+
return ret
|
295 |
+
|
296 |
+
|
297 |
+
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
298 |
+
# for floating point types (leading f) the scheme is:
|
299 |
+
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
300 |
+
# flags:
|
301 |
+
# - no-flags: means it follows IEEE 754 conventions
|
302 |
+
# - f: means finite values only (no infinities)
|
303 |
+
# - n: means nans are supported (non-standard encoding)
|
304 |
+
# for integer types the scheme is:
|
305 |
+
# `[u]int<size_bits>[b<bias>]`
|
306 |
+
# - if bias is not present it means its zero
|
307 |
+
|
308 |
+
|
309 |
+
class scalar_types:
|
310 |
+
int4 = ScalarType.int_(4, None)
|
311 |
+
uint4 = ScalarType.uint(4, None)
|
312 |
+
int8 = ScalarType.int_(8, None)
|
313 |
+
uint8 = ScalarType.uint(8, None)
|
314 |
+
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
315 |
+
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
316 |
+
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
317 |
+
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
318 |
+
|
319 |
+
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
320 |
+
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
321 |
+
|
322 |
+
# "gptq" types
|
323 |
+
uint2b2 = ScalarType.uint(2, 2)
|
324 |
+
uint3b4 = ScalarType.uint(3, 4)
|
325 |
+
uint4b8 = ScalarType.uint(4, 8)
|
326 |
+
uint8b128 = ScalarType.uint(8, 128)
|
327 |
+
|
328 |
+
# colloquial names
|
329 |
+
bfloat16 = float16_e8m7
|
330 |
+
float16 = float16_e5m10
|
ext-torch/torch_binding.cpp
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
// Activation used in fused MoE layers.
|
8 |
+
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
9 |
+
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
10 |
+
|
11 |
+
// Apply topk softmax to the gating outputs.
|
12 |
+
ops.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
13 |
+
"token_expert_indices, Tensor gating_output) -> ()");
|
14 |
+
ops.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
15 |
+
|
16 |
+
// Calculate the result of moe by summing up the partial results
|
17 |
+
// from all selected experts.
|
18 |
+
ops.def("moe_sum(Tensor! input, Tensor output) -> ()");
|
19 |
+
ops.impl("moe_sum", torch::kCUDA, &moe_sum);
|
20 |
+
|
21 |
+
// Aligning the number of tokens to be processed by each expert such
|
22 |
+
// that it is divisible by the block size.
|
23 |
+
ops.def("moe_align_block_size(Tensor topk_ids, int num_experts,"
|
24 |
+
" int block_size, Tensor! sorted_token_ids,"
|
25 |
+
" Tensor! experts_ids,"
|
26 |
+
" Tensor! num_tokens_post_pad) -> ()");
|
27 |
+
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
28 |
+
|
29 |
+
#ifndef USE_ROCM
|
30 |
+
ops.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
31 |
+
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
32 |
+
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
33 |
+
"int b_q_type, SymInt size_m, "
|
34 |
+
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
35 |
+
"topk, "
|
36 |
+
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
37 |
+
" -> Tensor");
|
38 |
+
#endif
|
39 |
+
}
|
40 |
+
|
41 |
+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
|
42 |
+
ops.impl("marlin_gemm_moe", &marlin_gemm_moe);
|
43 |
+
}
|
44 |
+
|
45 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
#include <core/scalar_type.hpp>
|
6 |
+
|
7 |
+
void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
|
8 |
+
|
9 |
+
void topk_softmax(torch::Tensor &topk_weights, torch::Tensor &topk_indices,
|
10 |
+
torch::Tensor &token_expert_indices,
|
11 |
+
torch::Tensor &gating_output);
|
12 |
+
|
13 |
+
void moe_sum(torch::Tensor &input, torch::Tensor &output);
|
14 |
+
|
15 |
+
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
16 |
+
int64_t block_size, torch::Tensor sorted_token_ids,
|
17 |
+
torch::Tensor experts_ids,
|
18 |
+
torch::Tensor num_tokens_post_pad);
|
19 |
+
|
20 |
+
#ifndef USE_ROCM
|
21 |
+
torch::Tensor marlin_gemm_moe(
|
22 |
+
const torch::Tensor &a, const torch::Tensor &b_q_weights,
|
23 |
+
const torch::Tensor &sorted_ids, const torch::Tensor &topk_weights,
|
24 |
+
const torch::Tensor &topk_ids, const torch::Tensor &b_scales,
|
25 |
+
torch::Tensor &b_zeros, const torch::Tensor &g_idx,
|
26 |
+
const torch::Tensor &perm, torch::Tensor &workspace,
|
27 |
+
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
|
28 |
+
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
|
29 |
+
int64_t moe_block_size, bool replicate_input, bool apply_weights);
|
30 |
+
#endif
|
ext-torch/utils/__init__.py
ADDED
File without changes
|
ext-torch/utils/marlin_utils.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from moe.scalar_type import ScalarType, scalar_types
|
7 |
+
|
8 |
+
from .quant_utils import pack_cols, unpack_cols
|
9 |
+
|
10 |
+
GPTQ_MARLIN_TILE = 16
|
11 |
+
GPTQ_MARLIN_MIN_THREAD_N = 64
|
12 |
+
GPTQ_MARLIN_MIN_THREAD_K = 128
|
13 |
+
GPTQ_MARLIN_MAX_PARALLEL = 16
|
14 |
+
|
15 |
+
GPTQ_MARLIN_24_TILE = 16
|
16 |
+
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
17 |
+
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
18 |
+
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
19 |
+
|
20 |
+
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
21 |
+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
22 |
+
|
23 |
+
MARLIN_QQQ_TILE = 16
|
24 |
+
MARLIN_QQQ_MIN_THREAD_N = 64
|
25 |
+
MARLIN_QQQ_MIN_THREAD_K = 128
|
26 |
+
MARLIN_QQQ_MAX_PARALLEL = 16
|
27 |
+
|
28 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
29 |
+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
30 |
+
MARLIN_QQQ_SUPPORTED_SYM = [True]
|
31 |
+
|
32 |
+
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
33 |
+
|
34 |
+
# In case there is a performance issue with Marlin, the variable below can be
|
35 |
+
# changed to False, which allows Marlin to perform global reductions in fp16
|
36 |
+
# precision (instead of fp32), and therefore, save on some memory movements.
|
37 |
+
USE_FP32_REDUCE_DEFAULT = True
|
38 |
+
|
39 |
+
|
40 |
+
# For binary size and compile time, we don't support the same types for with and
|
41 |
+
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
42 |
+
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
43 |
+
def query_marlin_supported_quant_types(
|
44 |
+
has_zp: bool, device_capability: Optional[int] = None
|
45 |
+
):
|
46 |
+
if device_capability is None:
|
47 |
+
capability_tuple = torch.cuda.get_device_capability()
|
48 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
49 |
+
|
50 |
+
if device_capability < 80:
|
51 |
+
return []
|
52 |
+
|
53 |
+
if has_zp:
|
54 |
+
# AWQ style, unsigned + runtime zero-point
|
55 |
+
return [scalar_types.uint4, scalar_types.uint8]
|
56 |
+
else:
|
57 |
+
# GPTQ style, unsigned + symmetric bias
|
58 |
+
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
59 |
+
# to add `scalar_types.float8_e4m3fn` here
|
60 |
+
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
61 |
+
|
62 |
+
|
63 |
+
def _check_marlin_supported(
|
64 |
+
quant_type: ScalarType,
|
65 |
+
group_size: Optional[int],
|
66 |
+
has_zp: bool,
|
67 |
+
device_capability: Optional[int] = None,
|
68 |
+
) -> Tuple[bool, Optional[str]]:
|
69 |
+
|
70 |
+
if device_capability is None:
|
71 |
+
capability_tuple = torch.cuda.get_device_capability()
|
72 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
73 |
+
|
74 |
+
supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
|
75 |
+
|
76 |
+
if quant_type not in supported_types:
|
77 |
+
return (
|
78 |
+
False,
|
79 |
+
f"Marlin does not support weight_bits = {quant_type}. "
|
80 |
+
f"Only types = {supported_types} "
|
81 |
+
f"are supported (for group_size = {group_size}, "
|
82 |
+
f"device_capability = {device_capability}, zp = {has_zp}).",
|
83 |
+
)
|
84 |
+
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
85 |
+
return (
|
86 |
+
False,
|
87 |
+
f"Marlin does not support group_size = {group_size}. "
|
88 |
+
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
89 |
+
"are supported.",
|
90 |
+
)
|
91 |
+
|
92 |
+
return True, None
|
93 |
+
|
94 |
+
|
95 |
+
def check_marlin_supported(
|
96 |
+
quant_type: ScalarType,
|
97 |
+
group_size: int,
|
98 |
+
has_zp: bool = False,
|
99 |
+
device_capability: Optional[int] = None,
|
100 |
+
) -> bool:
|
101 |
+
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
102 |
+
return cond
|
103 |
+
|
104 |
+
|
105 |
+
def verify_marlin_supported(
|
106 |
+
quant_type: ScalarType, group_size: int, has_zp: bool = False
|
107 |
+
) -> None:
|
108 |
+
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
109 |
+
if not cond:
|
110 |
+
assert err_msg is not None
|
111 |
+
raise ValueError(err_msg)
|
112 |
+
|
113 |
+
|
114 |
+
def verify_marlin_supports_shape(
|
115 |
+
output_size_per_partition: int,
|
116 |
+
input_size_per_partition: int,
|
117 |
+
input_size: int,
|
118 |
+
group_size: int,
|
119 |
+
) -> None:
|
120 |
+
|
121 |
+
# Validate output_size_per_partition
|
122 |
+
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
123 |
+
raise ValueError(
|
124 |
+
f"Weight output_size_per_partition = "
|
125 |
+
f"{output_size_per_partition} is not divisible by "
|
126 |
+
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
127 |
+
"Consider reducing tensor_parallel_size or running "
|
128 |
+
"with --quantization gptq."
|
129 |
+
)
|
130 |
+
|
131 |
+
# Validate input_size_per_partition
|
132 |
+
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
133 |
+
raise ValueError(
|
134 |
+
f"Weight input_size_per_partition = "
|
135 |
+
f"{input_size_per_partition} is not divisible "
|
136 |
+
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
137 |
+
"Consider reducing tensor_parallel_size or running "
|
138 |
+
"with --quantization gptq."
|
139 |
+
)
|
140 |
+
|
141 |
+
if group_size < input_size and input_size_per_partition % group_size != 0:
|
142 |
+
raise ValueError(
|
143 |
+
f"Weight input_size_per_partition = {input_size_per_partition}"
|
144 |
+
f" is not divisible by group_size = {group_size}."
|
145 |
+
"Consider reducing tensor_parallel_size or running "
|
146 |
+
"with --quantization gptq."
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
def check_marlin_supports_shape(
|
151 |
+
output_size_per_partition: int,
|
152 |
+
input_size_per_partition: int,
|
153 |
+
input_size: int,
|
154 |
+
group_size: int,
|
155 |
+
) -> Tuple[bool, Optional[str]]:
|
156 |
+
try:
|
157 |
+
verify_marlin_supports_shape(
|
158 |
+
output_size_per_partition, input_size_per_partition, input_size, group_size
|
159 |
+
)
|
160 |
+
except ValueError as e:
|
161 |
+
return False, e.__str__()
|
162 |
+
return True, None
|
163 |
+
|
164 |
+
|
165 |
+
def marlin_make_workspace(
|
166 |
+
output_size_per_partition: int, device: torch.device
|
167 |
+
) -> torch.Tensor:
|
168 |
+
max_workspace_size = (
|
169 |
+
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
170 |
+
) * GPTQ_MARLIN_MAX_PARALLEL
|
171 |
+
|
172 |
+
return torch.zeros(
|
173 |
+
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
178 |
+
return (not act_order) or (act_order and not is_row_parallel)
|
179 |
+
|
180 |
+
|
181 |
+
def marlin_repeat_scales_on_all_ranks(
|
182 |
+
act_order: bool, group_size: int, is_row_parallel: bool
|
183 |
+
) -> bool:
|
184 |
+
# Need to repeat scales on every rank if act_ordering or
|
185 |
+
# channelwise and RowParallelLinear
|
186 |
+
is_channelwise = group_size == -1
|
187 |
+
return act_order or (is_channelwise and is_row_parallel)
|
188 |
+
|
189 |
+
|
190 |
+
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
191 |
+
return torch.nn.Parameter(
|
192 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
197 |
+
return torch.nn.Parameter(
|
198 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
203 |
+
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
204 |
+
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
205 |
+
|
206 |
+
|
207 |
+
def get_scale_perms():
|
208 |
+
scale_perm: List[int] = []
|
209 |
+
for i in range(8):
|
210 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
211 |
+
scale_perm_single: List[int] = []
|
212 |
+
for i in range(4):
|
213 |
+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
214 |
+
return scale_perm, scale_perm_single
|
215 |
+
|
216 |
+
|
217 |
+
def marlin_permute_scales(
|
218 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
219 |
+
) -> torch.Tensor:
|
220 |
+
|
221 |
+
scale_perm, scale_perm_single = get_scale_perms()
|
222 |
+
if group_size < size_k and group_size != -1:
|
223 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
224 |
+
else:
|
225 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
226 |
+
s = s.reshape((-1, size_n)).contiguous()
|
227 |
+
|
228 |
+
return s
|
229 |
+
|
230 |
+
|
231 |
+
def marlin_moe_permute_scales(
|
232 |
+
s: torch.Tensor,
|
233 |
+
size_k: int,
|
234 |
+
size_n: int,
|
235 |
+
group_size: int,
|
236 |
+
):
|
237 |
+
num_experts = s.shape[0]
|
238 |
+
output = torch.empty(
|
239 |
+
(num_experts, s.shape[1], s.shape[2]),
|
240 |
+
device=s.device,
|
241 |
+
dtype=s.dtype,
|
242 |
+
)
|
243 |
+
|
244 |
+
for e in range(num_experts):
|
245 |
+
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
246 |
+
return output
|
247 |
+
|
248 |
+
|
249 |
+
def marlin_zero_points(
|
250 |
+
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
251 |
+
) -> torch.Tensor:
|
252 |
+
# Permute zero-points in a similar way to scales, but do not use the
|
253 |
+
# "single" permutation, since zero-points are applied on every MMA
|
254 |
+
scale_perm, _ = get_scale_perms()
|
255 |
+
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
256 |
+
|
257 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
258 |
+
if num_bits == 4:
|
259 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
260 |
+
elif num_bits == 8:
|
261 |
+
interleave = numpy.array([0, 2, 1, 3])
|
262 |
+
else:
|
263 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
264 |
+
|
265 |
+
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
266 |
+
zp = zp.reshape((-1, size_n)).contiguous()
|
267 |
+
zp = pack_cols(zp, num_bits, size_k, size_n)
|
268 |
+
|
269 |
+
return zp
|
270 |
+
|
271 |
+
|
272 |
+
def awq_to_marlin_zero_points(
|
273 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
274 |
+
) -> torch.Tensor:
|
275 |
+
# AWQ zero-points are quantized and packed on the column dim.
|
276 |
+
# In addition, the values are permuted based on dequantizer.
|
277 |
+
# Here we undo both of these, and then apply marlin permutation
|
278 |
+
# and pack it back.
|
279 |
+
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
280 |
+
|
281 |
+
# Undo interleaving (use argsort(..) to get inverse perm)
|
282 |
+
if num_bits == 4:
|
283 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
284 |
+
elif num_bits == 8:
|
285 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
286 |
+
else:
|
287 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
288 |
+
|
289 |
+
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
290 |
+
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
291 |
+
|
292 |
+
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
293 |
+
return marlin_zp
|
294 |
+
|
295 |
+
|
296 |
+
def moe_awq_to_marlin_zero_points(
|
297 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
298 |
+
):
|
299 |
+
num_experts = q_zp_packed.shape[0]
|
300 |
+
output = torch.empty(
|
301 |
+
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
302 |
+
device=q_zp_packed.device,
|
303 |
+
dtype=q_zp_packed.dtype,
|
304 |
+
)
|
305 |
+
for e in range(num_experts):
|
306 |
+
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
307 |
+
return output
|
ext-torch/utils/marlin_utils_test.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from moe.scalar_type import ScalarType
|
9 |
+
|
10 |
+
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
11 |
+
from .quant_utils import (
|
12 |
+
get_pack_factor,
|
13 |
+
gptq_quantize_weights,
|
14 |
+
quantize_weights,
|
15 |
+
sort_weights,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class MarlinWorkspace:
|
20 |
+
|
21 |
+
def __init__(self, out_features, min_thread_n, max_parallel):
|
22 |
+
assert (
|
23 |
+
out_features % min_thread_n == 0
|
24 |
+
), "out_features = {} is undivisible by min_thread_n = {}".format(
|
25 |
+
out_features, min_thread_n
|
26 |
+
)
|
27 |
+
|
28 |
+
max_workspace_size = (out_features // min_thread_n) * max_parallel
|
29 |
+
|
30 |
+
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
31 |
+
|
32 |
+
|
33 |
+
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
34 |
+
assert q_w.shape == (size_k, size_n)
|
35 |
+
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
36 |
+
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
37 |
+
|
38 |
+
# Permute weights to 16x64 marlin tiles
|
39 |
+
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
40 |
+
q_w = q_w.permute((0, 2, 1, 3))
|
41 |
+
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
42 |
+
|
43 |
+
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
44 |
+
|
45 |
+
return q_w
|
46 |
+
|
47 |
+
|
48 |
+
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
49 |
+
# Permute
|
50 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
51 |
+
|
52 |
+
# Pack
|
53 |
+
pack_factor = get_pack_factor(num_bits)
|
54 |
+
orig_device = q_w.device
|
55 |
+
|
56 |
+
q_w = q_w.cpu().numpy().astype(np.uint32)
|
57 |
+
|
58 |
+
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
59 |
+
for i in range(pack_factor):
|
60 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
61 |
+
|
62 |
+
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
63 |
+
|
64 |
+
return q_packed
|
65 |
+
|
66 |
+
|
67 |
+
def get_weight_perm(num_bits: int):
|
68 |
+
perm_list: List[int] = []
|
69 |
+
for i in range(32):
|
70 |
+
perm1: List[int] = []
|
71 |
+
col = i // 4
|
72 |
+
for block in [0, 1]:
|
73 |
+
for row in [
|
74 |
+
2 * (i % 4),
|
75 |
+
2 * (i % 4) + 1,
|
76 |
+
2 * (i % 4 + 4),
|
77 |
+
2 * (i % 4 + 4) + 1,
|
78 |
+
]:
|
79 |
+
perm1.append(16 * row + col + 8 * block)
|
80 |
+
for j in range(4):
|
81 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
82 |
+
|
83 |
+
perm = np.array(perm_list)
|
84 |
+
|
85 |
+
if num_bits == 4:
|
86 |
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
87 |
+
elif num_bits == 8:
|
88 |
+
interleave = np.array([0, 2, 1, 3])
|
89 |
+
else:
|
90 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
91 |
+
|
92 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
93 |
+
perm = torch.from_numpy(perm)
|
94 |
+
return perm
|
95 |
+
|
96 |
+
|
97 |
+
def marlin_quantize(
|
98 |
+
w: torch.Tensor,
|
99 |
+
quant_type: ScalarType,
|
100 |
+
group_size: int,
|
101 |
+
act_order: bool,
|
102 |
+
test_perm: Optional[torch.Tensor] = None,
|
103 |
+
):
|
104 |
+
size_k, size_n = w.shape
|
105 |
+
num_bits = quant_type.size_bits
|
106 |
+
|
107 |
+
# Normalize group_size
|
108 |
+
if group_size == -1:
|
109 |
+
group_size = size_k
|
110 |
+
assert group_size <= size_k
|
111 |
+
|
112 |
+
# Quantize (and apply act_order if provided)
|
113 |
+
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
114 |
+
w, quant_type, group_size, act_order, test_perm
|
115 |
+
)
|
116 |
+
|
117 |
+
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
118 |
+
# increasing
|
119 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
120 |
+
if act_order:
|
121 |
+
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
122 |
+
|
123 |
+
# Reformat to marlin
|
124 |
+
weight_perm = get_weight_perm(num_bits)
|
125 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
126 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
127 |
+
|
128 |
+
# Create result
|
129 |
+
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
130 |
+
for i in range(len(res_list)):
|
131 |
+
res_list[i] = res_list[i].to(w.device)
|
132 |
+
|
133 |
+
return res_list
|
134 |
+
|
135 |
+
|
136 |
+
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
137 |
+
size_k, size_n = w.shape
|
138 |
+
|
139 |
+
# Normalize group_size
|
140 |
+
if group_size == -1:
|
141 |
+
group_size = size_k
|
142 |
+
assert group_size <= size_k
|
143 |
+
|
144 |
+
# Detect num groups
|
145 |
+
assert size_k % group_size == 0
|
146 |
+
num_groups = size_k // group_size
|
147 |
+
|
148 |
+
# Quantize with zp
|
149 |
+
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
150 |
+
|
151 |
+
# Reformat to marlin
|
152 |
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
153 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
154 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
155 |
+
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
156 |
+
|
157 |
+
# Create result
|
158 |
+
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
159 |
+
for i in range(len(res_list)):
|
160 |
+
res_list[i] = res_list[i].to(w.device)
|
161 |
+
|
162 |
+
return res_list
|
ext-torch/utils/quant_utils.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file is used for /tests and /benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from moe.scalar_type import ScalarType, scalar_types
|
9 |
+
|
10 |
+
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
11 |
+
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
12 |
+
|
13 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
14 |
+
|
15 |
+
# Note: this is a hack. We should update each model to register the
|
16 |
+
# stacked params and get it from there instead in a future PR.
|
17 |
+
# fused_name: List[shard_name]
|
18 |
+
FUSED_LAYER_NAME_MAPPING = {
|
19 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
20 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def pack_quantized_values_into_int32(
|
25 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
26 |
+
):
|
27 |
+
# move dim to pack to the end
|
28 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
29 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
30 |
+
w_q_perm = w_q.permute(perm)
|
31 |
+
|
32 |
+
pack_factor = 32 // wtype.size_bits
|
33 |
+
mask = (1 << wtype.size_bits) - 1
|
34 |
+
|
35 |
+
new_shape_perm = list(w_q_perm.shape)
|
36 |
+
assert w_q_perm.shape[-1] % pack_factor == 0
|
37 |
+
new_shape_perm[-1] //= pack_factor
|
38 |
+
|
39 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
40 |
+
for i in range(pack_factor):
|
41 |
+
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
|
42 |
+
|
43 |
+
return res.permute(inv_perm)
|
44 |
+
|
45 |
+
|
46 |
+
def unpack_quantized_values_into_int32(
|
47 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
48 |
+
):
|
49 |
+
# move dim to pack to the end
|
50 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
51 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
52 |
+
w_q_perm = w_q.permute(perm)
|
53 |
+
|
54 |
+
pack_factor = 32 // wtype.size_bits
|
55 |
+
mask = (1 << wtype.size_bits) - 1
|
56 |
+
|
57 |
+
new_shape_perm = list(w_q_perm.shape)
|
58 |
+
new_shape_perm[-1] *= pack_factor
|
59 |
+
|
60 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
61 |
+
for i in range(pack_factor):
|
62 |
+
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
|
63 |
+
|
64 |
+
return res.permute(inv_perm)
|
65 |
+
|
66 |
+
|
67 |
+
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
68 |
+
# prefix: model.layers.0.self_attn.q_proj
|
69 |
+
# proj_name: q_proj
|
70 |
+
proj_name = prefix.split(".")[-1]
|
71 |
+
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
72 |
+
shard_prefixes = [
|
73 |
+
prefix.replace(proj_name, shard_proj_name)
|
74 |
+
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
75 |
+
]
|
76 |
+
|
77 |
+
is_skipped = None
|
78 |
+
for shard_prefix in shard_prefixes:
|
79 |
+
is_shard_skipped = shard_prefix in ignored_layers
|
80 |
+
|
81 |
+
if is_skipped is None:
|
82 |
+
is_skipped = is_shard_skipped
|
83 |
+
elif is_shard_skipped != is_skipped:
|
84 |
+
raise ValueError(
|
85 |
+
f"Detected some but not all shards of {prefix} "
|
86 |
+
"are quantized. All shards of fused layers "
|
87 |
+
"to have the same precision."
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
is_skipped = prefix in ignored_layers
|
91 |
+
|
92 |
+
assert is_skipped is not None
|
93 |
+
return is_skipped
|
94 |
+
|
95 |
+
|
96 |
+
def get_pack_factor(num_bits):
|
97 |
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
98 |
+
return 32 // num_bits
|
99 |
+
|
100 |
+
|
101 |
+
def permute_rows(
|
102 |
+
q_w: torch.Tensor,
|
103 |
+
w_ref: torch.Tensor,
|
104 |
+
group_size: int,
|
105 |
+
test_perm: Optional[torch.Tensor] = None,
|
106 |
+
):
|
107 |
+
assert q_w.shape == w_ref.shape
|
108 |
+
|
109 |
+
orig_device = q_w.device
|
110 |
+
k_size, _ = q_w.shape
|
111 |
+
|
112 |
+
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
113 |
+
for i in range(k_size):
|
114 |
+
g_idx[i] = i // group_size
|
115 |
+
|
116 |
+
# Simulate act_order by doing a random permutation on K
|
117 |
+
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
118 |
+
|
119 |
+
g_idx = g_idx[rand_perm].contiguous()
|
120 |
+
q_w = q_w[rand_perm, :].contiguous()
|
121 |
+
w_ref = w_ref[rand_perm, :].contiguous()
|
122 |
+
|
123 |
+
return (
|
124 |
+
w_ref.to(device=orig_device),
|
125 |
+
q_w.to(device=orig_device),
|
126 |
+
g_idx.to(device=orig_device),
|
127 |
+
rand_perm.to(device=orig_device),
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def quantize_weights(
|
132 |
+
w: torch.Tensor,
|
133 |
+
quant_type: ScalarType,
|
134 |
+
group_size: Optional[int],
|
135 |
+
zero_points: bool = False,
|
136 |
+
ref_zero_points_after_scales: bool = False,
|
137 |
+
):
|
138 |
+
assert (
|
139 |
+
quant_type.is_integer()
|
140 |
+
), "Floating point quantization may work but has not been tested"
|
141 |
+
assert not zero_points or group_size is not None, (
|
142 |
+
"to have group zero points, group_size must be provided "
|
143 |
+
"(-1 group_size is channelwise)"
|
144 |
+
)
|
145 |
+
|
146 |
+
orig_device = w.device
|
147 |
+
orig_type = w.dtype
|
148 |
+
size_k, size_n = w.shape
|
149 |
+
|
150 |
+
assert w.is_floating_point(), "w must be float"
|
151 |
+
|
152 |
+
if group_size == -1:
|
153 |
+
group_size = size_k
|
154 |
+
|
155 |
+
# Reshape to [groupsize, -1]
|
156 |
+
if group_size is not None and group_size < size_k:
|
157 |
+
w = w.reshape((-1, group_size, size_n))
|
158 |
+
w = w.permute(1, 0, 2)
|
159 |
+
w = w.reshape((group_size, -1))
|
160 |
+
|
161 |
+
# Compute scale for each group
|
162 |
+
max_val = torch.max(w, 0, keepdim=True).values
|
163 |
+
min_val = torch.min(w, 0, keepdim=True).values
|
164 |
+
|
165 |
+
max_q_val = quant_type.max()
|
166 |
+
min_q_val = quant_type.min()
|
167 |
+
|
168 |
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
169 |
+
maybe_w_zp = None
|
170 |
+
if group_size is not None:
|
171 |
+
if zero_points:
|
172 |
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
173 |
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
174 |
+
maybe_w_zp = (
|
175 |
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
# If the bias is such that there are no possible negative/positive
|
179 |
+
# values, set the max value to inf to avoid divide by 0
|
180 |
+
w_s = torch.max(
|
181 |
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
182 |
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
183 |
+
)
|
184 |
+
|
185 |
+
# Quantize
|
186 |
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
187 |
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
188 |
+
|
189 |
+
# Compute ref (dequantized)
|
190 |
+
# For some kernels (namely Machete) the zero-points are applied after the
|
191 |
+
# scales are applied, for this case computing the reference in similar way
|
192 |
+
# allows us to use tighter error tolerances in our unit tests.
|
193 |
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
194 |
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
195 |
+
else:
|
196 |
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
197 |
+
|
198 |
+
if quant_type.has_bias():
|
199 |
+
w_q += quant_type.bias
|
200 |
+
|
201 |
+
# Restore original shapes
|
202 |
+
if group_size is not None and group_size < size_k:
|
203 |
+
|
204 |
+
def reshape_w(w):
|
205 |
+
w = w.reshape((group_size, -1, size_n))
|
206 |
+
w = w.permute(1, 0, 2)
|
207 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
208 |
+
return w
|
209 |
+
|
210 |
+
w_q = reshape_w(w_q)
|
211 |
+
w_ref = reshape_w(w_ref)
|
212 |
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
213 |
+
|
214 |
+
if maybe_w_zp is not None:
|
215 |
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
216 |
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
217 |
+
|
218 |
+
return (
|
219 |
+
w_ref.to(device=orig_device),
|
220 |
+
w_q.to(device=orig_device),
|
221 |
+
w_s if group_size is not None else None,
|
222 |
+
maybe_w_zp,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def gptq_quantize_weights(
|
227 |
+
w: torch.Tensor,
|
228 |
+
quant_type: ScalarType,
|
229 |
+
group_size: int,
|
230 |
+
act_order: bool,
|
231 |
+
test_perm: Optional[torch.Tensor] = None,
|
232 |
+
):
|
233 |
+
size_k, _ = w.shape
|
234 |
+
|
235 |
+
assert w.is_floating_point(), "w must be float"
|
236 |
+
assert (
|
237 |
+
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
|
238 |
+
), f"Unsupported gptq type = {quant_type}"
|
239 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
240 |
+
size_k
|
241 |
+
], f"Unsupported groupsize = {group_size}"
|
242 |
+
|
243 |
+
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
244 |
+
|
245 |
+
# Apply act_order
|
246 |
+
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
247 |
+
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
248 |
+
if act_order:
|
249 |
+
assert (
|
250 |
+
group_size < size_k
|
251 |
+
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
252 |
+
group_size, size_k
|
253 |
+
)
|
254 |
+
|
255 |
+
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
256 |
+
|
257 |
+
return w_ref, w_q, w_s, g_idx, rand_perm
|
258 |
+
|
259 |
+
|
260 |
+
# QQQ employs different quant schemes for per-group and
|
261 |
+
# per-channel quantization.
|
262 |
+
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
263 |
+
orig_device = w.device
|
264 |
+
size_k, size_n = w.shape
|
265 |
+
|
266 |
+
assert w.is_floating_point(), "w must be float"
|
267 |
+
assert (
|
268 |
+
num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
|
269 |
+
), f"Unsupported num_bits = {num_bits}"
|
270 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
271 |
+
size_k
|
272 |
+
], f"Unsupported groupsize = {group_size}"
|
273 |
+
|
274 |
+
if group_size == -1:
|
275 |
+
group_size = size_k
|
276 |
+
assert group_size <= size_k
|
277 |
+
|
278 |
+
if group_size < size_k:
|
279 |
+
# Reshape to [groupsize, -1]
|
280 |
+
w = w.reshape((-1, group_size, size_n))
|
281 |
+
w = w.permute(1, 0, 2)
|
282 |
+
w = w.reshape((group_size, -1))
|
283 |
+
|
284 |
+
max_q_val = 2**num_bits - 1
|
285 |
+
half_q_val = (max_q_val + 1) // 2
|
286 |
+
|
287 |
+
# Compute scale for each group
|
288 |
+
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
289 |
+
s_group *= 2 / max_q_val # 2 => symmetric
|
290 |
+
|
291 |
+
# Quantize
|
292 |
+
q_w = torch.round(w / s_group).int()
|
293 |
+
q_w += half_q_val
|
294 |
+
q_w = torch.clamp(q_w, 0, max_q_val)
|
295 |
+
# Compute ref (dequantized)
|
296 |
+
w_ref = (q_w - half_q_val).half() * s_group
|
297 |
+
|
298 |
+
# Restore original shapes
|
299 |
+
def reshape_w(w):
|
300 |
+
w = w.reshape((group_size, -1, size_n))
|
301 |
+
w = w.permute(1, 0, 2)
|
302 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
303 |
+
return w
|
304 |
+
|
305 |
+
q_w = reshape_w(q_w)
|
306 |
+
w_ref = reshape_w(w_ref)
|
307 |
+
|
308 |
+
# Compute int8 quantization scale for each channel
|
309 |
+
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
|
310 |
+
s_channel /= 127.0
|
311 |
+
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
|
312 |
+
w_ref = t_int8.half() * s_channel
|
313 |
+
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
|
314 |
+
|
315 |
+
# Fuse scales
|
316 |
+
s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
|
317 |
+
dtype=torch.half
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
max_q_val = 2 ** (num_bits - 1) - 1
|
321 |
+
|
322 |
+
# Compute scale for each channel
|
323 |
+
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
324 |
+
s_channel /= max_q_val
|
325 |
+
|
326 |
+
# Quantize
|
327 |
+
q_w = torch.round(w / s_channel).int()
|
328 |
+
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
|
329 |
+
# Compute ref (dequantized)
|
330 |
+
w_ref = q_w.half() * s_channel
|
331 |
+
|
332 |
+
s_group = torch.tensor([], dtype=torch.half)
|
333 |
+
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
|
334 |
+
s_channel /= 2 ** (8 - num_bits)
|
335 |
+
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
|
336 |
+
|
337 |
+
return (
|
338 |
+
w_ref.to(device=orig_device),
|
339 |
+
q_w.to(device=orig_device),
|
340 |
+
s_group.to(device=orig_device),
|
341 |
+
s_channel.to(device=orig_device),
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
346 |
+
orig_device = q_w.device
|
347 |
+
|
348 |
+
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
349 |
+
|
350 |
+
g_idx = g_idx[sort_indices].contiguous()
|
351 |
+
q_w = q_w[sort_indices, :].contiguous()
|
352 |
+
|
353 |
+
return (
|
354 |
+
q_w.to(device=orig_device),
|
355 |
+
g_idx.to(device=orig_device),
|
356 |
+
sort_indices.to(device=orig_device),
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
def pack_rows(
|
361 |
+
q_w: torch.Tensor,
|
362 |
+
num_bits: int,
|
363 |
+
size_k: int,
|
364 |
+
size_n: int,
|
365 |
+
):
|
366 |
+
assert q_w.shape == (size_k, size_n)
|
367 |
+
|
368 |
+
pack_factor = get_pack_factor(num_bits)
|
369 |
+
assert size_k % pack_factor == 0
|
370 |
+
|
371 |
+
orig_device = q_w.device
|
372 |
+
|
373 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
374 |
+
|
375 |
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
376 |
+
|
377 |
+
for i in range(pack_factor):
|
378 |
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
379 |
+
|
380 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
381 |
+
return q_res
|
382 |
+
|
383 |
+
|
384 |
+
def pack_cols(
|
385 |
+
q_w: torch.Tensor,
|
386 |
+
num_bits: int,
|
387 |
+
size_k: int,
|
388 |
+
size_n: int,
|
389 |
+
):
|
390 |
+
assert q_w.shape == (size_k, size_n)
|
391 |
+
|
392 |
+
pack_factor = get_pack_factor(num_bits)
|
393 |
+
assert size_n % pack_factor == 0
|
394 |
+
|
395 |
+
orig_device = q_w.device
|
396 |
+
|
397 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
398 |
+
|
399 |
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
400 |
+
|
401 |
+
for i in range(pack_factor):
|
402 |
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
403 |
+
|
404 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
405 |
+
q_res = q_res.contiguous()
|
406 |
+
|
407 |
+
return q_res
|
408 |
+
|
409 |
+
|
410 |
+
def unpack_cols(
|
411 |
+
packed_q_w: torch.Tensor,
|
412 |
+
num_bits: int,
|
413 |
+
size_k: int,
|
414 |
+
size_n: int,
|
415 |
+
):
|
416 |
+
pack_factor = get_pack_factor(num_bits)
|
417 |
+
assert size_n % pack_factor == 0
|
418 |
+
assert packed_q_w.shape == (
|
419 |
+
size_k,
|
420 |
+
size_n // pack_factor,
|
421 |
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
422 |
+
packed_q_w.shape, size_k, size_n, pack_factor
|
423 |
+
)
|
424 |
+
|
425 |
+
orig_device = packed_q_w.device
|
426 |
+
|
427 |
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
428 |
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
429 |
+
|
430 |
+
mask = (1 << num_bits) - 1
|
431 |
+
for i in range(pack_factor):
|
432 |
+
vals = packed_q_w_cpu & mask
|
433 |
+
packed_q_w_cpu >>= num_bits
|
434 |
+
q_res[:, i::pack_factor] = vals
|
435 |
+
|
436 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
437 |
+
q_res = q_res.contiguous()
|
438 |
+
|
439 |
+
return q_res
|
440 |
+
|
441 |
+
|
442 |
+
def gptq_pack(
|
443 |
+
q_w: torch.Tensor,
|
444 |
+
num_bits: int,
|
445 |
+
size_k: int,
|
446 |
+
size_n: int,
|
447 |
+
):
|
448 |
+
return pack_rows(q_w, num_bits, size_k, size_n)
|
449 |
+
|
450 |
+
|
451 |
+
def awq_pack(
|
452 |
+
q_w: torch.Tensor,
|
453 |
+
num_bits: int,
|
454 |
+
size_k: int,
|
455 |
+
size_n: int,
|
456 |
+
):
|
457 |
+
assert q_w.shape == (size_k, size_n)
|
458 |
+
|
459 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
460 |
+
if num_bits == 4:
|
461 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
462 |
+
elif num_bits == 8:
|
463 |
+
interleave = numpy.array([0, 2, 1, 3])
|
464 |
+
else:
|
465 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
466 |
+
|
467 |
+
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
468 |
+
q_w = q_w.reshape((-1, size_n)).contiguous()
|
469 |
+
|
470 |
+
return pack_cols(q_w, num_bits, size_k, size_n)
|
flake.nix
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for activation kernels";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "git+ssh://git@github.com/huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
14 |
+
}
|
marlin-moe/marlin_kernels/marlin_moe_kernel.h
ADDED
@@ -0,0 +1,1616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <c10/cuda/CUDAGuard.h>
|
7 |
+
#include <cuda.h>
|
8 |
+
#include <cuda_fp16.h>
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
#include <iostream>
|
12 |
+
|
13 |
+
#include "core/scalar_type.hpp"
|
14 |
+
|
15 |
+
namespace marlin_moe {
|
16 |
+
|
17 |
+
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
18 |
+
|
19 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
20 |
+
|
21 |
+
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
22 |
+
// for instance as inputs to tensor core operations. Consequently, all
|
23 |
+
// corresponding index accesses must be compile-time constants, which is why we
|
24 |
+
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
25 |
+
// this.
|
26 |
+
template <typename T, int n>
|
27 |
+
struct Vec {
|
28 |
+
T elems[n];
|
29 |
+
__device__ T& operator[](int i) { return elems[i]; }
|
30 |
+
};
|
31 |
+
|
32 |
+
using I4 = Vec<int, 4>;
|
33 |
+
|
34 |
+
// Matrix fragments for tensor core instructions; their precise layout is
|
35 |
+
// documented here:
|
36 |
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
37 |
+
using FragA = Vec<half2, 4>;
|
38 |
+
using FragB = Vec<half2, 2>;
|
39 |
+
using FragC = Vec<float, 4>;
|
40 |
+
using FragS = Vec<half2, 1>; // quantization scales
|
41 |
+
using FragZP = Vec<half2, 4>;
|
42 |
+
|
43 |
+
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
44 |
+
// predication to handle batchsizes that are not multiples of 16.
|
45 |
+
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
46 |
+
bool pred = true) {
|
47 |
+
const int BYTES = 16;
|
48 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
49 |
+
asm volatile(
|
50 |
+
"{\n"
|
51 |
+
" .reg .pred p;\n"
|
52 |
+
" setp.ne.b32 p, %0, 0;\n"
|
53 |
+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
54 |
+
"}\n" ::"r"((int)pred),
|
55 |
+
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
56 |
+
}
|
57 |
+
|
58 |
+
// Asynchronous global->shared copy
|
59 |
+
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
60 |
+
const int BYTES = 16;
|
61 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
62 |
+
asm volatile(
|
63 |
+
"{\n"
|
64 |
+
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
65 |
+
"}\n" ::"r"(smem),
|
66 |
+
"l"(glob_ptr), "n"(BYTES));
|
67 |
+
}
|
68 |
+
|
69 |
+
// Async copy fence.
|
70 |
+
__device__ inline void cp_async_fence() {
|
71 |
+
asm volatile("cp.async.commit_group;\n" ::);
|
72 |
+
}
|
73 |
+
|
74 |
+
// Wait until at most `n` async copy stages are still pending.
|
75 |
+
template <int n>
|
76 |
+
__device__ inline void cp_async_wait() {
|
77 |
+
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
78 |
+
}
|
79 |
+
|
80 |
+
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
81 |
+
// output/accumulation.
|
82 |
+
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
|
83 |
+
FragC& frag_c) {
|
84 |
+
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
85 |
+
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
86 |
+
float* c = reinterpret_cast<float*>(&frag_c);
|
87 |
+
asm volatile(
|
88 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
89 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
90 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
91 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
92 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
93 |
+
}
|
94 |
+
|
95 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
96 |
+
// memory, directly in tensor core layout.
|
97 |
+
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
98 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
99 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
100 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
101 |
+
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
102 |
+
: "r"(smem));
|
103 |
+
}
|
104 |
+
|
105 |
+
// Lookup-table based 3-input logical operation; explicitly used for
|
106 |
+
// dequantization as the compiler does not seem to automatically recognize it in
|
107 |
+
// all cases.
|
108 |
+
template <int lut>
|
109 |
+
__device__ inline int lop3(int a, int b, int c) {
|
110 |
+
int res;
|
111 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
112 |
+
: "=r"(res)
|
113 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
114 |
+
return res;
|
115 |
+
}
|
116 |
+
|
117 |
+
// Constructs destination register by taking bytes from 2 sources (based on
|
118 |
+
// mask)
|
119 |
+
template <int start_byte, int mask>
|
120 |
+
__device__ inline uint32_t prmt(uint32_t a) {
|
121 |
+
uint32_t res;
|
122 |
+
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
123 |
+
: "=r"(res)
|
124 |
+
: "r"(a), "n"(start_byte), "n"(mask));
|
125 |
+
return res;
|
126 |
+
}
|
127 |
+
|
128 |
+
template <vllm::ScalarTypeId w_type_id>
|
129 |
+
__device__ inline FragB dequant(int q);
|
130 |
+
|
131 |
+
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
132 |
+
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
133 |
+
// with some small changes:
|
134 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
135 |
+
template <>
|
136 |
+
__device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
|
137 |
+
const int LO = 0x000f000f;
|
138 |
+
const int HI = 0x00f000f0;
|
139 |
+
const int EX = 0x64006400;
|
140 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
141 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
142 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
143 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
144 |
+
// directly into `SUB` and `ADD`.
|
145 |
+
const int SUB = 0x64086408;
|
146 |
+
const int MUL = 0x2c002c00;
|
147 |
+
const int ADD = 0xd480d480;
|
148 |
+
FragB frag_b;
|
149 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
150 |
+
*reinterpret_cast<const half2*>(&SUB));
|
151 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
152 |
+
*reinterpret_cast<const half2*>(&MUL),
|
153 |
+
*reinterpret_cast<const half2*>(&ADD));
|
154 |
+
return frag_b;
|
155 |
+
}
|
156 |
+
|
157 |
+
// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16
|
158 |
+
// Reference:
|
159 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
160 |
+
template <>
|
161 |
+
__device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
|
162 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
163 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
164 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
165 |
+
|
166 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
167 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
168 |
+
|
169 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
170 |
+
|
171 |
+
FragB frag_b;
|
172 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
173 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
174 |
+
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
175 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
176 |
+
return frag_b;
|
177 |
+
}
|
178 |
+
|
179 |
+
template <>
|
180 |
+
__device__ inline FragB dequant<vllm::kU4.id()>(int q) {
|
181 |
+
const int LO = 0x000f000f;
|
182 |
+
const int HI = 0x00f000f0;
|
183 |
+
const int EX = 0x64006400;
|
184 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
185 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
186 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
187 |
+
|
188 |
+
const int SUB = 0x64006400;
|
189 |
+
const int MUL = 0x2c002c00;
|
190 |
+
const int ADD = 0xd400d400;
|
191 |
+
FragB frag_b;
|
192 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
193 |
+
*reinterpret_cast<const half2*>(&SUB));
|
194 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
195 |
+
*reinterpret_cast<const half2*>(&MUL),
|
196 |
+
*reinterpret_cast<const half2*>(&ADD));
|
197 |
+
return frag_b;
|
198 |
+
}
|
199 |
+
|
200 |
+
template <>
|
201 |
+
__device__ inline FragB dequant<vllm::kU8.id()>(int q) {
|
202 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
203 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
204 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
205 |
+
|
206 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
207 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
208 |
+
|
209 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
210 |
+
|
211 |
+
FragB frag_b;
|
212 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
213 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
214 |
+
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
215 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
216 |
+
return frag_b;
|
217 |
+
}
|
218 |
+
|
219 |
+
// Multiply dequantized values by the corresponding quantization scale; used
|
220 |
+
// only for grouped quantization.
|
221 |
+
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
222 |
+
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
223 |
+
frag_b[0] = __hmul2(frag_b[0], s);
|
224 |
+
frag_b[1] = __hmul2(frag_b[1], s);
|
225 |
+
}
|
226 |
+
|
227 |
+
__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
|
228 |
+
half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]);
|
229 |
+
frag_b[0] = __hsub2(frag_b[0], zp);
|
230 |
+
frag_b[1] = __hsub2(frag_b[1], zp);
|
231 |
+
}
|
232 |
+
|
233 |
+
// Same as above, but for act_order (each K is multiplied individually)
|
234 |
+
__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
|
235 |
+
FragS& frag_s_3, FragS& frag_s_4, int i) {
|
236 |
+
__half2 s_val_1_2;
|
237 |
+
s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i];
|
238 |
+
s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i];
|
239 |
+
|
240 |
+
__half2 s_val_3_4;
|
241 |
+
s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i];
|
242 |
+
s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i];
|
243 |
+
|
244 |
+
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
|
245 |
+
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
246 |
+
}
|
247 |
+
|
248 |
+
// Given 2 floats multiply by 2 scales (halves)
|
249 |
+
__device__ inline void scale_float(float* c, FragS& s) {
|
250 |
+
__half* s_ptr = reinterpret_cast<__half*>(&s);
|
251 |
+
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
252 |
+
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
253 |
+
}
|
254 |
+
|
255 |
+
// Wait until barrier reaches `count`, then lock for current threadblock.
|
256 |
+
__device__ inline void barrier_acquire(int* lock, int count) {
|
257 |
+
if (threadIdx.x == 0) {
|
258 |
+
int state = -1;
|
259 |
+
do
|
260 |
+
// Guarantee that subsequent writes by this threadblock will be visible
|
261 |
+
// globally.
|
262 |
+
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
263 |
+
: "=r"(state)
|
264 |
+
: "l"(lock));
|
265 |
+
while (state != count);
|
266 |
+
}
|
267 |
+
__syncthreads();
|
268 |
+
}
|
269 |
+
|
270 |
+
// Release barrier and increment visitation count.
|
271 |
+
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
272 |
+
__syncthreads();
|
273 |
+
if (threadIdx.x == 0) {
|
274 |
+
if (reset) {
|
275 |
+
lock[0] = 0;
|
276 |
+
return;
|
277 |
+
}
|
278 |
+
int val = 1;
|
279 |
+
// Make sure that all writes since acquiring this barrier are visible
|
280 |
+
// globally, while releasing the barrier.
|
281 |
+
asm volatile("fence.acq_rel.gpu;\n");
|
282 |
+
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
283 |
+
:
|
284 |
+
: "l"(lock), "r"(val));
|
285 |
+
}
|
286 |
+
}
|
287 |
+
|
288 |
+
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
289 |
+
const int threads, // number of threads in a threadblock
|
290 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
291 |
+
// dimension (batchsize) of the
|
292 |
+
// threadblock
|
293 |
+
const int thread_n_blocks, // same for n dimension (output)
|
294 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
295 |
+
const int stages, // number of stages for the async global->shared
|
296 |
+
// fetch pipeline
|
297 |
+
const bool has_act_order, // whether act_order is enabled
|
298 |
+
const bool has_zp, // whether zero-points are enabled
|
299 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
300 |
+
// with a separate quantization scale
|
301 |
+
>
|
302 |
+
__device__ void MarlinMoESingle(
|
303 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
304 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
305 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
306 |
+
const int* __restrict__ sorted_ids, // int32 sorted ids of experts
|
307 |
+
const float* __restrict__ topk_weights, // float topk weights
|
308 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
309 |
+
// (k/groupsize)xn
|
310 |
+
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
311 |
+
// (k/groupsize)x(n/pack_factor)
|
312 |
+
const int* __restrict__ g_idx, // int32 group indices of shape k
|
313 |
+
const int* __restrict__ expert_offsets,
|
314 |
+
int num_groups, // number of scale groups per output channel
|
315 |
+
int expert_idx, // idx of current expert
|
316 |
+
int num_experts, // number of experts
|
317 |
+
int topk, // topk parameter of moe
|
318 |
+
int prob_m, // batch dimension m
|
319 |
+
int prob_n, // output dimension n
|
320 |
+
int prob_k, // reduction dimension k
|
321 |
+
int tot_m, // total number of rows in A and C
|
322 |
+
int* locks, // extra global storage for barrier synchronization
|
323 |
+
bool replicate_input, // do we use the same input for each expert?
|
324 |
+
bool apply_weights, // apply weights to output
|
325 |
+
int current_m_block // current m block to start kernel computation from
|
326 |
+
) {
|
327 |
+
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
328 |
+
constexpr int pack_factor = 32 / w_type.size_bits();
|
329 |
+
|
330 |
+
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
331 |
+
// better partitioning with less reductions
|
332 |
+
int parallel = 1;
|
333 |
+
if (prob_m > 16 * thread_m_blocks) {
|
334 |
+
parallel = prob_m / (16 * thread_m_blocks);
|
335 |
+
prob_m = 16 * thread_m_blocks;
|
336 |
+
}
|
337 |
+
|
338 |
+
int k_tiles = prob_k / 16 / thread_k_blocks;
|
339 |
+
int n_tiles = prob_n / 16 / thread_n_blocks;
|
340 |
+
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
|
341 |
+
|
342 |
+
if constexpr (!has_act_order && group_blocks != -1) {
|
343 |
+
if (group_blocks >= thread_k_blocks) {
|
344 |
+
// Ensure that the number of tiles in each stripe is a multiple of the
|
345 |
+
// groupsize; this avoids an annoying special case where a stripe starts
|
346 |
+
// in the middle of group.
|
347 |
+
iters = (group_blocks / thread_k_blocks) *
|
348 |
+
ceildiv(iters, (group_blocks / thread_k_blocks));
|
349 |
+
}
|
350 |
+
}
|
351 |
+
|
352 |
+
int slice_row = (iters * blockIdx.x) % k_tiles;
|
353 |
+
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
354 |
+
int slice_col = slice_col_par;
|
355 |
+
int slice_iters; // number of threadblock tiles in the current slice
|
356 |
+
int slice_count =
|
357 |
+
0; // total number of active threadblocks in the current slice
|
358 |
+
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
359 |
+
// top
|
360 |
+
|
361 |
+
// We can easily implement parallel problem execution by just remapping
|
362 |
+
// indices and advancing global pointers
|
363 |
+
if (slice_col_par >= n_tiles) {
|
364 |
+
locks += (slice_col_par / n_tiles) * n_tiles;
|
365 |
+
slice_col = slice_col_par % n_tiles;
|
366 |
+
sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
|
367 |
+
}
|
368 |
+
|
369 |
+
// Compute all information about the current slice which is required for
|
370 |
+
// synchronization.
|
371 |
+
auto init_slice = [&]() {
|
372 |
+
slice_iters =
|
373 |
+
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
374 |
+
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
375 |
+
if (slice_iters == 0) return;
|
376 |
+
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
377 |
+
slice_count = 1;
|
378 |
+
slice_idx = 0;
|
379 |
+
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
380 |
+
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
381 |
+
int col_off = col_first - k_tiles * slice_col_par;
|
382 |
+
slice_count = ceildiv(k_tiles - col_off, iters);
|
383 |
+
if (col_off > 0) slice_count++;
|
384 |
+
int delta_first = iters * blockIdx.x - col_first;
|
385 |
+
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
386 |
+
slice_idx = slice_count - 1;
|
387 |
+
else {
|
388 |
+
slice_idx = slice_count - 1 - delta_first / iters;
|
389 |
+
if (col_off > 0) slice_idx--;
|
390 |
+
}
|
391 |
+
}
|
392 |
+
if (slice_col == n_tiles) {
|
393 |
+
sorted_ids += 16 * thread_m_blocks;
|
394 |
+
locks += n_tiles;
|
395 |
+
slice_col = 0;
|
396 |
+
}
|
397 |
+
};
|
398 |
+
init_slice();
|
399 |
+
|
400 |
+
// A sizes/strides
|
401 |
+
|
402 |
+
// stride of the A matrix in global memory
|
403 |
+
int a_gl_stride = prob_k / 8;
|
404 |
+
// stride of an A matrix tile in shared memory
|
405 |
+
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
|
406 |
+
// delta between subsequent A tiles in global memory
|
407 |
+
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
|
408 |
+
// between subsequent accesses within a tile
|
409 |
+
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
|
410 |
+
// between shared memory writes
|
411 |
+
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
|
412 |
+
// between shared memory tile reads
|
413 |
+
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
|
414 |
+
// within a shared memory tile
|
415 |
+
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
|
416 |
+
// overall size of a tile
|
417 |
+
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
|
418 |
+
// number of shared write iterations for a tile
|
419 |
+
constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
|
420 |
+
|
421 |
+
// B sizes/strides
|
422 |
+
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
423 |
+
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
424 |
+
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
|
425 |
+
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
426 |
+
|
427 |
+
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
428 |
+
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
|
429 |
+
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
|
430 |
+
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
|
431 |
+
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
432 |
+
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
433 |
+
|
434 |
+
// Scale sizes/strides without act_order
|
435 |
+
int s_gl_stride = prob_n / 8;
|
436 |
+
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
437 |
+
constexpr int s_tb_groups =
|
438 |
+
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
439 |
+
? thread_k_blocks / group_blocks
|
440 |
+
: 1;
|
441 |
+
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
442 |
+
int s_gl_rd_delta = s_gl_stride;
|
443 |
+
// Scale size/strides with act_order
|
444 |
+
constexpr int tb_k = 16 * thread_k_blocks;
|
445 |
+
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
|
446 |
+
// constexpr int act_s_row_stride = 1;
|
447 |
+
// int act_s_col_stride = act_s_row_stride * num_groups;
|
448 |
+
int act_s_col_stride = 1;
|
449 |
+
int act_s_col_warp_stride = act_s_col_stride * 8;
|
450 |
+
int tb_n_warps = thread_n_blocks / 4;
|
451 |
+
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
452 |
+
|
453 |
+
// Zero-points sizes/strides
|
454 |
+
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
455 |
+
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
456 |
+
constexpr int zp_tb_groups = s_tb_groups;
|
457 |
+
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
458 |
+
int zp_gl_rd_delta = zp_gl_stride;
|
459 |
+
|
460 |
+
// Global A read index of current thread.
|
461 |
+
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
462 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
463 |
+
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
464 |
+
// Shared write index of current thread.
|
465 |
+
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
466 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
467 |
+
// Shared read index.
|
468 |
+
int a_sh_rd =
|
469 |
+
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
470 |
+
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
471 |
+
|
472 |
+
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
|
473 |
+
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
474 |
+
b_gl_rd += b_sh_stride * slice_col;
|
475 |
+
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
476 |
+
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
477 |
+
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
478 |
+
|
479 |
+
// For act_order
|
480 |
+
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
481 |
+
int slice_k_start = tb_k * slice_row;
|
482 |
+
int slice_k_finish = slice_k_start + tb_k * slice_iters;
|
483 |
+
int slice_k_start_shared_fetch = slice_k_start;
|
484 |
+
int slice_n_offset = act_s_col_tb_stride * slice_col;
|
485 |
+
|
486 |
+
// No act_order
|
487 |
+
int s_gl_rd;
|
488 |
+
if constexpr (!has_act_order) {
|
489 |
+
if constexpr (group_blocks == -1) {
|
490 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
491 |
+
} else {
|
492 |
+
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
493 |
+
s_sh_stride * slice_col + threadIdx.x;
|
494 |
+
}
|
495 |
+
}
|
496 |
+
int s_sh_wr = threadIdx.x;
|
497 |
+
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
498 |
+
|
499 |
+
// Zero-points
|
500 |
+
int zp_gl_rd;
|
501 |
+
if constexpr (has_zp) {
|
502 |
+
if constexpr (group_blocks == -1) {
|
503 |
+
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
504 |
+
} else {
|
505 |
+
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
506 |
+
zp_sh_stride * slice_col + threadIdx.x;
|
507 |
+
}
|
508 |
+
}
|
509 |
+
int zp_sh_wr = threadIdx.x;
|
510 |
+
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
511 |
+
|
512 |
+
// We use a different scale layout for grouped and column-wise quantization as
|
513 |
+
// we scale a `half2` tile in column-major layout in the former and in
|
514 |
+
// row-major in the latter case.
|
515 |
+
int s_sh_rd;
|
516 |
+
if constexpr (group_blocks != -1)
|
517 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
518 |
+
(threadIdx.x % 32) / 4;
|
519 |
+
else
|
520 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
521 |
+
(threadIdx.x % 32) % 4;
|
522 |
+
|
523 |
+
// Zero-points have the same read layout as the scales
|
524 |
+
// (without column-wise case)
|
525 |
+
constexpr int num_col_threads = 8;
|
526 |
+
constexpr int num_row_threads = 4;
|
527 |
+
constexpr int num_ints_per_thread = 8 / pack_factor;
|
528 |
+
int zp_sh_rd;
|
529 |
+
if constexpr (has_zp) {
|
530 |
+
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
531 |
+
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
532 |
+
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
533 |
+
}
|
534 |
+
|
535 |
+
int sh_first_group_id = -1;
|
536 |
+
int sh_num_groups = -1;
|
537 |
+
constexpr int sh_max_num_groups = 32;
|
538 |
+
|
539 |
+
extern __shared__ int4 sh[];
|
540 |
+
// Shared memory storage for global fetch pipelines.
|
541 |
+
int4* sh_a = sh;
|
542 |
+
int4* sh_b = sh_a + (stages * a_sh_stage);
|
543 |
+
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
544 |
+
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
545 |
+
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
546 |
+
|
547 |
+
// Precompute which thread should not read memory in which iterations; this is
|
548 |
+
// needed if there are more threads than required for a certain tilesize or
|
549 |
+
// when the batchsize is not a multiple of 16.
|
550 |
+
bool a_sh_wr_pred[a_sh_wr_iters];
|
551 |
+
#pragma unroll
|
552 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
553 |
+
int a_idx = a_sh_wr_delta * i + a_sh_wr;
|
554 |
+
int row = a_idx / a_gl_rd_delta_o;
|
555 |
+
if (row >= prob_m) {
|
556 |
+
a_sh_wr_pred[i] = false;
|
557 |
+
} else {
|
558 |
+
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
559 |
+
}
|
560 |
+
}
|
561 |
+
|
562 |
+
// To ensure that writing and reading A tiles to/from shared memory, the
|
563 |
+
// latter in fragment format, is fully bank conflict free, we need to use a
|
564 |
+
// rather fancy XOR-based layout. The key here is that neither reads nor
|
565 |
+
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
|
566 |
+
// same shared memory banks. Further, it seems (based on NSight-Compute) that
|
567 |
+
// each warp must also write a consecutive memory segment?
|
568 |
+
auto transform_a = [&](int i) {
|
569 |
+
int row = i / a_gl_rd_delta_o;
|
570 |
+
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
571 |
+
};
|
572 |
+
// Since the computation of this remapping is non-trivial and, due to our main
|
573 |
+
// loop unrolls, all shared memory accesses are static, we simply precompute
|
574 |
+
// both transformed reads and writes.
|
575 |
+
int a_sh_wr_trans[a_sh_wr_iters];
|
576 |
+
#pragma unroll
|
577 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
578 |
+
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
579 |
+
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
580 |
+
#pragma unroll
|
581 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
582 |
+
#pragma unroll
|
583 |
+
for (int j = 0; j < thread_m_blocks; j++)
|
584 |
+
a_sh_rd_trans[i][j] =
|
585 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
586 |
+
}
|
587 |
+
|
588 |
+
// Since B-accesses have non-constant stride they have to be computed at
|
589 |
+
// runtime; we break dependencies between subsequent accesses with a tile by
|
590 |
+
// maintining multiple pointers (we have enough registers), a tiny
|
591 |
+
// optimization.
|
592 |
+
const int4* B_ptr[b_sh_wr_iters];
|
593 |
+
#pragma unroll
|
594 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
595 |
+
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
596 |
+
|
597 |
+
// Register storage for double buffer of shared memory reads.
|
598 |
+
FragA frag_a[2][thread_m_blocks];
|
599 |
+
I4 frag_b_quant[2][b_thread_vecs];
|
600 |
+
FragC frag_c[thread_m_blocks][4][2];
|
601 |
+
FragS frag_s[2][4]; // No act-order
|
602 |
+
FragS act_frag_s[2][4][4]; // For act-order
|
603 |
+
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
604 |
+
FragZP frag_zp; // Zero-points in fp16
|
605 |
+
|
606 |
+
// Zero accumulators.
|
607 |
+
auto zero_accums = [&]() {
|
608 |
+
#pragma unroll
|
609 |
+
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
610 |
+
reinterpret_cast<float*>(frag_c)[i] = 0;
|
611 |
+
};
|
612 |
+
|
613 |
+
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
|
614 |
+
int last_group_id) {
|
615 |
+
sh_first_group_id = first_group_id;
|
616 |
+
sh_num_groups = last_group_id - first_group_id + 1;
|
617 |
+
|
618 |
+
if (sh_num_groups < sh_max_num_groups) {
|
619 |
+
sh_num_groups = sh_max_num_groups;
|
620 |
+
}
|
621 |
+
|
622 |
+
if (sh_first_group_id + sh_num_groups > num_groups) {
|
623 |
+
sh_num_groups = num_groups - sh_first_group_id;
|
624 |
+
}
|
625 |
+
|
626 |
+
int row_offset = first_group_id * s_gl_stride;
|
627 |
+
|
628 |
+
if (is_async) {
|
629 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
630 |
+
if (threadIdx.x < s_sh_stride) {
|
631 |
+
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
|
632 |
+
&scales_ptr[row_offset + (i * s_gl_stride) +
|
633 |
+
slice_n_offset + threadIdx.x]);
|
634 |
+
}
|
635 |
+
}
|
636 |
+
} else {
|
637 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
638 |
+
if (threadIdx.x < s_sh_stride) {
|
639 |
+
sh_s[(i * s_sh_stride) + threadIdx.x] =
|
640 |
+
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
|
641 |
+
threadIdx.x];
|
642 |
+
}
|
643 |
+
}
|
644 |
+
}
|
645 |
+
};
|
646 |
+
// Asynchronously fetch the next A, B and s tile from global to the next
|
647 |
+
// shared memory pipeline location.
|
648 |
+
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
649 |
+
if (pred) {
|
650 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
651 |
+
#pragma unroll
|
652 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
653 |
+
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
|
654 |
+
int row = a_idx / a_gl_stride;
|
655 |
+
int sorted_row =
|
656 |
+
replicate_input ? sorted_ids[row] / topk : sorted_ids[row];
|
657 |
+
int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
|
658 |
+
if (sorted_row < tot_m * (replicate_input ? 1 : topk) &&
|
659 |
+
new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) {
|
660 |
+
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx],
|
661 |
+
a_sh_wr_pred[i]);
|
662 |
+
}
|
663 |
+
}
|
664 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
665 |
+
#pragma unroll
|
666 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
667 |
+
#pragma unroll
|
668 |
+
for (int j = 0; j < b_thread_vecs; j++) {
|
669 |
+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
670 |
+
}
|
671 |
+
B_ptr[i] += b_gl_rd_delta_o;
|
672 |
+
}
|
673 |
+
|
674 |
+
if constexpr (has_act_order) {
|
675 |
+
// Fetch g_idx thread-block portion
|
676 |
+
int full_pipe = a_off;
|
677 |
+
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
|
678 |
+
if (cur_k < prob_k && cur_k < slice_k_finish) {
|
679 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
680 |
+
|
681 |
+
int4 const* cur_g_idx_stage_ptr =
|
682 |
+
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
|
683 |
+
|
684 |
+
if (threadIdx.x < g_idx_stage) {
|
685 |
+
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
|
686 |
+
&cur_g_idx_stage_ptr[threadIdx.x]);
|
687 |
+
}
|
688 |
+
}
|
689 |
+
} else {
|
690 |
+
if constexpr (group_blocks != -1) {
|
691 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
692 |
+
|
693 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
694 |
+
// Only fetch scales if this tile starts a new group
|
695 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
696 |
+
if (s_sh_wr_pred) {
|
697 |
+
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
698 |
+
}
|
699 |
+
s_gl_rd += s_gl_rd_delta;
|
700 |
+
}
|
701 |
+
} else {
|
702 |
+
for (int i = 0; i < s_tb_groups; i++) {
|
703 |
+
if (s_sh_wr_pred) {
|
704 |
+
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
|
705 |
+
&scales_ptr[s_gl_rd]);
|
706 |
+
}
|
707 |
+
s_gl_rd += s_gl_rd_delta;
|
708 |
+
}
|
709 |
+
}
|
710 |
+
}
|
711 |
+
|
712 |
+
if constexpr (has_zp && group_blocks != -1) {
|
713 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
714 |
+
|
715 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
716 |
+
// Only fetch zero-points if this tile starts a new group
|
717 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
718 |
+
if (zp_sh_wr_pred) {
|
719 |
+
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
720 |
+
}
|
721 |
+
zp_gl_rd += zp_gl_rd_delta;
|
722 |
+
}
|
723 |
+
} else {
|
724 |
+
for (int i = 0; i < zp_tb_groups; i++) {
|
725 |
+
if (zp_sh_wr_pred) {
|
726 |
+
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
727 |
+
&zp_ptr[zp_gl_rd]);
|
728 |
+
}
|
729 |
+
zp_gl_rd += zp_gl_rd_delta;
|
730 |
+
}
|
731 |
+
}
|
732 |
+
}
|
733 |
+
}
|
734 |
+
}
|
735 |
+
// Insert a fence even when we are winding down the pipeline to ensure that
|
736 |
+
// waiting is also correct at this point.
|
737 |
+
cp_async_fence();
|
738 |
+
};
|
739 |
+
|
740 |
+
auto fetch_zp_to_shared = [&]() {
|
741 |
+
if (zp_sh_wr_pred) {
|
742 |
+
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
743 |
+
}
|
744 |
+
};
|
745 |
+
|
746 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
747 |
+
auto wait_for_stage = [&]() {
|
748 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
749 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
750 |
+
// shared memory load is fully complete (as it may otherwise be
|
751 |
+
// overwritten).
|
752 |
+
cp_async_wait<stages - 2>();
|
753 |
+
__syncthreads();
|
754 |
+
};
|
755 |
+
|
756 |
+
// Load the next sub-tile from the current location in the shared memory pipe
|
757 |
+
// into the current register buffer.
|
758 |
+
auto fetch_to_registers = [&](int k, int pipe) {
|
759 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
760 |
+
#pragma unroll
|
761 |
+
for (int i = 0; i < thread_m_blocks; i++)
|
762 |
+
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
763 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
764 |
+
|
765 |
+
#pragma unroll
|
766 |
+
for (int i = 0; i < b_thread_vecs; i++) {
|
767 |
+
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
768 |
+
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
769 |
+
}
|
770 |
+
};
|
771 |
+
|
772 |
+
bool is_same_group[stages];
|
773 |
+
int same_group_id[stages];
|
774 |
+
|
775 |
+
auto init_same_group = [&](int pipe) {
|
776 |
+
if constexpr (!has_act_order) {
|
777 |
+
is_same_group[pipe] = false;
|
778 |
+
same_group_id[pipe] = 0;
|
779 |
+
return;
|
780 |
+
}
|
781 |
+
|
782 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
783 |
+
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
784 |
+
|
785 |
+
int group_id_1 = sh_g_idx_int_ptr[0];
|
786 |
+
int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
|
787 |
+
|
788 |
+
is_same_group[pipe] = group_id_1 == group_id_2;
|
789 |
+
same_group_id[pipe] = group_id_1;
|
790 |
+
};
|
791 |
+
|
792 |
+
auto fetch_scales_to_registers = [&](int k, int full_pipe) {
|
793 |
+
int pipe = full_pipe % stages;
|
794 |
+
|
795 |
+
if constexpr (!has_act_order) {
|
796 |
+
// No act-order case
|
797 |
+
if constexpr (group_blocks != -1) {
|
798 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
799 |
+
int4* sh_s_stage =
|
800 |
+
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
801 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
802 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
803 |
+
} else {
|
804 |
+
int warp_id = threadIdx.x / 32;
|
805 |
+
int n_warps = thread_n_blocks / 4;
|
806 |
+
|
807 |
+
int warp_row = warp_id / n_warps;
|
808 |
+
|
809 |
+
int cur_k = warp_row * 16;
|
810 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
811 |
+
|
812 |
+
int k_blocks = cur_k / 16;
|
813 |
+
int cur_group_id = k_blocks / group_blocks;
|
814 |
+
|
815 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
816 |
+
|
817 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
818 |
+
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
819 |
+
}
|
820 |
+
}
|
821 |
+
|
822 |
+
return;
|
823 |
+
}
|
824 |
+
|
825 |
+
// Act-order case
|
826 |
+
|
827 |
+
// Determine K of the "current" thread-block
|
828 |
+
int cur_k = slice_k_start + tb_k * full_pipe;
|
829 |
+
if (cur_k >= prob_k || cur_k >= slice_k_finish) {
|
830 |
+
return;
|
831 |
+
}
|
832 |
+
|
833 |
+
// Reset (to current thread-block) since we read g_idx portion from the
|
834 |
+
// shared memory
|
835 |
+
cur_k = 0;
|
836 |
+
|
837 |
+
// Progress to current iteration
|
838 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
839 |
+
|
840 |
+
// Determine "position" inside the thread-block (based on warp and
|
841 |
+
// thread-id)
|
842 |
+
int warp_id = threadIdx.x / 32;
|
843 |
+
int n_warps =
|
844 |
+
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
845 |
+
|
846 |
+
int warp_row = warp_id / n_warps;
|
847 |
+
int warp_col = warp_id % n_warps;
|
848 |
+
|
849 |
+
cur_k += warp_row * 16;
|
850 |
+
|
851 |
+
int th_id = threadIdx.x % 32;
|
852 |
+
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
853 |
+
|
854 |
+
int s_col_shift =
|
855 |
+
/*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
|
856 |
+
(th_id / 4) * act_s_col_stride;
|
857 |
+
|
858 |
+
if (is_same_group[pipe]) {
|
859 |
+
if (k % 2 == 0) {
|
860 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
|
861 |
+
sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
|
862 |
+
s_col_shift];
|
863 |
+
} else {
|
864 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
|
865 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
|
866 |
+
}
|
867 |
+
|
868 |
+
for (int i = 1; i < 4; i++) {
|
869 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
|
870 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
|
871 |
+
}
|
872 |
+
return;
|
873 |
+
}
|
874 |
+
|
875 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
876 |
+
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
877 |
+
|
878 |
+
constexpr int k_frag_offsets[4] = {0, 1, 8,
|
879 |
+
9}; // Tensor core offsets per thread
|
880 |
+
|
881 |
+
#pragma unroll
|
882 |
+
for (int i = 0; i < 4; i++) {
|
883 |
+
int actual_k = cur_k + k_frag_offsets[i];
|
884 |
+
|
885 |
+
int group_id = sh_g_idx_int_ptr[actual_k];
|
886 |
+
int rel_group_id = group_id - sh_first_group_id;
|
887 |
+
|
888 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
|
889 |
+
sh_s[rel_group_id * s_sh_stride + s_col_shift];
|
890 |
+
}
|
891 |
+
};
|
892 |
+
|
893 |
+
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
894 |
+
// This code does not handle group_blocks == 0,
|
895 |
+
// which signifies act_order.
|
896 |
+
// has_zp implies AWQ, which doesn't have act_order,
|
897 |
+
static_assert(!has_zp || group_blocks != 0);
|
898 |
+
|
899 |
+
if constexpr (has_zp) {
|
900 |
+
int pipe = full_pipe % stages;
|
901 |
+
|
902 |
+
if constexpr (group_blocks == -1) {
|
903 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
904 |
+
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
905 |
+
}
|
906 |
+
|
907 |
+
} else if constexpr (group_blocks >= thread_k_blocks) {
|
908 |
+
int4* sh_zp_stage =
|
909 |
+
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
910 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
911 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
912 |
+
frag_qzp[k % 2][i] =
|
913 |
+
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
914 |
+
}
|
915 |
+
} else {
|
916 |
+
int warp_id = threadIdx.x / 32;
|
917 |
+
int n_warps = thread_n_blocks / 4;
|
918 |
+
|
919 |
+
int warp_row = warp_id / n_warps;
|
920 |
+
|
921 |
+
int cur_k = warp_row * 16;
|
922 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
923 |
+
|
924 |
+
int k_blocks = cur_k / 16;
|
925 |
+
int cur_group_id = 0;
|
926 |
+
|
927 |
+
// Suppress bogus and persistent divide-by-zero warning
|
928 |
+
#pragma nv_diagnostic push
|
929 |
+
#pragma nv_diag_suppress divide_by_zero
|
930 |
+
cur_group_id = k_blocks / group_blocks;
|
931 |
+
#pragma nv_diagnostic pop
|
932 |
+
|
933 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
934 |
+
|
935 |
+
sh_zp_stage += cur_group_id * zp_sh_stride;
|
936 |
+
|
937 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
938 |
+
frag_qzp[k % 2][i] =
|
939 |
+
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
940 |
+
}
|
941 |
+
}
|
942 |
+
}
|
943 |
+
};
|
944 |
+
|
945 |
+
// Execute the actual tensor core matmul of a sub-tile.
|
946 |
+
auto matmul = [&](int k) {
|
947 |
+
if constexpr (has_zp) {
|
948 |
+
FragB frag_zp_0;
|
949 |
+
FragB frag_zp_1;
|
950 |
+
int zp_quant_0, zp_quant_1;
|
951 |
+
|
952 |
+
if constexpr (w_type.size_bits() == 4) {
|
953 |
+
zp_quant_0 = frag_qzp[k % 2][0];
|
954 |
+
zp_quant_1 = zp_quant_0 >> 8;
|
955 |
+
} else {
|
956 |
+
static_assert(w_type.size_bits() == 8);
|
957 |
+
zp_quant_0 = frag_qzp[k % 2][0];
|
958 |
+
zp_quant_1 = frag_qzp[k % 2][1];
|
959 |
+
}
|
960 |
+
|
961 |
+
frag_zp_0 = dequant<w_type_id>(zp_quant_0);
|
962 |
+
frag_zp_1 = dequant<w_type_id>(zp_quant_1);
|
963 |
+
|
964 |
+
frag_zp[0] = frag_zp_0[0];
|
965 |
+
frag_zp[1] = frag_zp_0[1];
|
966 |
+
frag_zp[2] = frag_zp_1[0];
|
967 |
+
frag_zp[3] = frag_zp_1[1];
|
968 |
+
}
|
969 |
+
|
970 |
+
// We have the m dimension as the inner loop in order to encourage overlapping
|
971 |
+
// dequantization and matmul operations.
|
972 |
+
#pragma unroll
|
973 |
+
for (int j = 0; j < 4; j++) {
|
974 |
+
int b_quant_0, b_quant_1;
|
975 |
+
if constexpr (w_type.size_bits() == 4) {
|
976 |
+
b_quant_0 = frag_b_quant[k % 2][0][j];
|
977 |
+
b_quant_1 = b_quant_0 >> 8;
|
978 |
+
} else {
|
979 |
+
static_assert(w_type.size_bits() == 8);
|
980 |
+
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
981 |
+
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
982 |
+
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
983 |
+
}
|
984 |
+
|
985 |
+
FragB frag_b0 = dequant<w_type_id>(b_quant_0);
|
986 |
+
FragB frag_b1 = dequant<w_type_id>(b_quant_1);
|
987 |
+
// Apply zero-point to frag_b0
|
988 |
+
if constexpr (has_zp) {
|
989 |
+
sub_zp(frag_b0, frag_zp[j], 0);
|
990 |
+
}
|
991 |
+
|
992 |
+
// Apply scale to frag_b0
|
993 |
+
if constexpr (has_act_order) {
|
994 |
+
scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
995 |
+
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
|
996 |
+
} else {
|
997 |
+
if constexpr (group_blocks != -1) {
|
998 |
+
scale(frag_b0, frag_s[k % 2][j], 0);
|
999 |
+
}
|
1000 |
+
}
|
1001 |
+
|
1002 |
+
// Apply zero-point to frag_b1
|
1003 |
+
if constexpr (has_zp) {
|
1004 |
+
sub_zp(frag_b1, frag_zp[j], 1);
|
1005 |
+
}
|
1006 |
+
|
1007 |
+
// Apply scale to frag_b1
|
1008 |
+
if constexpr (has_act_order) {
|
1009 |
+
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
1010 |
+
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
|
1011 |
+
|
1012 |
+
} else {
|
1013 |
+
if constexpr (group_blocks != -1) {
|
1014 |
+
scale(frag_b1, frag_s[k % 2][j], 1);
|
1015 |
+
}
|
1016 |
+
}
|
1017 |
+
|
1018 |
+
#pragma unroll
|
1019 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
1020 |
+
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
1021 |
+
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
1022 |
+
}
|
1023 |
+
}
|
1024 |
+
};
|
1025 |
+
|
1026 |
+
// Since we slice across the k dimension of a tile in order to increase the
|
1027 |
+
// number of warps while keeping the n dimension of a tile reasonable, we have
|
1028 |
+
// multiple warps that accumulate their partial sums of the same output
|
1029 |
+
// location; which we have to reduce over in the end. We do in shared memory.
|
1030 |
+
auto thread_block_reduce = [&]() {
|
1031 |
+
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
1032 |
+
if (red_off >= 1) {
|
1033 |
+
int red_idx = threadIdx.x / b_sh_stride_threads;
|
1034 |
+
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
1035 |
+
constexpr int red_sh_delta = b_sh_stride_threads;
|
1036 |
+
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
1037 |
+
(threadIdx.x % b_sh_stride_threads);
|
1038 |
+
|
1039 |
+
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
1040 |
+
// unnecessary read or write iterations, e.g., for two warps we write only
|
1041 |
+
// once by warp 1 and read only once by warp 0.
|
1042 |
+
|
1043 |
+
#pragma unroll
|
1044 |
+
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
1045 |
+
#pragma unroll
|
1046 |
+
for (int i = red_off; i > 0; i /= 2) {
|
1047 |
+
if (i <= red_idx && red_idx < 2 * i) {
|
1048 |
+
#pragma unroll
|
1049 |
+
for (int j = 0; j < 4 * 2; j++) {
|
1050 |
+
int red_sh_wr =
|
1051 |
+
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
1052 |
+
if (i < red_off) {
|
1053 |
+
float* c_rd =
|
1054 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
1055 |
+
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
1056 |
+
#pragma unroll
|
1057 |
+
for (int k = 0; k < 4; k++)
|
1058 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
1059 |
+
c_rd[k] + c_wr[k];
|
1060 |
+
}
|
1061 |
+
sh[red_sh_wr] =
|
1062 |
+
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
1063 |
+
}
|
1064 |
+
}
|
1065 |
+
__syncthreads();
|
1066 |
+
}
|
1067 |
+
if (red_idx == 0) {
|
1068 |
+
#pragma unroll
|
1069 |
+
for (int i = 0; i < 4 * 2; i++) {
|
1070 |
+
float* c_rd =
|
1071 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
1072 |
+
#pragma unroll
|
1073 |
+
for (int j = 0; j < 4; j++)
|
1074 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
1075 |
+
c_rd[j];
|
1076 |
+
}
|
1077 |
+
}
|
1078 |
+
__syncthreads();
|
1079 |
+
}
|
1080 |
+
}
|
1081 |
+
};
|
1082 |
+
|
1083 |
+
// Since multiple threadblocks may process parts of the same column slice, we
|
1084 |
+
// finally have to globally reduce over the results. As the striped
|
1085 |
+
// partitioning minimizes the number of such reductions and our outputs are
|
1086 |
+
// usually rather small, we perform this reduction serially in L2 cache.
|
1087 |
+
auto global_reduce = [&](bool first = false, bool last = false) {
|
1088 |
+
// We are very careful here to reduce directly in the output buffer to
|
1089 |
+
// maximize L2 cache utilization in this step. To do this, we write out
|
1090 |
+
// results in FP16 (but still reduce with FP32 compute).
|
1091 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
1092 |
+
if (threadIdx.x < active_threads) {
|
1093 |
+
int c_gl_stride = prob_n / 8;
|
1094 |
+
int c_gl_wr_delta_o = 8 * c_gl_stride;
|
1095 |
+
int c_gl_wr_delta_i = 4 * (active_threads / 32);
|
1096 |
+
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
|
1097 |
+
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
1098 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
1099 |
+
constexpr int c_sh_wr_delta = active_threads;
|
1100 |
+
int c_sh_wr = threadIdx.x;
|
1101 |
+
|
1102 |
+
int row = (threadIdx.x % 32) / 4;
|
1103 |
+
|
1104 |
+
if (!first) {
|
1105 |
+
// Interestingly, doing direct global accesses here really seems to mess up
|
1106 |
+
// the compiler and lead to slowdowns, hence we also use async-copies even
|
1107 |
+
// though these fetches are not actually asynchronous.
|
1108 |
+
#pragma unroll
|
1109 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1110 |
+
int c_idx =
|
1111 |
+
c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
|
1112 |
+
int sorted_row = sorted_ids[c_idx / c_gl_stride];
|
1113 |
+
int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
|
1114 |
+
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx],
|
1115 |
+
sorted_row < tot_m * topk &&
|
1116 |
+
(8 * (i / 2) + row < prob_m &&
|
1117 |
+
(i < (thread_m_blocks - 1) * 4 ||
|
1118 |
+
sorted_ids[8 * (i / 2) + row] < tot_m * topk)));
|
1119 |
+
}
|
1120 |
+
cp_async_fence();
|
1121 |
+
cp_async_wait<0>();
|
1122 |
+
}
|
1123 |
+
|
1124 |
+
#pragma unroll
|
1125 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1126 |
+
if (8 * (i / 2) + row < prob_m &&
|
1127 |
+
(i < (thread_m_blocks - 1) * 4 ||
|
1128 |
+
sorted_ids[8 * (i / 2) + row] < tot_m * topk)) {
|
1129 |
+
if (!first) {
|
1130 |
+
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
1131 |
+
#pragma unroll
|
1132 |
+
for (int j = 0; j < 2 * 4; j++) {
|
1133 |
+
reinterpret_cast<float*>(
|
1134 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
1135 |
+
__half2float(reinterpret_cast<__half*>(&c_red)[j]);
|
1136 |
+
}
|
1137 |
+
}
|
1138 |
+
if (!last) {
|
1139 |
+
int4 c;
|
1140 |
+
#pragma unroll
|
1141 |
+
for (int j = 0; j < 2 * 4; j++) {
|
1142 |
+
reinterpret_cast<__half*>(&c)[j] =
|
1143 |
+
__float2half(reinterpret_cast<float*>(
|
1144 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
1145 |
+
}
|
1146 |
+
int c_idx =
|
1147 |
+
c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
|
1148 |
+
int row = sorted_ids[c_idx / c_gl_stride];
|
1149 |
+
if (row < tot_m * topk) {
|
1150 |
+
int new_idx = row * c_gl_stride + c_idx % c_gl_stride;
|
1151 |
+
C[new_idx] = c;
|
1152 |
+
}
|
1153 |
+
}
|
1154 |
+
}
|
1155 |
+
}
|
1156 |
+
}
|
1157 |
+
};
|
1158 |
+
|
1159 |
+
// Write out the reduce final result in the correct layout. We only actually
|
1160 |
+
// reshuffle matrix fragments in this step, the reduction above is performed
|
1161 |
+
// in fragment layout.
|
1162 |
+
auto write_result = [&]() {
|
1163 |
+
int c_gl_stride = prob_n / 8;
|
1164 |
+
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
1165 |
+
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
1166 |
+
constexpr int c_sh_rd_delta =
|
1167 |
+
c_sh_stride * (threads / (2 * thread_n_blocks));
|
1168 |
+
|
1169 |
+
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
1170 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
1171 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
1172 |
+
int c_sh_wr =
|
1173 |
+
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
|
1174 |
+
c_sh_wr += 32 * (threadIdx.x / 32);
|
1175 |
+
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
1176 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
1177 |
+
|
1178 |
+
int c_gl_wr_end = c_gl_stride * prob_m;
|
1179 |
+
|
1180 |
+
// We first reorder in shared memory to guarantee the most efficient final
|
1181 |
+
// global write patterns
|
1182 |
+
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
1183 |
+
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
1184 |
+
|
1185 |
+
// For per-column quantization we finally apply the scale here (only for
|
1186 |
+
// 4-bit)
|
1187 |
+
if constexpr (!has_act_order && group_blocks == -1 &&
|
1188 |
+
w_type.size_bits() == 4) {
|
1189 |
+
res = __hmul2(res, s[0]);
|
1190 |
+
}
|
1191 |
+
|
1192 |
+
((half2*)sh)[idx] = res;
|
1193 |
+
};
|
1194 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1195 |
+
#pragma unroll
|
1196 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
1197 |
+
#pragma unroll
|
1198 |
+
for (int j = 0; j < 4; j++) {
|
1199 |
+
int wr = c_sh_wr + 8 * j;
|
1200 |
+
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
1201 |
+
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
1202 |
+
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
1203 |
+
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
1204 |
+
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
1205 |
+
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
1206 |
+
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
1207 |
+
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
1208 |
+
}
|
1209 |
+
c_sh_wr += 16 * (4 * c_sh_stride);
|
1210 |
+
}
|
1211 |
+
}
|
1212 |
+
__syncthreads();
|
1213 |
+
|
1214 |
+
#pragma unroll
|
1215 |
+
for (int i = 0;
|
1216 |
+
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
1217 |
+
i++) {
|
1218 |
+
if (c_gl_wr < c_gl_wr_end) {
|
1219 |
+
int row = sorted_ids[c_gl_wr / c_gl_stride];
|
1220 |
+
if (row < tot_m * topk) {
|
1221 |
+
int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
|
1222 |
+
if (!apply_weights) {
|
1223 |
+
C[off] = sh[c_sh_rd];
|
1224 |
+
} else {
|
1225 |
+
__half* ctrg = reinterpret_cast<__half*>(&C[off]);
|
1226 |
+
__half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
|
1227 |
+
for (int j = 0; j < 8; ++j) {
|
1228 |
+
ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j]));
|
1229 |
+
}
|
1230 |
+
}
|
1231 |
+
c_gl_wr += c_gl_wr_delta;
|
1232 |
+
c_sh_rd += c_sh_rd_delta;
|
1233 |
+
}
|
1234 |
+
}
|
1235 |
+
}
|
1236 |
+
};
|
1237 |
+
|
1238 |
+
// Start global fetch and register load pipelines.
|
1239 |
+
auto start_pipes = [&]() {
|
1240 |
+
|
1241 |
+
#pragma unroll
|
1242 |
+
for (int i = 0; i < stages - 1; i++) {
|
1243 |
+
if (has_act_order && i == 0) {
|
1244 |
+
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
1245 |
+
if (last_g_idx >= prob_k) {
|
1246 |
+
last_g_idx = prob_k - 1;
|
1247 |
+
}
|
1248 |
+
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
1249 |
+
}
|
1250 |
+
|
1251 |
+
if constexpr (has_zp && group_blocks == -1) {
|
1252 |
+
if (i == 0) {
|
1253 |
+
fetch_zp_to_shared();
|
1254 |
+
}
|
1255 |
+
}
|
1256 |
+
fetch_to_shared(i, i, i < slice_iters);
|
1257 |
+
}
|
1258 |
+
|
1259 |
+
zero_accums();
|
1260 |
+
wait_for_stage();
|
1261 |
+
init_same_group(0);
|
1262 |
+
fetch_to_registers(0, 0);
|
1263 |
+
fetch_scales_to_registers(0, 0);
|
1264 |
+
fetch_zp_to_registers(0, 0);
|
1265 |
+
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
1266 |
+
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
1267 |
+
};
|
1268 |
+
if (slice_iters) {
|
1269 |
+
start_pipes();
|
1270 |
+
}
|
1271 |
+
|
1272 |
+
// Main loop.
|
1273 |
+
while (slice_iters) {
|
1274 |
+
// We unroll over both the global fetch and the register load pipeline to
|
1275 |
+
// ensure all shared memory accesses are static. Note that both pipelines
|
1276 |
+
// have even length meaning that the next iteration will always start at
|
1277 |
+
// index 0.
|
1278 |
+
#pragma unroll
|
1279 |
+
for (int pipe = 0; pipe < stages;) {
|
1280 |
+
#pragma unroll
|
1281 |
+
for (int k = 0; k < b_sh_wr_iters; k++) {
|
1282 |
+
fetch_to_registers(k + 1, pipe % stages);
|
1283 |
+
fetch_scales_to_registers(k + 1, pipe);
|
1284 |
+
fetch_zp_to_registers(k + 1, pipe);
|
1285 |
+
if (k == b_sh_wr_iters - 2) {
|
1286 |
+
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
1287 |
+
slice_iters >= stages);
|
1288 |
+
pipe++;
|
1289 |
+
wait_for_stage();
|
1290 |
+
init_same_group(pipe % stages);
|
1291 |
+
}
|
1292 |
+
matmul(k);
|
1293 |
+
}
|
1294 |
+
slice_iters--;
|
1295 |
+
if (slice_iters == 0) {
|
1296 |
+
break;
|
1297 |
+
}
|
1298 |
+
}
|
1299 |
+
|
1300 |
+
a_gl_rd += a_gl_rd_delta_o * stages;
|
1301 |
+
slice_k_start += tb_k * stages;
|
1302 |
+
slice_k_start_shared_fetch += tb_k * stages;
|
1303 |
+
|
1304 |
+
if constexpr (has_act_order) {
|
1305 |
+
int first_group_id = g_idx[slice_k_start];
|
1306 |
+
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
1307 |
+
if (last_g_idx >= prob_k) {
|
1308 |
+
last_g_idx = prob_k - 1;
|
1309 |
+
}
|
1310 |
+
int last_group_id = g_idx[last_g_idx];
|
1311 |
+
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
1312 |
+
fetch_scales_to_shared(false, first_group_id, last_group_id);
|
1313 |
+
__syncthreads();
|
1314 |
+
}
|
1315 |
+
}
|
1316 |
+
|
1317 |
+
// Process results and, if necessary, proceed to the next column slice.
|
1318 |
+
// While this pattern may not be the most readable, other ways of writing
|
1319 |
+
// the loop seemed to noticeably worse performance after compilation.
|
1320 |
+
if (slice_iters == 0) {
|
1321 |
+
cp_async_wait<0>();
|
1322 |
+
bool last = slice_idx == slice_count - 1;
|
1323 |
+
if constexpr (!has_act_order && group_blocks == -1) {
|
1324 |
+
if constexpr (w_type.size_bits() == 8) {
|
1325 |
+
if (s_sh_wr_pred) {
|
1326 |
+
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
1327 |
+
}
|
1328 |
+
cp_async_fence();
|
1329 |
+
} else {
|
1330 |
+
// For 4-bit per-column scales, we only fetch them here in the
|
1331 |
+
// final step before write-out
|
1332 |
+
if (last) {
|
1333 |
+
if (s_sh_wr_pred) {
|
1334 |
+
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
1335 |
+
}
|
1336 |
+
cp_async_fence();
|
1337 |
+
}
|
1338 |
+
}
|
1339 |
+
}
|
1340 |
+
|
1341 |
+
thread_block_reduce();
|
1342 |
+
if constexpr (!has_act_order && group_blocks == -1) {
|
1343 |
+
if constexpr (w_type.size_bits() == 8) {
|
1344 |
+
cp_async_wait<0>();
|
1345 |
+
__syncthreads();
|
1346 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1347 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
1348 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
1349 |
+
}
|
1350 |
+
|
1351 |
+
} else {
|
1352 |
+
if (last) {
|
1353 |
+
cp_async_wait<0>();
|
1354 |
+
__syncthreads();
|
1355 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1356 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
1357 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
1358 |
+
}
|
1359 |
+
}
|
1360 |
+
}
|
1361 |
+
}
|
1362 |
+
|
1363 |
+
// For 8-bit channelwise, we apply the scale before the global reduction
|
1364 |
+
// that converts the fp32 results to fp16 (so that we avoid possible
|
1365 |
+
// overflow in fp16)
|
1366 |
+
if constexpr (!has_act_order && group_blocks == -1 &&
|
1367 |
+
w_type.size_bits() == 8) {
|
1368 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1369 |
+
#pragma unroll
|
1370 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
1371 |
+
#pragma unroll
|
1372 |
+
for (int j = 0; j < 4; j++) {
|
1373 |
+
scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
1374 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
1375 |
+
scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
|
1376 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
1377 |
+
|
1378 |
+
scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
|
1379 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
1380 |
+
scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
|
1381 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
1382 |
+
}
|
1383 |
+
}
|
1384 |
+
}
|
1385 |
+
}
|
1386 |
+
|
1387 |
+
if (slice_count > 1) { // only globally reduce if there is more than one
|
1388 |
+
// block in a slice
|
1389 |
+
barrier_acquire(&locks[slice_col], slice_idx);
|
1390 |
+
global_reduce(slice_idx == 0, last);
|
1391 |
+
barrier_release(&locks[slice_col], last);
|
1392 |
+
}
|
1393 |
+
if (last) // only the last block in a slice actually writes the result
|
1394 |
+
write_result();
|
1395 |
+
slice_row = 0;
|
1396 |
+
slice_col_par++;
|
1397 |
+
slice_col++;
|
1398 |
+
init_slice();
|
1399 |
+
if (slice_iters) {
|
1400 |
+
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
1401 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
1402 |
+
#pragma unroll
|
1403 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
1404 |
+
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
1405 |
+
if (slice_col == 0) {
|
1406 |
+
#pragma unroll
|
1407 |
+
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
1408 |
+
}
|
1409 |
+
|
1410 |
+
// Update slice k/n for scales loading
|
1411 |
+
if constexpr (has_act_order) {
|
1412 |
+
slice_k_start = tb_k * slice_row;
|
1413 |
+
slice_k_finish = slice_k_start + tb_k * slice_iters;
|
1414 |
+
slice_k_start_shared_fetch = slice_k_start;
|
1415 |
+
slice_n_offset = act_s_col_tb_stride * slice_col;
|
1416 |
+
|
1417 |
+
} else {
|
1418 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
1419 |
+
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
1420 |
+
}
|
1421 |
+
|
1422 |
+
start_pipes();
|
1423 |
+
}
|
1424 |
+
}
|
1425 |
+
}
|
1426 |
+
}
|
1427 |
+
|
1428 |
+
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
1429 |
+
const int threads, // number of threads in a threadblock
|
1430 |
+
const int thread_n_blocks, // same for n dimension (output)
|
1431 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
1432 |
+
const int stages, // number of stages for the async global->shared
|
1433 |
+
// fetch pipeline
|
1434 |
+
const bool has_act_order, // whether act_order is enabled
|
1435 |
+
const bool has_zp, // whether zero-points are enabled
|
1436 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
1437 |
+
// with a separate quantization scale
|
1438 |
+
>
|
1439 |
+
__global__ void MarlinMoE(
|
1440 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
1441 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
1442 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
1443 |
+
const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts
|
1444 |
+
const float* __restrict__ topk_weights, // float topk weights
|
1445 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
1446 |
+
// (k/groupsize)xn
|
1447 |
+
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
1448 |
+
// (k/groupsize)x(n/pack_factor)
|
1449 |
+
const int* __restrict__ g_idx, // int32 group indices of shape k
|
1450 |
+
const int* __restrict__ expert_offsets,
|
1451 |
+
int num_groups, // number of scale groups per output channel
|
1452 |
+
int expert_idx, // idx of current expert
|
1453 |
+
int num_experts, // number of experts
|
1454 |
+
int topk, // topk parameter of moe
|
1455 |
+
int prob_m, // batch dimension m
|
1456 |
+
int prob_n, // output dimension n
|
1457 |
+
int prob_k, // reduction dimension k
|
1458 |
+
int tot_m, // total number of rows in A and C
|
1459 |
+
int* locks, // extra global storage for barrier synchronization
|
1460 |
+
bool replicate_input, // do we use the same input for each expert?
|
1461 |
+
bool apply_weights, // apply weights to output
|
1462 |
+
int current_m_block, // current m block to start kernel computation from
|
1463 |
+
int max_par, // maximum parallelism
|
1464 |
+
int cfg_max_m_blocks // upper bound on m blocks
|
1465 |
+
) {
|
1466 |
+
int m_block_ctr = current_m_block;
|
1467 |
+
|
1468 |
+
const int* sorted_ids_expert =
|
1469 |
+
sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par;
|
1470 |
+
int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx];
|
1471 |
+
if (tot_its == 0) {
|
1472 |
+
return;
|
1473 |
+
}
|
1474 |
+
int tot_m_blocks = ceildiv(tot_its, 16);
|
1475 |
+
int pad = 16 * tot_m_blocks - tot_its;
|
1476 |
+
|
1477 |
+
if (m_block_ctr >= tot_m_blocks) {
|
1478 |
+
return;
|
1479 |
+
}
|
1480 |
+
|
1481 |
+
int max_block = tot_m_blocks - m_block_ctr;
|
1482 |
+
prob_m = tot_its - 16 * m_block_ctr;
|
1483 |
+
|
1484 |
+
int par = 1;
|
1485 |
+
if (max_block > cfg_max_m_blocks) {
|
1486 |
+
// Note that parallel > 1 currently only works for inputs without any
|
1487 |
+
// padding
|
1488 |
+
par = (16 * max_block - pad) / (16 * cfg_max_m_blocks);
|
1489 |
+
if (par > max_par) par = max_par;
|
1490 |
+
prob_m = (16 * cfg_max_m_blocks) * par;
|
1491 |
+
m_block_ctr += cfg_max_m_blocks * (par - 1);
|
1492 |
+
max_block = cfg_max_m_blocks;
|
1493 |
+
}
|
1494 |
+
|
1495 |
+
if (max_block == 1) {
|
1496 |
+
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
|
1497 |
+
stages, has_act_order, has_zp, group_blocks>(
|
1498 |
+
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
1499 |
+
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
1500 |
+
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
1501 |
+
current_m_block);
|
1502 |
+
} else if (max_block == 2) {
|
1503 |
+
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
|
1504 |
+
stages, has_act_order, has_zp, group_blocks>(
|
1505 |
+
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
1506 |
+
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
1507 |
+
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
1508 |
+
current_m_block);
|
1509 |
+
} else if (max_block == 3) {
|
1510 |
+
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
|
1511 |
+
stages, has_act_order, has_zp, group_blocks>(
|
1512 |
+
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
1513 |
+
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
1514 |
+
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
1515 |
+
current_m_block);
|
1516 |
+
} else {
|
1517 |
+
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
|
1518 |
+
stages, has_act_order, has_zp, group_blocks>(
|
1519 |
+
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
1520 |
+
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
1521 |
+
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
1522 |
+
current_m_block);
|
1523 |
+
}
|
1524 |
+
}
|
1525 |
+
|
1526 |
+
#else
|
1527 |
+
|
1528 |
+
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
1529 |
+
const int threads, // number of threads in a threadblock
|
1530 |
+
const int thread_n_blocks, // same for n dimension (output)
|
1531 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
1532 |
+
const int stages, // number of stages for the async global->shared
|
1533 |
+
// fetch pipeline
|
1534 |
+
const bool has_act_order, // whether act_order is enabled
|
1535 |
+
const bool has_zp, // whether zero-points are enabled
|
1536 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
1537 |
+
// with a separate quantization scale
|
1538 |
+
>
|
1539 |
+
__global__ void MarlinMoE(
|
1540 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
1541 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
1542 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
1543 |
+
const int* __restrict__ sorted_ids, // int32 sorted ids of experts
|
1544 |
+
const float* __restrict__ topk_weights, // float topk weights
|
1545 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
1546 |
+
// (k/groupsize)xn
|
1547 |
+
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
1548 |
+
// (k/groupsize)x(n/pack_factor)
|
1549 |
+
const int* __restrict__ g_idx, // int32 group indices of shape k
|
1550 |
+
const int* __restrict__ expert_offsets,
|
1551 |
+
int num_groups, // number of scale groups per output channel
|
1552 |
+
int expert_idx, // idx of current expert
|
1553 |
+
int num_experts, // number of experts
|
1554 |
+
int topk, // topk parameter of moe
|
1555 |
+
int prob_m, // batch dimension m
|
1556 |
+
int prob_n, // output dimension n
|
1557 |
+
int prob_k, // reduction dimension k
|
1558 |
+
int tot_m, // total number of rows in A and C
|
1559 |
+
int* locks, // extra global storage for barrier synchronization
|
1560 |
+
bool replicate_input, // do we use the same input for each expert?
|
1561 |
+
bool apply_weights, // apply weights to output
|
1562 |
+
int current_m_block, // current m block to start kernel computation from
|
1563 |
+
int max_par, // maximum parallelism
|
1564 |
+
int cfg_max_m_blocks // upper bound on m blocks
|
1565 |
+
) {
|
1566 |
+
// Marlin is not implemented yet for SM < 8.0
|
1567 |
+
assert(false);
|
1568 |
+
return;
|
1569 |
+
}
|
1570 |
+
|
1571 |
+
#endif
|
1572 |
+
|
1573 |
+
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
1574 |
+
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
1575 |
+
// we want relatively few warps to have many registers per warp and small tiles.
|
1576 |
+
const int USER_THREADS =
|
1577 |
+
256; // Note: This is only used with user-provided thread_k/n
|
1578 |
+
const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
1579 |
+
|
1580 |
+
static constexpr int min_thread_n = 64;
|
1581 |
+
static constexpr int min_thread_k = 64;
|
1582 |
+
|
1583 |
+
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
|
1584 |
+
HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
|
1585 |
+
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
|
1586 |
+
thread_k_blocks == THREAD_K_BLOCKS && \
|
1587 |
+
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
1588 |
+
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
1589 |
+
cudaFuncSetAttribute( \
|
1590 |
+
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
1591 |
+
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
|
1592 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
1593 |
+
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
1594 |
+
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
|
1595 |
+
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
1596 |
+
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
1597 |
+
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
1598 |
+
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
|
1599 |
+
replicate_input, apply_weights, m_block, max_par, \
|
1600 |
+
cfg_max_m_blocks); \
|
1601 |
+
}
|
1602 |
+
|
1603 |
+
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
1604 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
1605 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
1606 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
1607 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
1608 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
1609 |
+
|
1610 |
+
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
1611 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
1612 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
1613 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
1614 |
+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
1615 |
+
|
1616 |
+
} // namespace marlin_moe
|
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.cu
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "marlin_moe_kernel_ku4.h"
|
2 |
+
|
3 |
+
namespace marlin_moe {
|
4 |
+
|
5 |
+
// We return bool so we can create these different kernel calls as a sequence
|
6 |
+
// of if-elseif's.
|
7 |
+
bool call_marlin_moe_kernel_ku4(
|
8 |
+
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
9 |
+
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
10 |
+
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
11 |
+
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
12 |
+
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
13 |
+
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
14 |
+
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
15 |
+
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
16 |
+
int m_block, int max_par, int cfg_max_m_blocks) {
|
17 |
+
bool has_zp = true;
|
18 |
+
|
19 |
+
if (false) {
|
20 |
+
}
|
21 |
+
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
|
22 |
+
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
|
23 |
+
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
|
24 |
+
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
|
25 |
+
else {
|
26 |
+
return false;
|
27 |
+
}
|
28 |
+
return true;
|
29 |
+
}
|
30 |
+
|
31 |
+
} // namespace marlin_moe
|
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4.h
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "marlin_moe_kernel.h"
|
4 |
+
|
5 |
+
namespace marlin_moe {
|
6 |
+
|
7 |
+
// We return bool so we can create these different kernel calls as a sequence
|
8 |
+
// of if-elseif's.
|
9 |
+
bool call_marlin_moe_kernel_ku4(
|
10 |
+
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
11 |
+
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
12 |
+
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
13 |
+
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
14 |
+
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
15 |
+
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
16 |
+
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
17 |
+
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
18 |
+
int m_block, int max_par, int cfg_max_m_blocks);
|
19 |
+
|
20 |
+
} // namespace marlin_moe
|
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "marlin_moe_kernel_ku4b8.h"
|
2 |
+
|
3 |
+
namespace marlin_moe {
|
4 |
+
|
5 |
+
// We return bool so we can create these different kernel calls as a sequence
|
6 |
+
// of if-elseif's.
|
7 |
+
bool call_marlin_moe_kernel_ku4b8(
|
8 |
+
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
9 |
+
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
10 |
+
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
11 |
+
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
12 |
+
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
13 |
+
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
14 |
+
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
15 |
+
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
16 |
+
int m_block, int max_par, int cfg_max_m_blocks) {
|
17 |
+
bool has_zp = false;
|
18 |
+
|
19 |
+
if (false) {
|
20 |
+
}
|
21 |
+
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
|
22 |
+
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
|
23 |
+
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
|
24 |
+
GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
|
25 |
+
else {
|
26 |
+
return false;
|
27 |
+
}
|
28 |
+
return true;
|
29 |
+
}
|
30 |
+
|
31 |
+
} // namespace marlin_moe
|
marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "marlin_moe_kernel.h"
|
4 |
+
|
5 |
+
namespace marlin_moe {
|
6 |
+
|
7 |
+
// We return bool so we can create these different kernel calls as a sequence
|
8 |
+
// of if-elseif's.
|
9 |
+
bool call_marlin_moe_kernel_ku4b8(
|
10 |
+
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
11 |
+
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
12 |
+
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
13 |
+
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
14 |
+
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
15 |
+
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
16 |
+
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
17 |
+
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
18 |
+
int m_block, int max_par, int cfg_max_m_blocks);
|
19 |
+
|
20 |
+
} // namespace marlin_moe
|
marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "marlin_moe_kernel_ku8b128.h"
|
2 |
+
|
3 |
+
namespace marlin_moe {
|
4 |
+
|
5 |
+
// We return bool so we can create these different kernel calls as a sequence
|
6 |
+
// of if-elseif's.
|
7 |
+
bool call_marlin_moe_kernel_ku8b128(
|
8 |
+
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
9 |
+
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
10 |
+
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
11 |
+
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
12 |
+
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
13 |
+
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
14 |
+
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
15 |
+
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
16 |
+
int m_block, int max_par, int cfg_max_m_blocks) {
|
17 |
+
bool has_zp = false;
|
18 |
+
|
19 |
+
if (false) {
|
20 |
+
}
|
21 |
+
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
|
22 |
+
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
|
23 |
+
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
|
24 |
+
GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
|
25 |
+
else {
|
26 |
+
return false;
|
27 |
+
}
|
28 |
+
return true;
|
29 |
+
}
|
30 |
+
|
31 |
+
} // namespace marlin_moe
|
marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "marlin_moe_kernel.h"
|
4 |
+
|
5 |
+
namespace marlin_moe {
|
6 |
+
|
7 |
+
bool call_marlin_moe_kernel_ku8b128(
|
8 |
+
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
|
9 |
+
bool has_act_order, int group_blocks, int num_threads, int blocks,
|
10 |
+
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
|
11 |
+
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
|
12 |
+
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
|
13 |
+
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
|
14 |
+
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
|
15 |
+
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
|
16 |
+
int m_block, int max_par, int cfg_max_m_blocks);
|
17 |
+
|
18 |
+
}
|
marlin-moe/marlin_moe_ops.cu
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Modified by Neural Magic
|
3 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
4 |
+
*
|
5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
* you may not use this file except in compliance with the License.
|
7 |
+
* You may obtain a copy of the License at
|
8 |
+
*
|
9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
*
|
11 |
+
* Unless required by applicable law or agreed to in writing, software
|
12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
* See the License for the specific language governing permissions and
|
15 |
+
* limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
#include <torch/all.h>
|
19 |
+
|
20 |
+
#include <ATen/cuda/CUDAContext.h>
|
21 |
+
#include <c10/cuda/CUDAGuard.h>
|
22 |
+
#include <cuda.h>
|
23 |
+
#include <cuda_fp16.h>
|
24 |
+
#include <cuda_runtime.h>
|
25 |
+
|
26 |
+
#include <iostream>
|
27 |
+
|
28 |
+
#include "core/exception.hpp"
|
29 |
+
#include "core/scalar_type.hpp"
|
30 |
+
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
31 |
+
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
32 |
+
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
|
33 |
+
|
34 |
+
template <typename T>
|
35 |
+
inline std::string str(T x) {
|
36 |
+
return std::to_string(x);
|
37 |
+
}
|
38 |
+
|
39 |
+
namespace marlin_moe {
|
40 |
+
|
41 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
42 |
+
|
43 |
+
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
44 |
+
// on the given "perm" indices.
|
45 |
+
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
46 |
+
int const* __restrict__ perm_int_ptr,
|
47 |
+
int4* __restrict__ out_int4_ptr, int size_m,
|
48 |
+
int size_k, int block_rows) {
|
49 |
+
int start_row = block_rows * blockIdx.x;
|
50 |
+
int finish_row = start_row + block_rows;
|
51 |
+
if (finish_row > size_m) {
|
52 |
+
finish_row = size_m;
|
53 |
+
}
|
54 |
+
int cur_block_rows = finish_row - start_row;
|
55 |
+
|
56 |
+
int row_stride = size_k * sizeof(half) / 16;
|
57 |
+
|
58 |
+
auto permute_row = [&](int row) {
|
59 |
+
int iters = size_k / blockDim.x;
|
60 |
+
int rest = size_k % blockDim.x;
|
61 |
+
|
62 |
+
int offset = row * row_stride;
|
63 |
+
|
64 |
+
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
65 |
+
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
66 |
+
|
67 |
+
int base_k = 0;
|
68 |
+
|
69 |
+
for (int i = 0; i < iters; i++) {
|
70 |
+
int cur_k = base_k + threadIdx.x;
|
71 |
+
int src_pos = perm_int_ptr[cur_k];
|
72 |
+
|
73 |
+
out_half[cur_k] = a_row_half[src_pos];
|
74 |
+
|
75 |
+
base_k += blockDim.x;
|
76 |
+
}
|
77 |
+
|
78 |
+
if (rest) {
|
79 |
+
if (threadIdx.x < rest) {
|
80 |
+
int cur_k = base_k + threadIdx.x;
|
81 |
+
int src_pos = perm_int_ptr[cur_k];
|
82 |
+
|
83 |
+
out_half[cur_k] = a_row_half[src_pos];
|
84 |
+
}
|
85 |
+
}
|
86 |
+
};
|
87 |
+
|
88 |
+
for (int i = 0; i < cur_block_rows; i++) {
|
89 |
+
int cur_row = start_row + i;
|
90 |
+
if (cur_row < size_m) {
|
91 |
+
permute_row(cur_row);
|
92 |
+
}
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
|
97 |
+
int* __restrict__ expert_offsets,
|
98 |
+
int topk_length, int block_size) {
|
99 |
+
int expert_id = threadIdx.x;
|
100 |
+
int num_experts = blockDim.x;
|
101 |
+
|
102 |
+
int occurrences = 0;
|
103 |
+
for (int i = 0; i < topk_length; ++i) {
|
104 |
+
occurrences += (topk_ids[i] == expert_id);
|
105 |
+
}
|
106 |
+
expert_offsets[expert_id + 1] = occurrences;
|
107 |
+
__syncthreads();
|
108 |
+
|
109 |
+
if (threadIdx.x == 0) {
|
110 |
+
int tot_offset = 0;
|
111 |
+
expert_offsets[0] = 0;
|
112 |
+
for (int i = 0; i < num_experts; ++i) {
|
113 |
+
tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
|
114 |
+
expert_offsets[i + 1] = tot_offset;
|
115 |
+
}
|
116 |
+
}
|
117 |
+
__syncthreads();
|
118 |
+
}
|
119 |
+
|
120 |
+
#else
|
121 |
+
|
122 |
+
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
123 |
+
int const* __restrict__ perm_int_ptr,
|
124 |
+
int4* __restrict__ out_int4_ptr, int size_m,
|
125 |
+
int size_k, int block_rows) {
|
126 |
+
// Marlin is not implemented yet for SM < 8.0
|
127 |
+
assert(false);
|
128 |
+
return;
|
129 |
+
}
|
130 |
+
|
131 |
+
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
|
132 |
+
int* __restrict__ expert_offsets,
|
133 |
+
int topk_length, int block_size) {
|
134 |
+
// Marlin is not implemented yet for SM < 8.0
|
135 |
+
assert(false);
|
136 |
+
return;
|
137 |
+
}
|
138 |
+
|
139 |
+
#endif
|
140 |
+
|
141 |
+
typedef struct {
|
142 |
+
int thread_k;
|
143 |
+
int thread_n;
|
144 |
+
int num_threads;
|
145 |
+
} thread_config_t;
|
146 |
+
|
147 |
+
typedef struct {
|
148 |
+
int max_m_blocks;
|
149 |
+
thread_config_t tb_cfg;
|
150 |
+
} exec_config_t;
|
151 |
+
|
152 |
+
thread_config_t small_batch_thread_configs[] = {
|
153 |
+
// Ordered by priority
|
154 |
+
|
155 |
+
// thread_k, thread_n, num_threads
|
156 |
+
{128, 128, 256}, // Default
|
157 |
+
{128, 64, 128}, // Reduce N 2X, same K
|
158 |
+
{64, 256, 256}, // Reduce K 2X, increase N 2X
|
159 |
+
{64, 128, 128}, // Reduce K 2X, same N
|
160 |
+
{64, 64, 128}, // Reduce both 2X
|
161 |
+
};
|
162 |
+
|
163 |
+
thread_config_t large_batch_thread_configs[] = {
|
164 |
+
// Ordered by priority
|
165 |
+
|
166 |
+
// thread_k, thread_n, num_threads
|
167 |
+
{64, 256, 256}, // Default
|
168 |
+
{128, 128, 256}, // Reduce N 2X, increase K 2X
|
169 |
+
{64, 128, 128}, // Reduce N 2X, same K
|
170 |
+
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
171 |
+
{64, 64, 128}, // Reduce N 4X, same K
|
172 |
+
};
|
173 |
+
|
174 |
+
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
175 |
+
int prob_n, int prob_k, int num_bits, int group_size,
|
176 |
+
bool has_act_order, bool is_k_full) {
|
177 |
+
bool cache_scales_chunk = has_act_order && !is_k_full;
|
178 |
+
|
179 |
+
int tb_n = th_config.thread_n;
|
180 |
+
int tb_k = th_config.thread_k;
|
181 |
+
|
182 |
+
// Get max scale groups per thread-block
|
183 |
+
int tb_groups;
|
184 |
+
if (group_size == -1) {
|
185 |
+
tb_groups = 1;
|
186 |
+
} else if (group_size == 0) {
|
187 |
+
tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
|
188 |
+
} else {
|
189 |
+
tb_groups = ceildiv(tb_k, group_size);
|
190 |
+
}
|
191 |
+
|
192 |
+
if (cache_scales_chunk) {
|
193 |
+
int load_groups =
|
194 |
+
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
|
195 |
+
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
196 |
+
return load_groups * tb_n * 4;
|
197 |
+
|
198 |
+
} else {
|
199 |
+
int tb_scales = tb_groups * tb_n * 2;
|
200 |
+
|
201 |
+
return tb_scales * STAGES;
|
202 |
+
}
|
203 |
+
}
|
204 |
+
|
205 |
+
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
206 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
207 |
+
int scales_cache_size, int max_shared_mem) {
|
208 |
+
int pack_factor = 32 / num_bits;
|
209 |
+
|
210 |
+
// Get B size
|
211 |
+
int tb_k = th_config.thread_k;
|
212 |
+
int tb_n = th_config.thread_n;
|
213 |
+
|
214 |
+
int b_size = (tb_k * tb_n / pack_factor) * 4;
|
215 |
+
|
216 |
+
// Get A size
|
217 |
+
int m_blocks = ceildiv(prob_m, 16);
|
218 |
+
int tb_max_m = 16;
|
219 |
+
|
220 |
+
while (true) {
|
221 |
+
if (m_blocks >= max_m_blocks) {
|
222 |
+
tb_max_m *= max_m_blocks;
|
223 |
+
break;
|
224 |
+
}
|
225 |
+
|
226 |
+
max_m_blocks--;
|
227 |
+
if (max_m_blocks == 0) {
|
228 |
+
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
229 |
+
}
|
230 |
+
}
|
231 |
+
|
232 |
+
int a_size = (tb_max_m * tb_k) * 2;
|
233 |
+
|
234 |
+
float pipe_size = (a_size + b_size) * STAGES;
|
235 |
+
|
236 |
+
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
237 |
+
|
238 |
+
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
239 |
+
}
|
240 |
+
|
241 |
+
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
242 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
243 |
+
int group_size, bool has_act_order, bool is_k_full,
|
244 |
+
int max_shared_mem) {
|
245 |
+
// Sanity
|
246 |
+
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
247 |
+
th_config.num_threads == -1) {
|
248 |
+
return false;
|
249 |
+
}
|
250 |
+
|
251 |
+
// Verify K/N are divisible by thread K/N
|
252 |
+
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
253 |
+
return false;
|
254 |
+
}
|
255 |
+
|
256 |
+
// thread_k can be only 128 or 64 (because it must be less than groupsize
|
257 |
+
// which is 128)
|
258 |
+
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
|
259 |
+
return false;
|
260 |
+
}
|
261 |
+
|
262 |
+
// Verify min for thread K/N
|
263 |
+
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
264 |
+
return false;
|
265 |
+
}
|
266 |
+
|
267 |
+
// num_threads must be at least 128 (= 4 warps)
|
268 |
+
if (th_config.num_threads < 128) {
|
269 |
+
return false;
|
270 |
+
}
|
271 |
+
|
272 |
+
// Determine cache for scales
|
273 |
+
int scales_cache_size =
|
274 |
+
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
275 |
+
group_size, has_act_order, is_k_full);
|
276 |
+
|
277 |
+
// Check that pipeline fits into cache
|
278 |
+
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
279 |
+
num_bits, scales_cache_size, max_shared_mem)) {
|
280 |
+
return false;
|
281 |
+
}
|
282 |
+
|
283 |
+
return true;
|
284 |
+
}
|
285 |
+
|
286 |
+
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
287 |
+
int num_bits, int group_size,
|
288 |
+
bool has_act_order, bool is_k_full,
|
289 |
+
int max_shared_mem) {
|
290 |
+
int max_m_blocks = 4;
|
291 |
+
while (max_m_blocks > 0) {
|
292 |
+
if (prob_m <= 16) {
|
293 |
+
for (auto th_config : small_batch_thread_configs) {
|
294 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
295 |
+
num_bits, group_size, has_act_order, is_k_full,
|
296 |
+
max_shared_mem)) {
|
297 |
+
return exec_config_t{max_m_blocks, th_config};
|
298 |
+
}
|
299 |
+
}
|
300 |
+
} else {
|
301 |
+
for (auto th_config : large_batch_thread_configs) {
|
302 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
303 |
+
num_bits, group_size, has_act_order, is_k_full,
|
304 |
+
max_shared_mem)) {
|
305 |
+
return exec_config_t{max_m_blocks, th_config};
|
306 |
+
}
|
307 |
+
}
|
308 |
+
}
|
309 |
+
|
310 |
+
max_m_blocks--; // Process less M blocks per invocation to reduce cache
|
311 |
+
// usage
|
312 |
+
}
|
313 |
+
|
314 |
+
return exec_config_t{0, {-1, -1, -1}};
|
315 |
+
}
|
316 |
+
|
317 |
+
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
|
318 |
+
else if (KERNEL_FUNCTION( \
|
319 |
+
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
|
320 |
+
group_blocks, num_threads, blocks, max_shared_mem, stream, \
|
321 |
+
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
322 |
+
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
323 |
+
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
|
324 |
+
replicate_input, apply_weights, m_block, max_par, \
|
325 |
+
exec_cfg.max_m_blocks)) { \
|
326 |
+
}
|
327 |
+
|
328 |
+
void marlin_mm_moe(const void* A, const void* B, void* C,
|
329 |
+
const void* sorted_ids, const void* topk_weights,
|
330 |
+
const void* topk_ids, const void* s, void* zp,
|
331 |
+
const void* g_idx, const void* perm, void* a_tmp,
|
332 |
+
void* expert_offsets, int prob_m, int prob_n, int prob_k,
|
333 |
+
void* workspace, vllm::ScalarType const& q_type,
|
334 |
+
bool has_act_order, bool is_k_full, bool has_zp,
|
335 |
+
int num_groups, int group_size, int num_experts, int topk,
|
336 |
+
int moe_block_size, int dev, cudaStream_t stream,
|
337 |
+
int thread_k, int thread_n, int sms, int max_par,
|
338 |
+
bool replicate_input, bool apply_weights) {
|
339 |
+
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
340 |
+
", ", prob_n, ", ", prob_k, "]");
|
341 |
+
|
342 |
+
if (sms == -1) {
|
343 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
344 |
+
}
|
345 |
+
|
346 |
+
int max_shared_mem = 0;
|
347 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
348 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
349 |
+
TORCH_CHECK(max_shared_mem > 0);
|
350 |
+
|
351 |
+
int num_bits = q_type.size_bits();
|
352 |
+
|
353 |
+
// Set thread config
|
354 |
+
exec_config_t exec_cfg;
|
355 |
+
if (thread_k != -1 && thread_n != -1) {
|
356 |
+
// User-defined config
|
357 |
+
exec_cfg =
|
358 |
+
exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
|
359 |
+
} else {
|
360 |
+
// Auto config
|
361 |
+
exec_cfg =
|
362 |
+
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
|
363 |
+
has_act_order, is_k_full, max_shared_mem);
|
364 |
+
}
|
365 |
+
|
366 |
+
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
|
367 |
+
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
|
368 |
+
prob_m, prob_n, prob_k, num_bits, group_size,
|
369 |
+
has_act_order, is_k_full, max_shared_mem),
|
370 |
+
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
|
371 |
+
", thread_k = ", exec_cfg.tb_cfg.thread_k,
|
372 |
+
", thread_n = ", exec_cfg.tb_cfg.thread_n,
|
373 |
+
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
|
374 |
+
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
375 |
+
", group_size = ", group_size,
|
376 |
+
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
377 |
+
", max_shared_mem = ", max_shared_mem);
|
378 |
+
|
379 |
+
int num_threads = exec_cfg.tb_cfg.num_threads;
|
380 |
+
thread_k = exec_cfg.tb_cfg.thread_k;
|
381 |
+
thread_n = exec_cfg.tb_cfg.thread_n;
|
382 |
+
|
383 |
+
int thread_k_blocks = thread_k / 16;
|
384 |
+
int thread_n_blocks = thread_n / 16;
|
385 |
+
|
386 |
+
int blocks = sms;
|
387 |
+
|
388 |
+
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
389 |
+
" is not divisible by thread_n = ", thread_n);
|
390 |
+
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
391 |
+
" is not divisible by thread_k = ", thread_k);
|
392 |
+
|
393 |
+
int group_blocks = 0;
|
394 |
+
if (has_act_order) {
|
395 |
+
if (is_k_full) {
|
396 |
+
TORCH_CHECK(group_size != -1);
|
397 |
+
group_blocks = group_size / 16;
|
398 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
399 |
+
" is not divisible by group_blocks = ", group_blocks);
|
400 |
+
} else {
|
401 |
+
TORCH_CHECK(group_size == 0);
|
402 |
+
group_blocks = 0;
|
403 |
+
}
|
404 |
+
|
405 |
+
} else {
|
406 |
+
if (group_size == -1) {
|
407 |
+
group_blocks = -1;
|
408 |
+
} else {
|
409 |
+
group_blocks = group_size / 16;
|
410 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
411 |
+
" is not divisible by group_blocks = ", group_blocks);
|
412 |
+
}
|
413 |
+
}
|
414 |
+
|
415 |
+
int tot_m = prob_m;
|
416 |
+
|
417 |
+
const int* topk_ids_ptr = (const int*)topk_ids;
|
418 |
+
int* expert_offsets_ptr = (int*)expert_offsets;
|
419 |
+
compute_expert_offsets<<<1, num_experts, 0, stream>>>(
|
420 |
+
topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
|
421 |
+
|
422 |
+
bool do_permute_a = has_act_order;
|
423 |
+
|
424 |
+
// If we have a full K, then we can run the non-act-order version of Marlin
|
425 |
+
// (since the weight rows are reordered by increasing group ids, and by
|
426 |
+
// having a full K, we have full original groups)
|
427 |
+
if (is_k_full) {
|
428 |
+
has_act_order = false;
|
429 |
+
}
|
430 |
+
|
431 |
+
int pack_factor = 32 / q_type.size_bits();
|
432 |
+
|
433 |
+
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
|
434 |
+
const int4* A_ptr = (const int4*)A;
|
435 |
+
int4* a_tmp_ptr = (int4*)a_tmp;
|
436 |
+
const int4* B_ptr =
|
437 |
+
(const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
|
438 |
+
int4* C_ptr = (int4*)C;
|
439 |
+
const float* topk_weights_ptr = (const float*)topk_weights;
|
440 |
+
const int* sorted_ids_ptr = (const int*)sorted_ids;
|
441 |
+
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
|
442 |
+
const int4* zp_ptr =
|
443 |
+
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
|
444 |
+
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
|
445 |
+
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
|
446 |
+
int* locks = (int*)workspace;
|
447 |
+
|
448 |
+
if (do_permute_a) {
|
449 |
+
// Permute A columns
|
450 |
+
int topk_rows = replicate_input ? tot_m : tot_m * topk;
|
451 |
+
int block_rows = ceildiv(topk_rows, blocks);
|
452 |
+
permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
|
453 |
+
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
|
454 |
+
A_ptr = a_tmp_ptr;
|
455 |
+
}
|
456 |
+
|
457 |
+
int tot_m_blocks = ceildiv(tot_m, 16);
|
458 |
+
for (int m_block = 0; m_block < tot_m_blocks;
|
459 |
+
m_block += 4 * exec_cfg.max_m_blocks) {
|
460 |
+
if (false) {
|
461 |
+
}
|
462 |
+
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
|
463 |
+
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
|
464 |
+
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
|
465 |
+
else {
|
466 |
+
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
467 |
+
str(prob_n) + ", " + str(prob_k) + "]" +
|
468 |
+
", has_act_order = " + str(has_act_order) +
|
469 |
+
", num_groups = " + str(num_groups) +
|
470 |
+
", group_size = " + str(group_size) +
|
471 |
+
", thread_n_blocks = " + str(thread_n_blocks) +
|
472 |
+
", thread_k_blocks = " + str(thread_k_blocks));
|
473 |
+
}
|
474 |
+
}
|
475 |
+
}
|
476 |
+
}
|
477 |
+
|
478 |
+
} // namespace marlin_moe
|
479 |
+
|
480 |
+
torch::Tensor marlin_gemm_moe(
|
481 |
+
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
482 |
+
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
483 |
+
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
484 |
+
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
|
485 |
+
const torch::Tensor& perm, torch::Tensor& workspace,
|
486 |
+
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
|
487 |
+
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
|
488 |
+
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
|
489 |
+
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
490 |
+
bool has_zp = b_zeros.size(1) != 0;
|
491 |
+
if (has_zp) {
|
492 |
+
TORCH_CHECK(
|
493 |
+
b_q_type == vllm::kU4,
|
494 |
+
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
495 |
+
} else {
|
496 |
+
TORCH_CHECK(
|
497 |
+
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
498 |
+
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
|
499 |
+
}
|
500 |
+
|
501 |
+
int pack_factor = 32 / b_q_type.size_bits();
|
502 |
+
|
503 |
+
int max_par = 4;
|
504 |
+
|
505 |
+
int dev = a.get_device();
|
506 |
+
|
507 |
+
auto options_dtype =
|
508 |
+
torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
509 |
+
auto options_int =
|
510 |
+
torch::TensorOptions().dtype(torch::kInt).device(a.device());
|
511 |
+
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
|
512 |
+
torch::Tensor a_tmp =
|
513 |
+
replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
|
514 |
+
: torch::zeros({size_m, topk, size_k}, options_dtype);
|
515 |
+
torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
|
516 |
+
|
517 |
+
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
518 |
+
// auto -1)
|
519 |
+
int thread_k = -1;
|
520 |
+
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
521 |
+
// auto -1)
|
522 |
+
int thread_n = -1;
|
523 |
+
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
|
524 |
+
int sms = -1;
|
525 |
+
|
526 |
+
// Detect groupsize and act_order
|
527 |
+
int num_groups = -1;
|
528 |
+
int group_size = -1;
|
529 |
+
bool has_act_order = g_idx.size(1) != 0;
|
530 |
+
|
531 |
+
int b_rank = b_scales.sizes().size();
|
532 |
+
TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
|
533 |
+
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
|
534 |
+
" is not size_n = ", size_n);
|
535 |
+
num_groups = b_scales.size(1);
|
536 |
+
|
537 |
+
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
|
538 |
+
"if is_k_full is false, has_act_order must be true");
|
539 |
+
|
540 |
+
if (has_act_order) {
|
541 |
+
if (is_k_full) {
|
542 |
+
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
543 |
+
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
544 |
+
", is not divisible by num_groups = ", num_groups);
|
545 |
+
group_size = size_k / num_groups;
|
546 |
+
} else {
|
547 |
+
group_size = 0;
|
548 |
+
}
|
549 |
+
|
550 |
+
} else {
|
551 |
+
if (num_groups > 1) {
|
552 |
+
TORCH_CHECK(
|
553 |
+
size_k % num_groups == 0, "size_k = ", size_k,
|
554 |
+
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
|
555 |
+
group_size = size_k / num_groups;
|
556 |
+
} else {
|
557 |
+
group_size = -1;
|
558 |
+
}
|
559 |
+
}
|
560 |
+
|
561 |
+
// Verify b_zeros
|
562 |
+
if (has_zp) {
|
563 |
+
int rank = b_zeros.sizes().size();
|
564 |
+
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
565 |
+
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
566 |
+
"b_zeros dim 1 = ", b_zeros.size(1),
|
567 |
+
" is not num_groups = ", num_groups);
|
568 |
+
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
569 |
+
"b_zeros dim 2 = ", b_zeros.size(2),
|
570 |
+
" is not size_n / pack_factor = ", size_n / pack_factor);
|
571 |
+
}
|
572 |
+
|
573 |
+
marlin_moe::marlin_mm_moe(
|
574 |
+
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
|
575 |
+
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
|
576 |
+
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
577 |
+
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
|
578 |
+
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
|
579 |
+
num_experts, topk, moe_block_size, dev,
|
580 |
+
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
|
581 |
+
replicate_input, apply_weights);
|
582 |
+
return c;
|
583 |
+
}
|
584 |
+
|
moe/moe_align_sum_kernels.cu
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
|
5 |
+
#include <ATen/ATen.h>
|
6 |
+
#include <THC/THCAtomics.cuh>
|
7 |
+
|
8 |
+
#include "../cuda_compat.h"
|
9 |
+
#include "../dispatch_utils.h"
|
10 |
+
|
11 |
+
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
12 |
+
|
13 |
+
namespace vllm {
|
14 |
+
namespace moe {
|
15 |
+
|
16 |
+
namespace {
|
17 |
+
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
18 |
+
int32_t col) {
|
19 |
+
// don't worry about overflow because num_experts is relatively small
|
20 |
+
return row * total_col + col;
|
21 |
+
}
|
22 |
+
} // namespace
|
23 |
+
|
24 |
+
template <typename scalar_t>
|
25 |
+
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
26 |
+
int32_t* sorted_token_ids,
|
27 |
+
int32_t* expert_ids,
|
28 |
+
int32_t* total_tokens_post_pad,
|
29 |
+
int32_t num_experts,
|
30 |
+
int32_t block_size, size_t numel) {
|
31 |
+
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
32 |
+
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
33 |
+
|
34 |
+
extern __shared__ int32_t shared_mem[];
|
35 |
+
|
36 |
+
int32_t* tokens_cnts =
|
37 |
+
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
|
38 |
+
int32_t* cumsum =
|
39 |
+
shared_mem +
|
40 |
+
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
41 |
+
|
42 |
+
for (int i = 0; i < num_experts; ++i) {
|
43 |
+
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
44 |
+
}
|
45 |
+
|
46 |
+
/**
|
47 |
+
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
48 |
+
* which counts how many tokens in the token shard of thread_index are
|
49 |
+
* assigned to expert expert_index.
|
50 |
+
*/
|
51 |
+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
52 |
+
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
53 |
+
}
|
54 |
+
|
55 |
+
__syncthreads();
|
56 |
+
|
57 |
+
// For each expert we accumulate the token counts from the different threads.
|
58 |
+
if (threadIdx.x < num_experts) {
|
59 |
+
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
60 |
+
for (int i = 1; i <= blockDim.x; ++i) {
|
61 |
+
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
62 |
+
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
__syncthreads();
|
67 |
+
|
68 |
+
// We accumulate the token counts of all experts in thread 0.
|
69 |
+
if (threadIdx.x == 0) {
|
70 |
+
cumsum[0] = 0;
|
71 |
+
for (int i = 1; i <= num_experts; ++i) {
|
72 |
+
cumsum[i] = cumsum[i - 1] +
|
73 |
+
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
74 |
+
block_size) *
|
75 |
+
block_size;
|
76 |
+
}
|
77 |
+
*total_tokens_post_pad = cumsum[num_experts];
|
78 |
+
}
|
79 |
+
|
80 |
+
__syncthreads();
|
81 |
+
|
82 |
+
/**
|
83 |
+
* For each expert, each thread processes the tokens of the corresponding
|
84 |
+
* blocks and stores the corresponding expert_id for each block.
|
85 |
+
*/
|
86 |
+
if (threadIdx.x < num_experts) {
|
87 |
+
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
88 |
+
i += block_size) {
|
89 |
+
expert_ids[i / block_size] = threadIdx.x;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
/**
|
94 |
+
* Each thread processes a token shard, calculating the index of each token
|
95 |
+
* after sorting by expert number. Given the example topk_ids =
|
96 |
+
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
|
97 |
+
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
|
98 |
+
* padding value(preset in python).
|
99 |
+
*/
|
100 |
+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
101 |
+
int32_t expert_id = topk_ids[i];
|
102 |
+
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
103 |
+
* expert with expert_id needs to process, and
|
104 |
+
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
105 |
+
* processed by the expert with expert_id within the current thread's token
|
106 |
+
* shard.
|
107 |
+
*/
|
108 |
+
int32_t rank_post_pad =
|
109 |
+
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
110 |
+
cumsum[expert_id];
|
111 |
+
sorted_token_ids[rank_post_pad] = i;
|
112 |
+
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
113 |
+
}
|
114 |
+
}
|
115 |
+
|
116 |
+
template <typename scalar_t, int TOPK>
|
117 |
+
__global__ void moe_sum_kernel(
|
118 |
+
scalar_t* __restrict__ out, // [..., d]
|
119 |
+
const scalar_t* __restrict__ input, // [..., topk, d]
|
120 |
+
const int d) {
|
121 |
+
const int64_t token_idx = blockIdx.x;
|
122 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
123 |
+
scalar_t x = 0.0;
|
124 |
+
#pragma unroll
|
125 |
+
for (int k = 0; k < TOPK; ++k) {
|
126 |
+
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
|
127 |
+
}
|
128 |
+
out[token_idx * d + idx] = x;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
} // namespace moe
|
133 |
+
} // namespace vllm
|
134 |
+
|
135 |
+
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
136 |
+
int64_t block_size, torch::Tensor sorted_token_ids,
|
137 |
+
torch::Tensor experts_ids,
|
138 |
+
torch::Tensor num_tokens_post_pad) {
|
139 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
140 |
+
VLLM_DISPATCH_INTEGRAL_TYPES(
|
141 |
+
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
142 |
+
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
143 |
+
// tensors
|
144 |
+
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
145 |
+
const int32_t shared_mem =
|
146 |
+
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
147 |
+
sizeof(int32_t);
|
148 |
+
|
149 |
+
// set dynamic shared mem
|
150 |
+
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
151 |
+
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
152 |
+
(void*)kernel, shared_mem));
|
153 |
+
kernel<<<1, num_thread, shared_mem, stream>>>(
|
154 |
+
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
155 |
+
experts_ids.data_ptr<int32_t>(),
|
156 |
+
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
157 |
+
topk_ids.numel());
|
158 |
+
});
|
159 |
+
}
|
160 |
+
|
161 |
+
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
162 |
+
torch::Tensor& output) // [num_tokens, hidden_size]
|
163 |
+
{
|
164 |
+
const int hidden_size = input.size(-1);
|
165 |
+
const int num_tokens = output.numel() / hidden_size;
|
166 |
+
const int topk = input.size(1);
|
167 |
+
|
168 |
+
dim3 grid(num_tokens);
|
169 |
+
dim3 block(std::min(hidden_size, 1024));
|
170 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
171 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
172 |
+
|
173 |
+
switch (topk) {
|
174 |
+
case 2:
|
175 |
+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
176 |
+
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
177 |
+
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
178 |
+
hidden_size);
|
179 |
+
});
|
180 |
+
break;
|
181 |
+
|
182 |
+
case 3:
|
183 |
+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
184 |
+
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
|
185 |
+
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
186 |
+
hidden_size);
|
187 |
+
});
|
188 |
+
break;
|
189 |
+
|
190 |
+
case 4:
|
191 |
+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
192 |
+
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
193 |
+
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
194 |
+
hidden_size);
|
195 |
+
});
|
196 |
+
break;
|
197 |
+
|
198 |
+
default:
|
199 |
+
at::sum_out(output, input, 1);
|
200 |
+
break;
|
201 |
+
}
|
202 |
+
}
|
moe/topk_softmax_kernels.cu
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
3 |
+
* Copyright (c) 2024, The vLLM team.
|
4 |
+
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
5 |
+
* SPDX-License-Identifier: Apache-2.0
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#include <torch/all.h>
|
20 |
+
#include <ATen/cuda/CUDAContext.h>
|
21 |
+
#include <c10/cuda/CUDAGuard.h>
|
22 |
+
#include "../cuda_compat.h"
|
23 |
+
|
24 |
+
#ifndef USE_ROCM
|
25 |
+
#include <cub/util_type.cuh>
|
26 |
+
#include <cub/cub.cuh>
|
27 |
+
#else
|
28 |
+
#include <hipcub/util_type.hpp>
|
29 |
+
#include <hipcub/hipcub.hpp>
|
30 |
+
#endif
|
31 |
+
|
32 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
33 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
34 |
+
|
35 |
+
namespace vllm {
|
36 |
+
namespace moe {
|
37 |
+
|
38 |
+
/// Aligned array type
|
39 |
+
template <
|
40 |
+
typename T,
|
41 |
+
/// Number of elements in the array
|
42 |
+
int N,
|
43 |
+
/// Alignment requirement in bytes
|
44 |
+
int Alignment = sizeof(T) * N
|
45 |
+
>
|
46 |
+
class alignas(Alignment) AlignedArray {
|
47 |
+
float data[N];
|
48 |
+
};
|
49 |
+
|
50 |
+
// ====================== Softmax things ===============================
|
51 |
+
// We have our own implementation of softmax here so we can support transposing the output
|
52 |
+
// in the softmax kernel when we extend this module to support expert-choice routing.
|
53 |
+
template <int TPB>
|
54 |
+
__launch_bounds__(TPB) __global__
|
55 |
+
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
56 |
+
{
|
57 |
+
using BlockReduce = cub::BlockReduce<float, TPB>;
|
58 |
+
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
59 |
+
|
60 |
+
__shared__ float normalizing_factor;
|
61 |
+
__shared__ float float_max;
|
62 |
+
|
63 |
+
const int thread_row_offset = blockIdx.x * num_cols;
|
64 |
+
|
65 |
+
cub::Sum sum;
|
66 |
+
float threadData(-FLT_MAX);
|
67 |
+
|
68 |
+
// Don't touch finished rows.
|
69 |
+
if ((finished != nullptr) && finished[blockIdx.x])
|
70 |
+
{
|
71 |
+
return;
|
72 |
+
}
|
73 |
+
|
74 |
+
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
75 |
+
{
|
76 |
+
const int idx = thread_row_offset + ii;
|
77 |
+
threadData = max(static_cast<float>(input[idx]), threadData);
|
78 |
+
}
|
79 |
+
|
80 |
+
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
81 |
+
if (threadIdx.x == 0)
|
82 |
+
{
|
83 |
+
float_max = maxElem;
|
84 |
+
}
|
85 |
+
__syncthreads();
|
86 |
+
|
87 |
+
threadData = 0;
|
88 |
+
|
89 |
+
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
90 |
+
{
|
91 |
+
const int idx = thread_row_offset + ii;
|
92 |
+
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
93 |
+
}
|
94 |
+
|
95 |
+
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
96 |
+
|
97 |
+
if (threadIdx.x == 0)
|
98 |
+
{
|
99 |
+
normalizing_factor = 1.f / Z;
|
100 |
+
}
|
101 |
+
__syncthreads();
|
102 |
+
|
103 |
+
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
104 |
+
{
|
105 |
+
const int idx = thread_row_offset + ii;
|
106 |
+
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
107 |
+
output[idx] = val;
|
108 |
+
}
|
109 |
+
}
|
110 |
+
|
111 |
+
template <int TPB>
|
112 |
+
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
113 |
+
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
114 |
+
{
|
115 |
+
|
116 |
+
using cub_kvp = cub::KeyValuePair<int, float>;
|
117 |
+
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
118 |
+
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
119 |
+
|
120 |
+
cub_kvp thread_kvp;
|
121 |
+
cub::ArgMax arg_max;
|
122 |
+
|
123 |
+
const int num_rows = gridDim.x;
|
124 |
+
const int block_row = blockIdx.x;
|
125 |
+
|
126 |
+
const bool row_is_active = finished ? !finished[block_row] : true;
|
127 |
+
const int thread_read_offset = blockIdx.x * num_experts;
|
128 |
+
for (int k_idx = 0; k_idx < k; ++k_idx)
|
129 |
+
{
|
130 |
+
thread_kvp.key = 0;
|
131 |
+
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
132 |
+
|
133 |
+
cub_kvp inp_kvp;
|
134 |
+
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
135 |
+
{
|
136 |
+
const int idx = thread_read_offset + expert;
|
137 |
+
inp_kvp.key = expert;
|
138 |
+
inp_kvp.value = inputs_after_softmax[idx];
|
139 |
+
|
140 |
+
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
141 |
+
{
|
142 |
+
const int prior_winning_expert = indices[k * block_row + prior_k];
|
143 |
+
|
144 |
+
if (prior_winning_expert == expert)
|
145 |
+
{
|
146 |
+
inp_kvp = thread_kvp;
|
147 |
+
}
|
148 |
+
}
|
149 |
+
|
150 |
+
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
151 |
+
}
|
152 |
+
|
153 |
+
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
154 |
+
if (threadIdx.x == 0)
|
155 |
+
{
|
156 |
+
// Ignore experts the node isn't responsible for with expert parallelism
|
157 |
+
const int expert = result_kvp.key;
|
158 |
+
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
159 |
+
const bool should_process_row = row_is_active && node_uses_expert;
|
160 |
+
|
161 |
+
const int idx = k * block_row + k_idx;
|
162 |
+
output[idx] = result_kvp.value;
|
163 |
+
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
164 |
+
assert(indices[idx] >= 0);
|
165 |
+
source_rows[idx] = k_idx * num_rows + block_row;
|
166 |
+
}
|
167 |
+
__syncthreads();
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
// ====================== TopK softmax things ===============================
|
172 |
+
|
173 |
+
/*
|
174 |
+
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
175 |
+
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
176 |
+
a single warp and eliminate communication between warps (so no need to use shared mem).
|
177 |
+
|
178 |
+
It fuses the softmax, max and argmax into a single kernel.
|
179 |
+
|
180 |
+
Limitations:
|
181 |
+
1) This implementation is intended for when the number of experts is a small power of 2.
|
182 |
+
2) This implementation assumes k is small, but will work for any k.
|
183 |
+
*/
|
184 |
+
|
185 |
+
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
186 |
+
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
187 |
+
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
188 |
+
int* source_rows, const int k, const int start_expert, const int end_expert)
|
189 |
+
{
|
190 |
+
// We begin by enforcing compile time assertions and setting up compile time constants.
|
191 |
+
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
192 |
+
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
193 |
+
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
194 |
+
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
195 |
+
|
196 |
+
// Number of bytes each thread pulls in per load
|
197 |
+
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
198 |
+
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
199 |
+
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
200 |
+
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
201 |
+
|
202 |
+
// Restrictions based on previous section.
|
203 |
+
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
204 |
+
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
205 |
+
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
206 |
+
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
207 |
+
|
208 |
+
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
209 |
+
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
210 |
+
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
211 |
+
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
212 |
+
|
213 |
+
// Restrictions for previous section.
|
214 |
+
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
215 |
+
|
216 |
+
// ===================== From this point, we finally start computing run-time variables. ========================
|
217 |
+
|
218 |
+
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
219 |
+
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
220 |
+
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
221 |
+
|
222 |
+
// Now, using the base row per thread block, we compute the base row per warp.
|
223 |
+
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
224 |
+
|
225 |
+
// The threads in a warp are split into sub-groups that will work on a row.
|
226 |
+
// We compute row offset for each thread sub-group
|
227 |
+
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
228 |
+
const int thread_row = warp_base_row + thread_row_in_warp;
|
229 |
+
|
230 |
+
// Threads with indices out of bounds should early exit here.
|
231 |
+
if (thread_row >= num_rows)
|
232 |
+
{
|
233 |
+
return;
|
234 |
+
}
|
235 |
+
const bool row_is_active = finished ? !finished[thread_row] : true;
|
236 |
+
|
237 |
+
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
238 |
+
// row it will read.
|
239 |
+
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
240 |
+
|
241 |
+
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
242 |
+
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
243 |
+
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
244 |
+
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
245 |
+
|
246 |
+
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
247 |
+
// this can support all powers of 2 up to 16.
|
248 |
+
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
249 |
+
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
250 |
+
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
251 |
+
|
252 |
+
// Finally, we pull in the data from global mem
|
253 |
+
float row_chunk[VPT];
|
254 |
+
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
255 |
+
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
256 |
+
#pragma unroll
|
257 |
+
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
258 |
+
{
|
259 |
+
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
260 |
+
}
|
261 |
+
|
262 |
+
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
263 |
+
// convert to float afterwards for the exp + sum reduction.
|
264 |
+
float thread_max = row_chunk[0];
|
265 |
+
#pragma unroll
|
266 |
+
for (int ii = 1; ii < VPT; ++ii)
|
267 |
+
{
|
268 |
+
thread_max = max(thread_max, row_chunk[ii]);
|
269 |
+
}
|
270 |
+
|
271 |
+
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
272 |
+
#pragma unroll
|
273 |
+
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
274 |
+
{
|
275 |
+
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
|
276 |
+
}
|
277 |
+
|
278 |
+
// From this point, thread max in all the threads have the max within the row.
|
279 |
+
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
280 |
+
float row_sum = 0;
|
281 |
+
#pragma unroll
|
282 |
+
for (int ii = 0; ii < VPT; ++ii)
|
283 |
+
{
|
284 |
+
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
285 |
+
row_sum += row_chunk[ii];
|
286 |
+
}
|
287 |
+
|
288 |
+
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
289 |
+
#pragma unroll
|
290 |
+
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
291 |
+
{
|
292 |
+
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
|
293 |
+
}
|
294 |
+
|
295 |
+
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
296 |
+
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
297 |
+
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
298 |
+
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
299 |
+
// argmax after computing the softmax.
|
300 |
+
const float reciprocal_row_sum = 1.f / row_sum;
|
301 |
+
|
302 |
+
#pragma unroll
|
303 |
+
for (int ii = 0; ii < VPT; ++ii)
|
304 |
+
{
|
305 |
+
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
306 |
+
}
|
307 |
+
|
308 |
+
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
309 |
+
// with the max index.
|
310 |
+
int start_col = first_elt_read_by_thread;
|
311 |
+
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
312 |
+
|
313 |
+
for (int k_idx = 0; k_idx < k; ++k_idx)
|
314 |
+
{
|
315 |
+
// First, each thread does the local argmax
|
316 |
+
float max_val = row_chunk[0];
|
317 |
+
int expert = start_col;
|
318 |
+
#pragma unroll
|
319 |
+
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
|
320 |
+
{
|
321 |
+
#pragma unroll
|
322 |
+
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
|
323 |
+
{
|
324 |
+
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
325 |
+
|
326 |
+
// No check on the experts here since columns with the smallest index are processed first and only
|
327 |
+
// updated if > (not >=)
|
328 |
+
if (val > max_val)
|
329 |
+
{
|
330 |
+
max_val = val;
|
331 |
+
expert = col + ii;
|
332 |
+
}
|
333 |
+
}
|
334 |
+
}
|
335 |
+
|
336 |
+
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
337 |
+
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
338 |
+
// then blank out their max with -inf and the warp can run more iterations...
|
339 |
+
#pragma unroll
|
340 |
+
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
341 |
+
{
|
342 |
+
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
343 |
+
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
344 |
+
|
345 |
+
// We want lower indices to "win" in every thread so we break ties this way
|
346 |
+
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
347 |
+
{
|
348 |
+
max_val = other_max;
|
349 |
+
expert = other_expert;
|
350 |
+
}
|
351 |
+
}
|
352 |
+
|
353 |
+
// Write the max for this k iteration to global memory.
|
354 |
+
if (thread_group_idx == 0)
|
355 |
+
{
|
356 |
+
// Add a guard to ignore experts not included by this node
|
357 |
+
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
358 |
+
const bool should_process_row = row_is_active && node_uses_expert;
|
359 |
+
|
360 |
+
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
361 |
+
// single) thread per row of the input/output matrices.
|
362 |
+
const int idx = k * thread_row + k_idx;
|
363 |
+
output[idx] = max_val;
|
364 |
+
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
365 |
+
source_rows[idx] = k_idx * num_rows + thread_row;
|
366 |
+
}
|
367 |
+
|
368 |
+
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
369 |
+
if (k_idx + 1 < k)
|
370 |
+
{
|
371 |
+
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
372 |
+
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
373 |
+
|
374 |
+
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
375 |
+
if (thread_group_idx == thread_to_clear_in_group)
|
376 |
+
{
|
377 |
+
const int offset_for_expert = expert % ELTS_PER_LDG;
|
378 |
+
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
379 |
+
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
380 |
+
}
|
381 |
+
}
|
382 |
+
}
|
383 |
+
}
|
384 |
+
|
385 |
+
namespace detail
|
386 |
+
{
|
387 |
+
// Constructs some constants needed to partition the work across threads at compile time.
|
388 |
+
template <int EXPERTS, int BYTES_PER_LDG>
|
389 |
+
struct TopkConstants
|
390 |
+
{
|
391 |
+
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
392 |
+
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
393 |
+
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
394 |
+
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
395 |
+
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
396 |
+
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
397 |
+
};
|
398 |
+
} // namespace detail
|
399 |
+
|
400 |
+
template <int EXPERTS, int WARPS_PER_TB>
|
401 |
+
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
402 |
+
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
403 |
+
{
|
404 |
+
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
405 |
+
|
406 |
+
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
407 |
+
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
408 |
+
static constexpr int VPT = Constants::VPT;
|
409 |
+
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
410 |
+
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
411 |
+
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
412 |
+
|
413 |
+
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
414 |
+
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
415 |
+
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
416 |
+
}
|
417 |
+
|
418 |
+
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
419 |
+
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
420 |
+
gating_output, nullptr, topk_weights, topk_indicies, \
|
421 |
+
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
422 |
+
stream);
|
423 |
+
|
424 |
+
void topkGatingSoftmaxKernelLauncher(
|
425 |
+
const float* gating_output,
|
426 |
+
float* topk_weights,
|
427 |
+
int* topk_indicies,
|
428 |
+
int* token_expert_indices,
|
429 |
+
float* softmax_workspace,
|
430 |
+
const int num_tokens,
|
431 |
+
const int num_experts,
|
432 |
+
const int topk,
|
433 |
+
cudaStream_t stream) {
|
434 |
+
static constexpr int WARPS_PER_TB = 4;
|
435 |
+
switch (num_experts) {
|
436 |
+
case 1:
|
437 |
+
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
438 |
+
break;
|
439 |
+
case 2:
|
440 |
+
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
441 |
+
break;
|
442 |
+
case 4:
|
443 |
+
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
444 |
+
break;
|
445 |
+
case 8:
|
446 |
+
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
447 |
+
break;
|
448 |
+
case 16:
|
449 |
+
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
450 |
+
break;
|
451 |
+
case 32:
|
452 |
+
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
453 |
+
break;
|
454 |
+
case 64:
|
455 |
+
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
456 |
+
break;
|
457 |
+
case 128:
|
458 |
+
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
459 |
+
break;
|
460 |
+
case 256:
|
461 |
+
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
462 |
+
break;
|
463 |
+
default: {
|
464 |
+
TORCH_CHECK(softmax_workspace != nullptr,
|
465 |
+
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
466 |
+
static constexpr int TPB = 256;
|
467 |
+
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
468 |
+
gating_output, nullptr, softmax_workspace, num_experts);
|
469 |
+
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
470 |
+
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
|
471 |
+
num_experts, topk, 0, num_experts);
|
472 |
+
}
|
473 |
+
}
|
474 |
+
}
|
475 |
+
|
476 |
+
} // namespace moe
|
477 |
+
} // namespace vllm
|
478 |
+
|
479 |
+
void topk_softmax(
|
480 |
+
torch::Tensor& topk_weights, // [num_tokens, topk]
|
481 |
+
torch::Tensor& topk_indices, // [num_tokens, topk]
|
482 |
+
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
483 |
+
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
484 |
+
{
|
485 |
+
const int num_experts = gating_output.size(-1);
|
486 |
+
const int num_tokens = gating_output.numel() / num_experts;
|
487 |
+
const int topk = topk_weights.size(-1);
|
488 |
+
|
489 |
+
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
490 |
+
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
491 |
+
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
492 |
+
|
493 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
494 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
495 |
+
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
496 |
+
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
497 |
+
gating_output.data_ptr<float>(),
|
498 |
+
topk_weights.data_ptr<float>(),
|
499 |
+
topk_indices.data_ptr<int>(),
|
500 |
+
token_expert_indices.data_ptr<int>(),
|
501 |
+
softmax_workspace.data_ptr<float>(),
|
502 |
+
num_tokens,
|
503 |
+
num_experts,
|
504 |
+
topk,
|
505 |
+
stream);
|
506 |
+
}
|
test/__init__.py
ADDED
File without changes
|
test/kernels/__init__.py
ADDED
File without changes
|
test/kernels/test_moe.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tests for the MOE layers.
|
2 |
+
|
3 |
+
Run `pytest tests/kernels/test_moe.py`.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from moe._ops import ops
|
12 |
+
from moe.fused_moe import fused_moe, fused_topk, moe_align_block_size
|
13 |
+
from moe.fused_marlin_moe import fused_marlin_moe
|
14 |
+
from moe.scalar_type import scalar_types
|
15 |
+
from moe.utils.marlin_utils_test import marlin_quantize
|
16 |
+
|
17 |
+
from .utils import compute_max_diff, opcheck
|
18 |
+
|
19 |
+
|
20 |
+
def stack_and_dev(tensors: List[torch.Tensor]):
|
21 |
+
dev = tensors[0].device
|
22 |
+
return torch.stack(tensors, dim=0).to(dev)
|
23 |
+
|
24 |
+
|
25 |
+
NUM_EXPERTS = [8, 64]
|
26 |
+
TOP_KS = [2, 6]
|
27 |
+
|
28 |
+
|
29 |
+
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
30 |
+
@pytest.mark.parametrize("n", [128, 2048])
|
31 |
+
@pytest.mark.parametrize("k", [128, 1024])
|
32 |
+
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
33 |
+
@pytest.mark.parametrize("topk", TOP_KS)
|
34 |
+
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
35 |
+
@pytest.mark.parametrize("act_order", [True, False])
|
36 |
+
@pytest.mark.parametrize("num_bits", [4, 8])
|
37 |
+
@pytest.mark.parametrize("is_k_full", [True, False])
|
38 |
+
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
39 |
+
def test_fused_marlin_moe(
|
40 |
+
m: int,
|
41 |
+
n: int,
|
42 |
+
k: int,
|
43 |
+
e: int,
|
44 |
+
topk: int,
|
45 |
+
group_size: int,
|
46 |
+
act_order: bool,
|
47 |
+
num_bits: int,
|
48 |
+
is_k_full: bool,
|
49 |
+
):
|
50 |
+
torch.manual_seed(7)
|
51 |
+
|
52 |
+
# Filter act_order
|
53 |
+
if act_order:
|
54 |
+
if group_size == -1:
|
55 |
+
return
|
56 |
+
if group_size in (k, n):
|
57 |
+
return
|
58 |
+
else:
|
59 |
+
if not is_k_full:
|
60 |
+
return
|
61 |
+
|
62 |
+
quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
63 |
+
dtype = torch.float16
|
64 |
+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
65 |
+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
66 |
+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
67 |
+
|
68 |
+
w_ref1_l = []
|
69 |
+
qweight1_l = []
|
70 |
+
scales1_l = []
|
71 |
+
g_idx1_l = []
|
72 |
+
sort_indices1_l = []
|
73 |
+
|
74 |
+
for i in range(w1.shape[0]):
|
75 |
+
test_perm = torch.randperm(k)
|
76 |
+
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
77 |
+
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
78 |
+
)
|
79 |
+
w_ref1_l.append(w_ref1)
|
80 |
+
qweight1_l.append(qweight1)
|
81 |
+
scales1_l.append(scales1)
|
82 |
+
g_idx1_l.append(g_idx1)
|
83 |
+
sort_indices1_l.append(sort_indices1)
|
84 |
+
|
85 |
+
w_ref1 = stack_and_dev(w_ref1_l)
|
86 |
+
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
87 |
+
scales1 = stack_and_dev(scales1_l)
|
88 |
+
g_idx1 = stack_and_dev(g_idx1_l)
|
89 |
+
sort_indices1 = stack_and_dev(sort_indices1_l)
|
90 |
+
|
91 |
+
w_ref2_l = []
|
92 |
+
qweight2_l = []
|
93 |
+
scales2_l = []
|
94 |
+
g_idx2_l = []
|
95 |
+
sort_indices2_l = []
|
96 |
+
|
97 |
+
for i in range(w2.shape[0]):
|
98 |
+
test_perm = torch.randperm(n)
|
99 |
+
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
100 |
+
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
101 |
+
)
|
102 |
+
w_ref2_l.append(w_ref2)
|
103 |
+
qweight2_l.append(qweight2)
|
104 |
+
scales2_l.append(scales2)
|
105 |
+
g_idx2_l.append(g_idx2)
|
106 |
+
sort_indices2_l.append(sort_indices2)
|
107 |
+
|
108 |
+
w_ref2 = stack_and_dev(w_ref2_l)
|
109 |
+
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
110 |
+
scales2 = stack_and_dev(scales2_l)
|
111 |
+
g_idx2 = stack_and_dev(g_idx2_l)
|
112 |
+
sort_indices2 = stack_and_dev(sort_indices2_l)
|
113 |
+
|
114 |
+
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
115 |
+
|
116 |
+
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
117 |
+
|
118 |
+
triton_output = fused_moe(
|
119 |
+
a,
|
120 |
+
w_ref1.transpose(1, 2).contiguous(),
|
121 |
+
w_ref2.transpose(1, 2).contiguous(),
|
122 |
+
score,
|
123 |
+
topk,
|
124 |
+
renormalize=False,
|
125 |
+
)
|
126 |
+
marlin_output = fused_marlin_moe(
|
127 |
+
a,
|
128 |
+
qweight1,
|
129 |
+
qweight2,
|
130 |
+
scales1,
|
131 |
+
scales2,
|
132 |
+
score,
|
133 |
+
topk_weights,
|
134 |
+
topk_ids,
|
135 |
+
g_idx1=g_idx1,
|
136 |
+
g_idx2=g_idx2,
|
137 |
+
sort_indices1=sort_indices1,
|
138 |
+
sort_indices2=sort_indices2,
|
139 |
+
num_bits=num_bits,
|
140 |
+
is_k_full=is_k_full,
|
141 |
+
)
|
142 |
+
|
143 |
+
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
144 |
+
|
145 |
+
token_expert_indicies = torch.empty(m, topk, dtype=torch.int32, device=a.device)
|
146 |
+
|
147 |
+
opcheck(
|
148 |
+
ops.topk_softmax,
|
149 |
+
(
|
150 |
+
topk_weights,
|
151 |
+
topk_ids,
|
152 |
+
token_expert_indicies,
|
153 |
+
score.float(),
|
154 |
+
),
|
155 |
+
)
|
156 |
+
|
157 |
+
block_size_m = 4
|
158 |
+
|
159 |
+
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, e)
|
160 |
+
|
161 |
+
max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
|
162 |
+
workspace = torch.zeros(
|
163 |
+
max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False
|
164 |
+
)
|
165 |
+
|
166 |
+
zp = torch.empty((0, 0), dtype=dtype, device="cuda", requires_grad=False)
|
167 |
+
opcheck(
|
168 |
+
ops.marlin_gemm_moe,
|
169 |
+
(
|
170 |
+
a,
|
171 |
+
qweight1,
|
172 |
+
sorted_token_ids,
|
173 |
+
topk_weights,
|
174 |
+
topk_ids,
|
175 |
+
scales1,
|
176 |
+
zp,
|
177 |
+
g_idx1,
|
178 |
+
sort_indices1,
|
179 |
+
workspace,
|
180 |
+
quant_type.id,
|
181 |
+
m,
|
182 |
+
2 * n,
|
183 |
+
k,
|
184 |
+
True,
|
185 |
+
e,
|
186 |
+
topk,
|
187 |
+
block_size_m,
|
188 |
+
True,
|
189 |
+
False,
|
190 |
+
),
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
def test_moe_align_block_size_opcheck():
|
195 |
+
num_experts = 4
|
196 |
+
block_size = 4
|
197 |
+
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
198 |
+
|
199 |
+
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
200 |
+
sorted_ids = torch.empty(
|
201 |
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
202 |
+
)
|
203 |
+
sorted_ids.fill_(topk_ids.numel())
|
204 |
+
max_num_m_blocks = max_num_tokens_padded // block_size
|
205 |
+
expert_ids = torch.empty(
|
206 |
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
207 |
+
)
|
208 |
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
209 |
+
|
210 |
+
opcheck(
|
211 |
+
ops.moe_align_block_size,
|
212 |
+
(
|
213 |
+
topk_ids,
|
214 |
+
num_experts,
|
215 |
+
block_size,
|
216 |
+
sorted_ids,
|
217 |
+
expert_ids,
|
218 |
+
num_tokens_post_pad,
|
219 |
+
),
|
220 |
+
)
|
test/kernels/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Kernel test utils"""
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import random
|
5 |
+
import unittest
|
6 |
+
from numbers import Number
|
7 |
+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
8 |
+
|
9 |
+
import pytest
|
10 |
+
import torch
|
11 |
+
from torch._prims_common import TensorLikeType
|
12 |
+
|
13 |
+
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
14 |
+
# bugs related to this test in PyTorch 2.4.
|
15 |
+
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
16 |
+
"test_schema",
|
17 |
+
"test_autograd_registration",
|
18 |
+
"test_faketensor",
|
19 |
+
)
|
20 |
+
|
21 |
+
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
22 |
+
"test_schema",
|
23 |
+
"test_autograd_registration",
|
24 |
+
"test_faketensor",
|
25 |
+
"test_aot_dispatch_dynamic",
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
# Copied/modified from torch._refs.__init__.py
|
30 |
+
def fp8_allclose(
|
31 |
+
a: TensorLikeType,
|
32 |
+
b: TensorLikeType,
|
33 |
+
rtol: float = 1e-05,
|
34 |
+
atol: float = 1e-08,
|
35 |
+
equal_nan: bool = False,
|
36 |
+
) -> bool:
|
37 |
+
"""
|
38 |
+
Reference implementation of torch.allclose
|
39 |
+
"""
|
40 |
+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
41 |
+
|
42 |
+
return bool(
|
43 |
+
torch.all(
|
44 |
+
torch.isclose(
|
45 |
+
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
|
46 |
+
)
|
47 |
+
).item()
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def compute_max_diff(output, output_ref):
|
52 |
+
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
53 |
+
torch.abs(output_ref))
|
54 |
+
|
55 |
+
|
56 |
+
# A special version of op check that has a restricted default set of test_utils
|
57 |
+
# and a patched version of allclose that supports fp8 types.
|
58 |
+
def opcheck(
|
59 |
+
op: Union[
|
60 |
+
torch._ops.OpOverload,
|
61 |
+
torch._ops.OpOverloadPacket,
|
62 |
+
torch._library.custom_ops.CustomOpDef,
|
63 |
+
],
|
64 |
+
args: Tuple[Any, ...],
|
65 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
66 |
+
*,
|
67 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
68 |
+
raise_exception: bool = True,
|
69 |
+
cond: bool = True
|
70 |
+
) -> Dict[str, str]:
|
71 |
+
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
72 |
+
return (
|
73 |
+
torch.library.opcheck(
|
74 |
+
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
75 |
+
)
|
76 |
+
if cond
|
77 |
+
else {}
|
78 |
+
)
|