mra / benchmarks /benchmark.py
danieldk's picture
danieldk HF Staff
Benchmarks uploaded using `kernels`.
c958d25 verified
import torch
from kernels.benchmark import Benchmark
def mm_to_sparse_reference(
dense_A: torch.Tensor,
dense_B: torch.Tensor,
indices: torch.Tensor,
) -> torch.Tensor:
batch_size = dense_A.size(0)
A_num_block = dense_A.size(1)
B_num_block = dense_B.size(1)
dim = dense_A.size(2)
num_block = indices.size(1)
# Output: (batch_size, num_block, 32, 32)
sparse_C = torch.zeros(
batch_size, num_block, 32, 32, device=dense_A.device, dtype=dense_A.dtype
)
for b in range(batch_size):
for blk in range(num_block):
AB_idx = indices[b, blk].item()
A_idx = AB_idx // B_num_block
B_idx = AB_idx % B_num_block
A_block = dense_A[b, A_idx] # (dim, 32)
B_block = dense_B[b, B_idx] # (dim, 32)
# Kernel computes C = B.T @ A: (32, dim) @ (dim, 32) = (32, 32)
sparse_C[b, blk] = B_block.T @ A_block
return sparse_C
class MRABenchmark(Benchmark):
seed: int = 42
def setup(self):
# Config matching the kernel's expected format
batch_size = 2
num_heads = 8
head_dim = 64
block_size = 32 # Fixed by kernel
A_num_block = 4
B_num_block = 4
total_blocks = A_num_block * B_num_block
indices_per_block = 4 # Must be divisible by 4
self.batch_heads = batch_size * num_heads
# dense_A: [batch_size, A_num_block, dim, 32]
self.dense_a = torch.randn(
self.batch_heads,
A_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
# dense_B: [batch_size, B_num_block, dim, 32]
self.dense_b = torch.randn(
self.batch_heads,
B_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
# indices: [batch_size, num_block]
self.indices = torch.randint(
0,
total_blocks,
(self.batch_heads, indices_per_block),
device=self.device,
dtype=torch.int32,
)
def benchmark_base(self):
self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)
def verify_base(self) -> torch.Tensor:
return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)
def setup_large(self):
batch_size = 4
num_heads = 8
head_dim = 64
block_size = 32
A_num_block = 8
B_num_block = 8
total_blocks = A_num_block * B_num_block
indices_per_block = 8 # Must be divisible by 4
self.batch_heads = batch_size * num_heads
self.dense_a = torch.randn(
self.batch_heads,
A_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
self.dense_b = torch.randn(
self.batch_heads,
B_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
self.indices = torch.randint(
0,
total_blocks,
(self.batch_heads, indices_per_block),
device=self.device,
dtype=torch.int32,
)
def benchmark_large(self):
self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)
def verify_large(self) -> torch.Tensor:
return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)