| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <ATen/ATen.h> |
| | #include <ATen/cuda/CUDAContext.h> |
| | #include <cuda_runtime.h> |
| | #include <algorithm> |
| | #include <cub/cub.cuh> |
| |
|
| | #include "alignment_train_cuda.h" |
| | #include "utils.h" |
| |
|
| | namespace { |
| |
|
| | |
| | constexpr int BLOCK_DIM_X = 128; |
| | |
| | constexpr int BLOCK_DIM_Y = 8; |
| | |
| | constexpr int SCAN_BLOCK = 512; |
| |
|
| | #define gpuErrchk(ans) \ |
| | { gpuAssert((ans), __FILE__, __LINE__); } |
| |
|
| | inline void |
| | gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { |
| | if (code != cudaSuccess) { |
| | fprintf( |
| | stderr, |
| | "\nGPUassert: %s %s %d\n", |
| | cudaGetErrorString(code), |
| | file, |
| | line); |
| | if (abort) |
| | exit(code); |
| | } |
| | } |
| |
|
| | template <typename T> |
| | struct Prod { |
| | |
| | __host__ __device__ __forceinline__ T |
| | operator()(const T& a, const T& b) const { |
| | return a * b; |
| | } |
| | }; |
| |
|
| | template <typename T> |
| | struct BlockPrefixProdCallbackOp { |
| | |
| | T running_total; |
| |
|
| | |
| | __device__ BlockPrefixProdCallbackOp(T running_total) |
| | : running_total(running_total) {} |
| |
|
| | |
| | |
| | |
| | __device__ T operator()(const T block_aggregate) { |
| | T old_prefix = running_total; |
| | running_total *= block_aggregate; |
| | return old_prefix; |
| | } |
| | }; |
| |
|
| | template <typename T> |
| | struct BlockPrefixSumCallbackOp { |
| | |
| | T running_total; |
| |
|
| | |
| | __device__ BlockPrefixSumCallbackOp(T running_total) |
| | : running_total(running_total) {} |
| |
|
| | |
| | |
| | |
| | __device__ T operator()(const T block_aggregate) { |
| | T old_prefix = running_total; |
| | running_total += block_aggregate; |
| | return old_prefix; |
| | } |
| | }; |
| |
|
| | template <typename T> |
| | __global__ void oneMinusPKernel( |
| | const T* __restrict__ p_choose, |
| | T* __restrict__ cumprod_1mp, |
| | uint32_t bsz, |
| | uint32_t tgt_len, |
| | uint32_t src_len) { |
| | for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { |
| | for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) { |
| | for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) { |
| | uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; |
| | cumprod_1mp[idx] = 1 - p_choose[idx]; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | template <typename T, int TPB> |
| | __global__ void innermostScanKernel( |
| | T* __restrict__ cumprod_1mp, |
| | uint32_t bsz, |
| | uint32_t tgt_len, |
| | uint32_t src_len) { |
| | for (uint32_t b = blockIdx.y; b < bsz; b += gridDim.y) { |
| | for (uint32_t tgt = blockIdx.x; tgt < tgt_len; tgt += gridDim.x) { |
| | |
| | typedef cub::BlockScan<T, TPB> BlockScan; |
| | |
| | __shared__ typename BlockScan::TempStorage temp_storage; |
| | |
| | BlockPrefixProdCallbackOp<T> prefix_op(1); |
| |
|
| | const uint32_t tid = threadIdx.x; |
| | for (uint32_t block_src = 0; block_src < src_len; |
| | block_src += blockDim.x) { |
| | uint32_t src = block_src + tid; |
| | uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; |
| | T thread_data = (src < src_len) ? cumprod_1mp[idx] : (T)0; |
| |
|
| | |
| | BlockScan(temp_storage) |
| | .ExclusiveScan(thread_data, thread_data, Prod<T>(), prefix_op); |
| | __syncthreads(); |
| |
|
| | |
| | if (src < src_len) { |
| | cumprod_1mp[idx] = thread_data; |
| | } |
| | } |
| | } |
| | } |
| | } |
| |
|
| | template <typename T> |
| | __global__ void clampKernel( |
| | const T* __restrict__ cumprod_1mp, |
| | T* __restrict__ cumprod_1mp_clamp, |
| | uint32_t bsz, |
| | uint32_t tgt_len, |
| | uint32_t src_len, |
| | T min_val, |
| | T max_val) { |
| | for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { |
| | for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) { |
| | for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) { |
| | uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; |
| | if (cumprod_1mp[idx] < min_val) { |
| | cumprod_1mp_clamp[idx] = min_val; |
| | } else if (cumprod_1mp[idx] > max_val) { |
| | cumprod_1mp_clamp[idx] = max_val; |
| | } else { |
| | cumprod_1mp_clamp[idx] = cumprod_1mp[idx]; |
| | } |
| | } |
| | } |
| | } |
| | } |
| |
|
| | template <typename T> |
| | __global__ void initAlphaCUDAKernel( |
| | T* alpha, |
| | uint32_t bsz, |
| | uint32_t tgt_len, |
| | uint32_t src_len) { |
| | |
| | for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { |
| | alpha[b * tgt_len * src_len] = (T)1.0; |
| | } |
| | } |
| |
|
| | template <typename T, int TPB> |
| | __global__ void alignmentTrainCUDAKernel( |
| | const T* __restrict__ p_choose, |
| | const T* __restrict__ cumprod_1mp, |
| | const T* __restrict__ cumprod_1mp_clamp, |
| | T* __restrict__ alpha, |
| | uint32_t bsz, |
| | uint32_t tgt_len, |
| | uint32_t src_len, |
| | uint32_t tgt) { |
| | for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { |
| | |
| | typedef cub::BlockScan<T, TPB> BlockScan; |
| |
|
| | |
| | __shared__ typename BlockScan::TempStorage temp_storage; |
| | |
| | BlockPrefixSumCallbackOp<T> prefix_op(0); |
| |
|
| | uint32_t b_offset = b * tgt_len * src_len; |
| | const uint32_t tid = threadIdx.x; |
| | for (uint32_t block_src = 0; block_src < src_len; block_src += blockDim.x) { |
| | uint32_t src = block_src + tid; |
| | |
| | uint32_t inout_idx, alpha_idx; |
| | if (tgt == 0) { |
| | |
| | alpha_idx = b_offset + src; |
| | } else { |
| | |
| | alpha_idx = b_offset + (tgt - 1) * src_len + src; |
| | } |
| | inout_idx = b_offset + tgt * src_len + src; |
| | T thread_data = (T)0; |
| | if (src < src_len) { |
| | thread_data = alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx]; |
| | } |
| |
|
| | |
| | BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op); |
| | __syncthreads(); |
| |
|
| | if (src < src_len) { |
| | T out = thread_data * p_choose[inout_idx] * cumprod_1mp[inout_idx]; |
| | |
| | alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), (T)1.0); |
| | } |
| | } |
| | } |
| | } |
| |
|
| | template <typename T> |
| | void exclusiveCumprod( |
| | const T* p_choose, |
| | T* cumprod_1mp, |
| | uint32_t bsz, |
| | uint32_t tgt_len, |
| | uint32_t src_len, |
| | uint32_t max_grid_x, |
| | uint32_t max_grid_y, |
| | cudaStream_t& stream) { |
| | |
| | dim3 grid(std::min<T>(max_grid_x, bsz), 1, 1); |
| | dim3 block(BLOCK_DIM_X, BLOCK_DIM_Y, 1); |
| | oneMinusPKernel<T><<<grid, block, 0, stream>>>( |
| | p_choose, cumprod_1mp, bsz, tgt_len, src_len); |
| | gpuErrchk(cudaGetLastError()); |
| |
|
| | |
| | |
| | dim3 grid_scan( |
| | std::min<T>(max_grid_x, tgt_len), std::min<T>(max_grid_y, bsz), 1); |
| | innermostScanKernel<T, SCAN_BLOCK><<<grid_scan, SCAN_BLOCK, 0, stream>>>( |
| | cumprod_1mp, bsz, tgt_len, src_len); |
| | gpuErrchk(cudaGetLastError()); |
| | } |
| |
|
| | template <typename T> |
| | void alignmentTrainCUDAImpl( |
| | const T* p_choose, |
| | T* alpha, |
| | uint32_t bsz, |
| | uint32_t tgt_len, |
| | uint32_t src_len, |
| | float eps) { |
| | |
| | |
| | |
| | |
| | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| | uint32_t max_grid_x = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; |
| | uint32_t max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | uint32_t elements = bsz * tgt_len * src_len; |
| | T* cumprod_1mp; |
| | gpuErrchk(cudaMalloc(&cumprod_1mp, elements * sizeof(T))); |
| | exclusiveCumprod<T>( |
| | p_choose, |
| | cumprod_1mp, |
| | bsz, |
| | tgt_len, |
| | src_len, |
| | max_grid_x, |
| | max_grid_y, |
| | stream); |
| |
|
| | |
| | T* cumprod_1mp_clamp; |
| | gpuErrchk(cudaMalloc(&cumprod_1mp_clamp, elements * sizeof(T))); |
| | dim3 grid_clamp(std::min<T>(max_grid_x, bsz), 1, 1); |
| | dim3 block_clamp(BLOCK_DIM_X, BLOCK_DIM_Y, 1); |
| | clampKernel<T><<<grid_clamp, block_clamp, 0, stream>>>( |
| | cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0); |
| | gpuErrchk(cudaGetLastError()); |
| |
|
| | |
| | dim3 grid_init(std::min<int>(max_grid_x, bsz), 1, 1); |
| | initAlphaCUDAKernel<T> |
| | <<<grid_init, 1, 0, stream>>>(alpha, bsz, tgt_len, src_len); |
| | gpuErrchk(cudaGetLastError()); |
| |
|
| | const int grid = std::min(bsz, max_grid_x); |
| |
|
| | for (uint32_t i = 0; i < tgt_len; i++) { |
| | alignmentTrainCUDAKernel<T, SCAN_BLOCK><<<grid, SCAN_BLOCK, 0, stream>>>( |
| | p_choose, |
| | cumprod_1mp, |
| | cumprod_1mp_clamp, |
| | alpha, |
| | bsz, |
| | tgt_len, |
| | src_len, |
| | i); |
| | gpuErrchk(cudaGetLastError()); |
| | } |
| |
|
| | gpuErrchk(cudaFree(cumprod_1mp)); |
| | gpuErrchk(cudaFree(cumprod_1mp_clamp)); |
| | } |
| |
|
| | } |
| |
|
| | void alignmentTrainCUDAWrapper( |
| | const torch::Tensor& p_choose, |
| | torch::Tensor& alpha, |
| | float eps) { |
| | |
| | uint32_t bsz = p_choose.size(0); |
| | uint32_t tgt_len = p_choose.size(1); |
| | uint32_t src_len = p_choose.size(2); |
| |
|
| | cudaSetDevice(p_choose.get_device()); |
| |
|
| | AT_DISPATCH_FLOATING_TYPES_AND2( |
| | torch::ScalarType::Half, |
| | torch::ScalarType::BFloat16, |
| | p_choose.scalar_type(), |
| | "alignmentTrainCUDAImpl", |
| | [&]() { |
| | alignmentTrainCUDAImpl<scalar_t>( |
| | p_choose.data_ptr<scalar_t>(), |
| | alpha.data_ptr<scalar_t>(), |
| | bsz, |
| | tgt_len, |
| | src_len, |
| | eps); |
| | }); |
| | } |
| |
|