// Copyright (c) 2018 MathInf GmbH, Thomas Viehmann // Modified by zyddnys // Licensed under the BSD-3-Clause license // This is the GPU implementation of the Connectionist Temporal Loss. // We mostly follow Graves. // 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf // Note from zyddnys: // Added regression capability to CTC loss, currently we use L2 regression, future L1 regression maybe added // Two BLANKS where BLANK is the BLANK in CTC, BLANK_1 means regression part of this target is ignored // Many kernels are split into multiple kernels to prevent CUDA too much resources requested error // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based. // Graves et al call the probabilities y, we use log_probs (also calling them inputs) // A few optimizations (similar to those here, but also some I didn't take) are described in // 2. Minmin Sun: http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf #include #include #include #include #include #include #include #include #include #include #include using namespace c10; using namespace at; using namespace at::native; // log P(x|mu) template __device__ inline scalar_t custom_distance_forward_log(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { return -0.5 * std::log(2.0 * c10::pi) - std::log(sigma) - 0.5 * (x - mu) * (x - mu) / (sigma * sigma); } // d(P(x|mu))/dmu template __device__ inline scalar_t custom_distance_backward(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { scalar_t val = 1.0 / (sigma * std::sqrt(2 * c10::pi)) * std::exp(-0.5 * (x - mu) * (x - mu) / (sigma * sigma)); return val * (x - mu) / (sigma * sigma); } // log P(x|mu) template __device__ inline scalar_t custom_distance_forward_log_l1(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { return - std::log(2 * sigma) - std::abs(x - mu) / sigma; } template __device__ inline scalar_t sgn(scalar_t v) noexcept { if (std::abs(v) < std::numeric_limits::epsilon()) return 0; return v / std::abs(v); } // d(P(x|mu))/dmu template __device__ inline scalar_t custom_distance_backward_l1(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { return -sgn(mu - x) * std::exp(-std::abs(x - mu) / sigma) / (2 * sigma * sigma); } #if 0 // d(log P(x|mu))/dmu template __device__ inline scalar_t custom_distance_backward_log(scalar_t x, scalar_t mu) { return x - mu; } // P(x|mu) template __device__ inline scalar_t custom_distance_forward(scalar_t x, scalar_t mu) { return 0; } #endif // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) // so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in // l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK // - note that no bound-checking is done // - it is important to only call it with idx == 0 if the target length is 0 // - __restrict__ impact to be measured, see // https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/ template __device__ static inline int64_t get_target_prime( const target_t* __restrict__ target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) { if (idx % 2 == 0) { return BLANK; } else { return target[offset + stride * (idx / 2)]; } } template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_collect_log_realvalues_gpu_kernel(scalar_t* __restrict__ log_realvalues_data, const int64_t* __restrict__ input_lengths, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, const scalar_t* __restrict__ realval_data, int64_t num_realval, const scalar_t* __restrict__ targets_realval_data, int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride, int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { int64_t b = threadIdx.y + blockIdx.y * blockDim.y; int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime! if (b >= batch_size) return; int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t rv_batch_offset = b*rv_batch_stride; int64_t rvt_batch_offset = b*rvt_batch_stride; int64_t lr_batch_offset = b*lr_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; if (s >= target_length) return; int64_t target = targets_data[tg_batch_offset + s * tg_target_stride]; for (int64_t t = 0; t < input_length; t++) { scalar_t log_prod_n = 0; if (target != BLANK && target != BLANK_1) { for (int64_t i = 0; i < num_realval; ++i) { log_prod_n += custom_distance_forward_log( targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i], realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i], sigma ); } } log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s] = log_prod_n; } } template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_log_alpha_gpu_kernel_phase1(scalar_t* __restrict__ log_alpha_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, const scalar_t* __restrict__ log_realvalues_data, scalar_t* __restrict__ neg_log_likelihood_data, int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); // bookkeeping int64_t b = threadIdx.y + blockIdx.y * blockDim.y; if (b >= batch_size) return; int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t lp_batch_offset = b*lp_batch_stride; int64_t la_batch_offset = b*la_batch_stride; int64_t lr_batch_offset = b*lr_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; // first row (t=0), the three equations for alpha_1 above eq (6) for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) { int64_t s = threadIdx.x + block_s; scalar_t la; switch (s) { case 0: la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK]; break; case 1: { if (target_length != 0) { int64_t tgt = get_target_prime( targets_data, tg_batch_offset, tg_target_stride, 1, BLANK); scalar_t cur_logprob = log_probs_data[lp_batch_offset + lp_char_stride * tgt]; //if (tgt != BLANK_1) { cur_logprob += log_realvalues_data[lr_batch_offset + lr_input_stride * 0 + lr_target_stride * 0]; //} la = cur_logprob; } else { la = neginf; } // la = target_length == 0 ? neginf // : log_probs_data // [lp_batch_offset + // lp_char_stride * // get_target_prime( // targets_data, // tg_batch_offset, // tg_target_stride, // 1, // BLANK)]; } break; default: la = neginf; } if (s < 2*max_target_length+1) log_alpha_data[la_batch_offset + /* la_input_stride * 0 */ + la_target_stride * s] = la; } } template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_log_alpha_gpu_kernel_phase2(scalar_t* __restrict__ log_alpha_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, const scalar_t* __restrict__ log_realvalues_data, scalar_t* __restrict__ neg_log_likelihood_data, int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); // bookkeeping int64_t b = threadIdx.y + blockIdx.y * blockDim.y; if (b >= batch_size) return; int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t lp_batch_offset = b*lp_batch_stride; int64_t la_batch_offset = b*la_batch_stride; int64_t lr_batch_offset = b*lr_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; // first row (t=0), the three equations for alpha_1 above eq (6) for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) { int64_t s = threadIdx.x + block_s; scalar_t la; switch (s) { case 0: la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK]; break; case 1: { if (target_length != 0) { int64_t tgt = get_target_prime( targets_data, tg_batch_offset, tg_target_stride, 1, BLANK); scalar_t cur_logprob = log_probs_data[lp_batch_offset + lp_char_stride * tgt]; //if (tgt != BLANK_1) { cur_logprob += log_realvalues_data[lr_batch_offset + lr_input_stride * 0 + lr_target_stride * 0]; //} la = cur_logprob; } else { la = neginf; } // la = target_length == 0 ? neginf // : log_probs_data // [lp_batch_offset + // lp_char_stride * // get_target_prime( // targets_data, // tg_batch_offset, // tg_target_stride, // 1, // BLANK)]; } break; default: la = neginf; } if (s < 2*max_target_length+1) log_alpha_data[la_batch_offset + /* la_input_stride * 0 */ + la_target_stride * s] = la; } for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) { int64_t s = threadIdx.x + block_s; // These two only depend on s, so we can cache them. int64_t current_char; // l_s in eq (6) bool have_three; // flag which of the two cases in eq (6) we have if (s < 2 * target_length + 1 && target_length > 0) { current_char = get_target_prime( targets_data, tg_batch_offset, tg_target_stride, s, BLANK); have_three = ((s > 1) && (get_target_prime( targets_data, tg_batch_offset, tg_target_stride, s - 2, BLANK) != current_char)); } else { current_char = BLANK; have_three = false; } for (int64_t t=1; t < max_input_length; t++) { __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch if ((t < input_length) && (s < 2 * target_length + 1)) { scalar_t cur_logprob = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_char]; // if (current_char != BLANK_1 && current_char != BLANK) { // for (int64_t i = 0; i < num_realval; ++i) { // cur_logprob += custom_distance_forward_log( // targets_realval_data[rvt_batch_offset + rvt_input_stride * (s / 2) + rvt_label_stride * i], // realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i], // sigma // ); // } // } cur_logprob += (s % 2 == 1) ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0; // only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands, // lamax is the maximum for the logsumexp trick. scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s]; scalar_t lamax = la1; scalar_t la2, la3; if (s > 0) { la2 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-1)]; if (la2 > lamax) lamax = la2; } else { la2 = neginf; } if (have_three) { la3 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-2)]; if (la3 > lamax) lamax = la3; } else { la3 = neginf; } if (lamax == neginf) // when all are neginf. (then the whole thing is neginf, but we can pretend) lamax = 0; log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = std::log(std::exp(la1-lamax)+std::exp(la2-lamax)+std::exp(la3-lamax))+lamax + cur_logprob; } else { // otherwise we just set to neginf if (s < 2*max_target_length+1) log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = neginf; } } } __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch // compute the loss (eq (8)) if (threadIdx.x == 0) { scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)]; scalar_t l2 = target_length > 0 ? log_alpha_data [la_batch_offset + la_input_stride * (input_length - 1) + la_target_stride * (target_length * 2 - 1)] : neginf; scalar_t m = ((l1 > l2) ? l1 : l2); m = ((m == neginf) ? 0 : m); scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m; neg_log_likelihood_data[b] = -log_likelihood; } } template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_log_alpha_gpu_kernel_phase3(scalar_t* __restrict__ log_alpha_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, const scalar_t* __restrict__ realval_data, int64_t num_realval, const scalar_t* __restrict__ targets_realval_data, scalar_t* __restrict__ neg_log_likelihood_data, int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride, int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); // bookkeeping int64_t b = threadIdx.y + blockIdx.y * blockDim.y; if (b >= batch_size) return; int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t lp_batch_offset = b*lp_batch_stride; int64_t la_batch_offset = b*la_batch_stride; int64_t rv_batch_offset = b*rv_batch_stride; int64_t rvt_batch_offset = b*rvt_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; // __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch // compute the loss (eq (8)) if (threadIdx.x == 0) { scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)]; scalar_t l2 = target_length > 0 ? log_alpha_data [la_batch_offset + la_input_stride * (input_length - 1) + la_target_stride * (target_length * 2 - 1)] : neginf; scalar_t m = ((l1 > l2) ? l1 : l2); m = ((m == neginf) ? 0 : m); scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m; neg_log_likelihood_data[b] = -log_likelihood; } } // The forward computation. Lot's of admin and a call to the alpha kernel. // Note: we do not check that the labels are in the valid range. As we use // them for indexing in the kernels, you'll see memory errors when you // pass corrupt labels. // We support both a 2-dimensional tensor as targets (one set of targets in each row) and // a 1-dimensional tensor where all targets are concatenated (and we use target_lengths // to figure out where they begin). // We return log_alpha (currently, might change to (log_alpha+log_beta) to be passed to the // backward. The dispatch function will only return the loss. template std::tuple custom_ctc_loss_gpu_template( const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, IntArrayRef input_lengths, IntArrayRef target_lengths, scalar_t const sigma, int64_t BLANK, int64_t BLANK_1 ) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); // log_probs: input_len x batch_size x num_labels // targets [int64]: batch_size x target_length OR sum(target_lengths) // realval [float]: batch_size x input_len x num_realval // targets_realval [float]: batch_size x max_target_length x num_realval CheckedFrom c = "custom_ctc_loss_gpu"; using target_t = typename std::conditional::type; auto log_probs_arg = TensorArg(log_probs, "log_probs", 1); auto targets_arg = TensorArg(targets, "targets", 2); auto realval_arg = TensorArg(realval, "realval", 3); auto targets_realval_arg = TensorArg(targets_realval, "targets_realval", 4); checkAllSameGPU(c, {log_probs_arg, targets_arg, realval_arg, targets_realval_arg}); checkScalarType(c, targets_arg, target_scalar_type); checkDim(c, log_probs_arg, 3); checkDim(c, realval_arg, 3); checkDim(c, targets_realval_arg, 3); checkDimRange(c, targets_arg, 1, 3); int64_t batch_size = log_probs.size(0); int64_t num_realvals = realval.size(2); int64_t num_labels = log_probs.size(2); TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range"); TORCH_CHECK((0 <= BLANK_1) && (BLANK_1 < num_labels), "blank1 must be in label range"); TORCH_CHECK(input_lengths.size() == batch_size, "input_lengths must be of size batch_size"); TORCH_CHECK(realval.size(2) == targets_realval.size(2), "number of real values must be the same for both realval and targets_realval"); TORCH_CHECK(log_probs.size(1) == realval.size(1), "input_lengths must be the same for both log_probs and realval"); TORCH_CHECK(target_lengths.size() == batch_size, "target_lengths must be of size batch_size"); int64_t lp_input_stride = log_probs.stride(1); int64_t lp_char_stride = log_probs.stride(2); int64_t tg_target_stride; int64_t max_target_length = 0; auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong)); auto tg_batch_offsets_data = tg_batch_offsets.data_ptr(); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets_data[i] = pos; pos += target_lengths[i]; if (max_target_length < target_lengths[i]) max_target_length = target_lengths[i]; } tg_target_stride = targets.stride(0); checkSize(c, targets_arg, 0, pos); } else { // batch x max_target_length // dim is 2 int64_t tg_batch_stride = targets.stride(0); for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets_data[i] = i * tg_batch_stride; if (max_target_length < target_lengths[i]) max_target_length = target_lengths[i]; } tg_target_stride = targets.stride(1); checkSize(c, targets_arg, 0, batch_size); TORCH_CHECK(targets.size(1) >= max_target_length, "Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg, " (while checking arguments for ", c, ")"); } int64_t max_input_length = log_probs.size(1); for (int64_t b = 0; b < batch_size; b++) { TORCH_CHECK(input_lengths[b] <= max_input_length, "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b], " (while checking arguments for ", c, ")"); } auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong)); auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong)); tg_batch_offsets = tg_batch_offsets.cuda(); Tensor log_realvalues = at::zeros({batch_size, log_probs.size(1), std::max(max_target_length, int64_t(1))}, log_probs.options()); Tensor log_alpha = at::empty({batch_size, log_probs.size(1), 2*max_target_length+1}, log_probs.options()); Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options()); log_alpha.fill_(neginf); constexpr int max_threads = std::is_same::value ? 1024 : 896; // we need 72 or so 32 bit registers for double cudaStream_t stream = at::cuda::getCurrentCUDAStream(); { int threads_target = max_threads; while (threads_target / 2 >= max_target_length && threads_target > 1) { threads_target /= 2; } int threads_batch = std::min(max_threads / threads_target, (int) batch_size); dim3 block(threads_target, threads_batch); dim3 grid( std::max( (max_target_length + threads_target - 1) / threads_target, 1), (batch_size + threads_batch - 1) / threads_batch, 1); ctc_loss_collect_log_realvalues_gpu_kernel<<>> (log_realvalues.data_ptr(), input_lengths_t.data_ptr(), targets.data_ptr(), target_lengths_t.data_ptr(), realval.data_ptr(), num_realvals, targets_realval.data_ptr(), log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), realval.stride(0), realval.stride(1), realval.stride(2), targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, num_labels, sigma, BLANK, BLANK_1); C10_CUDA_KERNEL_LAUNCH_CHECK(); } // Very likely, we could be more clever here, e.g. learning (or genralizing and reusing) from SoftMax.cu... int threads_target = max_threads; while (threads_target / 2 >= 2*max_target_length+1) { threads_target /= 2; } int threads_batch = std::min(max_threads / threads_target, (int) batch_size); dim3 block(threads_target, threads_batch); dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch); // ctc_loss_log_alpha_gpu_kernel_phase1<<>>( // log_alpha.data_ptr(), // log_probs.data_ptr(), input_lengths_t.data_ptr(), log_probs.size(1), // targets.data_ptr(), target_lengths_t.data_ptr(), max_target_length, // log_realvalues.data_ptr(), // neg_log_likelihood.data_ptr(), // log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), // log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), // log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), // tg_batch_offsets.data_ptr(), tg_target_stride, // batch_size, sigma, BLANK, BLANK_1); // C10_CUDA_KERNEL_LAUNCH_CHECK(); ctc_loss_log_alpha_gpu_kernel_phase2<<>>( log_alpha.data_ptr(), log_probs.data_ptr(), input_lengths_t.data_ptr(), log_probs.size(1), targets.data_ptr(), target_lengths_t.data_ptr(), max_target_length, log_realvalues.data_ptr(), neg_log_likelihood.data_ptr(), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, sigma, BLANK, BLANK_1); C10_CUDA_KERNEL_LAUNCH_CHECK(); // ctc_loss_log_alpha_gpu_kernel_phase3<<>>( // log_alpha.data_ptr(), // log_probs.data_ptr(), input_lengths_t.data_ptr(), log_probs.size(1), // targets.data_ptr(), target_lengths_t.data_ptr(), max_target_length, // realval.data_ptr(), num_realvals, // targets_realval.data_ptr(), // neg_log_likelihood.data_ptr(), // log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), // log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), // realval.stride(0), realval.stride(1), realval.stride(2), // targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2), // tg_batch_offsets.data_ptr(), tg_target_stride, // batch_size, sigma, BLANK, BLANK_1); // C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(neg_log_likelihood, log_alpha); } // The second (backward) half of the forward backward algorithm, (10) and (11). This is parallel to the // alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.) template __global__ void C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, const scalar_t* __restrict__ log_realvalues_data, int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); int64_t b = threadIdx.y + blockIdx.y * blockDim.y; if (b >= batch_size) return; int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t lp_batch_offset = b*lp_batch_stride; int64_t lb_batch_offset = b*lb_batch_stride; int64_t lr_batch_offset = b*lr_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; // "first" row, the beta initiaization before eq (10) (t=target_length - differs per batch) for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) { int64_t s = threadIdx.x + block_s; scalar_t lb; if (s == 2*target_length) { lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK]; } else if (s == 2 * target_length - 1) { // false for target_length == 0 int64_t current_target_prime = get_target_prime( targets_data, tg_batch_offset, tg_target_stride, s, BLANK); scalar_t cur_logprob = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime]; lb = cur_logprob + log_realvalues_data[lr_batch_offset + lr_input_stride * (input_length - 1) + lr_target_stride * (target_length - 1)]; } else { lb = neginf; } if (s < 2*max_target_length+1) { log_beta_data[lb_batch_offset + (input_length-1) * lb_input_stride + lb_target_stride * s] = lb; } } // go backward in s for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) { int64_t s = threadIdx.x + block_s; int64_t current_target_prime; bool have_three; if (s < 2 * target_length + 1 && target_length > 0) { current_target_prime = get_target_prime( targets_data, tg_batch_offset, tg_target_stride, s, BLANK); have_three = ((s < 2 * target_length - 1) && (get_target_prime( targets_data, tg_batch_offset, tg_target_stride, s + 2, BLANK) != current_target_prime)); } else { current_target_prime = BLANK; have_three = false; } // now go backward in t. Note that we need to skip the last timestep that we did above. for (int64_t t=max_input_length-2; t>=0; t--) { __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item if ((t < input_length - 1) && (s < 2 * target_length + 1)) { scalar_t cur_logprob = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime]; cur_logprob += (s % 2 == 1) ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0; scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s]; scalar_t lbmax = lb1; scalar_t lb2, lb3; if (s < 2*target_length) { lb2 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+1)]; if (lb2 > lbmax) lbmax = lb2; } else { lb2 = neginf; } if (have_three) { lb3 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+2)]; if (lb3 > lbmax) lbmax = lb3; } else { lb3 = neginf; } if (lbmax == neginf) lbmax = 0; scalar_t lb = std::log(std::exp(lb1-lbmax)+std::exp(lb2-lbmax)+std::exp(lb3-lbmax))+lbmax + cur_logprob; log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb; } else if ( (s < 2 * max_target_length + 1) && (((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) || (t >= input_length))) { log_beta_data [lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = neginf; } } } } // This implements the subtrahend of equation (16) for all *nonblank* characters. // It assumes you have probs in gradient_data when called // and it modifies gradient_data to be, the gradient. // In order to facilitate this inplace update, We don't actually do this in logspace. // (The other variant implemented uses log_space and the differences seem to be // not so problematic at least with unit normal distributed test activations.) // Internally this uses atomicAdd because different threads may write to the same // gradient position. // This is parallelised over b and s again. // Note that for us, the Z of eqn (16) is actually constant for all t and it is the // likelihood - this is why we use the negative log likelihood below. // We also multiply by the input gradient to keep with standard autograd style. // I took this trick from [2], for moderate alphabet sizes a log-space // calculation (with an atomic log add) is similarly in performance, but for large // alphabets the inplace nature is a considerable advantage. template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data, const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, const scalar_t* __restrict__ log_realvalues_data, const scalar_t* __restrict__ neg_log_likelihood_data, int64_t gr_batch_stride, int64_t gr_input_stride, int64_t gr_char_stride, int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity) { int64_t b = threadIdx.y + blockIdx.y * blockDim.y; int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime! if (b >= batch_size) return; int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t gr_batch_offset = b*gr_batch_stride; int64_t lp_batch_offset = b*lp_batch_stride; int64_t la_batch_offset = b*la_batch_stride; int64_t lb_batch_offset = b*lb_batch_stride; int64_t lr_batch_offset = b*lr_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; if (s >= target_length) return; int64_t target = targets_data[tg_batch_offset + s * tg_target_stride]; scalar_t nll = neg_log_likelihood_data[b]; scalar_t gr = grad_out_data[b * grad_out_batch_stride]; if (zero_infinity && nll == std::numeric_limits::infinity()) return; for (int64_t t = 0; t < input_length; t++) { scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * target]; scalar_t log_alpha_beta = log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s*2+1)] + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s*2+1)]; scalar_t log_prod_n = log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s]; scalar_t log_alpha_beta_div_pr = log_alpha_beta - log_prod_n; gpuAtomicAddNoReturn(&gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * target], -std::exp(log_alpha_beta_div_pr + nll - lp) * gr); } } template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_backward_collect_realvalue_gpu_kernel(scalar_t* __restrict__ gradient_realval_data, const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, const scalar_t* log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, const scalar_t* __restrict__ realval_data, int64_t num_realval, const scalar_t* __restrict__ targets_realval_data, const scalar_t* __restrict__ log_realvalues_data, const scalar_t* __restrict__ neg_log_likelihood_data, int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride, int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride, int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { //constexpr scalar_t neginf = -std::numeric_limits::infinity(); int64_t b = threadIdx.y + blockIdx.y * blockDim.y; int64_t t = threadIdx.x + blockIdx.x * blockDim.x; if ((t >= max_input_length) || (b >= batch_size)) return; //int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t lp_batch_offset = b*lp_batch_stride; int64_t la_batch_offset = b*la_batch_stride; int64_t lb_batch_offset = b*lb_batch_stride; int64_t rv_batch_offset = b*rv_batch_stride; int64_t rvt_batch_offset = b*rvt_batch_stride; int64_t lr_batch_offset = b*lr_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; scalar_t nll = neg_log_likelihood_data[b]; scalar_t gr = grad_out_data[b * grad_out_batch_stride]; // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s] for (int s = 0; s < max_target_length; s++) { if (s < target_length) { int64_t current_target_prime = get_target_prime( targets_data, tg_batch_offset, tg_target_stride, s * 2 + 1, BLANK); scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime]; scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s * 2 + 1)] + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s * 2 + 1)]); scalar_t log_prod_n = log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s]; if (current_target_prime != BLANK && current_target_prime != BLANK_1) { scalar_t log_term1 = log_alpha_beta - lp - 2 * log_prod_n; for (int64_t i = 0; i != num_realval; ++i) { scalar_t log_constant_factors = log_prod_n - custom_distance_forward_log( targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i], realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i], static_cast(sigma) ); scalar_t grad_dp_dmu = std::exp(log_term1 + log_constant_factors + nll) * custom_distance_backward( targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i], realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i], static_cast(sigma) ); gradient_realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i] += -grad_dp_dmu * gr; } } } } } // This is the naive implementation of equation (16). It is parallelised in batch and input timestep. // It appears to be faster than the above method for small batch sizes. template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, const scalar_t* __restrict__ log_realvalues_data, const scalar_t* __restrict__ neg_log_likelihood_data, int64_t gr_batch_stride, int64_t gr_input_stride, int64_t gr_char_stride, int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); int64_t b = threadIdx.y + blockIdx.y * blockDim.y; int64_t t = threadIdx.x + blockIdx.x * blockDim.x; if ((t >= max_input_length) || (b >= batch_size)) return; int64_t input_length = input_lengths[b]; int64_t target_length = target_lengths[b]; int64_t gr_batch_offset = b*gr_batch_stride; int64_t lp_batch_offset = b*lp_batch_stride; int64_t la_batch_offset = b*la_batch_stride; int64_t lb_batch_offset = b*lb_batch_stride; int64_t lr_batch_offset = b*lr_batch_stride; int64_t tg_batch_offset = tg_batch_offsets[b]; // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s] for (int s = 0; s < 2*max_target_length+1; s++) { if (s < 2 * target_length + 1) { // if target_length == 0, s == 0 int64_t current_target_prime = get_target_prime( targets_data, tg_batch_offset, tg_target_stride, s, BLANK); scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]); scalar_t log_prod_n = s % 2 == 1 ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0; scalar_t log_alpha_beta_div_pr = log_alpha_beta - log_prod_n; scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime]; if (lcab == neginf) { lcab = log_alpha_beta_div_pr; } else { scalar_t max = ((lcab > log_alpha_beta_div_pr) ? lcab : log_alpha_beta_div_pr); lcab = std::log(std::exp(lcab-max)+std::exp(log_alpha_beta_div_pr-max))+max; } } } scalar_t nll = neg_log_likelihood_data[b]; scalar_t gr = grad_out_data[b * grad_out_batch_stride]; for (int64_t c = 0; c < num_labels; c++) { scalar_t& res = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * c]; if (t < input_length && (! zero_infinity || nll != std::numeric_limits::infinity())) { scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * c]; res = (std::exp(lp)-std::exp(res + nll - lp)) * gr; } else { res = 0.; } } } // This is to zero gradients which corresponding to the out-of-sequence position // Those gradients should not be used in any model update since the input // elements are padded template __global__ void #if defined (__HIP_PLATFORM_HCC__) C10_LAUNCH_BOUNDS_2((std::is_same::value ? 1024 : 896), 1) #endif ctc_loss_zero_padded_gradients( scalar_t* __restrict__ gradient_data, /* (T, B, D) layout */ const int64_t* __restrict__ input_lengths, /* (B, ) layout */ int64_t gr_batch_stride, int64_t gr_timestep_stride, int64_t gr_label_stride, int64_t batch_size, /* B */ int64_t max_input_length, /* T */ int64_t num_labels /* D */ ) { int64_t b = threadIdx.y + blockIdx.y * blockDim.y; int64_t t = threadIdx.x + blockIdx.x * blockDim.x; if (b >= batch_size || t >= max_input_length) { return; } scalar_t input_length = input_lengths[b]; if (t >= input_length) { for (int l = 0; l < num_labels; l++) gradient_data[ b * gr_batch_stride + t * gr_timestep_stride + l * gr_label_stride] = 0.0f; } } // The backward. It essentially computes eq 16 by using the above kernels. // We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward. template std::tuple custom_ctc_loss_backward_gpu_template( const Tensor& grad_out, const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, IntArrayRef input_lengths, IntArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, scalar_t const sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity ) { constexpr scalar_t neginf = -std::numeric_limits::infinity(); using target_t = typename std::conditional::type; int64_t batch_size = log_probs.size(0); int64_t num_realvals = realval.size(2); int64_t num_labels = log_probs.size(2); int64_t lp_input_stride = log_probs.stride(1); int64_t lp_char_stride = log_probs.stride(2); int64_t tg_target_stride; int64_t max_target_length; auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong))); auto tg_batch_offsets_data = tg_batch_offsets.data_ptr(); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; max_target_length = 0; for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets_data[i] = pos; pos += target_lengths[i]; if (max_target_length < target_lengths[i]) max_target_length = target_lengths[i]; } tg_target_stride = targets.stride(0); } else { // batch x max_target_length // dim is 2 int64_t tg_batch_stride = targets.stride(0); for (int64_t i = 0; i < batch_size; i++) { tg_batch_offsets_data[i] = i * tg_batch_stride; } tg_target_stride = targets.stride(1); max_target_length = log_alpha.size(2)/2; // targets.size(1) might be larger } auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong)); auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong)); tg_batch_offsets = tg_batch_offsets.cuda(); Tensor log_realvalues = at::zeros({batch_size, log_probs.size(1), std::max(max_target_length, int64_t(1))}, log_alpha.options()); Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT); log_beta.fill_(neginf); Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta)) Tensor grad_realval = at::full_like(realval, 0, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for sum (d realvalue) // As above, there may be better configurations to use. constexpr int max_threads = std::is_same::value ? 1024 : 896; // we need 72 or so 32 bit registers for double int threads_target = max_threads; while (threads_target / 2 >= 2*max_target_length+1) { threads_target /= 2; } int threads_batch = std::min(max_threads / threads_target, (int) batch_size); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); { int threads_target = max_threads; while (threads_target / 2 >= max_target_length && threads_target > 1) { threads_target /= 2; } int threads_batch = std::min(max_threads / threads_target, (int) batch_size); dim3 block(threads_target, threads_batch); dim3 grid( std::max( (max_target_length + threads_target - 1) / threads_target, 1), (batch_size + threads_batch - 1) / threads_batch, 1); ctc_loss_collect_log_realvalues_gpu_kernel<<>> (log_realvalues.data_ptr(), input_lengths_t.data_ptr(), targets.data_ptr(), target_lengths_t.data_ptr(), realval.data_ptr(), num_realvals, targets_realval.data_ptr(), log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), realval.stride(0), realval.stride(1), realval.stride(2), targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, num_labels, sigma, BLANK, BLANK_1); C10_CUDA_KERNEL_LAUNCH_CHECK(); } { dim3 block(threads_target, threads_batch); dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch); ctc_loss_backward_log_beta_gpu_kernel<<>> (log_beta.data_ptr(), log_probs.data_ptr(), input_lengths_t.data_ptr(), log_probs.size(1), targets.data_ptr(), target_lengths_t.data_ptr(), max_target_length, log_realvalues.data_ptr(), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, sigma, BLANK, BLANK_1); C10_CUDA_KERNEL_LAUNCH_CHECK(); } // Very crude heuristic for what is a small problem., based on linearly regressing problem dimensions on // the (capped) difference of timings. // Note that for OK problems target length <= input length, so we // only consider input length. bool is_large = (2*log_probs.size(1)+(24*batch_size)/10+(2*num_labels)/10) > 450; if (is_large) { // large alphabet, large batch // this computes the probs, minuend in (16) at::exp_out(grad, log_probs); // now we compute the subtrahend for the blanks. It is a straightforward reduction because we know that // blanks are in every other position. // maybe we should kernelize this, too. auto grad_blank = grad.narrow(2, BLANK, 1); grad_blank -= (at::logsumexp(log_alpha.as_strided({batch_size, log_alpha.size(1), max_target_length+1}, {log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2)*2}) + log_beta.as_strided({batch_size, log_beta.size(1), max_target_length+1}, {log_beta.stride(0), log_beta.stride(1), log_beta.stride(2)*2}), 2, true) .add_(neg_log_likelihood.view({batch_size, 1, 1})) .sub_(log_probs.narrow(2, BLANK, 1)) .exp_() ); // scale by output gradient (blanks and first summand of non-blanks) grad *= grad_out.view({batch_size, 1, 1}); if (zero_infinity) { grad = at::where(neg_log_likelihood.view({batch_size, 1, 1}) == Scalar(std::numeric_limits::infinity()), at::zeros({}, grad.options()), grad); } // For the non-blank characters, we use a kernel to compute the subtrahend. // Again we might configure block and grid in a better way. int threads_target = max_threads; while (threads_target / 2 >= max_target_length && threads_target > 1) { threads_target /= 2; } int threads_batch = std::min(max_threads / threads_target, (int) batch_size); dim3 block(threads_target, threads_batch); dim3 grid( std::max( (max_target_length + threads_target - 1) / threads_target, 1), (batch_size + threads_batch - 1) / threads_batch, 1); ctc_loss_backward_collect_nonblank_gpu_kernel<<>> (grad.data_ptr(), grad_out.data_ptr(), grad_out.stride(0), log_alpha.data_ptr(), log_beta.data_ptr(), log_probs.data_ptr(), input_lengths_t.data_ptr(), log_probs.size(1), targets.data_ptr(), target_lengths_t.data_ptr(), max_target_length, log_realvalues.data_ptr(), neg_log_likelihood.data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, num_labels, sigma, BLANK, BLANK_1, zero_infinity); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // small problem, use naive algorithm // Still no block/grid configuration guru... int threads_input = max_threads; while (threads_input / 2 >= log_probs.size(1) && threads_input > 1) { threads_input /= 2; } threads_batch = std::min(max_threads / threads_input, (int) batch_size); dim3 block(threads_input, threads_batch); dim3 grid((log_probs.size(1) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch); ctc_loss_backward_collect_gpu_kernel<<>> (grad.data_ptr(), grad_out.data_ptr(), grad_out.stride(0), log_alpha.data_ptr(), log_beta.data_ptr(), log_probs.data_ptr(), input_lengths_t.data_ptr(), log_probs.size(1), targets.data_ptr(), target_lengths_t.data_ptr(), max_target_length, log_realvalues.data_ptr(), neg_log_likelihood.data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, num_labels, sigma, BLANK, BLANK_1, zero_infinity); C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors } // collect real value gradients { int threads_input = max_threads; while (threads_input / 2 >= log_probs.size(1) && threads_input > 1) { threads_input /= 2; } threads_input = 512; threads_batch = std::min(max_threads / threads_input, (int) batch_size); threads_batch = 1; //threads_batch = threads_batch >> 4; //std::cout << "threads_input=" << threads_input << ",threads_batch=" << threads_batch << "\n"; dim3 block(threads_input, threads_batch); dim3 grid((log_probs.size(1) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch); ctc_loss_backward_collect_realvalue_gpu_kernel<<>> (grad_realval.data_ptr(), grad_out.data_ptr(), grad_out.stride(0), log_alpha.data_ptr(), log_beta.data_ptr(), log_probs.data_ptr(), input_lengths_t.data_ptr(), log_probs.size(1), targets.data_ptr(), target_lengths_t.data_ptr(), max_target_length, realval.data_ptr(), num_realvals, targets_realval.data_ptr(), log_realvalues.data_ptr(), neg_log_likelihood.data_ptr(), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), realval.stride(0), realval.stride(1), realval.stride(2), targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2), log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), tg_batch_offsets.data_ptr(), tg_target_stride, batch_size, num_labels, sigma, BLANK, BLANK_1); C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors } // zero those invalid gradient elements due to padding { int threads_input = max_threads; while (threads_input / 2 >= log_probs.size(1)) { threads_input /= 2; } threads_batch = std::min(max_threads / threads_input, (int) batch_size); dim3 block(threads_input, threads_batch); dim3 grid( (log_probs.size(1) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch); ctc_loss_zero_padded_gradients<<>>( grad.data_ptr(), input_lengths_t.data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), grad.size(0), grad.size(1), grad.size(2) ); C10_CUDA_KERNEL_LAUNCH_CHECK(); } return std::make_tuple(grad, grad_realval); } std::tuple custom_ctc_loss_gpu( const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, IntArrayRef input_lengths, IntArrayRef target_lengths, double const sigma, int64_t BLANK, int64_t BLANK_1 ) { return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "custom_ctc_loss_cuda", [&] { if (targets.scalar_type() == kLong) { return custom_ctc_loss_gpu_template(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, static_cast(sigma), BLANK, BLANK_1); } else { return custom_ctc_loss_gpu_template(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, static_cast(sigma), BLANK, BLANK_1); } }); } std::tuple custom_ctc_loss_backward_gpu( const Tensor& grad, const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, IntArrayRef input_lengths, IntArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, double const sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity ) { // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("ctc_loss_backward_gpu"); return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "custom_ctc_loss_backward_cuda", [&] { if (targets.scalar_type() == kLong) { return custom_ctc_loss_backward_gpu_template(grad, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, static_cast(sigma), BLANK, BLANK_1, zero_infinity); } else { return custom_ctc_loss_backward_gpu_template(grad, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, static_cast(sigma), BLANK, BLANK_1, zero_infinity); } }); }