Update quant.py
Browse files
quant.py
CHANGED
@@ -657,7 +657,8 @@ __global__ void VecQuant8MatMulKernel(
|
|
657 |
atomicAdd(&mul[b * width + w], res);
|
658 |
}
|
659 |
'''
|
660 |
-
open("quant_cuda_kernel.cu","w")
|
|
|
661 |
cppcode = '''
|
662 |
#include <torch/all.h>
|
663 |
#include <torch/python.h>
|
@@ -730,7 +731,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
730 |
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
731 |
}
|
732 |
'''
|
733 |
-
open("quant_cuda.cpp","w")
|
|
|
734 |
setup(
|
735 |
name='quant_cuda',
|
736 |
ext_modules=[cpp_extension.CUDAExtension(
|
|
|
657 |
atomicAdd(&mul[b * width + w], res);
|
658 |
}
|
659 |
'''
|
660 |
+
open("quant_cuda_kernel.cu","w") as f:
|
661 |
+
f.write(cucode)
|
662 |
cppcode = '''
|
663 |
#include <torch/all.h>
|
664 |
#include <torch/python.h>
|
|
|
731 |
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
|
732 |
}
|
733 |
'''
|
734 |
+
with open("quant_cuda.cpp","w") as f:
|
735 |
+
f.write(cppcode)
|
736 |
setup(
|
737 |
name='quant_cuda',
|
738 |
ext_modules=[cpp_extension.CUDAExtension(
|