|
|
|
|
|
|
|
|
|
#include <c10/util/BFloat16.h> |
|
#include <c10/util/Half.h> |
|
#include <c10/cuda/CUDAException.h> |
|
|
|
#include <cub/block/block_load.cuh> |
|
#include <cub/block/block_store.cuh> |
|
|
|
#include "causal_conv1d.h" |
|
#include "causal_conv1d_common.h" |
|
#include "static_switch.h" |
|
|
|
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_> |
|
struct Causal_conv1d_fwd_kernel_traits { |
|
using input_t = input_t_; |
|
using weight_t = weight_t_; |
|
static constexpr int kNThreads = kNThreads_; |
|
static constexpr int kWidth = kWidth_; |
|
static constexpr int kNBytes = sizeof(input_t); |
|
static_assert(kNBytes == 2 || kNBytes == 4); |
|
static constexpr int kNElts = kNBytes == 4 ? 4 : 8; |
|
static_assert(kWidth <= kNElts); |
|
static constexpr bool kIsVecLoad = kIsVecLoad_; |
|
using vec_t = typename BytesToType<kNBytes * kNElts>::Type; |
|
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>; |
|
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>; |
|
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>; |
|
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>; |
|
static constexpr int kSmemIOSize = kIsVecLoad |
|
? 0 |
|
: std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); |
|
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; |
|
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; |
|
}; |
|
|
|
template<typename Ktraits> |
|
__global__ __launch_bounds__(Ktraits::kNThreads) |
|
void causal_conv1d_fwd_kernel(ConvParamsBase params) { |
|
constexpr int kWidth = Ktraits::kWidth; |
|
constexpr int kNThreads = Ktraits::kNThreads; |
|
constexpr int kNElts = Ktraits::kNElts; |
|
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; |
|
using input_t = typename Ktraits::input_t; |
|
using vec_t = typename Ktraits::vec_t; |
|
using weight_t = typename Ktraits::weight_t; |
|
|
|
|
|
extern __shared__ char smem_[]; |
|
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_); |
|
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_); |
|
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_); |
|
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_); |
|
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize); |
|
|
|
const int tidx = threadIdx.x; |
|
const int batch_id = blockIdx.x; |
|
const int channel_id = blockIdx.y; |
|
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride |
|
+ channel_id * params.x_c_stride; |
|
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; |
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride |
|
+ channel_id * params.out_c_stride; |
|
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); |
|
|
|
|
|
if (tidx == 0) { |
|
input_t zeros[kNElts] = {0}; |
|
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0]; |
|
} |
|
|
|
float weight_vals[kWidth]; |
|
#pragma unroll |
|
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } |
|
|
|
constexpr int kChunkSize = kNThreads * kNElts; |
|
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; |
|
for (int chunk = 0; chunk < n_chunks; ++chunk) { |
|
input_t x_vals_load[2 * kNElts] = {0}; |
|
if constexpr(kIsVecLoad) { |
|
Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); |
|
} else { |
|
__syncthreads(); |
|
Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); |
|
} |
|
x += kChunkSize; |
|
__syncthreads(); |
|
|
|
|
|
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; } |
|
__syncthreads(); |
|
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; |
|
__syncthreads(); |
|
|
|
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; } |
|
|
|
float x_vals[2 * kNElts]; |
|
#pragma unroll |
|
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } |
|
|
|
float out_vals[kNElts]; |
|
#pragma unroll |
|
for (int i = 0; i < kNElts; ++i) { |
|
out_vals[i] = bias_val; |
|
#pragma unroll |
|
for (int w = 0; w < kWidth; ++w) { |
|
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; |
|
} |
|
} |
|
|
|
if (params.silu_activation) { |
|
#pragma unroll |
|
for (int i = 0; i < kNElts; ++i) { |
|
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); |
|
} |
|
} |
|
|
|
input_t out_vals_store[kNElts]; |
|
#pragma unroll |
|
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } |
|
if constexpr(kIsVecLoad) { |
|
Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); |
|
} else { |
|
Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); |
|
} |
|
out += kChunkSize; |
|
} |
|
} |
|
|
|
template<int kNThreads, int kWidth, typename input_t, typename weight_t> |
|
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { |
|
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; |
|
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { |
|
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>; |
|
constexpr int kSmemSize = Ktraits::kSmemSize; |
|
dim3 grid(params.batch, params.dim); |
|
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>; |
|
if (kSmemSize >= 48 * 1024) { |
|
C10_CUDA_CHECK(cudaFuncSetAttribute( |
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); |
|
} |
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
}); |
|
} |
|
|
|
template<typename input_t, typename weight_t> |
|
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { |
|
if (params.width == 2) { |
|
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); |
|
} else if (params.width == 3) { |
|
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); |
|
} else if (params.width == 4) { |
|
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); |
|
} |
|
} |
|
|
|
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_> |
|
struct Causal_conv1d_channellast_fwd_kernel_traits { |
|
|
|
|
|
|
|
|
|
using input_t = input_t_; |
|
using weight_t = weight_t_; |
|
static constexpr int kNThreads = kNThreads_; |
|
static_assert(kNThreads % 32 == 0); |
|
static constexpr int kNWarps = kNThreads / 32; |
|
static constexpr int kWidth = kWidth_; |
|
static constexpr int kChunkSizeL = kChunkSizeL_; |
|
static constexpr int kNBytes = sizeof(input_t); |
|
static_assert(kNBytes == 2 || kNBytes == 4); |
|
static constexpr int kNElts = kNBytes == 4 ? 4 : 8; |
|
static constexpr int kNEltsPerRow = 128 / kNBytes; |
|
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; |
|
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); |
|
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; |
|
static_assert(kNColsPerWarp * kNThreadsPerRow == 32); |
|
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; |
|
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; |
|
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); |
|
static constexpr bool kIsVecLoad = kIsVecLoad_; |
|
using vec_t = typename BytesToType<kNBytes * kNElts>::Type; |
|
|
|
|
|
|
|
|
|
|
|
}; |
|
|
|
template<typename Ktraits> |
|
__global__ __launch_bounds__(Ktraits::kNThreads) |
|
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { |
|
constexpr int kWidth = Ktraits::kWidth; |
|
constexpr int kNThreads = Ktraits::kNThreads; |
|
constexpr int kNElts = Ktraits::kNElts; |
|
constexpr int kNWarp = Ktraits::kNWarps; |
|
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; |
|
constexpr int kLPerLoad = Ktraits::kNColsPerLoad; |
|
constexpr int kChunkSizeL = Ktraits::kChunkSizeL; |
|
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; |
|
using input_t = typename Ktraits::input_t; |
|
using vec_t = typename Ktraits::vec_t; |
|
using weight_t = typename Ktraits::weight_t; |
|
|
|
|
|
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; |
|
|
|
const int tid = threadIdx.x; |
|
const int l_idx = tid / kNThreadsPerC; |
|
const int c_idx = tid % kNThreadsPerC; |
|
const int batch_id = blockIdx.x; |
|
const int chunk_l_id = blockIdx.y; |
|
const int chunk_c_id = blockIdx.z; |
|
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride |
|
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; |
|
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) |
|
+ chunk_c_id * kChunkSizeC * params.weight_c_stride; |
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride |
|
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; |
|
|
|
#pragma unroll |
|
for (int l = 0; l < Ktraits::kNLoads; ++l) { |
|
input_t x_vals_load[kNElts] = {0}; |
|
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen |
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { |
|
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride); |
|
} |
|
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0]; |
|
} |
|
|
|
if (l_idx < kWidth - 1) { |
|
input_t x_vals_load[kNElts] = {0}; |
|
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 |
|
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen |
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { |
|
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride); |
|
} |
|
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0]; |
|
} |
|
|
|
__syncthreads(); |
|
|
|
constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); |
|
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); |
|
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; |
|
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); |
|
|
|
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); |
|
static_assert((kLPerThread & (kLPerThread - 1)) == 0); |
|
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); |
|
static_assert(kNThreadsPerRow <= 32); |
|
|
|
const int row_idx = tid / kNThreadsPerRow; |
|
const int col_idx = tid % kNThreadsPerRow; |
|
|
|
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); |
|
float weight_vals[kWidth] = {0}; |
|
if (chunk_c_id + kChunkSizeC + row_idx < params.dim) { |
|
#pragma unroll |
|
for (int w = 0; w < kWidth; ++w) { |
|
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; |
|
} |
|
} |
|
float x_vals[kWidth - 1 + kLPerThread]; |
|
#pragma unroll |
|
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { |
|
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); |
|
} |
|
|
|
float out_vals[kLPerThread]; |
|
#pragma unroll |
|
for (int i = 0; i < kLPerThread; ++i) { |
|
out_vals[i] = bias_val; |
|
#pragma unroll |
|
for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; } |
|
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } |
|
} |
|
|
|
|
|
__syncwarp(); |
|
#pragma unroll |
|
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } |
|
__syncthreads(); |
|
|
|
#pragma unroll |
|
for (int l = 0; l < Ktraits::kNLoads; ++l) { |
|
input_t out_vals_store[kNElts]; |
|
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx]; |
|
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen |
|
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { |
|
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0]; |
|
} |
|
} |
|
|
|
} |
|
|
|
template<int kNThreads, int kWidth, typename input_t, typename weight_t> |
|
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { |
|
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>; |
|
|
|
constexpr int kChunkSizeL = Ktraits::kChunkSizeL; |
|
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; |
|
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; |
|
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; |
|
|
|
dim3 grid(params.batch, n_chunks_L, n_chunks_C); |
|
dim3 block(Ktraits::kNThreads); |
|
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits>; |
|
|
|
|
|
|
|
|
|
|
|
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
template<typename input_t, typename weight_t> |
|
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { |
|
if (params.width == 2) { |
|
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); |
|
} else if (params.width == 3) { |
|
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); |
|
} else if (params.width == 4) { |
|
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); |
|
} |
|
} |
|
|
|
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
|
|
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |