import torch import custom_ctc_cpp Tensor = torch.Tensor class CustomCTCLossFunction(torch.autograd.Function): @staticmethod def forward( ctx, log_probs: Tensor, targets: Tensor, realval: Tensor, targets_realval: Tensor, input_lengths: Tensor, target_lengths: Tensor, sigma: float = 1, blank: int = 0, blank1: int = 0, reduction: str = "mean", zero_infinity: bool = False ): assert reduction in ['none', 'mean'] if isinstance(input_lengths, list) : input_lengths = Tensor(input_lengths).long().to(log_probs.device) if isinstance(target_lengths, list) : target_lengths = Tensor(target_lengths).long().to(log_probs.device) neg_log_likelihood, log_alpha = custom_ctc_cpp.forward(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, sigma, blank, blank1, zero_infinity) ctx.save_for_backward(neg_log_likelihood, log_alpha, log_probs, targets, realval, targets_realval, input_lengths, target_lengths) ctx.blank = blank ctx.blank1 = blank1 ctx.zero_infinity = zero_infinity ctx.sigma = sigma ctx.reduction = reduction if reduction == 'mean' : return (neg_log_likelihood / target_lengths.clamp_min(1)).mean() return neg_log_likelihood @staticmethod def backward(ctx, grad_out): neg_log_likelihood, log_alpha, log_probs, targets, realval, targets_realval, input_lengths, target_lengths = ctx.saved_tensors if ctx.reduction == 'mean' : if grad_out.numel() == 0 : grad_out = torch.ones_like(neg_log_likelihood) else : grad_out = grad_out.view(1).tile(neg_log_likelihood.size(0)) grad_out /= target_lengths.clamp_min(1) grad_out /= log_probs.size(0) outputs_cls, outputs_realval = custom_ctc_cpp.backward(grad_out, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, ctx.sigma, ctx.blank, ctx.blank1, ctx.zero_infinity) return outputs_cls, None, outputs_realval, None, None, None, None, None, None, None, None