# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # """Label smoothing module.""" import torch from torch import nn class LabelSmoothingLoss(nn.Module): """Label-smoothing loss. In a standard CE loss, the label's data distribution is: [0,1,2] -> [ [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], ] In the smoothing version CE Loss,some probabilities are taken from the true label prob (1.0) and are divided among other labels. e.g. smoothing=0.1 [0,1,2] -> [ [0.9, 0.05, 0.05], [0.05, 0.9, 0.05], [0.05, 0.05, 0.9], ] Args: size (int): the number of class padding_idx (int): padding class id which will be ignored for loss smoothing (float): smoothing rate (0.0 means the conventional CE) normalize_length (bool): normalize loss by sequence length if True normalize loss by batch size if False """ def __init__( self, size: int, padding_idx: int, smoothing: float, normalize_length: bool = False, ): """Construct an LabelSmoothingLoss object.""" super(LabelSmoothingLoss, self).__init__() self.criterion = nn.KLDivLoss(reduction="none") self.padding_idx = padding_idx self.confidence = 1.0 - smoothing self.smoothing = smoothing self.size = size self.normalize_length = normalize_length def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute loss between x and target. The model outputs and data labels tensors are flatten to (batch*seqlen, class) shape and a mask is applied to the padding part which should not be calculated for loss. Args: x (torch.Tensor): prediction (batch, seqlen, class) target (torch.Tensor): target signal masked with self.padding_id (batch, seqlen) Returns: loss (torch.Tensor) : The KL loss, scalar float value """ assert x.size(2) == self.size batch_size = x.size(0) x = x.view(-1, self.size) target = target.view(-1) # use zeros_like instead of torch.no_grad() for true_dist, # since no_grad() can not be exported by JIT true_dist = torch.zeros_like(x) true_dist.fill_(self.smoothing / (self.size - 1)) ignore = target == self.padding_idx # (B,) total = len(target) - ignore.sum().item() target = target.masked_fill(ignore, 0) # avoid -1 index true_dist.scatter_(1, target.unsqueeze(1), self.confidence) kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) denom = total if self.normalize_length else batch_size return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom