diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..758e3f23497b40c3017c83994fe2be87efb70d09
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,110 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu126-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu128-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu130-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu126-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu128-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu129-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu126-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu128-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu130-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu126-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu128-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu130-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu126-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu128-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu129-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu126-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu128-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu130-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu126-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu128-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu130-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu126-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu128-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu129-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu126-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu128-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu130-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
+build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8e376826f5a5cf7085a8d2fcf8afe60551c2a799
--- /dev/null
+++ b/README.md
@@ -0,0 +1,16 @@
+---
+tags:
+ - kernels
+---
+
+RWKV kernel for transformers
+### Performance
+
+
+
+
+
+
+
+
+
diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..d997b615e88d1ff417b3b1c159014806440034de
--- /dev/null
+++ b/benchmarks/benchmark.py
@@ -0,0 +1,81 @@
+import torch
+
+from kernels.benchmark import Benchmark
+
+
+def rwkv_wkv_reference(
+ w: torch.Tensor, u: torch.Tensor, k: torch.Tensor, v: torch.Tensor
+) -> torch.Tensor:
+ B, T, C = k.shape
+ device = k.device
+ dtype = k.dtype
+
+ y = torch.zeros(B, T, C, device=device, dtype=dtype)
+
+ # State: accumulated numerator, denominator, and max exponent
+ aa = torch.zeros(B, C, device=device, dtype=torch.float32)
+ bb = torch.zeros(B, C, device=device, dtype=torch.float32)
+ pp = torch.full((B, C), -1e38, device=device, dtype=torch.float32)
+
+ w = w.float()
+ u = u.float()
+
+ for t in range(T):
+ kt = k[:, t, :].float() # [B, C]
+ vt = v[:, t, :].float() # [B, C]
+
+ # Output computation
+ ww = u + kt
+ p = torch.maximum(pp, ww)
+ e1 = torch.exp(pp - p)
+ e2 = torch.exp(ww - p)
+ y[:, t, :] = ((e1 * aa + e2 * vt) / (e1 * bb + e2)).to(dtype)
+
+ # State update (note: w + pp, not pp - w)
+ ww = w + pp
+ p = torch.maximum(ww, kt)
+ e1 = torch.exp(ww - p)
+ e2 = torch.exp(kt - p)
+ aa = e1 * aa + e2 * vt
+ bb = e1 * bb + e2
+ pp = p
+
+ return y
+
+
+class RwkvBenchmark(Benchmark):
+ seed: int = 42
+
+ def setup(self):
+ B, T, C = 2, 64, 256
+
+ self.w = torch.randn(
+ C, device=self.device, dtype=torch.float32
+ ).abs() # Decay should be positive
+ self.u = torch.randn(C, device=self.device, dtype=torch.float32)
+ self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
+ self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
+ self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)
+
+ def benchmark_base(self):
+ self.out.zero_()
+ self.kernel.forward(self.w, self.u, self.k, self.v, self.out)
+
+ def verify_base(self) -> torch.Tensor:
+ return rwkv_wkv_reference(self.w, self.u, self.k, self.v)
+
+ def setup_large(self):
+ B, T, C = 8, 256, 512
+
+ self.w = torch.randn(C, device=self.device, dtype=torch.float32).abs()
+ self.u = torch.randn(C, device=self.device, dtype=torch.float32)
+ self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
+ self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
+ self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)
+
+ def benchmark_large(self):
+ self.out.zero_()
+ self.kernel.forward(self.w, self.u, self.k, self.v, self.out)
+
+ def verify_large(self) -> torch.Tensor:
+ return rwkv_wkv_reference(self.w, self.u, self.k, self.v)
diff --git a/build.toml b/build.toml
new file mode 100644
index 0000000000000000000000000000000000000000..212d2cdef56c5edd6e5a38954dba357f181e69f8
--- /dev/null
+++ b/build.toml
@@ -0,0 +1,31 @@
+[general]
+name = "rwkv"
+universal = false
+
+[torch]
+src = [
+ "torch-ext/torch_binding.cpp",
+]
+
+[kernel.rwkv]
+depends = ["torch"]
+backend = "cuda"
+cuda-capabilities = [
+ "8.0",
+ "8.9",
+ "9.0",
+ "10.0",
+ "12.0",
+]
+include = ["."]
+src = [
+ "rwkv/wkv_cuda.cu",
+ "rwkv/wkv_cuda_bf16.cu",
+]
+cuda-flags = [
+ "-res-usage",
+ "--use_fast_math",
+ "-O3",
+ "--extra-device-vectorization",
+ "-DTmax=1024",
+]
diff --git a/build/torch210-cu128-x86_64-windows/__init__.py b/build/torch210-cu128-x86_64-windows/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd6a665271bbf1428c46acfb331ac4ff0258f5e
--- /dev/null
+++ b/build/torch210-cu128-x86_64-windows/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch210-cu128-x86_64-windows/_ops.py b/build/torch210-cu128-x86_64-windows/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42b492624cadfc963e93ba90b8a1cdecc4b0999
--- /dev/null
+++ b/build/torch210-cu128-x86_64-windows/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_38ccc47
+ops = torch.ops._rwkv_cuda_38ccc47
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_38ccc47::{op_name}"
diff --git a/build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd b/build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd
new file mode 100644
index 0000000000000000000000000000000000000000..3ae6662c009f80685af977e90beb43c793b78c61
--- /dev/null
+++ b/build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e29ecb0b1d346a26decac4671633cf56fbc0b12df10027b1cca06a41f48976a7
+size 423424
diff --git a/build/torch210-cu128-x86_64-windows/metadata.json b/build/torch210-cu128-x86_64-windows/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..a1693688376cea25e6195d1985d05148d8e8d6fe
--- /dev/null
+++ b/build/torch210-cu128-x86_64-windows/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cu128-x86_64-windows/rwkv/__init__.py b/build/torch210-cu128-x86_64-windows/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc434ef44e63409acb52a8f3fff54a4adc46ed6a
--- /dev/null
+++ b/build/torch210-cu128-x86_64-windows/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..03c28f3b8000c9ad78d6c224fe17e1a5089de3c1
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6eedf562b917d7e6865bbeb76296d462c5dd519365ab0454bc31558788c889c3
+size 2232072
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/metadata.json b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,13 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu126-aarch64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu126-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..47529547732afe5c9cf217545484fbbf8b9d06b8
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ef57a3b7b3028cf74874ad5f99fd237a942d20322b6c47e824cdcde75a612ac7
+size 2116408
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,13 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..c9b63aa134adf3218288823bacb450b9e5bb4114
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7baac5a668392dd6029e0b87b9d54f163e0c41ca6052a35a5d03116773784114
+size 2428792
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu128-aarch64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu128-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..9a02a6025049256b06b95909cb4734b5d53cc744
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb33f213f9412b8120ed540d92ef1e9c70ed35590feff5d53113140ce5298ee0
+size 2318768
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..018b21e75e17a5a78cd972bf46f8484e6accd848
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a90cc433fa1453bb04a4993f2af0485266c227d0ece13808be55d3a7f280f4f
+size 2432776
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu130-aarch64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu130-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..94af622484b088865eb866947b7bfbd73ad56149
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:024c9733ab553416a3165fabcfd1f8957e99b2c256c3fc5f51d1e3032200ff96
+size 2348248
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..b30b2056e569db65fccc0195498ab93179539ef0
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f97c2ec3958d76db37b74a83093e2f127317b51c94d089bc0ac710499fbdc49a
+size 2232072
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/metadata.json b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,13 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu126-aarch64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu126-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..f9bd0e3c7d67e019fcca722f0d1f3706e040a58a
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a55661a0185ca9375700bbd80fe85272180c0f081e115d45ef53b471f55fc1e9
+size 2116408
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,13 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu126-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..0b930c093e6f03c71f6d22f4f1bfb95e64aaeeb4
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01ce97da4775c7e455e8002a81050197073a3455e1465674e0f08f41eee082fe
+size 2428792
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu128-aarch64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu128-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..fa4f44f581e035f206ca3b10b7567118736def1f
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ba2b12cfcc7f264fcb312f71be6f98763a697e1870300120fc4294a7715d5de
+size 2318768
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu128-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..0b55948d42d16a03617534637377957b2be5d966
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3095e84d107fbbe7425a85b2f9cde793d0fef4c25afea4f7f42ff16974268106
+size 2432776
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu130-aarch64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu130-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..8ae414352979a91551bbb10891b2a223333279d6
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:582930c14cdc9d0f9a5b83ad198433bbf0d53e2a3b904a63d788f4918487f347
+size 2348248
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch211-cxx11-cu130-x86_64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch211-cxx11-cu130-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c881de2e0f54ba27fce13d55f35013e2b90dea0b
Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49bded237b33ffeb440b398ca5da96497d92364c
Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c54b92a73953ef7168a8a45b1d3caa2e64acc4cb
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_0aa7cb0
+ops = torch.ops._rwkv_0aa7cb0
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_0aa7cb0::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..af67a351934b49c158ae3e5d7a4fa465ace19ab1
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a561edd91db3da01aad4f484eca06f44e98843f859317aeaa9a3df85339ce850
+size 2065400
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dee5151cb6937785bd5a63c6a8f02f9bb3d351c
Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bad2999c65772f5fd868bf6af0fc88b592fa665
Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c54b92a73953ef7168a8a45b1d3caa2e64acc4cb
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_0aa7cb0
+ops = torch.ops._rwkv_0aa7cb0
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_0aa7cb0::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..fdf4405e0825430c46f683f9a763da3a39d96873
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b115095e24c64d21eecefdc0ed4d9c13ea9485930f8f56cc99874762c501f278
+size 2106408
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..382a6ba0237ee4e4fe73cbb610a447987d9005e7
Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fdac79c69dd79159f35740b56bde8c18f82a375
Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c54b92a73953ef7168a8a45b1d3caa2e64acc4cb
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_0aa7cb0
+ops = torch.ops._rwkv_0aa7cb0
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_0aa7cb0::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..df248c48b997c805b9ec6368141f1d027a9c641b
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19029bb516a8cf3428f4e47a2b5f26f091d169206e51df7c4e89880553a15df7
+size 2308848
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_44f2fa4
+ops = torch.ops._rwkv_44f2fa4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_44f2fa4::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..4cc06763b7d171a586c906cf1e6bb55cba0e4709
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d383c1186dde17bb427291205ae83bdb2e49abf15b4eb482151aa0940e51d4d
+size 2111504
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_44f2fa4
+ops = torch.ops._rwkv_44f2fa4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_44f2fa4::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..026e5fa245ea37eee14a4233f948fd28abe48747
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5fb0335b36148895aac63e1822e14c4be054dd842dbdddcfab668b65b2335ff2
+size 2318032
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_44f2fa4
+ops = torch.ops._rwkv_44f2fa4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_44f2fa4::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..62492882cb398511bfca1748b134d07316503937
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2178eb41ec85f3861d61391de45ec47c3127490fbd902f1e2167683514cf867b
+size 2335432
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dffb89844d9c7644252ba9bc5b0635b106174cd
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_efd954c
+ops = torch.ops._rwkv_cuda_efd954c
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_efd954c::{op_name}"
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so b/build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..894d88388663606729b454e35fd4a17f9de139dc
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6d39263049bd9e5823007907a9b2f366226c9dacc5a6157f68f8e317bd1108c
+size 2230968
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/metadata.json b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json
@@ -0,0 +1,13 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu126-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu126-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_44f2fa4
+ops = torch.ops._rwkv_44f2fa4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_44f2fa4::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..87947ce206483ed63225e5ac32379232817c05b3
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64e542cd9e7f9a7ea5bf22a91a27fb34f6352b7fbbc39463f2f5d183f6386d05
+size 2111480
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dffb89844d9c7644252ba9bc5b0635b106174cd
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_efd954c
+ops = torch.ops._rwkv_cuda_efd954c
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_efd954c::{op_name}"
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..87d47bd77178df8cfa9089434aa9eca8f59562de
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d1cde967371ab4341efc9a43eead44a82fe2e80099f170e0f4781f96f8d68961
+size 2427688
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu128-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu128-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_44f2fa4
+ops = torch.ops._rwkv_44f2fa4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_44f2fa4::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..602898b1e3cb5930b41c1772b31a77942b830400
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e37b16a151528c1ad0eecfeb1c3296ac45e6856d3b5b08f7acc68499e3ceca4c
+size 2318000
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..07ad4b66966eb57d772dffb889dafeabb11fa217
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03745749fe48ca7bda892fa48ee9ca71897488a03287aa18dcd103b1baf63c85
+size 2429112
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu129-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch29-cxx11-cu129-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_5849bdb
+ops = torch.ops._rwkv_cuda_5849bdb
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_5849bdb::{op_name}"
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..d3c16bd6d3b3f3eec3b740fe15241cc7faba92d9
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e124cba0e78954c0aa034427c32ff799bd43f5050eeaceca2c8989841b45495
+size 2335432
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu129-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23
--- /dev/null
+++ b/build/torch29-cxx11-cu129-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import importlib.util
+import sys
+from pathlib import Path
+from types import ModuleType
+
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dffb89844d9c7644252ba9bc5b0635b106174cd
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_cuda_efd954c
+ops = torch.ops._rwkv_cuda_efd954c
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_cuda_efd954c::{op_name}"
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..8325463d49dec8d7492c677dbe0863e4056c206f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab7e3bbeaf2bb811191f7dd30fa68f7c5c772423020b9fe994d31273bf77dc4a
+size 2431672
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json
@@ -0,0 +1,15 @@
+{
+ "version": 1,
+ "license": "Apache-2.0",
+ "python-depends": [],
+ "backend": {
+ "type": "cuda",
+ "archs": [
+ "10.0",
+ "12.0",
+ "8.0",
+ "8.9",
+ "9.0"
+ ]
+ }
+}
diff --git a/build/torch29-cxx11-cu130-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu130-aarch64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
@@ -0,0 +1,9 @@
+import torch
+from . import _rwkv_44f2fa4
+ops = torch.ops._rwkv_44f2fa4
+
+def add_op_namespace_prefix(op_name: str):
+ """
+ Prefix op by namespace.
+ """
+ return f"_rwkv_44f2fa4::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..6844b7973cff37a87fa06f89130dba801198b046
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66177f2e067a2624a79e80664f3d87547c1a5916a91711f792495c4a77f5b70f
+size 2343392
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json
new file mode 100644
index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json
@@ -0,0 +1,4 @@
+{
+ "version": 1,
+ "python-depends": []
+}
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py
@@ -0,0 +1,26 @@
+import ctypes
+import sys
+
+import importlib
+from pathlib import Path
+from types import ModuleType
+
+def _import_from_path(file_path: Path) -> ModuleType:
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
+ # it would also be used for other imports. So, we make a module name that
+ # depends on the path for it to be unique using the hex-encoded hash of
+ # the path.
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
+ module_name = path_hash
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
+ module = importlib.util.module_from_spec(spec)
+ if module is None:
+ raise ImportError(f"Cannot load module {module_name} from spec")
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
diff --git a/flake.lock b/flake.lock
new file mode 100644
index 0000000000000000000000000000000000000000..85b5d60a855bf4c19555cc9b8de8ca88d6fd3ae9
--- /dev/null
+++ b/flake.lock
@@ -0,0 +1,168 @@
+{
+ "nodes": {
+ "flake-compat": {
+ "locked": {
+ "lastModified": 1747046372,
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
+ "type": "github"
+ },
+ "original": {
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "type": "github"
+ }
+ },
+ "flake-compat_2": {
+ "locked": {
+ "lastModified": 1747046372,
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
+ "type": "github"
+ },
+ "original": {
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "type": "github"
+ }
+ },
+ "flake-utils": {
+ "inputs": {
+ "systems": "systems"
+ },
+ "locked": {
+ "lastModified": 1731533236,
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
+ "type": "github"
+ },
+ "original": {
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "type": "github"
+ }
+ },
+ "flake-utils_2": {
+ "inputs": {
+ "systems": "systems_2"
+ },
+ "locked": {
+ "lastModified": 1731533236,
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
+ "type": "github"
+ },
+ "original": {
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "type": "github"
+ }
+ },
+ "hf-nix": {
+ "inputs": {
+ "flake-compat": "flake-compat_2",
+ "flake-utils": "flake-utils_2",
+ "nixpkgs": "nixpkgs"
+ },
+ "locked": {
+ "lastModified": 1759851564,
+ "narHash": "sha256-Xybkhm0FM/VzlZ5WndTYq/X/9MAeddd4EQ2Vz8GdkOA=",
+ "owner": "huggingface",
+ "repo": "hf-nix",
+ "rev": "351655d9f124805ed7c1193aa61550ce245f4570",
+ "type": "github"
+ },
+ "original": {
+ "owner": "huggingface",
+ "repo": "hf-nix",
+ "type": "github"
+ }
+ },
+ "kernel-builder": {
+ "inputs": {
+ "flake-compat": "flake-compat",
+ "flake-utils": "flake-utils",
+ "hf-nix": "hf-nix",
+ "nixpkgs": [
+ "kernel-builder",
+ "hf-nix",
+ "nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1760035358,
+ "narHash": "sha256-N5vmCrgwcIluPclf/hmnofLK77EJJYh5PR8SRvw++es=",
+ "owner": "huggingface",
+ "repo": "kernel-builder",
+ "rev": "a48cbd19ae7e425dfc1865188ef06dac43ab9244",
+ "type": "github"
+ },
+ "original": {
+ "owner": "huggingface",
+ "repo": "kernel-builder",
+ "type": "github"
+ }
+ },
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1755963616,
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
+ "owner": "nixos",
+ "repo": "nixpkgs",
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nixos",
+ "ref": "nixos-unstable-small",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "kernel-builder": "kernel-builder"
+ }
+ },
+ "systems": {
+ "locked": {
+ "lastModified": 1681028828,
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
+ "owner": "nix-systems",
+ "repo": "default",
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nix-systems",
+ "repo": "default",
+ "type": "github"
+ }
+ },
+ "systems_2": {
+ "locked": {
+ "lastModified": 1681028828,
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
+ "owner": "nix-systems",
+ "repo": "default",
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nix-systems",
+ "repo": "default",
+ "type": "github"
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/flake.nix b/flake.nix
new file mode 100644
index 0000000000000000000000000000000000000000..f4954b5ca991ee9ab5893205c29997dd1c1112c2
--- /dev/null
+++ b/flake.nix
@@ -0,0 +1,17 @@
+{
+ description = "Flake for rwkv kernels";
+
+ inputs = {
+ kernel-builder.url = "github:huggingface/kernel-builder";
+ };
+
+ outputs =
+ {
+ self,
+ kernel-builder,
+ }:
+ kernel-builder.lib.genFlakeOutputs {
+ path = ./.;
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
+ };
+}
diff --git a/media/benches_dark_animation.svg b/media/benches_dark_animation.svg
new file mode 100644
index 0000000000000000000000000000000000000000..5c6d7df0404cd13941099d4dbdac79a6f9a49cae
--- /dev/null
+++ b/media/benches_dark_animation.svg
@@ -0,0 +1,33 @@
+
\ No newline at end of file
diff --git a/media/benches_dark_latency.svg b/media/benches_dark_latency.svg
new file mode 100644
index 0000000000000000000000000000000000000000..b725426e75ca91a0130ef63bc8dc9b5988685d79
--- /dev/null
+++ b/media/benches_dark_latency.svg
@@ -0,0 +1,1944 @@
+
+
+
diff --git a/media/benches_dark_throughput.svg b/media/benches_dark_throughput.svg
new file mode 100644
index 0000000000000000000000000000000000000000..0cf5976f02e5f5c049fcc4f08db1795e86245f32
--- /dev/null
+++ b/media/benches_dark_throughput.svg
@@ -0,0 +1,2054 @@
+
+
+
diff --git a/media/benches_light_animation.svg b/media/benches_light_animation.svg
new file mode 100644
index 0000000000000000000000000000000000000000..c24d1ef7199943fff7a5ae2267d07fcc788c2251
--- /dev/null
+++ b/media/benches_light_animation.svg
@@ -0,0 +1,33 @@
+
\ No newline at end of file
diff --git a/media/benches_light_latency.svg b/media/benches_light_latency.svg
new file mode 100644
index 0000000000000000000000000000000000000000..72158f4383b74fe8054c8aa3115f986ab59ba297
--- /dev/null
+++ b/media/benches_light_latency.svg
@@ -0,0 +1,1944 @@
+
+
+
diff --git a/media/benches_light_throughput.svg b/media/benches_light_throughput.svg
new file mode 100644
index 0000000000000000000000000000000000000000..571a3e12bdbdf7761cb37e68c2d3bf55d776bbce
--- /dev/null
+++ b/media/benches_light_throughput.svg
@@ -0,0 +1,2054 @@
+
+
+
diff --git a/rwkv/wkv_cuda.cu b/rwkv/wkv_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..571d5a8a8307e95aac689eb3c9333d1ad350c7de
--- /dev/null
+++ b/rwkv/wkv_cuda.cu
@@ -0,0 +1,187 @@
+#include
+#include
+
+#define MIN_VALUE (-1e38)
+
+template
+__global__ void kernel_forward(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ F aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+}
+
+template
+__global__ void kernel_forward_with_state(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset_s = _b * C * 3 + _c * 3;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+ F *__restrict__ const s = _s + _offset_s;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ F aa = s[0], bb = s[1], pp = s[2];
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ s[0] = aa;
+ s[1] = bb;
+ s[2] = pp;
+}
+
+template
+__global__ void kernel_backward(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
+ const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
+ F *__restrict__ const _gv
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ const F *__restrict__ const y = _y + _offset;
+ const F *__restrict__ const gy = _gy + _offset;
+ F *__restrict__ const gk = _gk + _offset;
+ F *__restrict__ const gv = _gv + _offset;
+
+ F q[Tmax], r[Tmax];
+
+ F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+ const F yy = y[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ const F qq = gy[ii] / (e1 * bb + e2);
+ gw += (ga - gb * yy) * e1 * qq;
+ gu += (vv - yy) * e2 * qq;
+ q[i] = qq;
+ r[i] = ww - p;
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ ga = e1 * (aa + ga);
+ gb = e1 * (bb + gb);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
+ _gu[_offsetBC] = gu;
+
+ aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+ const F yy = y[ii];
+ const F qq = q[i];
+ const F rr = r[i];
+
+ F e1 = qq * exp(rr);
+ F e2 = exp(kk + pp);
+ gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
+ gv[ii] = e1 + e2 * aa;
+
+ const F ww = w + pp;
+ const F www = rr - u - kk;
+ const F p = max(ww, www);
+ e1 = exp(ww - p);
+ e2 = qq * exp(www - p);
+ aa = e1 * aa + e2;
+ bb = e1 * bb - e2 * yy;
+ pp = p;
+ }
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward<<>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_with_state<<>>(B, T, C, w, u, k, v, y, s);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
+}
diff --git a/rwkv/wkv_cuda_bf16.cu b/rwkv/wkv_cuda_bf16.cu
new file mode 100644
index 0000000000000000000000000000000000000000..042cb4aba1db98be5916aea1de86a7fed0b6510d
--- /dev/null
+++ b/rwkv/wkv_cuda_bf16.cu
@@ -0,0 +1,186 @@
+#include
+#include
+#include "ATen/ATen.h"
+#define MIN_VALUE (-1e38)
+typedef at::BFloat16 bf16;
+
+__global__ void kernel_forward_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ bf16 *__restrict__ const y = _y + _offset;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ float aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+}
+
+__global__ void kernel_forward_with_state_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y,
+ float *__restrict__ const _s
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset_s = _b * C * 3 + _c * 3;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ bf16 *__restrict__ const y = _y + _offset;
+ float *__restrict__ const s = _s + _offset_s;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ float aa = s[0], bb = s[1], pp = s[2];
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ s[0] = aa;
+ s[1] = bb;
+ s[2] = pp;
+}
+
+__global__ void kernel_backward_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y,
+ const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu,
+ bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ const bf16 *__restrict__ const y = _y + _offset;
+ const bf16 *__restrict__ const gy = _gy + _offset;
+ bf16 *__restrict__ const gk = _gk + _offset;
+ bf16 *__restrict__ const gv = _gv + _offset;
+
+ float q[Tmax], r[Tmax];
+
+ float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+ const float yy = float(y[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ const float qq = float(gy[ii]) / (e1 * bb + e2);
+ gw += (ga - gb * yy) * e1 * qq;
+ gu += (vv - yy) * e2 * qq;
+ q[i] = qq;
+ r[i] = ww - p;
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ ga = e1 * (aa + ga);
+ gb = e1 * (bb + gb);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
+ _gu[_offsetBC] = bf16(gu);
+
+ aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+ const float yy = float(y[ii]);
+ const float qq = q[i];
+ const float rr = r[i];
+
+ float e1 = qq * exp(rr);
+ float e2 = exp(kk + pp);
+ gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
+ gv[ii] = bf16(e1 + e2 * aa);
+
+ const float ww = w + pp;
+ const float www = rr - u - kk;
+ const float p = max(ww, www);
+ e1 = exp(ww - p);
+ e2 = qq * exp(www - p);
+ aa = e1 * aa + e2;
+ bb = e1 * bb - e2 * yy;
+ pp = p;
+ }
+}
+
+void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_bf16<<>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_with_state_bf16<<>>(B, T, C, w, u, k, v, y, s);
+}
+
+void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_backward_bf16<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
+}
diff --git a/torch-ext/rwkv/__init__.py b/torch-ext/rwkv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494
--- /dev/null
+++ b/torch-ext/rwkv/__init__.py
@@ -0,0 +1,170 @@
+from ._ops import ops
+from typing import Tuple, Any
+
+# Use a broad Tensor alias to avoid importing torch at import time.
+from torch import Tensor
+
+def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (float32).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward(w, u, k, v, y)
+
+
+def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None:
+ """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``).
+
+ Runs the CUDA kernel and writes the result into ``y`` in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y))
+ ops.forward_bf16(w, u, k, v, y)
+
+
+def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (float32).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state(w, u, k, v, y, s)
+
+
+def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None:
+ """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``).
+
+ Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+ s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs.
+ """
+ _validate_device_match((w, u, k, v, y, s))
+ ops.forward_with_state_bf16(w, u, k, v, y, s)
+
+
+def backward(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (float32).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def backward_bf16(
+ w: Tensor,
+ u: Tensor,
+ k: Tensor,
+ v: Tensor,
+ y: Tensor,
+ gy: Tensor,
+ gw: Tensor,
+ gu: Tensor,
+ gk: Tensor,
+ gv: Tensor,
+) -> None:
+ """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``).
+
+ Writes gradients into the provided tensors in-place.
+
+ Args:
+ w: Decay weights, shape ``[C]``, dtype ``torch.float32``.
+ u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``.
+ gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place).
+ gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place).
+
+ Notes:
+ - All tensors must be on the same CUDA device.
+ - Shapes must agree on ``B``, ``T`` and ``C``.
+ """
+ _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv))
+ ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv)
+
+
+def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None:
+ """Minimal runtime validation that all tensors live on the same CUDA device."""
+ if not tensors:
+ return
+ device = tensors[0].device
+ if not device.type == "cuda":
+ raise RuntimeError("RWKV CUDA ops require CUDA tensors")
+ for t in tensors[1:]:
+ if t.device != device:
+ raise RuntimeError("All tensors must be on the same CUDA device")
+
+
+__all__ = [
+ "forward",
+ "forward_bf16",
+ "forward_with_state",
+ "forward_with_state_bf16",
+ "backward",
+ "backward_bf16",
+]
\ No newline at end of file
diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c42395d0aa9bb14566a58d73f959c9803fc73cde
--- /dev/null
+++ b/torch-ext/torch_binding.cpp
@@ -0,0 +1,74 @@
+#include
+#include "ATen/ATen.h"
+#include
+
+#include "registration.h"
+
+typedef at::BFloat16 bf16;
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
+void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
+void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
+void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
+void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
+
+void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr());
+}
+void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_forward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr());
+}
+void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_forward_with_state(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr());
+}
+void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_forward_with_state_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr());
+}
+void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr());
+}
+void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_backward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(),
+ gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr());
+}
+
+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
+ ops.def("forward", forward);
+ ops.impl("forward", torch::kCUDA, &forward);
+
+ ops.def("forward_bf16", forward_bf16);
+ ops.impl("forward_bf16", torch::kCUDA, &forward_bf16);
+
+ ops.def("forward_with_state", forward_with_state);
+ ops.impl("forward_with_state", torch::kCUDA, &forward_with_state);
+
+ ops.def("forward_with_state_bf16", forward_with_state_bf16);
+ ops.impl("forward_with_state_bf16", torch::kCUDA, &forward_with_state_bf16);
+
+ ops.def("backward", backward);
+ ops.impl("backward", torch::kCUDA, &backward);
+
+ ops.def("backward_bf16", backward_bf16);
+ ops.impl("backward_bf16", torch::kCUDA, &backward_bf16);
+}
+
+REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
\ No newline at end of file