Kernels
rotary / benchmarks /benchmark.py
danieldk's picture
danieldk HF Staff
Benchmarks uploaded using `kernels`.
74d778e verified
import torch
from kernels.benchmark import Benchmark
def apply_rotary_reference(
x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool
) -> tuple[torch.Tensor, torch.Tensor]:
if not conj:
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos
else:
out1 = x1 * cos + x2 * sin
out2 = -x1 * sin + x2 * cos
return out1, out2
class RotaryBenchmark(Benchmark):
seed: int = 42
def setup(self):
batch_size = 2
seqlen = 128
num_heads = 8
head_dim = 64
rotary_dim = 32
# Query tensor split into rotary parts
self.x1 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
self.x2 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
# Rotary position embeddings
self.cos = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
self.sin = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
# Output tensors (in-place, so clone inputs)
self.out1 = self.x1.clone()
self.out2 = self.x2.clone()
def benchmark_base(self):
# Reset outputs to input values for in-place operation
self.out1.copy_(self.x1)
self.out2.copy_(self.x2)
self.kernel.apply_rotary(
self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
)
def verify_base(self) -> torch.Tensor:
ref_out1, ref_out2 = apply_rotary_reference(
self.x1, self.x2, self.cos, self.sin, False
)
# Concatenate for comparison (benchmark compares self.out with returned tensor)
self.out = torch.cat([self.out1, self.out2], dim=-1)
return torch.cat([ref_out1, ref_out2], dim=-1)
def setup_large(self):
batch_size = 8
seqlen = 512
num_heads = 32
rotary_dim = 64
self.x1 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
self.x2 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
self.cos = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
self.sin = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
self.out1 = self.x1.clone()
self.out2 = self.x2.clone()
def benchmark_large(self):
self.out1.copy_(self.x1)
self.out2.copy_(self.x2)
self.kernel.apply_rotary(
self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
)
def verify_large(self) -> torch.Tensor:
ref_out1, ref_out2 = apply_rotary_reference(
self.x1, self.x2, self.cos, self.sin, False
)
self.out = torch.cat([self.out1, self.out2], dim=-1)
return torch.cat([ref_out1, ref_out2], dim=-1)