Charris71's picture
Upload folder using huggingface_hub
a1bea4a
raw
history blame contribute delete
No virus
5.28 kB
import torch
import torch.nn as nn
import quant_cuda
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
print('Benchmarking LLaMa-7B FC2 matvec ...')
DEV = torch.device('cuda:0')
B = 5
L = 128
M = 4096
N = 11008
DTYPE = torch.half
mat = torch.randn((M, N), device=DEV, dtype=DTYPE)
vec = torch.randn((B, M), device=DEV, dtype=DTYPE)
mul = torch.zeros((B, N), device=DEV, dtype=DTYPE)
COUNT = 1000
import time
tick = time.time()
for _ in range(COUNT):
torch.matmul(vec, mat, out=mul)
torch.cuda.synchronize()
print('FP16:', (time.time() - tick) / COUNT)
DTYPE = torch.float
mat = mat.to(DTYPE)
vec = vec.to(DTYPE)
mul = mul.to(DTYPE)
mat = torch.randint(-1000000000, 1000000000, (M // 256 * 32, N), device=DEV, dtype=torch.int)
scales = torch.randn(N, device=DEV, dtype=DTYPE)
zeros = torch.randint(-1000000000, 1000000000, (1, N // 256 * 32), device=DEV, dtype=torch.int)
COUNT = 1000
import time
vec = vec.float()
tick = time.time()
for _ in range(COUNT):
quant_cuda.vecquant2matmul(vec, mat, mul, scales, zeros, M)
torch.cuda.synchronize()
print('2bit:', (time.time() - tick) / COUNT)
vec = vec.half()
tick = time.time()
for _ in range(COUNT):
quant_cuda.vecquant2matmul_faster(vec, mat, mul, scales, zeros, M, M//2)
torch.cuda.synchronize()
print('2bit:', (time.time() - tick) / COUNT, '(faster)')
vec = vec.float()
tick = time.time()
for _ in range(COUNT):
quant_cuda.vecquant3matmul(vec, mat, mul, scales, zeros, M)
torch.cuda.synchronize()
print('3bit:', (time.time() - tick) / COUNT)
vec = vec.half()
tick = time.time()
for _ in range(COUNT):
quant_cuda.vecquant3matmul_faster(vec, mat, mul, scales, zeros, M, M//2)
torch.cuda.synchronize()
print('3bit:', (time.time() - tick) / COUNT, '(faster)')
vec = vec.float()
tick = time.time()
for _ in range(COUNT):
quant_cuda.vecquant4matmul(vec, mat, mul, scales, zeros, M)
torch.cuda.synchronize()
print('4bit:', (time.time() - tick) / COUNT)
vec = vec.half()
tick = time.time()
for _ in range(COUNT):
quant_cuda.vecquant4matmul_faster(vec, mat, mul, scales, zeros, M, M//2)
torch.cuda.synchronize()
print('4bit:', (time.time() - tick) / COUNT, '(faster)')
vec = vec.float()
tick = time.time()
for _ in range(COUNT):
quant_cuda.vecquant8matmul(vec, mat, mul, scales, zeros, M)
torch.cuda.synchronize()
print('8bit:', (time.time() - tick) / COUNT)
print('Verifiying kernel correctness ...')
M = 4096
N = 11008
from quant import *
layer = nn.Linear(M, N)
vec = torch.randn(B,L,M).to(DEV)
quantizer = Quantizer()
quantizer.configure(2, perchannel=True, sym=False, mse=False)
quantizer.find_params(layer.weight.data, weight=True)
layer.weight.data = quantize(
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
)
qlayer = QuantLinear(2, -1, layer.in_features, layer.out_features, kernel_switch_threshold = False)
qlayer.pack(layer, quantizer.scale, quantizer.zero)
qlayer = qlayer.to(DEV)
layer = layer.to(DEV)
with torch.no_grad():
print('2bit Simu:', layer.to(DEV)(vec))
print('2bit Kern:', qlayer(vec))
qlayer.faster = True
print('2bit Kern:', qlayer(vec.half()), '(faster)')
print('\n')
layer = nn.Linear(M, N)
vec = torch.randn(B,L,M).to(DEV)
quantizer = Quantizer()
quantizer.configure(3, perchannel=True, sym=False, mse=False)
quantizer.find_params(layer.weight.data, weight=True)
layer.weight.data = quantize(
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
)
qlayer = QuantLinear(3, -1, layer.in_features, layer.out_features, kernel_switch_threshold = False)
qlayer.pack(layer, quantizer.scale, quantizer.zero)
qlayer = qlayer.to(DEV)
layer = layer.to(DEV)
with torch.no_grad():
print('3bit Simu:', layer.to(DEV)(vec))
print('3bit Kern:', qlayer(vec))
qlayer.faster = True
print('3bit Kern:', qlayer(vec.half()), '(faster)')
print('\n')
layer = nn.Linear(M, N)
vec = torch.randn(B,L,M).to(DEV)
quantizer = Quantizer()
quantizer.configure(4, perchannel=True, sym=False, mse=False)
quantizer.find_params(layer.weight.data, weight=True)
layer.weight.data = quantize(
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
)
qlayer = QuantLinear(4, -1, layer.in_features, layer.out_features, kernel_switch_threshold = False)
qlayer.pack(layer, quantizer.scale, quantizer.zero)
qlayer = qlayer.to(DEV)
layer = layer.to(DEV)
with torch.no_grad():
print('4bit Simu:', layer.to(DEV)(vec))
print('4bit Kern:', qlayer(vec))
qlayer.faster = True
print('4bit Kern:', qlayer(vec.half()), '(faster)')
print('\n')
layer = nn.Linear(M, N)
vec = torch.randn(B,L,M).to(DEV)
quantizer = Quantizer()
quantizer.configure(8, perchannel=True, sym=False, mse=False)
quantizer.find_params(layer.weight.data, weight=True)
layer.weight.data = quantize(
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
)
qlayer = QuantLinear(8, -1, layer.in_features, layer.out_features, kernel_switch_threshold = False)
qlayer.pack(layer, quantizer.scale, quantizer.zero)
qlayer = qlayer.to(DEV)
layer = layer.to(DEV)
with torch.no_grad():
print('8bit Simu:', layer.to(DEV)(vec))
print('8bit Kern:', qlayer(vec))
print('\n')