#include #include #include "ATen/ATen.h" #include #define MIN_VALUE (-1e38) typedef at::Half fp16; __half *cast(fp16 *ptr) { return reinterpret_cast<__half *>(ptr); } template __global__ void kernel_wkv_forward(const int B, const int T, const int C, const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int _b = idx / C; const int _c = idx % C; const int _offset = _b * T * C + _c; const int _state_offset = _b * C + _c; float u = _u[_c]; float w = _w[_c]; const F *__restrict__ const k = _k + _offset; const F *__restrict__ const v = _v + _offset; F *__restrict__ const y = _y + _offset; float aa = _aa[_state_offset]; float bb = _bb[_state_offset]; float pp = _pp[_state_offset]; for (int i = 0; i < T; i++) { const int ii = i * C; const float kk = float(k[ii]); const float vv = float(v[ii]); float ww = u + kk; float p = max(pp, ww); float e1 = exp(pp - p); float e2 = exp(ww - p); y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2)); ww = w + pp; p = max(ww, kk); e1 = exp(ww - p); e2 = exp(kk - p); aa = e1 * aa + e2 * vv; bb = e1 * bb + e2; pp = p; } _aa[_state_offset] = aa; _bb[_state_offset] = bb; _pp[_state_offset] = pp; } template void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) { dim3 threadsPerBlock( min(C, 32) ); assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_wkv_forward<<>>(B, T, C, w, u, k, v, y, aa, bb, pp); } template void cuda_wkv_forward( int B, int T, int C, float *w, float *u, fp16 *k, fp16 *v, fp16 *y, float *aa, float *bb, float *pp); template void cuda_wkv_forward( int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *aa, float *bb, float *pp); __global__ void kernel_mm_seq_fp32i8( const int B, const int N, const int M, const float *__restrict__ const x, const int x_stride, const uint8_t *__restrict__ const w, const int w_stride, const float *__restrict__ const mx, const float *__restrict__ const rx, const float *__restrict__ const my, const float *__restrict__ const ry, float *__restrict__ const y, const int y_stride) { const int i = blockIdx.x * blockDim.x + threadIdx.x; const int k = blockIdx.y * blockDim.y + threadIdx.y; if (i < B && k < M) { float y_local = 0; for (int j = 0; j < N; ++j) { y_local += x[i * x_stride + j] * ( (float(w[j * w_stride + k]) + 0.5f) * rx[k] * ry[j] + mx[k] + my[j] ); } y[i * y_stride + k] = y_local; } } template void cuda_mm8_seq(int B, int N, int M, F *x, int x_stride, uint8_t *w, int w_stride, F *mx, F *rx, F *my, F *ry, F *y, int y_stride); template <> void cuda_mm8_seq(int B, int N, int M, float *x, int x_stride, uint8_t *w, int w_stride, float *mx, float *rx, float *my, float *ry, float *y, int y_stride) { dim3 blockSize(1, 128); dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); kernel_mm_seq_fp32i8<<>>( B, N, M, x, x_stride, w, w_stride, mx, rx, my, ry, y, y_stride); } __global__ void kernel_mm_seq_fp16i8( const int B, const int N, const int M, const __half *__restrict__ const x, const int x_stride, const uint8_t *__restrict__ const w, const int w_stride, const __half *__restrict__ const mx, const __half *__restrict__ const rx, const __half *__restrict__ const my, const __half *__restrict__ const ry, __half *__restrict__ const y, const int y_stride) { const int i = blockIdx.x * blockDim.x + threadIdx.x; const int k = blockIdx.y * blockDim.y + threadIdx.y; if (i < B && k < M) { float y_local = 0; for (int j = 0; j < N; ++j) { y_local += __half2float(x[i * x_stride + j]) * ( (float(w[j * w_stride + k]) + 0.5f) * __half2float(rx[k]) * __half2float(ry[j]) + __half2float(mx[k]) + __half2float(my[j]) ); } y[i * y_stride + k] = __float2half(y_local); } } template <> void cuda_mm8_seq(int B, int N, int M, fp16 *x, int x_stride, uint8_t *w, int w_stride, fp16 *mx, fp16 *rx, fp16 *my, fp16 *ry, fp16 *y, int y_stride) { dim3 blockSize(1, 128); dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); kernel_mm_seq_fp16i8<<>>( B, N, M, cast(x), x_stride, w, w_stride, cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride); } #define MM8_ONE_JSPLIT 24 #define MM8_ONE_TILE 1024 __global__ void kernel_mm_one_fp32i8( const int N, const int M, const float *__restrict__ const x, const uint8_t *__restrict__ const w, const int w_stride, const float *__restrict__ const mx, const float *__restrict__ const rx, const float *__restrict__ const my, const float *__restrict__ const ry, float *__restrict__ const y) { const int k = blockIdx.y * blockDim.y + threadIdx.y; const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); if (k < M) { float y_local = 0; for (int j = j0; j < j1; ++j) { y_local += x[j] * ( (float(w[j * w_stride + k]) + 0.5f) * rx[k] * ry[j] + mx[k] + my[j] ); } atomicAdd(&y[k], y_local); } } template void cuda_mm8_one(int N, int M, F *x, uint8_t *w, int w_stride, F *mx, F *rx, F *my, F *ry, float *y); template <> void cuda_mm8_one(int N, int M, float *x, uint8_t *w, int w_stride, float *mx, float *rx, float *my, float *ry, float *y) { dim3 blockSize(1, MM8_ONE_TILE); dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); kernel_mm_one_fp32i8<<>>( N, M, x, w, w_stride, mx, rx, my, ry, y); } __global__ void kernel_mm_one_fp16i8( const int N, const int M, const __half *__restrict__ const x, const uint8_t *__restrict__ const w, const int w_stride, const __half *__restrict__ const mx, const __half *__restrict__ const rx, const __half *__restrict__ const my, const __half *__restrict__ const ry, float *__restrict__ const y) { const int k = blockIdx.y * blockDim.y + threadIdx.y; const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); if (k < M) { float y_local = 0; for (int j = j0; j < j1; ++j) { y_local += __half2float(x[j]) * ( (float(w[j * w_stride + k]) + 0.5f) * __half2float(rx[k]) * __half2float(ry[j]) + __half2float(mx[k]) + __half2float(my[j]) ); } atomicAdd(&y[k], y_local); } } template <> void cuda_mm8_one(int N, int M, fp16 *x, uint8_t *w, int w_stride, fp16 *mx, fp16 *rx, fp16 *my, fp16 *ry, float *y) { dim3 blockSize(1, MM8_ONE_TILE); dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); kernel_mm_one_fp16i8<<>>( N, M, cast(x), w, w_stride, cast(mx), cast(rx), cast(my), cast(ry), y); }