chinoll commited on
Commit
84093e7
1 Parent(s): 27e85b0

Update quant.py

Browse files
Files changed (1) hide show
  1. quant.py +606 -2
quant.py CHANGED
@@ -120,7 +120,8 @@ class Quantizer(nn.Module):
120
 
121
 
122
  try:
123
- import quant_cuda
 
124
  except:
125
  import os
126
  import sys
@@ -130,6 +131,608 @@ except:
130
  from setuptools import setup, Extension
131
  from torch.utils import cpp_extension
132
  os.chdir(dir_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  setup(
134
  name='quant_cuda',
135
  ext_modules=[cpp_extension.CUDAExtension(
@@ -146,7 +749,8 @@ except:
146
  sys.path.append(os.path.join(i,j))
147
  break
148
  break
149
- import quant_cuda
 
150
 
151
 
152
  # Assumes layer is perfectly divisible into 256 * 256 blocks
 
120
 
121
 
122
  try:
123
+ import importlib
124
+ quant_cuda = importlib.import_module("quant_cuda")
125
  except:
126
  import os
127
  import sys
 
131
  from setuptools import setup, Extension
132
  from torch.utils import cpp_extension
133
  os.chdir(dir_path)
134
+ cucode = '''
135
+ #include <torch/all.h>
136
+ #include <torch/python.h>
137
+ #include <cuda.h>
138
+ #include <cuda_runtime.h>
139
+
140
+ // atomicAdd for double-precision floating-point numbers on hardware with
141
+ // compute capability < 6.0 from:
142
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
143
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
144
+ __device__ double atomicAdd(
145
+ double* address,
146
+ double val
147
+ ) {
148
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
149
+ unsigned long long int old = *address_as_ull, assumed;
150
+
151
+ do {
152
+ assumed = old;
153
+ old = atomicCAS(
154
+ address_as_ull,
155
+ assumed,
156
+ __double_as_longlong(val + __longlong_as_double(assumed))
157
+ );
158
+
159
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
160
+ } while (assumed != old);
161
+
162
+ return __longlong_as_double(old);
163
+ }
164
+ #endif
165
+
166
+ template <typename scalar_t>
167
+ __global__ void VecQuant2MatMulKernel(
168
+ const scalar_t* __restrict__ vec,
169
+ const int* __restrict__ mat,
170
+ scalar_t* __restrict__ mul,
171
+ const scalar_t* __restrict__ scales,
172
+ const int* __restrict__ zeros,
173
+ int batch,
174
+ int vec_height,
175
+ int height,
176
+ int width,
177
+ int zero_width,
178
+ int groupsize
179
+ );
180
+
181
+ template <typename scalar_t>
182
+ __global__ void VecQuant3MatMulKernel(
183
+ const scalar_t* __restrict__ vec,
184
+ const int* __restrict__ mat,
185
+ scalar_t* __restrict__ mul,
186
+ const scalar_t* __restrict__ scales,
187
+ const int* __restrict__ zeros,
188
+ int batch,
189
+ int vec_height,
190
+ int height,
191
+ int width,
192
+ int zero_width,
193
+ int groupsize
194
+ );
195
+
196
+ template <typename scalar_t>
197
+ __global__ void VecQuant4MatMulKernel(
198
+ const scalar_t* __restrict__ vec,
199
+ const int* __restrict__ mat,
200
+ scalar_t* __restrict__ mul,
201
+ const scalar_t* __restrict__ scales,
202
+ const int* __restrict__ zeros,
203
+ int batch,
204
+ int vec_height,
205
+ int height,
206
+ int width,
207
+ int zero_width,
208
+ int groupsize
209
+ );
210
+
211
+ template <typename scalar_t>
212
+ __global__ void VecQuant8MatMulKernel(
213
+ const scalar_t* __restrict__ vec,
214
+ const int* __restrict__ mat,
215
+ scalar_t* __restrict__ mul,
216
+ const scalar_t* __restrict__ scales,
217
+ const int* __restrict__ zeros,
218
+ int batch,
219
+ int vec_height,
220
+ int height,
221
+ int width,
222
+ int zero_width,
223
+ int groupsize
224
+ );
225
+
226
+ const int BLOCKWIDTH = 256;
227
+ const int BLOCKHEIGHT2 = 16;
228
+ const int BLOCKHEIGHT3 = 24;
229
+ const int BLOCKHEIGHT4 = 32;
230
+ const int BLOCKHEIGHT8 = 64;
231
+
232
+ __device__ inline unsigned int as_unsigned(int i) {
233
+ return *reinterpret_cast<unsigned int*>(&i);
234
+ }
235
+
236
+ void vecquant2matmul_cuda(
237
+ torch::Tensor vec,
238
+ torch::Tensor mat,
239
+ torch::Tensor mul,
240
+ torch::Tensor scales,
241
+ torch::Tensor zeros,
242
+ int groupsize
243
+ ) {
244
+ int batch = vec.size(0);
245
+ int vec_height = vec.size(1);
246
+ int height = mat.size(0);
247
+ int width = mat.size(1);
248
+ int zero_width = zeros.size(1);
249
+
250
+ dim3 blocks(
251
+ (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
252
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
253
+ batch
254
+ );
255
+ dim3 threads(BLOCKWIDTH);
256
+
257
+ AT_DISPATCH_FLOATING_TYPES(
258
+ vec.type(), "vecquant2matmul_cuda", ([&] {
259
+ VecQuant2MatMulKernel<<<blocks, threads>>>(
260
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
261
+ scales.data<scalar_t>(), zeros.data<int>(),
262
+ batch, vec_height, height, width, zero_width, groupsize
263
+ );
264
+ })
265
+ );
266
+ }
267
+
268
+ template <typename scalar_t>
269
+ __global__ void VecQuant2MatMulKernel(
270
+ const scalar_t* __restrict__ vec,
271
+ const int* __restrict__ mat,
272
+ scalar_t* __restrict__ mul,
273
+ const scalar_t* __restrict__ scales,
274
+ const int* __restrict__ zeros,
275
+ int batch,
276
+ int vec_height,
277
+ int height,
278
+ int width,
279
+ int zero_width,
280
+ int groupsize
281
+ ) {
282
+ int b = blockIdx.z;
283
+ int h = BLOCKHEIGHT2 * blockIdx.x;
284
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
285
+
286
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
287
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
288
+ __syncthreads();
289
+
290
+ scalar_t res = 0;
291
+ int i = width * h + w;
292
+ int g_h = h * 16;
293
+ int k = 0;
294
+
295
+ int z_w = w / 16;
296
+ int z_mod = (w % 16) * 2;
297
+
298
+ unsigned int tmp;
299
+
300
+ while (k < BLOCKWIDTH) {
301
+ tmp = as_unsigned(mat[i]);
302
+
303
+ int g = (g_h + k) / groupsize;
304
+ scalar_t scale = scales[g * width + w];
305
+ scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
306
+
307
+ res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
308
+ res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
309
+ res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
310
+ res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3];
311
+ res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4];
312
+ res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5];
313
+ res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6];
314
+ res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7];
315
+ res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8];
316
+ res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9];
317
+ res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10];
318
+ res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11];
319
+ res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12];
320
+ res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
321
+ res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
322
+ res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
323
+
324
+ i += width;
325
+ k += 16;
326
+ }
327
+
328
+ atomicAdd(&mul[b * width + w], res);
329
+ }
330
+
331
+ void vecquant3matmul_cuda(
332
+ torch::Tensor vec,
333
+ torch::Tensor mat,
334
+ torch::Tensor mul,
335
+ torch::Tensor scales,
336
+ torch::Tensor zeros,
337
+ int groupsize
338
+ ) {
339
+ int batch = vec.size(0);
340
+ int vec_height = vec.size(1);
341
+ int height = mat.size(0);
342
+ int width = mat.size(1);
343
+ int zero_width = zeros.size(1);
344
+
345
+ dim3 blocks(
346
+ (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
347
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
348
+ batch
349
+ );
350
+ dim3 threads(BLOCKWIDTH);
351
+
352
+ AT_DISPATCH_FLOATING_TYPES(
353
+ vec.type(), "vecquant3matmul_cuda", ([&] {
354
+ VecQuant3MatMulKernel<<<blocks, threads>>>(
355
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
356
+ scales.data<scalar_t>(), zeros.data<int>(),
357
+ batch, vec_height, height, width, zero_width, groupsize
358
+ );
359
+ })
360
+ );
361
+ }
362
+
363
+ template <typename scalar_t>
364
+ __global__ void VecQuant3MatMulKernel(
365
+ const scalar_t* __restrict__ vec,
366
+ const int* __restrict__ mat,
367
+ scalar_t* __restrict__ mul,
368
+ const scalar_t* __restrict__ scales,
369
+ const int* __restrict__ zeros,
370
+ int batch,
371
+ int vec_height,
372
+ int height,
373
+ int width,
374
+ int zero_width,
375
+ int groupsize
376
+ ) {
377
+ int b = blockIdx.z;
378
+ int h = BLOCKHEIGHT3 * blockIdx.x;
379
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
380
+
381
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
382
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
383
+ __syncthreads();
384
+
385
+ scalar_t res = 0;
386
+ int i = width * h + w;
387
+ int g_h = (h / 3) * 32;
388
+ int k = 0;
389
+
390
+ int z_w = (w / 32) * 3; // ((w / 256) * 24) / 3
391
+ int z_mod = w % 32;
392
+ int z_bit;
393
+
394
+ if (z_mod != 10){
395
+ if (z_mod != 21){
396
+ z_bit = z_mod;
397
+ if (z_bit > 21){
398
+ z_bit -= 22;
399
+ z_bit *= 3;
400
+ z_bit += 2;
401
+ z_w += 2;
402
+ } else if (z_bit > 10){
403
+ z_bit -= 11;
404
+ z_bit *= 3;
405
+ z_bit += 1;
406
+ z_w += 1;
407
+ } else {
408
+ z_bit *= 3;
409
+ }
410
+ } else {
411
+ z_w += 1;
412
+ }
413
+ }
414
+
415
+ unsigned int tmp1;
416
+ unsigned int tmp2;
417
+ unsigned int tmp;
418
+ unsigned int z_tmp;
419
+
420
+ while (k < BLOCKWIDTH) {
421
+ tmp1 = as_unsigned(mat[i]);
422
+
423
+ int g = (g_h + k) / groupsize;
424
+ scalar_t scale = scales[g * width + w];
425
+ scalar_t zero;
426
+ if (z_mod == 10) {
427
+ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
428
+ zero = scale * scalar_t((z_tmp) + 1);
429
+ } else if (z_mod == 21){
430
+ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
431
+ zero = scale * scalar_t((z_tmp) + 1);
432
+ } else {
433
+ zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
434
+ }
435
+
436
+ res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
437
+ res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
438
+ res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
439
+ res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
440
+ res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
441
+ res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
442
+ res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
443
+ res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
444
+ res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
445
+ res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
446
+
447
+ i += width;
448
+ tmp2 = as_unsigned(mat[i]);
449
+ tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
450
+ tmp2 >>= 1;
451
+ res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
452
+ k += 11;
453
+
454
+ res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
455
+ res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
456
+ res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
457
+ res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3];
458
+ res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
459
+ res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
460
+ res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
461
+ res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
462
+ res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
463
+ res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
464
+
465
+ i += width;
466
+ tmp1 = as_unsigned(mat[i]);
467
+ tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
468
+ tmp1 >>= 2;
469
+ res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
470
+ k += 11;
471
+
472
+ res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
473
+ res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
474
+ res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
475
+ res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3];
476
+ res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
477
+ res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
478
+ res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
479
+ res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
480
+ res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
481
+ res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
482
+
483
+ i += width;
484
+ k += 10;
485
+ }
486
+
487
+ atomicAdd(&mul[b * width + w], res);
488
+ }
489
+
490
+ void vecquant4matmul_cuda(
491
+ torch::Tensor vec,
492
+ torch::Tensor mat,
493
+ torch::Tensor mul,
494
+ torch::Tensor scales,
495
+ torch::Tensor zeros,
496
+ int groupsize
497
+ ) {
498
+ int batch = vec.size(0);
499
+ int vec_height = vec.size(1);
500
+ int height = mat.size(0);
501
+ int width = mat.size(1);
502
+ int zero_width = zeros.size(1);
503
+
504
+ dim3 blocks(
505
+ (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
506
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
507
+ batch
508
+ );
509
+ dim3 threads(BLOCKWIDTH);
510
+
511
+ AT_DISPATCH_FLOATING_TYPES(
512
+ vec.type(), "vecquant4matmul_cuda", ([&] {
513
+ VecQuant4MatMulKernel<<<blocks, threads>>>(
514
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
515
+ scales.data<scalar_t>(), zeros.data<int>(),
516
+ batch, vec_height, height, width, zero_width, groupsize
517
+ );
518
+ })
519
+ );
520
+ }
521
+
522
+ template <typename scalar_t>
523
+ __global__ void VecQuant4MatMulKernel(
524
+ const scalar_t* __restrict__ vec,
525
+ const int* __restrict__ mat,
526
+ scalar_t* __restrict__ mul,
527
+ const scalar_t* __restrict__ scales,
528
+ const int* __restrict__ zeros,
529
+ int batch,
530
+ int vec_height,
531
+ int height,
532
+ int width,
533
+ int zero_width,
534
+ int groupsize
535
+ ) {
536
+ int b = blockIdx.z;
537
+ int h = BLOCKHEIGHT4 * blockIdx.x;
538
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
539
+
540
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
541
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
542
+ __syncthreads();
543
+
544
+ scalar_t res = 0;
545
+ int i = width * h + w;
546
+ int g_h = h * 8;
547
+ int k = 0;
548
+
549
+ int z_w = w / 8;
550
+ int z_mod = (w % 8) * 4;
551
+
552
+ unsigned int tmp;
553
+
554
+ while (k < BLOCKWIDTH) {
555
+ tmp = as_unsigned(mat[i]);
556
+
557
+ int g = (g_h + k) / groupsize;
558
+ scalar_t scale = scales[g * width + w];
559
+ scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
560
+
561
+ res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
562
+ res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
563
+ res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
564
+ res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3];
565
+ res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4];
566
+ res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
567
+ res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
568
+ res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
569
+
570
+ i += width;
571
+ k += 8;
572
+ }
573
+
574
+ atomicAdd(&mul[b * width + w], res);
575
+ }
576
+
577
+ void vecquant8matmul_cuda(
578
+ torch::Tensor vec,
579
+ torch::Tensor mat,
580
+ torch::Tensor mul,
581
+ torch::Tensor scales,
582
+ torch::Tensor zeros,
583
+ int groupsize
584
+ ) {
585
+ int batch = vec.size(0);
586
+ int vec_height = vec.size(1);
587
+ int height = mat.size(0);
588
+ int width = mat.size(1);
589
+ int zero_width = zeros.size(1);
590
+
591
+ dim3 blocks(
592
+ (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
593
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH,
594
+ batch
595
+ );
596
+ dim3 threads(BLOCKWIDTH);
597
+
598
+ AT_DISPATCH_FLOATING_TYPES(
599
+ vec.type(), "vecquant8matmul_cuda", ([&] {
600
+ VecQuant8MatMulKernel<<<blocks, threads>>>(
601
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
602
+ scales.data<scalar_t>(), zeros.data<int>(),
603
+ batch, vec_height, height, width, zero_width, groupsize
604
+ );
605
+ })
606
+ );
607
+ }
608
+
609
+ template <typename scalar_t>
610
+ __global__ void VecQuant8MatMulKernel(
611
+ const scalar_t* __restrict__ vec,
612
+ const int* __restrict__ mat,
613
+ scalar_t* __restrict__ mul,
614
+ const scalar_t* __restrict__ scales,
615
+ const int* __restrict__ zeros,
616
+ int batch,
617
+ int vec_height,
618
+ int height,
619
+ int width,
620
+ int zero_width,
621
+ int groupsize
622
+ ) {
623
+ int b = blockIdx.z;
624
+ int h = BLOCKHEIGHT8 * blockIdx.x;
625
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
626
+
627
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
628
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
629
+ __syncthreads();
630
+
631
+ scalar_t res = 0;
632
+ int i = width * h + w;
633
+ int g_h = h * 4;
634
+ int k = 0;
635
+
636
+ int z_w = w / 4;
637
+ int z_mod = (w % 4) * 8;
638
+
639
+ unsigned int tmp;
640
+
641
+ while (k < BLOCKWIDTH) {
642
+ tmp = as_unsigned(mat[i]);
643
+
644
+ int g = (g_h + k) / groupsize;
645
+ scalar_t scale = scales[g * width + w];
646
+ scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
647
+
648
+ res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
649
+ res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
650
+ res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
651
+ res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
652
+
653
+ i += width;
654
+ k += 4;
655
+ }
656
+
657
+ atomicAdd(&mul[b * width + w], res);
658
+ }
659
+ '''
660
+ with open("quant_cuda_kernel.cu","w") as f:
661
+ f.write(cucode)
662
+ cppcode = '''
663
+ #include <torch/all.h>
664
+ #include <torch/python.h>
665
+ #include <c10/cuda/CUDAGuard.h>
666
+
667
+ void vecquant2matmul_cuda(
668
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
669
+ torch::Tensor scales, torch::Tensor zeros,
670
+ int groupsize
671
+ );
672
+
673
+ void vecquant2matmul(
674
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
675
+ torch::Tensor scales, torch::Tensor zeros,
676
+ int groupsize
677
+ ) {
678
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
679
+ vecquant2matmul_cuda(vec, mat, mul, scales, zeros,groupsize);
680
+ }
681
+
682
+ void vecquant3matmul_cuda(
683
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
684
+ torch::Tensor scales, torch::Tensor zeros,
685
+ int groupsize
686
+ );
687
+
688
+ void vecquant3matmul(
689
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
690
+ torch::Tensor scales, torch::Tensor zeros,
691
+ int groupsize
692
+ ) {
693
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
694
+ vecquant3matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
695
+ }
696
+
697
+ void vecquant4matmul_cuda(
698
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
699
+ torch::Tensor scales, torch::Tensor zeros,
700
+ int groupsize
701
+ );
702
+
703
+ void vecquant4matmul(
704
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
705
+ torch::Tensor scales, torch::Tensor zeros,
706
+ int groupsize
707
+ ) {
708
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
709
+ vecquant4matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
710
+ }
711
+
712
+ void vecquant8matmul_cuda(
713
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
714
+ torch::Tensor scales, torch::Tensor zeros,
715
+ int groupsize
716
+ );
717
+
718
+ void vecquant8matmul(
719
+ torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
720
+ torch::Tensor scales, torch::Tensor zeros,
721
+ int groupsize
722
+ ) {
723
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
724
+ vecquant8matmul_cuda(vec, mat, mul, scales, zeros, groupsize);
725
+ }
726
+
727
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
728
+ m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
729
+ m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
730
+ m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
731
+ m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
732
+ }
733
+ '''
734
+ with open("quant_cuda.cpp","w") as f:
735
+ f.write(cppcode)
736
  setup(
737
  name='quant_cuda',
738
  ext_modules=[cpp_extension.CUDAExtension(
 
749
  sys.path.append(os.path.join(i,j))
750
  break
751
  break
752
+ import importlib
753
+ quant_cuda = importlib.import_module("quant_cuda")
754
 
755
 
756
  # Assumes layer is perfectly divisible into 256 * 256 blocks