JustinLin610
update
8437114
raw history blame
No virus
5.87 kB
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
return (a + b - 1) / b;
}
template <int FS, int SB, int padding_l, typename scalar_t>
__inline__ __device__ void zeroSharedMem(scalar_t* data) {
/*
Given an array of length FS + SB, zero out the first padding_l and last
(FS - padding_l) values in the array
*/
int tid = threadIdx.x;
if (FS < SB) {
// zero all if we have enough threads in a block to do all of them
if (tid < padding_l || tid > SB - FS + padding_l - 1) {
data[tid] = scalar_t(0.0);
}
} else {
// otherwise zero out one block at a time
const int numIterations = divUp<int, int>(FS, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if (tid + offset < padding_l) {
data[tid + offset] = scalar_t(0.0);
} else if (tid + offset < FS) {
data[SB + tid + offset] = scalar_t(0.0);
}
}
}
}
template <typename scalar_t>
__inline__ __device__ scalar_t warpReduce(scalar_t data) {
/*
Reduce an array within each warp. After processing all values in warp will
caontain the sum of all original values in that warp.
data - pointer to data to reduce
*/
data += __shfl_xor_sync(SHFL_MASK, data, 16);
data += __shfl_xor_sync(SHFL_MASK, data, 8);
data += __shfl_xor_sync(SHFL_MASK, data, 4);
data += __shfl_xor_sync(SHFL_MASK, data, 2);
data += __shfl_xor_sync(SHFL_MASK, data, 1);
return data;
}
template <typename scalar_t>
__inline__ __device__ scalar_t blockReduce(scalar_t data) {
/*
Reduce an entire array on the block level. After processing, the
first value in the array will contain the reduced sum.
data - pointer to data to reduce
*/
static __shared__ scalar_t warpSum[32];
const int tid = threadIdx.x;
int wid = tid / 32;
int lane = tid % 32;
__syncthreads();
// reduce each warp then write to shared memory
scalar_t sum = warpReduce(data);
if (lane == 0) {
warpSum[wid] = sum;
}
__syncthreads();
scalar_t v;
// perform final sum of partial warp sums
if (tid < blockDim.x / 32) {
v = warpSum[lane];
} else {
v = scalar_t(0.0);
}
if (wid == 0) {
v = warpReduce(v);
}
__syncthreads();
return v;
}
void checkCudaStatus(cudaError_t status, int lineNumber = -1) {
if (status != cudaSuccess) {
std::cout << cudaGetErrorString(status) << " at line " << lineNumber
<< std::endl;
std::cout << "Exiting" << std::endl;
exit(1);
}
}
template <int FS, int SB, int padding_l, typename scalar_t>
__device__ void load_input_to_shared(
const scalar_t* input, // global memory
int inputOffset,
int sequenceLength,
int iteration,
int numIterations,
bool no_prev,
scalar_t* output /* shared memory */) {
/*
Load a block size of input into shared memory with
right and left overhang of total size FS. If previously
loaded memory, overlap will be shifted over to reduce
global memory access
input - pointer to start of channel sequence
inputOffset - how far in the sequence to start loading
sequenceLength - total length of sequence
iteration - which block of sequence we are loading
numIterations - total number of blocks to load
no_prev - whether to load the whole block if the previous block
wasn't loaded
output - shared memory to write input to
*/
const int tid = threadIdx.x;
// Load the left "overhang" of input
if (iteration > 0) {
if (padding_l < SB) {
// load all at once
if (tid < padding_l) {
output[tid] =
(no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB];
}
} else {
// load in chunks of size SB
int numIterations = divUp<int, int>(padding_l, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if ((tid + offset) < padding_l) {
output[tid + offset] = (no_prev)
? input[inputOffset - padding_l + tid + offset]
: output[tid + offset + SB];
}
}
}
}
// Load the right "overhang" of input
if (iteration < (numIterations - 1)) {
const int elementsLeft = sequenceLength - (iteration + 1) * SB;
if ((FS - padding_l) < SB) {
// load all at once
if (tid < (FS - padding_l)) {
output[padding_l + SB + tid] = (tid < elementsLeft)
? input[inputOffset + SB + tid]
: scalar_t(0.0);
}
} else {
// load in chunks of size SB
int numIterations = divUp<int, int>(FS - padding_l, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if ((tid + offset) < (FS - padding_l)) {
output[padding_l + SB + tid + offset] =
((tid + offset) < elementsLeft)
? input[inputOffset + SB + tid + offset]
: scalar_t(0.0);
}
}
}
}
// We should also clear out the right "overhang"
if (iteration == (numIterations - 1)) {
if ((FS - padding_l) < SB) {
// clear out all at once
if (tid < (FS - padding_l)) {
output[padding_l + SB + tid] = scalar_t(0.0);
}
} else {
// clear in chunks of size SB
int numIterations = divUp<int, int>(FS - padding_l, SB);
for (int i = 0; i < numIterations; i++) {
int offset = i * SB;
if ((tid + offset) < (FS - padding_l)) {
output[padding_l + SB + tid + offset] = scalar_t(0.0);
}
}
}
}
output[tid + padding_l] = ((inputOffset + tid) < sequenceLength)
? input[inputOffset + tid]
: scalar_t(0.0);
}