#include #include #include #include #include // atomicAdd for double-precision floating-point numbers on hardware with // compute capability < 6.0 from: // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 __device__ double atomicAdd( double* address, double val ) { unsigned long long int* address_as_ull = (unsigned long long int*)address; unsigned long long int old = *address_as_ull, assumed; do { assumed = old; old = atomicCAS( address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)) ); // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); return __longlong_as_double(old); } #endif template __global__ void VecQuant2MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); template __global__ void VecQuant3MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); template __global__ void VecQuant4MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); __global__ void VecQuant2MatMulKernelFaster( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); __global__ void VecQuant3MatMulKernelFaster( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); __global__ void VecQuant4MatMulKernelFaster( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); const int BLOCKWIDTH = 256; const int BLOCKHEIGHT2 = 16; const int BLOCKHEIGHT3 = 24; const int BLOCKHEIGHT4 = 32; const int BLOCKHEIGHT8 = 64; __device__ inline unsigned int as_unsigned(int i) { return *reinterpret_cast(&i); } void vecquant2matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, int groupsize ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant2matmul_cuda", ([&] { VecQuant2MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width, zero_width, groupsize ); }) ); } template __global__ void VecQuant2MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ) { int b = blockIdx.z; int h = BLOCKHEIGHT2 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t res = 0; int i = width * h + w; int g_h = h * 16; int k = 0; int z_w = w / 16; int z_mod = (w % 16) * 2; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9]; res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10]; res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11]; res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12]; res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; i += width; k += 16; } atomicAdd(&mul[b * width + w], res); } void vecquant3matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, int groupsize ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant3matmul_cuda", ([&] { VecQuant3MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width, zero_width, groupsize ); }) ); } template __global__ void VecQuant3MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ) { int b = blockIdx.z; int h = BLOCKHEIGHT3 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t res = 0; int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; if (z_bit > 21){ z_bit -= 22; z_bit *= 3; z_bit += 2; z_w += 2; } else if (z_bit > 10){ z_bit -= 11; z_bit *= 3; z_bit += 1; z_w += 1; } else { z_bit *= 3; } } else { z_w += 1; } } unsigned int tmp1; unsigned int tmp2; unsigned int tmp; unsigned int z_tmp; while (k < BLOCKWIDTH) { tmp1 = as_unsigned(mat[i]); int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); zero = scale * scalar_t((z_tmp) + 1); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); zero = scale * scalar_t((z_tmp) + 1); } else { zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; i += width; tmp2 = as_unsigned(mat[i]); tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); tmp2 >>= 1; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; i += width; tmp1 = as_unsigned(mat[i]); tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); tmp1 >>= 2; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; i += width; k += 10; } atomicAdd(&mul[b * width + w], res); } void vecquant4matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, int groupsize ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant4matmul_cuda", ([&] { VecQuant4MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width, zero_width, groupsize ); }) ); } template __global__ void VecQuant4MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ) { int b = blockIdx.z; int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t res = 0; int i = width * h + w; int g_h = h * 8; int k = 0; int z_w = w / 8; int z_mod = (w % 8) * 4; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3]; res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4]; res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; i += width; k += 8; } atomicAdd(&mul[b * width + w], res); } void vecquant8matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, int groupsize ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant8matmul_cuda", ([&] { VecQuant8MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), batch, vec_height, height, width, zero_width, groupsize ); }) ); } template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ) { int b = blockIdx.z; int h = BLOCKHEIGHT8 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); scalar_t res = 0; int i = width * h + w; int g_h = h * 4; int k = 0; int z_w = w / 4; int z_mod = (w % 4) * 8; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; i += width; k += 4; } atomicAdd(&mul[b * width + w], res); } void vecquant2matmul_faster_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, int groupsize, int vec_height ) { int batch = vec.size(0); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); VecQuant2MatMulKernelFaster<<>>( (half2*) vec.data_ptr(), mat.data_ptr(), mul.data_ptr(), scales.data_ptr(), zeros.data_ptr(), batch, vec_height, height, width, zero_width, groupsize ); } __global__ void VecQuant2MatMulKernelFaster( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ) { const int blockwidth2 = BLOCKWIDTH / 2; int b = blockIdx.z; int h = BLOCKHEIGHT2 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ half2 blockvec[blockwidth2]; if (threadIdx.x < blockwidth2) blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x]; __shared__ half2 deq2[16][16]; int val = threadIdx.x / 16; int off = threadIdx.x % 16; for (; val < 16; val += BLOCKWIDTH / 16) { deq2[val][off] = __halves2half2( __int2half_rn(val & 0x3), __int2half_rn(val >> 2) ); } int i = width * h + w; int g_h = h * 16; int k = 0; int z_w = w / 16; int z_mod = (w % 16) * 2; float res = 0; half2 res2; unsigned int tmp; __syncthreads(); while (k < blockwidth2) { int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); res2 = {}; tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xf][off], scale, zero), blockvec[k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 12) & 0xf][off], scale, zero), blockvec[k + 3], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xf][off], scale, zero), blockvec[k + 4], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 20) & 0xf][off], scale, zero), blockvec[k + 5], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xf][off], scale, zero), blockvec[k + 6], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scale, zero), blockvec[k + 7], res2); i += width; k += 8; res += __half2float(res2.x) + __half2float(res2.y); } atomicAdd(&mul[b * width + w], res); } void vecquant3matmul_faster_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, int groupsize, int vec_height ) { int batch = vec.size(0); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); VecQuant3MatMulKernelFaster<<>>( (half2*) vec.data_ptr(), mat.data_ptr(), mul.data_ptr(), scales.data_ptr(), zeros.data_ptr(), batch, vec_height, height, width, zero_width, groupsize ); } __global__ void VecQuant3MatMulKernelFaster( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ) { const int blockwidth2 = BLOCKWIDTH / 2; int b = blockIdx.z; int h = BLOCKHEIGHT3 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ half2 blockvec[blockwidth2]; if (threadIdx.x < blockwidth2) blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x]; __shared__ half2 deq2[64][32]; int val = threadIdx.x / 32; int off = threadIdx.x % 32; for (; val < 64; val += BLOCKWIDTH / 32) { deq2[val][off] = __halves2half2( __int2half_rn(val & 0x7), __int2half_rn(val >> 3) ); } int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; if (z_bit > 21){ z_bit -= 22; z_bit *= 3; z_bit += 2; z_w += 2; } else if (z_bit > 10){ z_bit -= 11; z_bit *= 3; z_bit += 1; z_w += 1; } else { z_bit *= 3; } } else { z_w += 1; } } float res = 0; half2 res2; unsigned int tmp1; unsigned int tmp2; unsigned int tmp; unsigned int z_tmp; __syncthreads(); while (k < blockwidth2) { int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); half2 zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); } else { zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); } res2 = {}; tmp1 = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); i += width; tmp2 = as_unsigned(mat[i]); tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c); res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2); tmp2 >>= 4; k += 6; res2 = __hfma2(__hfma2(deq2[(tmp2 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp2 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); i += width; tmp1 = as_unsigned(mat[i]); tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30); res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2); tmp1 >>= 2; k += 5; res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2); i += width; k += 5; res += __half2float(res2.x) + __half2float(res2.y); } atomicAdd(&mul[b * width + w], res); } void vecquant4matmul_faster_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, int groupsize, int vec_height ) { int batch = vec.size(0); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, batch ); dim3 threads(BLOCKWIDTH); VecQuant4MatMulKernelFaster<<>>( (half2*) vec.data_ptr(), mat.data_ptr(), mul.data_ptr(), scales.data_ptr(), zeros.data_ptr(), batch, vec_height, height, width, zero_width, groupsize ); } __global__ void VecQuant4MatMulKernelFaster( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ) { const int blockwidth2 = BLOCKWIDTH / 2; int b = blockIdx.z; int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ half2 blockvec[blockwidth2]; if (threadIdx.x < blockwidth2) blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * blockwidth2 + threadIdx.x]; __shared__ half2 deq2[256][8]; int val = threadIdx.x / 8; int off = threadIdx.x % 8; for (; val < 256; val += BLOCKWIDTH / 8) { deq2[val][off] = __halves2half2( __int2half_rn(val & 0xF), __int2half_rn(val >> 4) ); } int i = width * h + w; int g_h = h * 8; int k = 0; int z_w = w / 8; int z_mod = (w % 8) * 4; float res = 0; half2 res2; unsigned int tmp; __syncthreads(); while (k < blockwidth2) { int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); res2 = {}; tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scale, zero), blockvec[k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scale, zero), blockvec[k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); i += width; k += 4; res += __half2float(res2.x) + __half2float(res2.y); } atomicAdd(&mul[b * width + w], res); }