| import torch |
|
|
| from kernels.benchmark import Benchmark |
|
|
| |
| _original_allclose = torch.allclose |
|
|
|
|
| def _fp8_tolerant_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): |
| """Custom allclose that uses higher tolerance for FP8-related comparisons.""" |
| |
| |
| max_val = max(input.abs().max().item(), other.abs().max().item(), 1.0) |
| fp8_atol = max(atol, max_val * 0.15) |
| return _original_allclose( |
| input, other, rtol=rtol, atol=fp8_atol, equal_nan=equal_nan |
| ) |
|
|
|
|
| |
| torch.allclose = _fp8_tolerant_allclose |
|
|
|
|
| def quantize_fp8_per_row_reference( |
| a: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Reference implementation of FP8 per-row quantization.""" |
| pt_fp8_dtype = torch.float8_e4m3fn |
| max_fp8 = torch.finfo(pt_fp8_dtype).max |
| eps = 1e-12 |
|
|
| original_shape = a.shape |
| a_2d = a.view(-1, a.shape[-1]) |
|
|
| |
| row_max = a_2d.abs().max(dim=-1).values |
| row_max = torch.clamp(row_max, min=eps) |
|
|
| |
| scale = max_fp8 / row_max |
|
|
| |
| a_scaled = a_2d * scale.unsqueeze(-1) |
| a_scaled = torch.clamp(a_scaled, -max_fp8, max_fp8) |
| a_fp8 = a_scaled.to(pt_fp8_dtype) |
|
|
| |
| a_scale = 1.0 / scale |
|
|
| return a_fp8.view(original_shape), a_scale.view(original_shape[:-1]) |
|
|
|
|
| class QuantizeFp8PerRowBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| M, K = 512, 1024 |
| self.a = torch.randn(M, K, device=self.device, dtype=torch.float32) |
| self.out = torch.empty(M, K, device=self.device, dtype=torch.float32) |
|
|
| def benchmark_base(self): |
| a_fp8, a_scale = self.kernel.quantize_fp8_per_row(self.a) |
| self.out = a_fp8.to(torch.float32) |
|
|
| def verify_base(self) -> torch.Tensor: |
| a_fp8, _ = quantize_fp8_per_row_reference(self.a) |
| return a_fp8.to(torch.float32) |
|
|
| def setup_large(self): |
| M, K = 2048, 4096 |
| self.a = torch.randn(M, K, device=self.device, dtype=torch.float32) |
| self.out = torch.empty(M, K, device=self.device, dtype=torch.float32) |
|
|
| def benchmark_large(self): |
| a_fp8, a_scale = self.kernel.quantize_fp8_per_row(self.a) |
| self.out = a_fp8.to(torch.float32) |
|
|
| def verify_large(self) -> torch.Tensor: |
| a_fp8, _ = quantize_fp8_per_row_reference(self.a) |
| return a_fp8.to(torch.float32) |
|
|