diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..758e3f23497b40c3017c83994fe2be87efb70d09 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,110 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_eb0e3e5_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_rwkv_d5e72cf.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_rwkv_1aa1cb1.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_rwkv_86a3859.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8e376826f5a5cf7085a8d2fcf8afe60551c2a799 --- /dev/null +++ b/README.md @@ -0,0 +1,16 @@ +--- +tags: + - kernels +--- + +RWKV kernel for transformers +### Performance + + + + + + + + + diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..d997b615e88d1ff417b3b1c159014806440034de --- /dev/null +++ b/benchmarks/benchmark.py @@ -0,0 +1,81 @@ +import torch + +from kernels.benchmark import Benchmark + + +def rwkv_wkv_reference( + w: torch.Tensor, u: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: + B, T, C = k.shape + device = k.device + dtype = k.dtype + + y = torch.zeros(B, T, C, device=device, dtype=dtype) + + # State: accumulated numerator, denominator, and max exponent + aa = torch.zeros(B, C, device=device, dtype=torch.float32) + bb = torch.zeros(B, C, device=device, dtype=torch.float32) + pp = torch.full((B, C), -1e38, device=device, dtype=torch.float32) + + w = w.float() + u = u.float() + + for t in range(T): + kt = k[:, t, :].float() # [B, C] + vt = v[:, t, :].float() # [B, C] + + # Output computation + ww = u + kt + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + y[:, t, :] = ((e1 * aa + e2 * vt) / (e1 * bb + e2)).to(dtype) + + # State update (note: w + pp, not pp - w) + ww = w + pp + p = torch.maximum(ww, kt) + e1 = torch.exp(ww - p) + e2 = torch.exp(kt - p) + aa = e1 * aa + e2 * vt + bb = e1 * bb + e2 + pp = p + + return y + + +class RwkvBenchmark(Benchmark): + seed: int = 42 + + def setup(self): + B, T, C = 2, 64, 256 + + self.w = torch.randn( + C, device=self.device, dtype=torch.float32 + ).abs() # Decay should be positive + self.u = torch.randn(C, device=self.device, dtype=torch.float32) + self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 + self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 + self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32) + + def benchmark_base(self): + self.out.zero_() + self.kernel.forward(self.w, self.u, self.k, self.v, self.out) + + def verify_base(self) -> torch.Tensor: + return rwkv_wkv_reference(self.w, self.u, self.k, self.v) + + def setup_large(self): + B, T, C = 8, 256, 512 + + self.w = torch.randn(C, device=self.device, dtype=torch.float32).abs() + self.u = torch.randn(C, device=self.device, dtype=torch.float32) + self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 + self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 + self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32) + + def benchmark_large(self): + self.out.zero_() + self.kernel.forward(self.w, self.u, self.k, self.v, self.out) + + def verify_large(self) -> torch.Tensor: + return rwkv_wkv_reference(self.w, self.u, self.k, self.v) diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..212d2cdef56c5edd6e5a38954dba357f181e69f8 --- /dev/null +++ b/build.toml @@ -0,0 +1,31 @@ +[general] +name = "rwkv" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", +] + +[kernel.rwkv] +depends = ["torch"] +backend = "cuda" +cuda-capabilities = [ + "8.0", + "8.9", + "9.0", + "10.0", + "12.0", +] +include = ["."] +src = [ + "rwkv/wkv_cuda.cu", + "rwkv/wkv_cuda_bf16.cu", +] +cuda-flags = [ + "-res-usage", + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + "-DTmax=1024", +] diff --git a/build/torch210-cu128-x86_64-windows/__init__.py b/build/torch210-cu128-x86_64-windows/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd6a665271bbf1428c46acfb331ac4ff0258f5e --- /dev/null +++ b/build/torch210-cu128-x86_64-windows/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch210-cu128-x86_64-windows/_ops.py b/build/torch210-cu128-x86_64-windows/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d42b492624cadfc963e93ba90b8a1cdecc4b0999 --- /dev/null +++ b/build/torch210-cu128-x86_64-windows/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_38ccc47 +ops = torch.ops._rwkv_cuda_38ccc47 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_38ccc47::{op_name}" diff --git a/build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd b/build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd new file mode 100644 index 0000000000000000000000000000000000000000..3ae6662c009f80685af977e90beb43c793b78c61 --- /dev/null +++ b/build/torch210-cu128-x86_64-windows/_rwkv_cuda_38ccc47.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e29ecb0b1d346a26decac4671633cf56fbc0b12df10027b1cca06a41f48976a7 +size 423424 diff --git a/build/torch210-cu128-x86_64-windows/metadata.json b/build/torch210-cu128-x86_64-windows/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a1693688376cea25e6195d1985d05148d8e8d6fe --- /dev/null +++ b/build/torch210-cu128-x86_64-windows/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cu128-x86_64-windows/rwkv/__init__.py b/build/torch210-cu128-x86_64-windows/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc434ef44e63409acb52a8f3fff54a4adc46ed6a --- /dev/null +++ b/build/torch210-cu128-x86_64-windows/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-aarch64-linux/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..03c28f3b8000c9ad78d6c224fe17e1a5089de3c1 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6eedf562b917d7e6865bbeb76296d462c5dd519365ab0454bc31558788c889c3 +size 2232072 diff --git a/build/torch210-cxx11-cu126-aarch64-linux/metadata.json b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,13 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu126-aarch64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..47529547732afe5c9cf217545484fbbf8b9d06b8 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef57a3b7b3028cf74874ad5f99fd237a942d20322b6c47e824cdcde75a612ac7 +size 2116408 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,13 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c9b63aa134adf3218288823bacb450b9e5bb4114 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7baac5a668392dd6029e0b87b9d54f163e0c41ca6052a35a5d03116773784114 +size 2428792 diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu128-aarch64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..9a02a6025049256b06b95909cb4734b5d53cc744 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb33f213f9412b8120ed540d92ef1e9c70ed35590feff5d53113140ce5298ee0 +size 2318768 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..018b21e75e17a5a78cd972bf46f8484e6accd848 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a90cc433fa1453bb04a4993f2af0485266c227d0ece13808be55d3a7f280f4f +size 2432776 diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..94af622484b088865eb866947b7bfbd73ad56149 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:024c9733ab553416a3165fabcfd1f8957e99b2c256c3fc5f51d1e3032200ff96 +size 2348248 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu130-x86_64-linux/rwkv/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu126-aarch64-linux/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..b30b2056e569db65fccc0195498ab93179539ef0 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f97c2ec3958d76db37b74a83093e2f127317b51c94d089bc0ac710499fbdc49a +size 2232072 diff --git a/build/torch211-cxx11-cu126-aarch64-linux/metadata.json b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,13 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu126-aarch64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f9bd0e3c7d67e019fcca722f0d1f3706e040a58a --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a55661a0185ca9375700bbd80fe85272180c0f081e115d45ef53b471f55fc1e9 +size 2116408 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,13 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0b930c093e6f03c71f6d22f4f1bfb95e64aaeeb4 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01ce97da4775c7e455e8002a81050197073a3455e1465674e0f08f41eee082fe +size 2428792 diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu128-aarch64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..fa4f44f581e035f206ca3b10b7567118736def1f --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ba2b12cfcc7f264fcb312f71be6f98763a697e1870300120fc4294a7715d5de +size 2318768 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0b55948d42d16a03617534637377957b2be5d966 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3095e84d107fbbe7425a85b2f9cde793d0fef4c25afea4f7f42ff16974268106 +size 2432776 diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu130-aarch64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8ae414352979a91551bbb10891b2a223333279d6 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:582930c14cdc9d0f9a5b83ad198433bbf0d53e2a3b904a63d788f4918487f347 +size 2348248 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu130-x86_64-linux/rwkv/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c881de2e0f54ba27fce13d55f35013e2b90dea0b Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49bded237b33ffeb440b398ca5da96497d92364c Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c54b92a73953ef7168a8a45b1d3caa2e64acc4cb --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_0aa7cb0 +ops = torch.ops._rwkv_0aa7cb0 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_0aa7cb0::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..af67a351934b49c158ae3e5d7a4fa465ace19ab1 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a561edd91db3da01aad4f484eca06f44e98843f859317aeaa9a3df85339ce850 +size 2065400 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dee5151cb6937785bd5a63c6a8f02f9bb3d351c Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bad2999c65772f5fd868bf6af0fc88b592fa665 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c54b92a73953ef7168a8a45b1d3caa2e64acc4cb --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_0aa7cb0 +ops = torch.ops._rwkv_0aa7cb0 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_0aa7cb0::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..fdf4405e0825430c46f683f9a763da3a39d96873 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b115095e24c64d21eecefdc0ed4d9c13ea9485930f8f56cc99874762c501f278 +size 2106408 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..382a6ba0237ee4e4fe73cbb610a447987d9005e7 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fdac79c69dd79159f35740b56bde8c18f82a375 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c54b92a73953ef7168a8a45b1d3caa2e64acc4cb --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_0aa7cb0 +ops = torch.ops._rwkv_0aa7cb0 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_0aa7cb0::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..df248c48b997c805b9ec6368141f1d027a9c641b --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/rwkv/_rwkv_0aa7cb0.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19029bb516a8cf3428f4e47a2b5f26f091d169206e51df7c4e89880553a15df7 +size 2308848 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_44f2fa4 +ops = torch.ops._rwkv_44f2fa4 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_44f2fa4::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..4cc06763b7d171a586c906cf1e6bb55cba0e4709 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d383c1186dde17bb427291205ae83bdb2e49abf15b4eb482151aa0940e51d4d +size 2111504 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_44f2fa4 +ops = torch.ops._rwkv_44f2fa4 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_44f2fa4::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..026e5fa245ea37eee14a4233f948fd28abe48747 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fb0335b36148895aac63e1822e14c4be054dd842dbdddcfab668b65b2335ff2 +size 2318032 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_44f2fa4 +ops = torch.ops._rwkv_44f2fa4 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_44f2fa4::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..62492882cb398511bfca1748b134d07316503937 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_rwkv_44f2fa4.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2178eb41ec85f3861d61391de45ec47c3127490fbd902f1e2167683514cf867b +size 2335432 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5dffb89844d9c7644252ba9bc5b0635b106174cd --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_efd954c +ops = torch.ops._rwkv_cuda_efd954c + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_efd954c::{op_name}" diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so b/build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..894d88388663606729b454e35fd4a17f9de139dc --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_rwkv_cuda_efd954c.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6d39263049bd9e5823007907a9b2f366226c9dacc5a6157f68f8e317bd1108c +size 2230968 diff --git a/build/torch29-cxx11-cu126-aarch64-linux/metadata.json b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..724a12daaf58a25e6b5fc3e62fc3b76e4b54ccd9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,13 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_44f2fa4 +ops = torch.ops._rwkv_44f2fa4 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_44f2fa4::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..87947ce206483ed63225e5ac32379232817c05b3 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_rwkv_44f2fa4.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64e542cd9e7f9a7ea5bf22a91a27fb34f6352b7fbbc39463f2f5d183f6386d05 +size 2111480 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5dffb89844d9c7644252ba9bc5b0635b106174cd --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_efd954c +ops = torch.ops._rwkv_cuda_efd954c + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_efd954c::{op_name}" diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..87d47bd77178df8cfa9089434aa9eca8f59562de --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_rwkv_cuda_efd954c.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1cde967371ab4341efc9a43eead44a82fe2e80099f170e0f4781f96f8d68961 +size 2427688 diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu128-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_44f2fa4 +ops = torch.ops._rwkv_44f2fa4 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_44f2fa4::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..602898b1e3cb5930b41c1772b31a77942b830400 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_rwkv_44f2fa4.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e37b16a151528c1ad0eecfeb1c3296ac45e6856d3b5b08f7acc68499e3ceca4c +size 2318000 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..07ad4b66966eb57d772dffb889dafeabb11fa217 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03745749fe48ca7bda892fa48ee9ca71897488a03287aa18dcd103b1baf63c85 +size 2429112 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6985878866ce4c1c622d1d977162df0e6c564cee --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_5849bdb +ops = torch.ops._rwkv_cuda_5849bdb + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_5849bdb::{op_name}" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d3c16bd6d3b3f3eec3b740fe15241cc7faba92d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_rwkv_cuda_5849bdb.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e124cba0e78954c0aa034427c32ff799bd43f5050eeaceca2c8989841b45495 +size 2335432 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5dffb89844d9c7644252ba9bc5b0635b106174cd --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_cuda_efd954c +ops = torch.ops._rwkv_cuda_efd954c + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_cuda_efd954c::{op_name}" diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8325463d49dec8d7492c677dbe0863e4056c206f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_rwkv_cuda_efd954c.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab7e3bbeaf2bb811191f7dd30fa68f7c5c772423020b9fe994d31273bf77dc4a +size 2431672 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..890b88c81c1da1a10714a26cdca182aa27c63f6f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "12.0", + "8.0", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ada2ab024f61c605a79ddc243163cfb3b8a39ef7 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _rwkv_44f2fa4 +ops = torch.ops._rwkv_44f2fa4 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_rwkv_44f2fa4::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..6844b7973cff37a87fa06f89130dba801198b046 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_rwkv_44f2fa4.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66177f2e067a2624a79e80664f3d87547c1a5916a91711f792495c4a77f5b70f +size 2343392 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/rwkv/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..85b5d60a855bf4c19555cc9b8de8ca88d6fd3ae9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,168 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1759851564, + "narHash": "sha256-Xybkhm0FM/VzlZ5WndTYq/X/9MAeddd4EQ2Vz8GdkOA=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "351655d9f124805ed7c1193aa61550ce245f4570", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760035358, + "narHash": "sha256-N5vmCrgwcIluPclf/hmnofLK77EJJYh5PR8SRvw++es=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "a48cbd19ae7e425dfc1865188ef06dac43ab9244", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1755963616, + "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..f4954b5ca991ee9ab5893205c29997dd1c1112c2 --- /dev/null +++ b/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for rwkv kernels"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/media/benches_dark_animation.svg b/media/benches_dark_animation.svg new file mode 100644 index 0000000000000000000000000000000000000000..5c6d7df0404cd13941099d4dbdac79a6f9a49cae --- /dev/null +++ b/media/benches_dark_animation.svg @@ -0,0 +1,33 @@ + +kernels-community/rwkv vs Torch - Relative Speed +PyTorch 2.11.0+cu130 · CPU + +RwkvBenchmark.base +455.24x + + + + + + + +RwkvBenchmark.large +1137.57x + + + + + + + +Kernel + +Torch (ref) + + + + + + + + \ No newline at end of file diff --git a/media/benches_dark_latency.svg b/media/benches_dark_latency.svg new file mode 100644 index 0000000000000000000000000000000000000000..b725426e75ca91a0130ef63bc8dc9b5988685d79 --- /dev/null +++ b/media/benches_dark_latency.svg @@ -0,0 +1,1944 @@ + + + + + + + + 2026-03-25T20:26:09.339152 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/media/benches_dark_throughput.svg b/media/benches_dark_throughput.svg new file mode 100644 index 0000000000000000000000000000000000000000..0cf5976f02e5f5c049fcc4f08db1795e86245f32 --- /dev/null +++ b/media/benches_dark_throughput.svg @@ -0,0 +1,2054 @@ + + + + + + + + 2026-03-25T20:26:09.509614 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/media/benches_light_animation.svg b/media/benches_light_animation.svg new file mode 100644 index 0000000000000000000000000000000000000000..c24d1ef7199943fff7a5ae2267d07fcc788c2251 --- /dev/null +++ b/media/benches_light_animation.svg @@ -0,0 +1,33 @@ + +kernels-community/rwkv vs Torch - Relative Speed +PyTorch 2.11.0+cu130 · CPU + +RwkvBenchmark.base +455.24x + + + + + + + +RwkvBenchmark.large +1137.57x + + + + + + + +Kernel + +Torch (ref) + + + + + + + + \ No newline at end of file diff --git a/media/benches_light_latency.svg b/media/benches_light_latency.svg new file mode 100644 index 0000000000000000000000000000000000000000..72158f4383b74fe8054c8aa3115f986ab59ba297 --- /dev/null +++ b/media/benches_light_latency.svg @@ -0,0 +1,1944 @@ + + + + + + + + 2026-03-25T20:26:08.556480 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/media/benches_light_throughput.svg b/media/benches_light_throughput.svg new file mode 100644 index 0000000000000000000000000000000000000000..571a3e12bdbdf7761cb37e68c2d3bf55d776bbce --- /dev/null +++ b/media/benches_light_throughput.svg @@ -0,0 +1,2054 @@ + + + + + + + + 2026-03-25T20:26:08.956848 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/rwkv/wkv_cuda.cu b/rwkv/wkv_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..571d5a8a8307e95aac689eb3c9333d1ad350c7de --- /dev/null +++ b/rwkv/wkv_cuda.cu @@ -0,0 +1,187 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +template +__global__ void kernel_forward_with_state( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset_s = _b * C * 3 + _c * 3; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + F *__restrict__ const s = _s + _offset_s; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = s[0], bb = s[1], pp = s[2]; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + s[0] = aa; + s[1] = bb; + s[2] = pp; +} + +template +__global__ void kernel_backward( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y, + const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, + F *__restrict__ const _gv +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; + const F *__restrict__ const gy = _gy + _offset; + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F q[Tmax], r[Tmax]; + + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_with_state<<>>(B, T, C, w, u, k, v, y, s); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/rwkv/wkv_cuda_bf16.cu b/rwkv/wkv_cuda_bf16.cu new file mode 100644 index 0000000000000000000000000000000000000000..042cb4aba1db98be5916aea1de86a7fed0b6510d --- /dev/null +++ b/rwkv/wkv_cuda_bf16.cu @@ -0,0 +1,186 @@ +#include +#include +#include "ATen/ATen.h" +#define MIN_VALUE (-1e38) +typedef at::BFloat16 bf16; + +__global__ void kernel_forward_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +__global__ void kernel_forward_with_state_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y, + float *__restrict__ const _s +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset_s = _b * C * 3 + _c * 3; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + float *__restrict__ const s = _s + _offset_s; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = s[0], bb = s[1], pp = s[2]; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + s[0] = aa; + s[1] = bb; + s[2] = pp; +} + +__global__ void kernel_backward_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y, + const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, + bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + const bf16 *__restrict__ const y = _y + _offset; + const bf16 *__restrict__ const gy = _gy + _offset; + bf16 *__restrict__ const gk = _gk + _offset; + bf16 *__restrict__ const gv = _gv + _offset; + + float q[Tmax], r[Tmax]; + + float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + const float qq = float(gy[ii]) / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = bf16(gu); + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + const float qq = q[i]; + const float rr = r[i]; + + float e1 = qq * exp(rr); + float e2 = exp(kk + pp); + gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); + gv[ii] = bf16(e1 + e2 * aa); + + const float ww = w + pp; + const float www = rr - u - kk; + const float p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_bf16<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_with_state_bf16<<>>(B, T, C, w, u, k, v, y, s); +} + +void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward_bf16<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/torch-ext/rwkv/__init__.py b/torch-ext/rwkv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea22dbc160d2b75b32c3c8c5d882d25c77b494 --- /dev/null +++ b/torch-ext/rwkv/__init__.py @@ -0,0 +1,170 @@ +from ._ops import ops +from typing import Tuple, Any + +# Use a broad Tensor alias to avoid importing torch at import time. +from torch import Tensor + +def forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (float32). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward(w, u, k, v, y) + + +def forward_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor) -> None: + """RWKV WKV forward pass (bfloat16 inputs/outputs, float32 ``w``). + + Runs the CUDA kernel and writes the result into ``y`` in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y)) + ops.forward_bf16(w, u, k, v, y) + + +def forward_with_state(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (float32). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.float32``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state(w, u, k, v, y, s) + + +def forward_with_state_bf16(w: Tensor, u: Tensor, k: Tensor, v: Tensor, y: Tensor, s: Tensor) -> None: + """RWKV WKV forward pass with persistent state (bfloat16 inputs/outputs, float32 ``w`` and ``s``). + + Runs the CUDA kernel using and updating state ``s`` and writes the result into ``y``. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u: Input tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + k: Key tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + v: Value tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + y: Output tensor, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + s: Stateful tensor, shape ``[B, C]``, dtype ``torch.float32`` (updated in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B`` and ``C``; ``y`` shares ``[B, T, C]`` with inputs. + """ + _validate_device_match((w, u, k, v, y, s)) + ops.forward_with_state_bf16(w, u, k, v, y, s) + + +def backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (float32). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.float32``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.float32``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.float32`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.float32`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def backward_bf16( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + y: Tensor, + gy: Tensor, + gw: Tensor, + gu: Tensor, + gk: Tensor, + gv: Tensor, +) -> None: + """RWKV WKV backward pass (bfloat16 inputs/outputs/gradients, float32 ``w``). + + Writes gradients into the provided tensors in-place. + + Args: + w: Decay weights, shape ``[C]``, dtype ``torch.float32``. + u, k, v, y: Forward-pass tensors, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gy: Gradient of ``y``, shape ``[B, T, C]``, dtype ``torch.bfloat16``. + gw: Gradient for ``w``, shape ``[C]``, dtype ``torch.bfloat16`` (written in-place). + gu, gk, gv: Gradients for ``u``, ``k``, ``v`` respectively, shape ``[B, T, C]``, dtype ``torch.bfloat16`` (written in-place). + + Notes: + - All tensors must be on the same CUDA device. + - Shapes must agree on ``B``, ``T`` and ``C``. + """ + _validate_device_match((w, u, k, v, y, gy, gw, gu, gk, gv)) + ops.backward_bf16(w, u, k, v, y, gy, gw, gu, gk, gv) + + +def _validate_device_match(tensors: Tuple[Tensor, ...]) -> None: + """Minimal runtime validation that all tensors live on the same CUDA device.""" + if not tensors: + return + device = tensors[0].device + if not device.type == "cuda": + raise RuntimeError("RWKV CUDA ops require CUDA tensors") + for t in tensors[1:]: + if t.device != device: + raise RuntimeError("All tensors must be on the same CUDA device") + + +__all__ = [ + "forward", + "forward_bf16", + "forward_with_state", + "forward_with_state_bf16", + "backward", + "backward_bf16", +] \ No newline at end of file diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c42395d0aa9bb14566a58d73f959c9803fc73cde --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,74 @@ +#include +#include "ATen/ATen.h" +#include + +#include "registration.h" + +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); +void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s); +void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); +void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); + +void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_with_state(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); +} +void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_with_state_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); +} +void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} +void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_backward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), + gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("forward", forward); + ops.impl("forward", torch::kCUDA, &forward); + + ops.def("forward_bf16", forward_bf16); + ops.impl("forward_bf16", torch::kCUDA, &forward_bf16); + + ops.def("forward_with_state", forward_with_state); + ops.impl("forward_with_state", torch::kCUDA, &forward_with_state); + + ops.def("forward_with_state_bf16", forward_with_state_bf16); + ops.impl("forward_with_state_bf16", torch::kCUDA, &forward_with_state_bf16); + + ops.def("backward", backward); + ops.impl("backward", torch::kCUDA, &backward); + + ops.def("backward_bf16", backward_bf16); + ops.impl("backward_bf16", torch::kCUDA, &backward_bf16); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file