JohannesGaessler commited on
Commit
7ab774c
·
unverified ·
1 Parent(s): 3ff7660

CUDA: more warps for mmvq on NVIDIA (llama/5394)

Browse files
Files changed (1) hide show
  1. ggml-cuda.cu +86 -47
ggml-cuda.cu CHANGED
@@ -5310,22 +5310,26 @@ template <bool need_check> static __global__ void
5310
  #endif // __CUDA_ARCH__ >= CC_VOLTA
5311
  }
5312
 
5313
- template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 
 
 
 
 
 
 
5314
  static __global__ void mul_mat_vec_q(
5315
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
5317
 
5318
  const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5319
 
5320
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
5321
-
5322
- if (row >= nrows_x) {
5323
- return;
5324
- }
5325
 
5326
  const int blocks_per_row_x = ncols_x / qk;
5327
  const int blocks_per_col_y = nrows_y / QK8_1;
5328
- const int blocks_per_warp = vdr * WARP_SIZE / qi;
5329
 
5330
  // partial sum for each thread
5331
  float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
@@ -5333,12 +5337,12 @@ static __global__ void mul_mat_vec_q(
5333
  const block_q_t * x = (const block_q_t *) vx;
5334
  const block_q8_1 * y = (const block_q8_1 *) vy;
5335
 
5336
- for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
5337
  const int ibx = row*blocks_per_row_x + i; // x block index
5338
 
5339
  const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
5340
 
5341
- const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
5342
 
5343
  #pragma unroll
5344
  for (int j = 0; j < ncols_y; ++j) {
@@ -5346,9 +5350,25 @@ static __global__ void mul_mat_vec_q(
5346
  }
5347
  }
5348
 
 
 
 
 
 
 
 
 
 
 
 
 
5349
  // sum up partial sums and write back result
5350
  #pragma unroll
5351
  for (int j = 0; j < ncols_y; ++j) {
 
 
 
 
5352
  tmp[j] = warp_reduce_sum(tmp[j]);
5353
 
5354
  if (threadIdx.x == 0) {
@@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda(
6833
  GGML_ASSERT(ncols_x % qk == 0);
6834
  GGML_ASSERT(ncols_y <= 4);
6835
 
6836
- const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6837
- const dim3 block_nums(block_num_y, 1, 1);
6838
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6839
- switch (ncols_y) {
6840
- case 1:
6841
- mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
6842
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6843
- break;
6844
- case 2:
6845
- mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
6846
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6847
- break;
6848
- case 3:
6849
- mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
6850
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6851
- break;
6852
- case 4:
6853
- mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
6854
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6855
- break;
6856
- // case 5:
6857
- // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6859
- // break;
6860
- // case 6:
6861
- // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6863
- // break;
6864
- // case 7:
6865
- // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6867
- // break;
6868
- // case 8:
6869
- // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6871
- // break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6872
  default:
6873
  GGML_ASSERT(false);
6874
- // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875
- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6876
  break;
6877
  }
6878
  }
 
5310
  #endif // __CUDA_ARCH__ >= CC_VOLTA
5311
  }
5312
 
5313
+ #define MMVQ_NWARPS_NVIDIA 4
5314
+ #define MMVQ_NWARPS_AMD_RDNA2 1
5315
+ #define MMVQ_NWARPS_AMD_OLD 4
5316
+
5317
+ template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
5318
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5319
+ __launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
5320
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
5321
  static __global__ void mul_mat_vec_q(
5322
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5323
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
5324
 
5325
  const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
5326
 
5327
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
5328
+ const int row = blockIdx.x;
 
 
 
5329
 
5330
  const int blocks_per_row_x = ncols_x / qk;
5331
  const int blocks_per_col_y = nrows_y / QK8_1;
5332
+ const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
5333
 
5334
  // partial sum for each thread
5335
  float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
 
5337
  const block_q_t * x = (const block_q_t *) vx;
5338
  const block_q8_1 * y = (const block_q8_1 *) vy;
5339
 
5340
+ for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
5341
  const int ibx = row*blocks_per_row_x + i; // x block index
5342
 
5343
  const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
5344
 
5345
+ const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
5346
 
5347
  #pragma unroll
5348
  for (int j = 0; j < ncols_y; ++j) {
 
5350
  }
5351
  }
5352
 
5353
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
5354
+ if (threadIdx.y > 0) {
5355
+ #pragma unroll
5356
+ for (int j = 0; j < ncols_y; ++j) {
5357
+ tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
5358
+ }
5359
+ }
5360
+ __syncthreads();
5361
+ if (threadIdx.y > 0) {
5362
+ return;
5363
+ }
5364
+
5365
  // sum up partial sums and write back result
5366
  #pragma unroll
5367
  for (int j = 0; j < ncols_y; ++j) {
5368
+ #pragma unroll
5369
+ for (int i = 0; i < nwarps-1; ++i) {
5370
+ tmp[j] += tmp_shared[i][j][threadIdx.x];
5371
+ }
5372
  tmp[j] = warp_reduce_sum(tmp[j]);
5373
 
5374
  if (threadIdx.x == 0) {
 
6853
  GGML_ASSERT(ncols_x % qk == 0);
6854
  GGML_ASSERT(ncols_y <= 4);
6855
 
6856
+ int id;
6857
+ CUDA_CHECK(cudaGetDevice(&id));
6858
+
6859
+ int nwarps;
6860
+ if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
6861
+ nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
6862
+ } else {
6863
+ nwarps = MMVQ_NWARPS_NVIDIA;
6864
+ }
6865
+
6866
+ const dim3 block_nums(nrows_x, 1, 1);
6867
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
6868
+
6869
+ switch (nwarps) {
6870
+ case 1: switch(ncols_y) {
6871
+ case 1:
6872
+ mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
6873
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6874
+ break;
6875
+ case 2:
6876
+ mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
6877
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6878
+ break;
6879
+ case 3:
6880
+ mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
6881
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6882
+ break;
6883
+ case 4:
6884
+ mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
6885
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6886
+ break;
6887
+ default:
6888
+ GGML_ASSERT(false);
6889
+ break;
6890
+ } break;
6891
+ case 4: switch(ncols_y) {
6892
+ case 1:
6893
+ mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
6894
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6895
+ break;
6896
+ case 2:
6897
+ mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
6898
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6899
+ break;
6900
+ case 3:
6901
+ mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
6902
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6903
+ break;
6904
+ case 4:
6905
+ mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
6906
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6907
+ break;
6908
+ default:
6909
+ GGML_ASSERT(false);
6910
+ break;
6911
+ } break;
6912
+
6913
  default:
6914
  GGML_ASSERT(false);
 
 
6915
  break;
6916
  }
6917
  }