testapi / training /ocr /custom_ctc_cuda_driver.cc
Sunday01's picture
up
9dce458
#include <torch/extension.h>
#include <vector>
using torch::Tensor;
using torch::IntArrayRef;
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu(
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
double const sigma,
int64_t BLANK,
int64_t BLANK_1
);
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu(
const Tensor& grad,
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
const Tensor& neg_log_likelihood,
const Tensor& log_alpha,
double const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
);
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu_driver(
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
const Tensor& input_lengths,
const Tensor& target_lengths,
double const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
) {
(void)zero_infinity;
Tensor ilc = input_lengths.contiguous();
Tensor tlc = target_lengths.contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
return custom_ctc_loss_gpu(log_probs, targets, realval, targets_realval, il, tl, sigma, BLANK, BLANK_1);
}
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu_driver(
const Tensor& grad,
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
const Tensor& input_lengths,
const Tensor& target_lengths,
const Tensor& neg_log_likelihood,
const Tensor& log_alpha,
double const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
) {
Tensor ilc = input_lengths.contiguous();
Tensor tlc = target_lengths.contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
return custom_ctc_loss_backward_gpu(grad, log_probs, targets, realval, targets_realval, il, tl, neg_log_likelihood, log_alpha, sigma, BLANK, BLANK_1, zero_infinity);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &custom_ctc_loss_gpu_driver, "custom CTC forward (CUDA)");
m.def("backward", &custom_ctc_loss_backward_gpu_driver, "custom CTC backward (CUDA)");
}