| #include "common.h" |
| #include "gemm.h" |
| #include "vec.h" |
|
|
| namespace { |
|
|
| template <typename scalar_t> |
| inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { |
| using Vec = at::vec::Vectorized<scalar_t>; |
| |
| #pragma GCC unroll 4 |
| for (int64_t d = 0; d < size; d += Vec::size()) { |
| Vec data = Vec::loadu(input + d); |
| data.store(out + d); |
| } |
| } |
|
|
| template <> |
| inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { |
| |
| std::memcpy(out, input, size * sizeof(uint8_t)); |
| } |
|
|
| template <typename scalar_t> |
| inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { |
| using bVec = at::vec::Vectorized<scalar_t>; |
| using fVec = at::vec::Vectorized<float>; |
| constexpr int kVecSize = bVec::size(); |
| const fVec weight_vec = fVec(weight); |
| int64_t d; |
| #pragma GCC unroll 4 |
| for (d = 0; d <= size - kVecSize; d += kVecSize) { |
| fVec data0 = fVec::loadu(input + d) * weight_vec; |
| fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; |
| bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); |
| out_vec.store(out + d); |
| } |
| for (; d < size; ++d) { |
| out[d] = static_cast<scalar_t>(input[d] * weight); |
| } |
| } |
|
|
| |
| template <typename scalar_t> |
| inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { |
| using bVec = at::vec::Vectorized<scalar_t>; |
| using fVec = at::vec::Vectorized<float>; |
| constexpr int kVecSize = bVec::size(); |
| if (topk == 1) { |
| |
| copy_stub(out, input, K); |
| } else { |
| |
| int64_t d; |
| #pragma GCC unroll 4 |
| for (d = 0; d <= K - kVecSize; d += kVecSize) { |
| fVec sum_fvec0 = fVec(0.f); |
| fVec sum_fvec1 = fVec(0.f); |
| for (int t = 0; t < topk; ++t) { |
| bVec x_bvec = bVec::loadu(input + t * K + d); |
| fVec x_fvec0, x_fvec1; |
| std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); |
|
|
| sum_fvec0 += x_fvec0; |
| sum_fvec1 += x_fvec1; |
| } |
| bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1); |
| out_bvec.store(out + d); |
| } |
| for (; d < K; ++d) { |
| float sum_val = 0.f; |
| for (int t = 0; t < topk; ++t) { |
| sum_val += static_cast<float>(input[t * K + d]); |
| } |
| out[d] = static_cast<scalar_t>(sum_val); |
| } |
| } |
| } |
|
|
| |
| template <typename scalar_t> |
| inline void add_mul_stub( |
| scalar_t* __restrict__ out, |
| const float* __restrict__ input, |
| const scalar_t* __restrict__ input2, |
| float scale, |
| int64_t size) { |
| using bVec = at::vec::Vectorized<scalar_t>; |
| using fVec = at::vec::Vectorized<float>; |
| constexpr int kVecSize = bVec::size(); |
| const fVec s_vec = fVec(scale); |
| int64_t d; |
| #pragma GCC unroll 4 |
| for (d = 0; d <= size - kVecSize; d += kVecSize) { |
| fVec x0 = fVec::loadu(input + d); |
| fVec x1 = fVec::loadu(input + d + fVec::size()); |
|
|
| bVec y_bvec = bVec::loadu(input2 + d); |
| fVec y0, y1; |
| std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); |
|
|
| x0 = x0 + y0 * s_vec; |
| x1 = x1 + y1 * s_vec; |
| bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); |
| out_vec.store(out + d); |
| } |
| for (; d < size; ++d) { |
| out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale); |
| } |
| } |
|
|
| template <typename scalar_t, int BLOCK_N> |
| inline void silu_and_mul( |
| scalar_t* __restrict__ C, |
| const int32_t* __restrict__ C0, |
| const int32_t* __restrict__ C1, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs0, |
| const float* __restrict__ Bs1, |
| const int32_t* __restrict__ Bcomp0, |
| const int32_t* __restrict__ Bcomp1, |
| int64_t m_size, |
| int64_t N) { |
| #if defined(CPU_CAPABILITY_AVX512) |
| constexpr int COLS = BLOCK_N / 16; |
| static_assert(COLS % 2 == 0); |
|
|
| __m512 vc0[COLS]; |
| __m512 vc1[COLS]; |
| __m512i vcomp0[COLS]; |
| __m512i vcomp1[COLS]; |
| __m512 vas; |
| __m512 vbs0[COLS]; |
| __m512 vbs1[COLS]; |
|
|
| auto load_scale_and_comp = [&](auto col) { |
| vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); |
| vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); |
| vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); |
| vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); |
| }; |
| Unroll<COLS>{}(load_scale_and_comp); |
|
|
| auto scalec = [&](auto col, int64_t m) { |
| |
| vas = _mm512_set1_ps(As[m]); |
| |
| __m512i vc32_0 = _mm512_loadu_si512(C0 + m * BLOCK_N + col * 16); |
| __m512i vc32_1 = _mm512_loadu_si512(C1 + m * BLOCK_N + col * 16); |
| vc0[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_0, vcomp0[col])); |
| vc1[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_1, vcomp1[col])); |
| vc0[col] = _mm512_mul_ps(_mm512_mul_ps(vc0[col], vas), vbs0[col]); |
| vc1[col] = _mm512_mul_ps(_mm512_mul_ps(vc1[col], vas), vbs1[col]); |
| }; |
|
|
| using bVec = at::vec::Vectorized<scalar_t>; |
| using fVec = at::vec::Vectorized<float>; |
| const fVec one = fVec(1.f); |
| auto silu_and_mul = [&](auto col) { |
| fVec x = fVec(vc0[col]); |
| fVec y = fVec(vc1[col]); |
| x = x / (one + x.neg().exp_u20()); |
| vc0[col] = x * y; |
| }; |
|
|
| auto storec = [&](auto col, int64_t m) { |
| if constexpr (col % 2 == 0) { |
| fVec x0 = fVec(vc0[col + 0]); |
| fVec x1 = fVec(vc0[col + 1]); |
| bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); |
| out_vec.store(C + m * N + col * 16); |
| } |
| }; |
|
|
| for (int64_t m = 0; m < m_size; ++m) { |
| Unroll<COLS>{}(scalec, m); |
| Unroll<COLS>{}(silu_and_mul); |
| Unroll<COLS>{}(storec, m); |
| } |
| #else |
| TORCH_CHECK(false, "silu_and_mul: scalar path not implemented!"); |
| #endif |
| } |
|
|
| template <int BLOCK_N> |
| inline void scale_C( |
| float* __restrict__ C, |
| const int32_t* __restrict__ Ctmp, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs, |
| const int32_t* __restrict__ Bcomp, |
| int64_t m_size) { |
| #if defined(CPU_CAPABILITY_AVX512) |
| constexpr int COLS = BLOCK_N / 16; |
| static_assert(COLS % 2 == 0); |
|
|
| __m512 vc[COLS]; |
| __m512i vcomp[COLS]; |
| __m512 vas; |
| __m512 vbs[COLS]; |
|
|
| auto load_scale_and_comp = [&](auto col) { |
| vcomp[col] = _mm512_loadu_si512(Bcomp + col * 16); |
| vbs[col] = _mm512_loadu_ps(Bs + col * 16); |
| }; |
| Unroll<COLS>{}(load_scale_and_comp); |
|
|
| auto scalec = [&](auto col, int64_t m) { |
| |
| vas = _mm512_set1_ps(As[m]); |
| |
| __m512i vc32 = _mm512_loadu_si512(Ctmp + m * BLOCK_N + col * 16); |
| vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp[col])); |
| vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vas), vbs[col]); |
| _mm512_storeu_ps(C + m * BLOCK_N + col * 16, vc[col]); |
| }; |
|
|
| for (int64_t m = 0; m < m_size; ++m) { |
| Unroll<COLS>{}(scalec, m); |
| } |
| #else |
| TORCH_CHECK(false, "scale_C: scalar path not implemented!"); |
| #endif |
| } |
|
|
| |
| template <typename scalar_t, int BLOCK_M, int BLOCK_N> |
| struct tinygemm_kernel_vnni { |
| static inline void apply( |
| const uint8_t* __restrict__ A, |
| const int8_t* __restrict__ B0, |
| const int8_t* __restrict__ B1, |
| scalar_t* __restrict__ C, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs0, |
| const float* __restrict__ Bs1, |
| const int32_t* __restrict__ Bcomp0, |
| const int32_t* __restrict__ Bcomp1, |
| int64_t K, |
| int64_t lda, |
| int64_t ldb, |
| int64_t ldc) { |
| TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); |
| } |
| }; |
|
|
| #if defined(CPU_CAPABILITY_AVX512) |
| template <int BLOCK_M, int BLOCK_N> |
| struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> { |
| static inline void apply( |
| const uint8_t* __restrict__ A, |
| const int8_t* __restrict__ B0, |
| const int8_t* __restrict__ B1, |
| at::BFloat16* __restrict__ C, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs0, |
| const float* __restrict__ Bs1, |
| const int32_t* __restrict__ Bcomp0, |
| const int32_t* __restrict__ Bcomp1, |
| int64_t K, |
| int64_t lda, |
| int64_t ldb, |
| int64_t ldc) { |
| constexpr int ROWS = BLOCK_M; |
| constexpr int COLS = BLOCK_N / 16; |
| static_assert(COLS % 2 == 0); |
|
|
| __m512i va; |
| __m512i vb0[COLS]; |
| __m512i vb1[COLS]; |
| __m512i vc0[ROWS * COLS]; |
| __m512i vc1[ROWS * COLS]; |
| __m512i vcomp0[COLS]; |
| __m512i vcomp1[COLS]; |
| __m512 vas; |
| __m512 vbs0[COLS]; |
| __m512 vbs1[COLS]; |
|
|
| auto loadc = [&](auto i) { |
| vc0[i] = _mm512_set1_epi32(0); |
| vc1[i] = _mm512_set1_epi32(0); |
| }; |
| Unroll<ROWS * COLS>{}(loadc); |
|
|
| const int64_t K4 = K >> 2; |
| const int64_t lda4 = lda >> 2; |
| const int64_t ldb4 = ldb; |
| const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); |
| const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0); |
| const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1); |
|
|
| auto compute = [&](auto i, int64_t k) { |
| constexpr int row = i / COLS; |
| constexpr int col = i % COLS; |
|
|
| if constexpr (col == 0) { |
| va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); |
| } |
| if constexpr (row == 0) { |
| vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); |
| vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); |
| } |
| vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); |
| vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); |
| }; |
| for (int64_t k = 0; k < K4; ++k) { |
| Unroll<ROWS * COLS>{}(compute, k); |
| } |
|
|
| auto scalec = [&](auto i) { |
| constexpr int row = i / COLS; |
| constexpr int col = i % COLS; |
|
|
| |
| if constexpr (col == 0) { |
| vas = _mm512_set1_ps(As[row]); |
| } |
| |
| if constexpr (row == 0) { |
| vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); |
| vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); |
| vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); |
| vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); |
| } |
| __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); |
| __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); |
| vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, vas), vbs0[col])); |
| vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, vas), vbs1[col])); |
| }; |
| Unroll<ROWS * COLS>{}(scalec); |
|
|
| using Vec = at::vec::Vectorized<float>; |
| const Vec one = Vec(1.f); |
| auto storec = [&](auto i) { |
| constexpr int row = i / COLS; |
| constexpr int col = i % COLS; |
| |
| if constexpr (col % 2 == 0) { |
| Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); |
| Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); |
| Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); |
| Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); |
| |
| x0 = x0 / (one + x0.neg().exp_u20()); |
| x1 = x1 / (one + x1.neg().exp_u20()); |
| |
| x0 = x0 * y0; |
| x1 = x1 * y1; |
|
|
| _mm512_storeu_si512( |
| reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), |
| (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); |
| } |
| }; |
| Unroll<ROWS * COLS>{}(storec); |
| } |
| }; |
| #endif |
|
|
| #define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \ |
| tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \ |
| A + mb_start * lda, \ |
| B0 + nb_start * 4, \ |
| B1 + nb_start * 4, \ |
| C + mb_start * ldc + nb_start, \ |
| As + mb_start, \ |
| Bs0 + nb_start, \ |
| Bs1 + nb_start, \ |
| Bcomp0 + nb_start, \ |
| Bcomp1 + nb_start, \ |
| K, \ |
| lda, \ |
| ldb, \ |
| ldc); |
|
|
| template <typename scalar_t> |
| void tinygemm_kernel( |
| const uint8_t* __restrict__ A, |
| const int8_t* __restrict__ B0, |
| const int8_t* __restrict__ B1, |
| scalar_t* __restrict__ C, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs0, |
| const float* __restrict__ Bs1, |
| int64_t M, |
| int64_t N, |
| int64_t K, |
| int64_t lda, |
| int64_t ldb, |
| int64_t ldc) { |
| const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K); |
| const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K); |
|
|
| |
| constexpr int64_t BLOCK_M = 4; |
| constexpr int64_t BLOCK_N = 32; |
| const int64_t MB = div_up(M, BLOCK_M); |
| const int64_t NB = div_up(N, BLOCK_N); |
| for (int mb = 0; mb < MB; ++mb) { |
| int64_t mb_start = mb * BLOCK_M; |
| int64_t mb_size = std::min(BLOCK_M, M - mb_start); |
| for (int64_t nb = 0; nb < NB; ++nb) { |
| int64_t nb_start = nb * BLOCK_N; |
| int64_t nb_size = std::min(BLOCK_N, N - nb_start); |
|
|
| switch (mb_size << 4 | nb_size >> 4) { |
| case 0x12: |
| LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); |
| break; |
| case 0x22: |
| LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); |
| break; |
| case 0x32: |
| LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); |
| break; |
| case 0x42: |
| LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); |
| break; |
| default: |
| TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); |
| } |
| } |
| } |
| } |
|
|
| |
| template <typename scalar_t, int BLOCK_M, int BLOCK_N> |
| struct tinygemm_kernel_vnni2 { |
| static inline void apply( |
| const uint8_t* __restrict__ A, |
| const int8_t* __restrict__ B, |
| float* __restrict__ C, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs, |
| const int32_t* __restrict__ Bcomp, |
| int64_t K, |
| int64_t lda, |
| int64_t ldb, |
| int64_t ldc) { |
| TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); |
| } |
| }; |
|
|
| #if defined(CPU_CAPABILITY_AVX512) |
| template <int BLOCK_M, int BLOCK_N> |
| struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> { |
| static inline void apply( |
| const uint8_t* __restrict__ A, |
| const int8_t* __restrict__ B, |
| float* __restrict__ C, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs, |
| const int32_t* __restrict__ Bcomp, |
| int64_t K, |
| int64_t lda, |
| int64_t ldb, |
| int64_t ldc) { |
| constexpr int ROWS = BLOCK_M; |
| constexpr int COLS = BLOCK_N / 16; |
| static_assert(COLS % 2 == 0); |
|
|
| __m512i va; |
| __m512i vb[COLS]; |
| __m512i vc[ROWS * COLS]; |
| __m512i vcomp[COLS]; |
| __m512 vas; |
| __m512 vbs[COLS]; |
|
|
| auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); }; |
| Unroll<ROWS * COLS>{}(loadc); |
|
|
| const int64_t K4 = K >> 2; |
| const int64_t lda4 = lda >> 2; |
| const int64_t ldb4 = ldb; |
| const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); |
| const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B); |
|
|
| auto compute = [&](auto i, int64_t k) { |
| constexpr int row = i / COLS; |
| constexpr int col = i % COLS; |
|
|
| if constexpr (col == 0) { |
| va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); |
| } |
| if constexpr (row == 0) { |
| vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); |
| } |
| vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); |
| }; |
| for (int64_t k = 0; k < K4; ++k) { |
| Unroll<ROWS * COLS>{}(compute, k); |
| } |
|
|
| auto storec = [&](auto i) { |
| constexpr int row = i / COLS; |
| constexpr int col = i % COLS; |
|
|
| |
| if constexpr (col == 0) { |
| vas = _mm512_set1_ps(As[row]); |
| } |
| |
| |
| if constexpr (row == 0) { |
| if constexpr (col % 2 == 0) { |
| vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); |
| vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); |
| vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); |
| vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); |
| } |
| } |
| __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); |
| x = _mm512_mul_ps(_mm512_mul_ps(x, vas), vbs[col]); |
| _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); |
| }; |
| Unroll<ROWS * COLS>{}(storec); |
| } |
| }; |
| #endif |
|
|
| #define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \ |
| tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \ |
| A + mb_start * lda, \ |
| B + nb_start * 4, \ |
| C + mb_start * ldc + nb_start, \ |
| As + mb_start, \ |
| Bs + nb_start, \ |
| Bcomp + nb_start, \ |
| K, \ |
| lda, \ |
| ldb, \ |
| ldc); |
|
|
| template <typename scalar_t> |
| void tinygemm_kernel( |
| const uint8_t* __restrict__ A, |
| const int8_t* __restrict__ B, |
| float* __restrict__ C, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs, |
| int64_t M, |
| int64_t N, |
| int64_t K, |
| int64_t lda, |
| int64_t ldb, |
| int64_t ldc) { |
| |
| const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K); |
|
|
| |
| constexpr int64_t BLOCK_M = 4; |
| constexpr int64_t BLOCK_N = 64; |
| const int64_t MB = div_up(M, BLOCK_M); |
| const int64_t NB = div_up(N, BLOCK_N); |
| for (int64_t mb = 0; mb < MB; ++mb) { |
| int64_t mb_start = mb * BLOCK_M; |
| int64_t mb_size = std::min(BLOCK_M, M - mb_start); |
| for (int64_t nb = 0; nb < NB; ++nb) { |
| int64_t nb_start = nb * BLOCK_N; |
| int64_t nb_size = std::min(BLOCK_N, N - nb_start); |
|
|
| switch (mb_size << 4 | nb_size >> 4) { |
| case 0x12: |
| LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); |
| break; |
| case 0x22: |
| LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); |
| break; |
| case 0x32: |
| LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); |
| break; |
| case 0x42: |
| LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); |
| break; |
| default: |
| TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); |
| } |
| } |
| } |
| } |
|
|
| } |
|
|
| template <typename scalar_t> |
| void fused_experts_int8_kernel_impl( |
| scalar_t* __restrict__ output, |
| scalar_t* __restrict__ ic1, |
| scalar_t* __restrict__ ic2, |
| uint8_t* __restrict__ A_tmp, |
| float* __restrict__ C_tmp, |
| uint8_t* __restrict__ Aq_tmp, |
| float* __restrict__ As_tmp, |
| const scalar_t* __restrict__ input, |
| const int8_t* __restrict__ packed_w1, |
| const int8_t* __restrict__ packed_w2, |
| const float* __restrict__ w1s, |
| const float* __restrict__ w2s, |
| const float* __restrict__ topk_weights, |
| const int32_t* __restrict__ sorted_ids, |
| const int32_t* __restrict__ expert_ids, |
| const int32_t* __restrict__ offsets, |
| int64_t M, |
| int64_t N, |
| int64_t K, |
| int64_t E, |
| int64_t topk, |
| int64_t num_tokens_post_pad) { |
| |
| constexpr int64_t BLOCK_M = block_size_m(); |
| constexpr int64_t BLOCK_N = block_size_n(); |
|
|
| |
| at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { |
| for (int64_t m = begin; m < end; ++m) { |
| quantize_row_int8<scalar_t>(Aq_tmp + m * K, As_tmp[m], input + m * K, K); |
| } |
| }); |
|
|
| |
| const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); |
| const int64_t NB = div_up(N, BLOCK_N); |
|
|
| |
| TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); |
|
|
| |
| const int64_t packed_K = get_row_size<int8_t>(K); |
| const int64_t packed_N = get_row_size<int8_t>(N); |
|
|
| const int64_t stride_e = 2 * N * packed_K; |
| const int64_t stride_n = packed_K; |
|
|
| int64_t avg_M = std::max(int64_t(1), M * topk / E); |
| const bool use_brgemm = can_use_brgemm<int8_t>(avg_M); |
|
|
| |
| parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { |
| |
| int tid = get_thread_num(); |
| uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; |
| int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N; |
| int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; |
|
|
| alignas(64) float As[BLOCK_M]; |
|
|
| loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) { |
| |
| int64_t nb_upper = nb, nb_lower = nb + NB; |
| int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N); |
|
|
| |
| int32_t expert_id = expert_ids[mb]; |
| const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n; |
| const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n; |
| const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb_upper * BLOCK_N; |
| const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb_lower * BLOCK_N; |
|
|
| |
| const int32_t* A_ids = sorted_ids + mb * BLOCK_M; |
| int64_t m_size = offsets[mb + 1] - offsets[mb]; |
|
|
| for (int64_t m = 0; m < m_size; ++m) { |
| int32_t index = A_ids[m] / topk; |
| copy_stub(A + m * K, Aq_tmp + index * K, K); |
| As[m] = As_tmp[index]; |
| } |
|
|
| if (use_brgemm) { |
| |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| K, |
| K, |
| n_size, |
| BLOCK_N, |
| false, |
| A, |
| B0, |
| C0); |
|
|
| |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| K, |
| K, |
| n_size, |
| BLOCK_N, |
| false, |
| A, |
| B1, |
| C1); |
|
|
| const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K); |
| const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K); |
|
|
| |
| const int64_t offset = offsets[mb]; |
| silu_and_mul<scalar_t, BLOCK_N>( |
| ic1 + offset * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N); |
| } else { |
| |
| const int64_t offset = offsets[mb]; |
| tinygemm_kernel( |
| A, |
| B0, |
| B1, |
| ic1 + offset * N + nb * BLOCK_N, |
| As, |
| Bs0, |
| Bs1, |
| m_size, |
| n_size, |
| K, |
| K, |
| n_size, |
| N); |
| } |
| }); |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm_release(); |
| } |
| }); |
|
|
| |
| at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { |
| for (int64_t m = begin; m < end; ++m) { |
| quantize_row_int8<scalar_t>(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N); |
| } |
| }); |
|
|
| |
| |
| const int64_t OC = K; |
| const int64_t IC = N; |
| const int64_t MB2 = MB; |
| const int64_t NB2 = div_up(OC, BLOCK_N); |
| const int64_t stride_e2 = OC * packed_N; |
| const int64_t stride_oc = packed_N; |
|
|
| |
| parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { |
| |
| int tid = get_thread_num(); |
| float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; |
| int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N); |
|
|
| loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { |
| int64_t m_size = offsets[mb + 1] - offsets[mb]; |
| int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); |
|
|
| |
| |
| const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; |
| const float* __restrict__ As = As_tmp + offsets[mb]; |
| const int32_t* A_ids = sorted_ids + mb * BLOCK_M; |
|
|
| |
| int32_t expert_id = expert_ids[mb]; |
| const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; |
| const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; |
|
|
| |
| if (use_brgemm) { |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| IC, |
| IC, |
| n_size, |
| BLOCK_N, |
| false, |
| A, |
| B, |
| C32); |
|
|
| |
| const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC); |
| scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size); |
| } else { |
| tinygemm_kernel<scalar_t>( |
| A, |
| B, |
| C, |
| As, |
| Bs, |
| m_size, |
| n_size, |
| IC, |
| IC, |
| n_size, |
| BLOCK_N); |
| } |
|
|
| |
| |
| for (int64_t m = 0; m < m_size; ++m) { |
| int32_t index = A_ids[m]; |
| float weight = topk_weights[index]; |
| copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); |
| } |
| }); |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm_release(); |
| } |
| }); |
|
|
| |
| |
| at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { |
| for (int64_t m = begin; m < end; ++m) { |
| sum_stub(output + m * K, ic2 + m * topk * K, topk, K); |
| } |
| }); |
| } |
|
|
| #define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \ |
| template void fused_experts_int8_kernel_impl<TYPE>( \ |
| TYPE* __restrict__ output, \ |
| TYPE* __restrict__ ic1, \ |
| TYPE* __restrict__ ic2, \ |
| uint8_t* __restrict__ A_tmp, \ |
| float* __restrict__ C_tmp, \ |
| uint8_t* __restrict__ Aq_tmp, \ |
| float* __restrict__ As_tmp, \ |
| const TYPE* __restrict__ input, \ |
| const int8_t* __restrict__ packed_w1, \ |
| const int8_t* __restrict__ packed_w2, \ |
| const float* __restrict__ w1s, \ |
| const float* __restrict__ w2s, \ |
| const float* __restrict__ topk_weights, \ |
| const int32_t* __restrict__ sorted_ids, \ |
| const int32_t* __restrict__ expert_ids, \ |
| const int32_t* __restrict__ offsets, \ |
| int64_t M, \ |
| int64_t N, \ |
| int64_t K, \ |
| int64_t E, \ |
| int64_t topk, \ |
| int64_t num_tokens_post_pad) |
|
|
| INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); |
| INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); |
|
|
| template <typename scalar_t> |
| void shared_expert_int8_kernel_impl( |
| scalar_t* __restrict__ output, |
| scalar_t* __restrict__ ic1, |
| float* __restrict__ C_tmp, |
| uint8_t* __restrict__ Aq_tmp, |
| float* __restrict__ As_tmp, |
| const scalar_t* __restrict__ input, |
| const int8_t* __restrict__ packed_w1, |
| const int8_t* __restrict__ packed_w2, |
| const float* __restrict__ w1s, |
| const float* __restrict__ w2s, |
| const scalar_t* __restrict__ fused_experts_out, |
| float routed_scaling_factor, |
| int64_t M, |
| int64_t N, |
| int64_t K) { |
| |
| constexpr int64_t BLOCK_M = block_size_m(); |
| constexpr int64_t BLOCK_N = block_size_n(); |
|
|
| |
| at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { |
| for (int64_t m = begin; m < end; ++m) { |
| quantize_row_int8<scalar_t>(Aq_tmp + m * K, As_tmp[m], input + m * K, K); |
| } |
| }); |
|
|
| |
| const int64_t MB = div_up(M, BLOCK_M); |
| const int64_t NB = div_up(N, BLOCK_N); |
|
|
| TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); |
|
|
| |
| const int64_t packed_K = get_row_size<int8_t>(K); |
| const int64_t packed_N = get_row_size<int8_t>(N); |
| const int64_t stride_n = packed_K; |
|
|
| const bool use_brgemm = can_use_brgemm<int8_t>(M); |
|
|
| |
| parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { |
| |
| int tid = get_thread_num(); |
| int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N; |
| int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; |
|
|
| loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) { |
| |
| int64_t nb_upper = nb, nb_lower = nb + NB; |
| int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N); |
| int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); |
|
|
| |
| const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; |
| const float* As = As_tmp + mb * BLOCK_M; |
|
|
| |
| const int8_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n; |
| const int8_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n; |
| const float* __restrict__ Bs0 = w1s + nb_upper * BLOCK_N; |
| const float* __restrict__ Bs1 = w1s + nb_lower * BLOCK_N; |
|
|
| if (use_brgemm) { |
| |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| K, |
| K, |
| n_size, |
| BLOCK_N, |
| false, |
| A, |
| B0, |
| C0); |
|
|
| |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| K, |
| K, |
| n_size, |
| BLOCK_N, |
| false, |
| A, |
| B1, |
| C1); |
|
|
| const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K); |
| const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K); |
|
|
| |
| silu_and_mul<scalar_t, BLOCK_N>( |
| ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N); |
| } else { |
| |
| tinygemm_kernel( |
| A, |
| B0, |
| B1, |
| ic1 + mb * BLOCK_M * N + nb * BLOCK_N, |
| As, |
| Bs0, |
| Bs1, |
| m_size, |
| n_size, |
| K, |
| K, |
| n_size, |
| N); |
| } |
| }); |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm_release(); |
| } |
| }); |
|
|
| |
| at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { |
| for (int64_t m = begin; m < end; ++m) { |
| quantize_row_int8<scalar_t>(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N); |
| } |
| }); |
|
|
| |
| |
| const int64_t OC = K; |
| const int64_t IC = N; |
| const int64_t MB2 = MB; |
| const int64_t NB2 = div_up(OC, BLOCK_N); |
| const int64_t stride_oc = packed_N; |
|
|
| |
| parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { |
| |
| int tid = get_thread_num(); |
| float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; |
| int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N); |
|
|
| loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { |
| int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); |
| int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); |
|
|
| |
| const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; |
| const float* __restrict__ As = As_tmp + mb * BLOCK_M; |
|
|
| |
| const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; |
| const float* __restrict__ Bs = w2s + nb * BLOCK_N; |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| IC, |
| IC, |
| n_size, |
| BLOCK_N, |
| false, |
| A, |
| B, |
| C32); |
|
|
| |
| const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC); |
| scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size); |
| } else { |
| |
| tinygemm_kernel<scalar_t>( |
| A, |
| B, |
| C, |
| As, |
| Bs, |
| m_size, |
| n_size, |
| IC, |
| IC, |
| n_size, |
| BLOCK_N); |
| } |
|
|
| |
| scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; |
| const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; |
| for (int64_t m = 0; m < m_size; ++m) { |
| add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); |
| } |
| }); |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm_release(); |
| } |
| }); |
| } |
|
|
| #define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ |
| template void shared_expert_int8_kernel_impl<TYPE>( \ |
| TYPE* __restrict__ output, \ |
| TYPE* __restrict__ ic1, \ |
| float* __restrict__ C_tmp, \ |
| uint8_t* __restrict__ Aq_tmp, \ |
| float* __restrict__ As_tmp, \ |
| const TYPE* __restrict__ input, \ |
| const int8_t* __restrict__ packed_w1, \ |
| const int8_t* __restrict__ packed_w2, \ |
| const float* __restrict__ w1s, \ |
| const float* __restrict__ w2s, \ |
| const TYPE* __restrict__ fused_experts_out, \ |
| float routed_scaling_factor, \ |
| int64_t M, \ |
| int64_t N, \ |
| int64_t K) |
|
|
| INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); |
| INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); |
|
|