import torch import quiptools_cuda def benchmark(): torch.manual_seed(42) M = 1 N = 12288 K = 4096 x = torch.randn((M, K), dtype=torch.float32, device="cuda") Qidxs = torch.randint(1<<15, (N, K//8), dtype=torch.int16, device="cuda") codebook = torch.randint(0x7FFFFFFFFFFFFFFF, (256,), dtype=torch.int64, device="cuda") # start_event = torch.cuda.Event(enable_timing=True) # end_event = torch.cuda.Event(enable_timing=True) # start_event.record() x = quiptools_cuda.decode_matmul_e8p(x, Qidxs - 0x8000, codebook) # end_event.record() # torch.cuda.synchronize() # elapsed_time_ms = start_event.elapsed_time(end_event) # print(f"Elapsed: {elapsed_time_ms:.4f}ms") if __name__ == "__main__": benchmark()