|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <ATen/ATen.h> |
|
#include <cuda.h> |
|
#include <cuda_runtime.h> |
|
#include <cuda_fp16.h> |
|
#include <cuda_profiler_api.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <torch/extension.h> |
|
#include "type_shim.h" |
|
#include <assert.h> |
|
#include <cfloat> |
|
#include <limits> |
|
#include <stdint.h> |
|
#include <c10/macros/Macros.h> |
|
|
|
namespace |
|
{ |
|
|
|
|
|
constexpr int ELEMENTS_PER_LDG_STG = 1; |
|
constexpr int BUFFER_SIZE = 32; |
|
constexpr int FILTER_SIZE = 12; |
|
constexpr int HALF_FILTER_SIZE = 6; |
|
constexpr int UPSAMPLE_REPLICATION_PAD = 5; |
|
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; |
|
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; |
|
|
|
template <typename input_t, typename output_t, typename acc_t> |
|
__global__ void anti_alias_activation_forward( |
|
output_t *dst, |
|
const input_t *src, |
|
const input_t *up_ftr, |
|
const input_t *down_ftr, |
|
const input_t *alpha, |
|
const input_t *beta, |
|
int batch_size, |
|
int channels, |
|
int seq_len) |
|
{ |
|
|
|
input_t up_filter[FILTER_SIZE]; |
|
input_t down_filter[FILTER_SIZE]; |
|
|
|
|
|
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; |
|
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; |
|
|
|
|
|
output_t output[BUFFER_SIZE]; |
|
|
|
|
|
|
|
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); |
|
int local_offset = threadIdx.x * BUFFER_SIZE; |
|
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; |
|
|
|
|
|
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; |
|
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset; |
|
|
|
|
|
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); |
|
input_t seq_left_most_value = right_most_pntr[0]; |
|
input_t seq_right_most_value = right_most_pntr[seq_len - 1]; |
|
|
|
|
|
src += block_offset + local_offset; |
|
dst += block_offset + local_offset; |
|
|
|
|
|
alpha = alpha + blockIdx.y; |
|
input_t alpha_val = expf(alpha[0]); |
|
beta = beta + blockIdx.y; |
|
input_t beta_val = expf(beta[0]); |
|
|
|
#pragma unroll |
|
for (int it = 0; it < FILTER_SIZE; it += 1) |
|
{ |
|
up_filter[it] = up_ftr[it]; |
|
down_filter[it] = down_ftr[it]; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1) |
|
{ |
|
int element_index = seq_offset + it; |
|
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD)) |
|
{ |
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value; |
|
} |
|
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD)) |
|
{ |
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value; |
|
} |
|
if ((element_index >= 0) && (element_index < seq_len)) |
|
{ |
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it]; |
|
} |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1) |
|
{ |
|
input_t acc = 0.0; |
|
int element_index = intermediate_seq_offset + it; |
|
#pragma unroll |
|
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) |
|
{ |
|
if ((element_index + f_idx) >= 0) |
|
{ |
|
acc += up_filter[f_idx] * elements[it + f_idx]; |
|
} |
|
} |
|
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc; |
|
} |
|
|
|
|
|
double no_div_by_zero = 0.000000001; |
|
#pragma unroll |
|
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1) |
|
{ |
|
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1) |
|
{ |
|
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT]; |
|
} |
|
#pragma unroll |
|
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1) |
|
{ |
|
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1]; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int it = 0; it < BUFFER_SIZE; it += 1) |
|
{ |
|
input_t acc = 0.0; |
|
#pragma unroll |
|
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) |
|
{ |
|
|
|
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT]; |
|
} |
|
output[it] = acc; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG) |
|
{ |
|
int element_index = seq_offset + it; |
|
if (element_index < seq_len) |
|
{ |
|
dst[it] = output[it]; |
|
} |
|
} |
|
|
|
} |
|
|
|
template <typename input_t, typename output_t, typename acc_t> |
|
void dispatch_anti_alias_activation_forward( |
|
output_t *dst, |
|
const input_t *src, |
|
const input_t *up_ftr, |
|
const input_t *down_ftr, |
|
const input_t *alpha, |
|
const input_t *beta, |
|
int batch_size, |
|
int channels, |
|
int seq_len) |
|
{ |
|
if (seq_len == 0) |
|
{ |
|
return; |
|
} |
|
else |
|
{ |
|
|
|
constexpr int threads_per_block = 128; |
|
constexpr int seq_len_per_block = 4096; |
|
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; |
|
dim3 blocks(blocks_per_seq_len, channels, batch_size); |
|
dim3 threads(threads_per_block, 1, 1); |
|
|
|
anti_alias_activation_forward<input_t, output_t, acc_t> |
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len); |
|
} |
|
} |
|
} |
|
|
|
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta) |
|
{ |
|
|
|
const int batches = input.size(0); |
|
const int channels = input.size(1); |
|
const int seq_len = input.size(2); |
|
|
|
|
|
auto act_options = input.options().requires_grad(false); |
|
|
|
torch::Tensor anti_alias_activation_results = |
|
torch::empty({batches, channels, seq_len}, act_options); |
|
|
|
void *input_ptr = static_cast<void *>(input.data_ptr()); |
|
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr()); |
|
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr()); |
|
void *alpha_ptr = static_cast<void *>(alpha.data_ptr()); |
|
void *beta_ptr = static_cast<void *>(beta.data_ptr()); |
|
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr()); |
|
|
|
DISPATCH_FLOAT_HALF_AND_BFLOAT( |
|
input.scalar_type(), |
|
"dispatch anti alias activation_forward", |
|
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>( |
|
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr), |
|
reinterpret_cast<const scalar_t *>(input_ptr), |
|
reinterpret_cast<const scalar_t *>(up_filter_ptr), |
|
reinterpret_cast<const scalar_t *>(down_filter_ptr), |
|
reinterpret_cast<const scalar_t *>(alpha_ptr), |
|
reinterpret_cast<const scalar_t *>(beta_ptr), |
|
batches, |
|
channels, |
|
seq_len);); |
|
return anti_alias_activation_results; |
|
} |