auto_avsr / espnet /nets /scorers /length_bonus.py
mpc001's picture
Upload 125 files
09481f3
raw
history blame
No virus
1.78 kB
"""Length bonus module."""
from typing import Any
from typing import List
from typing import Tuple
import torch
from espnet.nets.scorer_interface import BatchScorerInterface
class LengthBonus(BatchScorerInterface):
"""Length bonus in beam search."""
def __init__(self, n_vocab: int):
"""Initialize class.
Args:
n_vocab (int): The number of tokens in vocabulary for beam search
"""
self.n = n_vocab
def score(self, y, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and None
"""
return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return (
torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
ys.shape[0], self.n
),
None,
)