# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # """Scorer interface module.""" from abc import ABC from typing import Any from typing import List from typing import Tuple import torch class ScorerInterface: """Scorer interface for beam search. The scorer performs scoring of the all tokens in vocabulary. Examples: * Search heuristics * :class:`espnet.nets.scorers.length_bonus.LengthBonus` * Decoder networks of the sequence-to-sequence models * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder .Decoder` * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` * Neural language models * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` """ def init_state(self, x: torch.Tensor) -> Any: """Get an initial state for decoding (optional). Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ return None def select_state(self, state: Any, i: int, new_id: int = None) -> Any: """Select state with relative ids in the main beam search. Args: state: Decoder state for prefix tokens i (int): Index to select a state in the main beam search new_id (int): New label index to select a state if necessary Returns: state: pruned state """ return None if state is None else state[i] def score( self, y: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: """Score new token (required). Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): The encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of scores for next token that has a shape of `(n_vocab)` and next state for ys """ raise NotImplementedError def final_score(self, state: Any) -> float: """Score eos (optional). Args: state: Scorer state for prefix tokens Returns: float: final score """ return 0.0 class BatchScorerInterface(ScorerInterface, ABC): """Batch scorer interface.""" def batch_init_state(self, x: torch.Tensor) -> Any: """Get an initial state for decoding (optional). Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ return self.init_state(x) def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch (required). Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ scores = list() outstates = list() for i, (y, state, x) in enumerate(zip(ys, states, xs)): score, outstate = self.score(y, state, x) outstates.append(outstate) scores.append(score) scores = torch.cat(scores, 0).view(ys.shape[0], -1) return scores, outstates class PartialScorerInterface(ScorerInterface, ABC): """Partial scorer interface for beam search. The partial scorer performs scoring when non-partial scorer finished scoring and receives pre-pruned next tokens to score because it is too heavy to score all the tokens. Examples: * Prefix search for connectionist-temporal-classification models * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` """ def score_partial( self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: """Score new token (required). Args: y (torch.Tensor): 1D prefix token next_tokens (torch.Tensor): torch.int64 next token to score state: decoder state for prefix tokens x (torch.Tensor): The encoder feature that generates ys Returns: tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)` and next state for ys """ raise NotImplementedError class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface, ABC): """Batch partial scorer interface for beam search.""" def batch_score_partial( self, ys: torch.Tensor, next_tokens: torch.Tensor, states: List[Any], xs: torch.Tensor, ) -> Tuple[torch.Tensor, Any]: """Score new token (required). Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, Any]: Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` and next states for ys """ raise NotImplementedError