File size: 771 Bytes
c1a41d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import quiptools_cuda


def benchmark():
    torch.manual_seed(42)
    M = 1
    N = 12288
    K = 4096

    x = torch.randn((M, K), dtype=torch.float32, device="cuda")
    Qidxs = torch.randint(1<<15, (N, K//8), dtype=torch.int16, device="cuda")
    codebook = torch.randint(0x7FFFFFFFFFFFFFFF, (256,), dtype=torch.int64, device="cuda")

    # start_event = torch.cuda.Event(enable_timing=True)
    # end_event = torch.cuda.Event(enable_timing=True)
    # start_event.record()
    x = quiptools_cuda.decode_matmul_e8p(x, Qidxs - 0x8000, codebook)
    # end_event.record()
    # torch.cuda.synchronize()
    # elapsed_time_ms = start_event.elapsed_time(end_event)
    # print(f"Elapsed: {elapsed_time_ms:.4f}ms")


if __name__ == "__main__":
    benchmark()