|
|
|
import torch |
|
import quiptools_cuda |
|
import time |
|
|
|
k = 32*32 |
|
m = 8192*2 |
|
n = 8192//2 |
|
|
|
x = torch.randn(k,n,dtype=torch.float16,device="cuda") |
|
z = torch.zeros(k,m,dtype=torch.float16,device="cuda") |
|
cb = torch.randn(256,4,dtype=torch.float16,device="cuda") |
|
yidxs = torch.randint(256,(m,n//4),device="cuda").to(torch.uint8) |
|
|
|
yidxs_reordered = yidxs.view(m,n//(4*4),4).permute(1,0,2).reshape(m,n//4).contiguous() |
|
|
|
y1 = torch.zeros(m,n,dtype=torch.float16,device="cuda") |
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
start = time.time() |
|
y = cb[yidxs.view(-1).to(torch.int32),:].view(m,n) |
|
z0 = x @ y.t() |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f"elapsed for pure torch: {end - start}") |
|
|
|
|
|
torch.cuda.synchronize() |
|
start = time.time() |
|
quiptools_cuda.decompress_d4_origorder(yidxs,cb,y1) |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f"elapsed for orig decompress_d4: {end - start}") |
|
|
|
assert((y1 == y).all()) |
|
|
|
y1.zero_() |
|
torch.cuda.synchronize() |
|
start = time.time() |
|
quiptools_cuda.decompress_d4(yidxs_reordered,cb,y1) |
|
z1 = x @ y1.t() |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f"elapsed for decompress_d4 and multiply: {end - start}") |
|
|
|
assert((y1 == y).all()) |
|
|
|
|
|
torch.cuda.synchronize() |
|
start = time.time() |
|
quiptools_cuda.lookupmatmul_d4_k8(x,yidxs_reordered,cb,z) |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f" elapsed for k8 cuda: {end - start}") |
|
|
|
lookupmatmul_d4_k8_err = ((z.to(torch.float32) - z0.to(torch.float32)).square().sum() / ((z0.to(torch.float32)).square().sum()+1e-10)) |
|
print(f"lookupmatmul_d4_k8 error: {lookupmatmul_d4_k8_err}") |
|
|
|
torch.cuda.synchronize() |
|
start = time.time() |
|
quiptools_cuda.lookupmatmul_d4_k16(x,yidxs_reordered,cb,z) |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f" elapsed for k16 cuda: {end - start}") |
|
|
|
|
|
lookupmatmul_d4_k16_err = ((z.to(torch.float32) - z0.to(torch.float32)).square().sum() / ((z0.to(torch.float32)).square().sum()+1e-10)) |
|
print(f"lookupmatmul_d4_k16 error: {lookupmatmul_d4_k16_err}") |
|
|
|
|
|
torch.cuda.synchronize() |
|
start = time.time() |
|
quiptools_cuda.lookupmatmul_d4_k32(x,yidxs_reordered,cb,z) |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f" elapsed for k32 cuda: {end - start}") |
|
|
|
|
|
lookupmatmul_d4_k32_err = ((z.to(torch.float32) - z0.to(torch.float32)).square().sum() / ((z0.to(torch.float32)).square().sum()+1e-10)) |
|
print(f"lookupmatmul_d4_k32 error: {lookupmatmul_d4_k32_err}") |
|
|