conex / espnet /nets /transducer_decoder_interface.py
tobiasc's picture
Initial commit
ad16788
"""Transducer decoder interface module."""
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
@dataclass
class Hypothesis:
"""Default hypothesis definition for beam search."""
score: float
yseq: List[int]
dec_state: Union[
Tuple[torch.Tensor, Optional[torch.Tensor]], List[torch.Tensor], torch.Tensor
]
lm_state: Union[Dict[str, Any], List[Any]] = None
@dataclass
class NSCHypothesis(Hypothesis):
"""Extended hypothesis definition for NSC beam search."""
y: List[torch.Tensor] = None
lm_scores: torch.Tensor = None
class TransducerDecoderInterface:
"""Decoder interface for transducer models."""
def init_state(
self,
batch_size: int,
device: torch.device,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
]:
"""Initialize decoder states.
Args:
batch_size: Batch size for initial state
device: Device for initial state
Returns:
state: Initialized state
"""
raise NotImplementedError("init_state method is not implemented")
def score(
self,
hyp: Union[Hypothesis, NSCHypothesis],
cache: Dict[str, Any],
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
torch.Tensor,
List[Optional[torch.Tensor]],
]:
"""Forward one hypothesis.
Args:
hyp: Hypothesis.
cache: Pairs of (y, state) for each token sequence (key)
Returns:
y: Decoder outputs
new_state: New decoder state
lm_tokens: Token id for LM
"""
raise NotImplementedError("score method is not implemented")
def batch_score(
self,
hyps: Union[List[Hypothesis], List[NSCHypothesis]],
batch_states: Union[
Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
],
cache: Dict[str, Any],
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
torch.Tensor,
List[Optional[torch.Tensor]],
]:
"""Forward batch of hypotheses.
Args:
hyps: Batch of hypotheses
batch_states: Batch of decoder states
cache: pairs of (y, state) for each token sequence (key)
Returns:
batch_y: Decoder outputs
batch_states: Batch of decoder states
lm_tokens: Batch of token ids for LM
"""
raise NotImplementedError("batch_score method is not implemented")
def select_state(
self,
batch_states: Union[
Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
],
idx: int,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
]:
"""Get decoder state from batch for given id.
Args:
batch_states: Batch of decoder states
idx: Index to extract state from batch
Returns:
state_idx: Decoder state for given id
"""
raise NotImplementedError("select_state method is not implemented")
def create_batch_states(
self,
batch_states: Union[
Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
],
l_states: List[
Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
List[Optional[torch.Tensor]],
]
],
l_tokens: List[List[int]],
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]]
]:
"""Create batch of decoder states.
Args:
batch_states: Batch of decoder states
l_states: List of decoder states
l_tokens: List of token sequences for input batch
Returns:
batch_states: Batch of decoder states
"""
raise NotImplementedError("create_batch_states method is not implemented")