| import torch |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def rmsnorm_reference(x: torch.Tensor, eps: float) -> torch.Tensor: |
| rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) |
| return x / rms |
|
|
|
|
| class TinygradRmsBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| batch_size = 32 |
| seq_len = 512 |
| hidden_size = 1024 |
| self.eps = 1e-6 |
|
|
| self.x = torch.randn( |
| batch_size, seq_len, hidden_size, device=self.device, dtype=torch.float32 |
| ) |
| self.out = torch.empty_like(self.x) |
|
|
| def benchmark_base(self): |
| self.out = self.kernel.tinygrad_rms_norm_simple(self.x, self.eps) |
|
|
| def verify_base(self) -> torch.Tensor: |
| return rmsnorm_reference(self.x, self.eps) |
|
|
| def setup_large(self): |
| |
| batch_size = 64 |
| seq_len = 1024 |
| hidden_size = 1024 |
| self.eps = 1e-6 |
|
|
| self.x = torch.randn( |
| batch_size, seq_len, hidden_size, device=self.device, dtype=torch.float32 |
| ) |
| self.out = torch.empty_like(self.x) |
|
|
| def benchmark_large(self): |
| self.out = self.kernel.tinygrad_rms_norm_simple(self.x, self.eps) |
|
|
| def verify_large(self) -> torch.Tensor: |
| return rmsnorm_reference(self.x, self.eps) |
|
|