File size: 2,401 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
#include <torch/extension.h>
#include <vector>
using torch::Tensor;
using torch::IntArrayRef;
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu(
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
double const sigma,
int64_t BLANK,
int64_t BLANK_1
);
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu(
const Tensor& grad,
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
const Tensor& neg_log_likelihood,
const Tensor& log_alpha,
double const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
);
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu_driver(
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
const Tensor& input_lengths,
const Tensor& target_lengths,
double const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
) {
(void)zero_infinity;
Tensor ilc = input_lengths.contiguous();
Tensor tlc = target_lengths.contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
return custom_ctc_loss_gpu(log_probs, targets, realval, targets_realval, il, tl, sigma, BLANK, BLANK_1);
}
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu_driver(
const Tensor& grad,
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
const Tensor& input_lengths,
const Tensor& target_lengths,
const Tensor& neg_log_likelihood,
const Tensor& log_alpha,
double const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
) {
Tensor ilc = input_lengths.contiguous();
Tensor tlc = target_lengths.contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
return custom_ctc_loss_backward_gpu(grad, log_probs, targets, realval, targets_realval, il, tl, neg_log_likelihood, log_alpha, sigma, BLANK, BLANK_1, zero_infinity);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &custom_ctc_loss_gpu_driver, "custom CTC forward (CUDA)");
m.def("backward", &custom_ctc_loss_backward_gpu_driver, "custom CTC backward (CUDA)");
}
|