KnutJaegersberg's picture
Upload 132 files
c1a41d7
raw
history blame contribute delete
No virus
2.54 kB
import torch
import quiptools_cuda
import time
k = 32*32
m = 8192*2
n = 8192//2
x = torch.randn(k,n,dtype=torch.float16,device="cuda")
z = torch.zeros(k,m,dtype=torch.float16,device="cuda")
cb = torch.randn(256,4,dtype=torch.float16,device="cuda")
yidxs = torch.randint(256,(m,n//4),device="cuda").to(torch.uint8)
yidxs_reordered = yidxs.view(m,n//(4*4),4).permute(1,0,2).reshape(m,n//4).contiguous()
y1 = torch.zeros(m,n,dtype=torch.float16,device="cuda")
# yidxs_reordered = yidxs.view(m//32,32,n//(4*4),4).permute(0,2,1,3).reshape(m,n//4).contiguous()
# yidxs_reordered_k16 = yidxs.view(m//16,16,n//(4*4),4).permute(0,2,1,3).reshape(m,n//4).contiguous()
torch.cuda.synchronize()
start = time.time()
y = cb[yidxs.view(-1).to(torch.int32),:].view(m,n)
z0 = x @ y.t()
torch.cuda.synchronize()
end = time.time()
print(f"elapsed for pure torch: {end - start}")
torch.cuda.synchronize()
start = time.time()
quiptools_cuda.decompress_d4_origorder(yidxs,cb,y1)
torch.cuda.synchronize()
end = time.time()
print(f"elapsed for orig decompress_d4: {end - start}")
assert((y1 == y).all())
y1.zero_()
torch.cuda.synchronize()
start = time.time()
quiptools_cuda.decompress_d4(yidxs_reordered,cb,y1)
z1 = x @ y1.t()
torch.cuda.synchronize()
end = time.time()
print(f"elapsed for decompress_d4 and multiply: {end - start}")
assert((y1 == y).all())
torch.cuda.synchronize()
start = time.time()
quiptools_cuda.lookupmatmul_d4_k8(x,yidxs_reordered,cb,z)
torch.cuda.synchronize()
end = time.time()
print(f" elapsed for k8 cuda: {end - start}")
lookupmatmul_d4_k8_err = ((z.to(torch.float32) - z0.to(torch.float32)).square().sum() / ((z0.to(torch.float32)).square().sum()+1e-10))
print(f"lookupmatmul_d4_k8 error: {lookupmatmul_d4_k8_err}")
torch.cuda.synchronize()
start = time.time()
quiptools_cuda.lookupmatmul_d4_k16(x,yidxs_reordered,cb,z)
torch.cuda.synchronize()
end = time.time()
print(f" elapsed for k16 cuda: {end - start}")
lookupmatmul_d4_k16_err = ((z.to(torch.float32) - z0.to(torch.float32)).square().sum() / ((z0.to(torch.float32)).square().sum()+1e-10))
print(f"lookupmatmul_d4_k16 error: {lookupmatmul_d4_k16_err}")
torch.cuda.synchronize()
start = time.time()
quiptools_cuda.lookupmatmul_d4_k32(x,yidxs_reordered,cb,z)
torch.cuda.synchronize()
end = time.time()
print(f" elapsed for k32 cuda: {end - start}")
lookupmatmul_d4_k32_err = ((z.to(torch.float32) - z0.to(torch.float32)).square().sum() / ((z0.to(torch.float32)).square().sum()+1e-10))
print(f"lookupmatmul_d4_k32 error: {lookupmatmul_d4_k32_err}")