#include #include "ATen/ATen.h" #include typedef at::BFloat16 bf16; typedef at::Half fp16; typedef float fp32; void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); } void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); } void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16"); m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16"); m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32"); } TORCH_LIBRARY(rwkv6, m) { m.def("forward_bf16", forward_bf16); m.def("forward_fp16", forward_fp16); m.def("forward_fp32", forward_fp32); }