|
|
|
import torch |
|
import quiptools_cuda |
|
import time |
|
torch.manual_seed(0) |
|
m = 8192*2 |
|
n = 8192//2 |
|
|
|
cb = torch.randn(256,8,dtype=torch.float16,device="cuda") |
|
cb_even = cb.sum(dim=-1) > 0 |
|
yidxs = torch.randint(2**16,(m,n//8),device="cuda").to(torch.int16) |
|
y1 = torch.zeros(m,n,dtype=torch.float16,device="cuda") |
|
|
|
''' |
|
torch.cuda.synchronize() |
|
start = time.time() |
|
y = cb[yidxs.view(-1).to(torch.int32)+2**15,:].view(m,n) |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f"elapsed for pure torch: {end - start}") |
|
''' |
|
|
|
torch.cuda.synchronize() |
|
start = time.time() |
|
quiptools_cuda.decompress_e8p_origorder(yidxs,cb,cb_even,y1) |
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f"elapsed for orig decompress_e8p: {end - start}") |
|
|
|
print(y1) |
|
|
|
|
|
|