Delete rwkv5_op.cpp
Browse files- rwkv5_op.cpp +0 -34
rwkv5_op.cpp
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include "ATen/ATen.h"
|
3 |
-
#include <c10/cuda/CUDAGuard.h>
|
4 |
-
typedef at::BFloat16 bf16;
|
5 |
-
typedef at::Half fp16;
|
6 |
-
typedef float fp32;
|
7 |
-
|
8 |
-
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
9 |
-
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
10 |
-
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
11 |
-
|
12 |
-
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
13 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
14 |
-
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
15 |
-
}
|
16 |
-
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
17 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
18 |
-
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
19 |
-
}
|
20 |
-
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
21 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
22 |
-
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
23 |
-
}
|
24 |
-
|
25 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
26 |
-
m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
|
27 |
-
m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
|
28 |
-
m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
|
29 |
-
}
|
30 |
-
TORCH_LIBRARY(rwkv5, m) {
|
31 |
-
m.def("forward_bf16", forward_bf16);
|
32 |
-
m.def("forward_fp16", forward_fp16);
|
33 |
-
m.def("forward_fp32", forward_fp32);
|
34 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|