| import torch |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def rwkv_wkv_reference( |
| w: torch.Tensor, u: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| ) -> torch.Tensor: |
| B, T, C = k.shape |
| device = k.device |
| dtype = k.dtype |
|
|
| y = torch.zeros(B, T, C, device=device, dtype=dtype) |
|
|
| |
| aa = torch.zeros(B, C, device=device, dtype=torch.float32) |
| bb = torch.zeros(B, C, device=device, dtype=torch.float32) |
| pp = torch.full((B, C), -1e38, device=device, dtype=torch.float32) |
|
|
| w = w.float() |
| u = u.float() |
|
|
| for t in range(T): |
| kt = k[:, t, :].float() |
| vt = v[:, t, :].float() |
|
|
| |
| ww = u + kt |
| p = torch.maximum(pp, ww) |
| e1 = torch.exp(pp - p) |
| e2 = torch.exp(ww - p) |
| y[:, t, :] = ((e1 * aa + e2 * vt) / (e1 * bb + e2)).to(dtype) |
|
|
| |
| ww = w + pp |
| p = torch.maximum(ww, kt) |
| e1 = torch.exp(ww - p) |
| e2 = torch.exp(kt - p) |
| aa = e1 * aa + e2 * vt |
| bb = e1 * bb + e2 |
| pp = p |
|
|
| return y |
|
|
|
|
| class RwkvBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| B, T, C = 2, 64, 256 |
|
|
| self.w = torch.randn( |
| C, device=self.device, dtype=torch.float32 |
| ).abs() |
| self.u = torch.randn(C, device=self.device, dtype=torch.float32) |
| self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32) |
|
|
| def benchmark_base(self): |
| self.out.zero_() |
| self.kernel.forward(self.w, self.u, self.k, self.v, self.out) |
|
|
| def verify_base(self) -> torch.Tensor: |
| return rwkv_wkv_reference(self.w, self.u, self.k, self.v) |
|
|
| def setup_large(self): |
| B, T, C = 8, 256, 512 |
|
|
| self.w = torch.randn(C, device=self.device, dtype=torch.float32).abs() |
| self.u = torch.randn(C, device=self.device, dtype=torch.float32) |
| self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32) |
|
|
| def benchmark_large(self): |
| self.out.zero_() |
| self.kernel.forward(self.w, self.u, self.k, self.v, self.out) |
|
|
| def verify_large(self) -> torch.Tensor: |
| return rwkv_wkv_reference(self.w, self.u, self.k, self.v) |
|
|