# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # import torch import torch.nn.functional as F class CTC(torch.nn.Module): """CTC module""" def __init__( self, odim: int, encoder_output_size: int, dropout_rate: float = 0.0, reduce: bool = True, ): """Construct CTC module Args: odim: dimension of outputs encoder_output_size: number of encoder projection units dropout_rate: dropout rate (0.0 ~ 1.0) reduce: reduce the CTC loss into a scalar """ super().__init__() eprojs = encoder_output_size self.dropout_rate = dropout_rate self.ctc_lo = torch.nn.Linear(eprojs, odim) reduction_type = "sum" if reduce else "none" self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_pad: torch.Tensor, ys_lens: torch.Tensor, ) -> torch.Tensor: """Calculate CTC loss. Args: hs_pad: batch of padded hidden state sequences (B, Tmax, D) hlens: batch of lengths of hidden state sequences (B) ys_pad: batch of padded character id sequence tensor (B, Lmax) ys_lens: batch of lengths of character sequence (B) """ # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) # ys_hat: (B, L, D) -> (L, B, D) ys_hat = ys_hat.transpose(0, 1) ys_hat = ys_hat.log_softmax(2) loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) # Batch-size average loss = loss / ys_hat.size(1) return loss def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: """log_softmax of frame activations Args: Tensor hs_pad: 3d tensor (B, Tmax, eprojs) Returns: torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) """ return F.log_softmax(self.ctc_lo(hs_pad), dim=2) def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: """argmax of frame activations Args: torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) Returns: torch.Tensor: argmax applied 2d tensor (B, Tmax) """ return torch.argmax(self.ctc_lo(hs_pad), dim=2)