Spaces:
Running
Running
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
"""Scorer interface module.""" | |
from abc import ABC | |
from typing import Any | |
from typing import List | |
from typing import Tuple | |
import torch | |
class ScorerInterface: | |
"""Scorer interface for beam search. | |
The scorer performs scoring of the all tokens in vocabulary. | |
Examples: | |
* Search heuristics | |
* :class:`espnet.nets.scorers.length_bonus.LengthBonus` | |
* Decoder networks of the sequence-to-sequence models | |
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder | |
.Decoder` | |
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` | |
* Neural language models | |
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` | |
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` | |
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` | |
""" | |
def init_state(self, x: torch.Tensor) -> Any: | |
"""Get an initial state for decoding (optional). | |
Args: | |
x (torch.Tensor): The encoded feature tensor | |
Returns: initial state | |
""" | |
return None | |
def select_state(self, state: Any, i: int, new_id: int = None) -> Any: | |
"""Select state with relative ids in the main beam search. | |
Args: | |
state: Decoder state for prefix tokens | |
i (int): Index to select a state in the main beam search | |
new_id (int): New label index to select a state if necessary | |
Returns: | |
state: pruned state | |
""" | |
return None if state is None else state[i] | |
def score( | |
self, y: torch.Tensor, state: Any, x: torch.Tensor | |
) -> Tuple[torch.Tensor, Any]: | |
"""Score new token (required). | |
Args: | |
y (torch.Tensor): 1D torch.int64 prefix tokens. | |
state: Scorer state for prefix tokens | |
x (torch.Tensor): The encoder feature that generates ys. | |
Returns: | |
tuple[torch.Tensor, Any]: Tuple of | |
scores for next token that has a shape of `(n_vocab)` | |
and next state for ys | |
""" | |
raise NotImplementedError | |
def final_score(self, state: Any) -> float: | |
"""Score eos (optional). | |
Args: | |
state: Scorer state for prefix tokens | |
Returns: | |
float: final score | |
""" | |
return 0.0 | |
class BatchScorerInterface(ScorerInterface, ABC): | |
"""Batch scorer interface.""" | |
def batch_init_state(self, x: torch.Tensor) -> Any: | |
"""Get an initial state for decoding (optional). | |
Args: | |
x (torch.Tensor): The encoded feature tensor | |
Returns: initial state | |
""" | |
return self.init_state(x) | |
def batch_score( | |
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor | |
) -> Tuple[torch.Tensor, List[Any]]: | |
"""Score new token batch (required). | |
Args: | |
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
states (List[Any]): Scorer states for prefix tokens. | |
xs (torch.Tensor): | |
The encoder feature that generates ys (n_batch, xlen, n_feat). | |
Returns: | |
tuple[torch.Tensor, List[Any]]: Tuple of | |
batchfied scores for next token with shape of `(n_batch, | |
n_vocab)` | |
and next state list for ys. | |
""" | |
scores = list() | |
outstates = list() | |
for i, (y, state, x) in enumerate(zip(ys, states, xs)): | |
score, outstate = self.score(y, state, x) | |
outstates.append(outstate) | |
scores.append(score) | |
scores = torch.cat(scores, 0).view(ys.shape[0], -1) | |
return scores, outstates | |
class PartialScorerInterface(ScorerInterface, ABC): | |
"""Partial scorer interface for beam search. | |
The partial scorer performs scoring when non-partial scorer finished scoring | |
and receives pre-pruned next tokens to score because it is too heavy to | |
score all the tokens. | |
Examples: | |
* Prefix search for connectionist-temporal-classification models | |
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` | |
""" | |
def score_partial( | |
self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor | |
) -> Tuple[torch.Tensor, Any]: | |
"""Score new token (required). | |
Args: | |
y (torch.Tensor): 1D prefix token | |
next_tokens (torch.Tensor): torch.int64 next token to score | |
state: decoder state for prefix tokens | |
x (torch.Tensor): The encoder feature that generates ys | |
Returns: | |
tuple[torch.Tensor, Any]: | |
Tuple of a score tensor for y that has a shape | |
`(len(next_tokens),)` and next state for ys | |
""" | |
raise NotImplementedError | |
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface, ABC): | |
"""Batch partial scorer interface for beam search.""" | |
def batch_score_partial( | |
self, | |
ys: torch.Tensor, | |
next_tokens: torch.Tensor, | |
states: List[Any], | |
xs: torch.Tensor, | |
) -> Tuple[torch.Tensor, Any]: | |
"""Score new token (required). | |
Args: | |
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, | |
n_token). | |
states (List[Any]): Scorer states for prefix tokens. | |
xs (torch.Tensor): | |
The encoder feature that generates ys (n_batch, xlen, n_feat). | |
Returns: | |
tuple[torch.Tensor, Any]: | |
Tuple of a score tensor for ys that has a shape `(n_batch, | |
n_vocab)` | |
and next states for ys | |
""" | |
raise NotImplementedError | |