auto_avsr / espnet /nets /scorer_interface.py
mpc001's picture
Upload 125 files
09481f3
raw
history blame
No virus
5.94 kB
"""Scorer interface module."""
from typing import Any
from typing import List
from typing import Tuple
import torch
import warnings
class ScorerInterface:
"""Scorer interface for beam search.
The scorer performs scoring of the all tokens in vocabulary.
Examples:
* Search heuristics
* :class:`espnet.nets.scorers.length_bonus.LengthBonus`
* Decoder networks of the sequence-to-sequence models
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
* Neural language models
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
"""
def init_state(self, x: torch.Tensor) -> Any:
"""Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return None
def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label index to select a state if necessary
Returns:
state: pruned state
"""
return None if state is None else state[i]
def score(
self, y: torch.Tensor, state: Any, x: torch.Tensor
) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
scores for next token that has a shape of `(n_vocab)`
and next state for ys
"""
raise NotImplementedError
def final_score(self, state: Any) -> float:
"""Score eos (optional).
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return 0.0
class BatchScorerInterface(ScorerInterface):
"""Batch scorer interface."""
def batch_init_state(self, x: torch.Tensor) -> Any:
"""Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return self.init_state(x)
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
warnings.warn(
"{} batch score is implemented through for loop not parallelized".format(
self.__class__.__name__
)
)
scores = list()
outstates = list()
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
score, outstate = self.score(y, state, x)
outstates.append(outstate)
scores.append(score)
scores = torch.cat(scores, 0).view(ys.shape[0], -1)
return scores, outstates
class PartialScorerInterface(ScorerInterface):
"""Partial scorer interface for beam search.
The partial scorer performs scoring when non-partial scorer finished scoring,
and receives pre-pruned next tokens to score because it is too heavy to score
all the tokens.
Examples:
* Prefix search for connectionist-temporal-classification models
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
"""
def score_partial(
self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor
) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
raise NotImplementedError
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
"""Batch partial scorer interface for beam search."""
def batch_score_partial(
self,
ys: torch.Tensor,
next_tokens: torch.Tensor,
states: List[Any],
xs: torch.Tensor,
) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
and next states for ys
"""
raise NotImplementedError