# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # from typing import Optional import six import torch import numpy as np def sequence_mask( lengths, maxlen: Optional[int] = None, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: if maxlen is None: maxlen = lengths.max() row_vector = torch.arange(0, maxlen, 1).to(lengths.device) matrix = torch.unsqueeze(lengths, dim=-1) mask = row_vector < matrix mask = mask.detach() return mask.type(dtype).to(device) if device is not None else mask.type(dtype) def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))): """End detection. described in Eq. (50) of S. Watanabe et al "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" :param ended_hyps: :param i: :param M: :param d_end: :return: """ if len(ended_hyps) == 0: return False count = 0 best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] for m in six.moves.range(M): # get ended_hyps with their length is i - m hyp_length = i - m hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] if len(hyps_same_length) > 0: best_hyp_same_length = sorted( hyps_same_length, key=lambda x: x["score"], reverse=True )[0] if best_hyp_same_length["score"] - best_hyp["score"] < d_end: count += 1 if count == M: return True else: return False