"""Transducer decoder interface module.""" from dataclasses import dataclass from typing import Any from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import Union import torch @dataclass class Hypothesis: """Default hypothesis definition for beam search.""" score: float yseq: List[int] dec_state: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[torch.Tensor], torch.Tensor ] lm_state: Union[Dict[str, Any], List[Any]] = None @dataclass class NSCHypothesis(Hypothesis): """Extended hypothesis definition for NSC beam search.""" y: List[torch.Tensor] = None lm_scores: torch.Tensor = None class TransducerDecoderInterface: """Decoder interface for transducer models.""" def init_state( self, batch_size: int, device: torch.device, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ]: """Initialize decoder states. Args: batch_size: Batch size for initial state device: Device for initial state Returns: state: Initialized state """ raise NotImplementedError("init_state method is not implemented") def score( self, hyp: Union[Hypothesis, NSCHypothesis], cache: Dict[str, Any], ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor, List[Optional[torch.Tensor]], ]: """Forward one hypothesis. Args: hyp: Hypothesis. cache: Pairs of (y, state) for each token sequence (key) Returns: y: Decoder outputs new_state: New decoder state lm_tokens: Token id for LM """ raise NotImplementedError("score method is not implemented") def batch_score( self, hyps: Union[List[Hypothesis], List[NSCHypothesis]], batch_states: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ], cache: Dict[str, Any], ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor, List[Optional[torch.Tensor]], ]: """Forward batch of hypotheses. Args: hyps: Batch of hypotheses batch_states: Batch of decoder states cache: pairs of (y, state) for each token sequence (key) Returns: batch_y: Decoder outputs batch_states: Batch of decoder states lm_tokens: Batch of token ids for LM """ raise NotImplementedError("batch_score method is not implemented") def select_state( self, batch_states: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ], idx: int, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ]: """Get decoder state from batch for given id. Args: batch_states: Batch of decoder states idx: Index to extract state from batch Returns: state_idx: Decoder state for given id """ raise NotImplementedError("select_state method is not implemented") def create_batch_states( self, batch_states: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ], l_states: List[ Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]], ] ], l_tokens: List[List[int]], ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ]: """Create batch of decoder states. Args: batch_states: Batch of decoder states l_states: List of decoder states l_tokens: List of token sequences for input batch Returns: batch_states: Batch of decoder states """ raise NotImplementedError("create_batch_states method is not implemented")