|
"""Length bonus module.""" |
|
from typing import Any |
|
from typing import List |
|
from typing import Tuple |
|
|
|
import torch |
|
|
|
from espnet.nets.scorer_interface import BatchScorerInterface |
|
|
|
|
|
class LengthBonus(BatchScorerInterface): |
|
"""Length bonus in beam search.""" |
|
|
|
def __init__(self, n_vocab: int): |
|
"""Initialize class. |
|
|
|
Args: |
|
n_vocab (int): The number of tokens in vocabulary for beam search |
|
|
|
""" |
|
self.n = n_vocab |
|
|
|
def score(self, y, state, x): |
|
"""Score new token. |
|
|
|
Args: |
|
y (torch.Tensor): 1D torch.int64 prefix tokens. |
|
state: Scorer state for prefix tokens |
|
x (torch.Tensor): 2D encoder feature that generates ys. |
|
|
|
Returns: |
|
tuple[torch.Tensor, Any]: Tuple of |
|
torch.float32 scores for next token (n_vocab) |
|
and None |
|
|
|
""" |
|
return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None |
|
|
|
def batch_score( |
|
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
|
) -> Tuple[torch.Tensor, List[Any]]: |
|
"""Score new token batch. |
|
|
|
Args: |
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
|
states (List[Any]): Scorer states for prefix tokens. |
|
xs (torch.Tensor): |
|
The encoder feature that generates ys (n_batch, xlen, n_feat). |
|
|
|
Returns: |
|
tuple[torch.Tensor, List[Any]]: Tuple of |
|
batchfied scores for next token with shape of `(n_batch, n_vocab)` |
|
and next state list for ys. |
|
|
|
""" |
|
return ( |
|
torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( |
|
ys.shape[0], self.n |
|
), |
|
None, |
|
) |
|
|