Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torchaudio | |
| from torch import nn | |
| from torch.nn.utils.rnn import pad_sequence | |
| from modules.wenet_extractor.transducer.predictor import PredictorBase | |
| from modules.wenet_extractor.transducer.search.greedy_search import basic_greedy_search | |
| from modules.wenet_extractor.transducer.search.prefix_beam_search import ( | |
| PrefixBeamSearch, | |
| ) | |
| from modules.wenet_extractor.transformer.asr_model import ASRModel | |
| from modules.wenet_extractor.transformer.ctc import CTC | |
| from modules.wenet_extractor.transformer.decoder import ( | |
| BiTransformerDecoder, | |
| TransformerDecoder, | |
| ) | |
| from modules.wenet_extractor.transformer.label_smoothing_loss import LabelSmoothingLoss | |
| from modules.wenet_extractor.utils.common import ( | |
| IGNORE_ID, | |
| add_blank, | |
| add_sos_eos, | |
| reverse_pad_list, | |
| ) | |
| class Transducer(ASRModel): | |
| """Transducer-ctc-attention hybrid Encoder-Predictor-Decoder model""" | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| blank: int, | |
| encoder: nn.Module, | |
| predictor: PredictorBase, | |
| joint: nn.Module, | |
| attention_decoder: Optional[ | |
| Union[TransformerDecoder, BiTransformerDecoder] | |
| ] = None, | |
| ctc: Optional[CTC] = None, | |
| ctc_weight: float = 0, | |
| ignore_id: int = IGNORE_ID, | |
| reverse_weight: float = 0.0, | |
| lsm_weight: float = 0.0, | |
| length_normalized_loss: bool = False, | |
| transducer_weight: float = 1.0, | |
| attention_weight: float = 0.0, | |
| ) -> None: | |
| assert attention_weight + ctc_weight + transducer_weight == 1.0 | |
| super().__init__( | |
| vocab_size, | |
| encoder, | |
| attention_decoder, | |
| ctc, | |
| ctc_weight, | |
| ignore_id, | |
| reverse_weight, | |
| lsm_weight, | |
| length_normalized_loss, | |
| ) | |
| self.blank = blank | |
| self.transducer_weight = transducer_weight | |
| self.attention_decoder_weight = 1 - self.transducer_weight - self.ctc_weight | |
| self.predictor = predictor | |
| self.joint = joint | |
| self.bs = None | |
| # Note(Mddct): decoder also means predictor in transducer, | |
| # but here decoder is attention decoder | |
| del self.criterion_att | |
| if attention_decoder is not None: | |
| self.criterion_att = LabelSmoothingLoss( | |
| size=vocab_size, | |
| padding_idx=ignore_id, | |
| smoothing=lsm_weight, | |
| normalize_length=length_normalized_loss, | |
| ) | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| text: torch.Tensor, | |
| text_lengths: torch.Tensor, | |
| ) -> Dict[str, Optional[torch.Tensor]]: | |
| """Frontend + Encoder + predictor + joint + loss | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| text: (Batch, Length) | |
| text_lengths: (Batch,) | |
| """ | |
| assert text_lengths.dim() == 1, text_lengths.shape | |
| # Check that batch_size is unified | |
| assert ( | |
| speech.shape[0] | |
| == speech_lengths.shape[0] | |
| == text.shape[0] | |
| == text_lengths.shape[0] | |
| ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) | |
| # Encoder | |
| encoder_out, encoder_mask = self.encoder(speech, speech_lengths) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
| # predictor | |
| ys_in_pad = add_blank(text, self.blank, self.ignore_id) | |
| predictor_out = self.predictor(ys_in_pad) | |
| # joint | |
| joint_out = self.joint(encoder_out, predictor_out) | |
| # NOTE(Mddct): some loss implementation require pad valid is zero | |
| # torch.int32 rnnt_loss required | |
| rnnt_text = text.to(torch.int64) | |
| rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to( | |
| torch.int32 | |
| ) | |
| rnnt_text_lengths = text_lengths.to(torch.int32) | |
| encoder_out_lens = encoder_out_lens.to(torch.int32) | |
| loss = torchaudio.functional.rnnt_loss( | |
| joint_out, | |
| rnnt_text, | |
| encoder_out_lens, | |
| rnnt_text_lengths, | |
| blank=self.blank, | |
| reduction="mean", | |
| ) | |
| loss_rnnt = loss | |
| loss = self.transducer_weight * loss | |
| # optional attention decoder | |
| loss_att: Optional[torch.Tensor] = None | |
| if self.attention_decoder_weight != 0.0 and self.decoder is not None: | |
| loss_att, _ = self._calc_att_loss( | |
| encoder_out, encoder_mask, text, text_lengths | |
| ) | |
| # optional ctc | |
| loss_ctc: Optional[torch.Tensor] = None | |
| if self.ctc_weight != 0.0 and self.ctc is not None: | |
| loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) | |
| else: | |
| loss_ctc = None | |
| if loss_ctc is not None: | |
| loss = loss + self.ctc_weight * loss_ctc.sum() | |
| if loss_att is not None: | |
| loss = loss + self.attention_decoder_weight * loss_att.sum() | |
| # NOTE: 'loss' must be in dict | |
| return { | |
| "loss": loss, | |
| "loss_att": loss_att, | |
| "loss_ctc": loss_ctc, | |
| "loss_rnnt": loss_rnnt, | |
| } | |
| def init_bs(self): | |
| if self.bs is None: | |
| self.bs = PrefixBeamSearch( | |
| self.encoder, self.predictor, self.joint, self.ctc, self.blank | |
| ) | |
| def _cal_transducer_score( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| hyps_lens: torch.Tensor, | |
| hyps_pad: torch.Tensor, | |
| ): | |
| # ignore id -> blank, add blank at head | |
| hyps_pad_blank = add_blank(hyps_pad, self.blank, self.ignore_id) | |
| xs_in_lens = encoder_mask.squeeze(1).sum(1).int() | |
| # 1. Forward predictor | |
| predictor_out = self.predictor(hyps_pad_blank) | |
| # 2. Forward joint | |
| joint_out = self.joint(encoder_out, predictor_out) | |
| rnnt_text = hyps_pad.to(torch.int64) | |
| rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to( | |
| torch.int32 | |
| ) | |
| # 3. Compute transducer loss | |
| loss_td = torchaudio.functional.rnnt_loss( | |
| joint_out, | |
| rnnt_text, | |
| xs_in_lens, | |
| hyps_lens.int(), | |
| blank=self.blank, | |
| reduction="none", | |
| ) | |
| return loss_td * -1 | |
| def _cal_attn_score( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| hyps_pad: torch.Tensor, | |
| hyps_lens: torch.Tensor, | |
| ): | |
| # (beam_size, max_hyps_len) | |
| ori_hyps_pad = hyps_pad | |
| # td_score = loss_td * -1 | |
| hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) | |
| hyps_lens = hyps_lens + 1 # Add <sos> at begining | |
| # used for right to left decoder | |
| r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) | |
| r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id) | |
| decoder_out, r_decoder_out, _ = self.decoder( | |
| encoder_out, | |
| encoder_mask, | |
| hyps_pad, | |
| hyps_lens, | |
| r_hyps_pad, | |
| self.reverse_weight, | |
| ) # (beam_size, max_hyps_len, vocab_size) | |
| decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) | |
| decoder_out = decoder_out.cpu().numpy() | |
| # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a | |
| # conventional transformer decoder. | |
| r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) | |
| r_decoder_out = r_decoder_out.cpu().numpy() | |
| return decoder_out, r_decoder_out | |
| def beam_search( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| decoding_chunk_size: int = -1, | |
| beam_size: int = 5, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| ctc_weight: float = 0.3, | |
| transducer_weight: float = 0.7, | |
| ): | |
| """beam search | |
| Args: | |
| speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
| speech_length (torch.Tensor): (batch, ) | |
| beam_size (int): beam size for beam search | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| ctc_weight (float): ctc probability weight in transducer | |
| prefix beam search. | |
| final_prob = ctc_weight * ctc_prob + transducer_weight * transducer_prob | |
| transducer_weight (float): transducer probability weight in | |
| prefix beam search | |
| Returns: | |
| List[List[int]]: best path result | |
| """ | |
| self.init_bs() | |
| beam, _ = self.bs.prefix_beam_search( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size, | |
| beam_size, | |
| num_decoding_left_chunks, | |
| simulate_streaming, | |
| ctc_weight, | |
| transducer_weight, | |
| ) | |
| return beam[0].hyp[1:], beam[0].score | |
| def transducer_attention_rescoring( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| beam_size: int, | |
| decoding_chunk_size: int = -1, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| reverse_weight: float = 0.0, | |
| ctc_weight: float = 0.0, | |
| attn_weight: float = 0.0, | |
| transducer_weight: float = 0.0, | |
| search_ctc_weight: float = 1.0, | |
| search_transducer_weight: float = 0.0, | |
| beam_search_type: str = "transducer", | |
| ) -> List[List[int]]: | |
| """beam search | |
| Args: | |
| speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
| speech_length (torch.Tensor): (batch, ) | |
| beam_size (int): beam size for beam search | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| ctc_weight (float): ctc probability weight using in rescoring. | |
| rescore_prob = ctc_weight * ctc_prob + | |
| transducer_weight * (transducer_loss * -1) + | |
| attn_weight * attn_prob | |
| attn_weight (float): attn probability weight using in rescoring. | |
| transducer_weight (float): transducer probability weight using in | |
| rescoring | |
| search_ctc_weight (float): ctc weight using | |
| in rnnt beam search (seeing in self.beam_search) | |
| search_transducer_weight (float): transducer weight using | |
| in rnnt beam search (seeing in self.beam_search) | |
| Returns: | |
| List[List[int]]: best path result | |
| """ | |
| assert speech.shape[0] == speech_lengths.shape[0] | |
| assert decoding_chunk_size != 0 | |
| if reverse_weight > 0.0: | |
| # decoder should be a bitransformer decoder if reverse_weight > 0.0 | |
| assert hasattr(self.decoder, "right_decoder") | |
| device = speech.device | |
| batch_size = speech.shape[0] | |
| # For attention rescoring we only support batch_size=1 | |
| assert batch_size == 1 | |
| # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size | |
| self.init_bs() | |
| if beam_search_type == "transducer": | |
| beam, encoder_out = self.bs.prefix_beam_search( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size=decoding_chunk_size, | |
| beam_size=beam_size, | |
| num_decoding_left_chunks=num_decoding_left_chunks, | |
| ctc_weight=search_ctc_weight, | |
| transducer_weight=search_transducer_weight, | |
| ) | |
| beam_score = [s.score for s in beam] | |
| hyps = [s.hyp[1:] for s in beam] | |
| elif beam_search_type == "ctc": | |
| hyps, encoder_out = self._ctc_prefix_beam_search( | |
| speech, | |
| speech_lengths, | |
| beam_size=beam_size, | |
| decoding_chunk_size=decoding_chunk_size, | |
| num_decoding_left_chunks=num_decoding_left_chunks, | |
| simulate_streaming=simulate_streaming, | |
| ) | |
| beam_score = [hyp[1] for hyp in hyps] | |
| hyps = [hyp[0] for hyp in hyps] | |
| assert len(hyps) == beam_size | |
| # build hyps and encoder output | |
| hyps_pad = pad_sequence( | |
| [torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps], | |
| True, | |
| self.ignore_id, | |
| ) # (beam_size, max_hyps_len) | |
| hyps_lens = torch.tensor( | |
| [len(hyp) for hyp in hyps], device=device, dtype=torch.long | |
| ) # (beam_size,) | |
| encoder_out = encoder_out.repeat(beam_size, 1, 1) | |
| encoder_mask = torch.ones( | |
| beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device | |
| ) | |
| # 2.1 calculate transducer score | |
| td_score = self._cal_transducer_score( | |
| encoder_out, | |
| encoder_mask, | |
| hyps_lens, | |
| hyps_pad, | |
| ) | |
| # 2.2 calculate attention score | |
| decoder_out, r_decoder_out = self._cal_attn_score( | |
| encoder_out, | |
| encoder_mask, | |
| hyps_pad, | |
| hyps_lens, | |
| ) | |
| # Only use decoder score for rescoring | |
| best_score = -float("inf") | |
| best_index = 0 | |
| for i, hyp in enumerate(hyps): | |
| score = 0.0 | |
| for j, w in enumerate(hyp): | |
| score += decoder_out[i][j][w] | |
| score += decoder_out[i][len(hyp)][self.eos] | |
| td_s = td_score[i] | |
| # add right to left decoder score | |
| if reverse_weight > 0: | |
| r_score = 0.0 | |
| for j, w in enumerate(hyp): | |
| r_score += r_decoder_out[i][len(hyp) - j - 1][w] | |
| r_score += r_decoder_out[i][len(hyp)][self.eos] | |
| score = score * (1 - reverse_weight) + r_score * reverse_weight | |
| # add ctc score | |
| score = ( | |
| score * attn_weight | |
| + beam_score[i] * ctc_weight | |
| + td_s * transducer_weight | |
| ) | |
| if score > best_score: | |
| best_score = score | |
| best_index = i | |
| return hyps[best_index], best_score | |
| def greedy_search( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| decoding_chunk_size: int = -1, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| n_steps: int = 64, | |
| ) -> List[List[int]]: | |
| """greedy search | |
| Args: | |
| speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
| speech_length (torch.Tensor): (batch, ) | |
| beam_size (int): beam size for beam search | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| Returns: | |
| List[List[int]]: best path result | |
| """ | |
| # TODO(Mddct): batch decode | |
| assert speech.size(0) == 1 | |
| assert speech.shape[0] == speech_lengths.shape[0] | |
| assert decoding_chunk_size != 0 | |
| # TODO(Mddct): forward chunk by chunk | |
| _ = simulate_streaming | |
| # Let's assume B = batch_size | |
| encoder_out, encoder_mask = self.encoder( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size, | |
| num_decoding_left_chunks, | |
| ) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum() | |
| hyps = basic_greedy_search(self, encoder_out, encoder_out_lens, n_steps=n_steps) | |
| return hyps | |
| def forward_encoder_chunk( | |
| self, | |
| xs: torch.Tensor, | |
| offset: int, | |
| required_cache_size: int, | |
| att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
| cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| return self.encoder.forward_chunk( | |
| xs, offset, required_cache_size, att_cache, cnn_cache | |
| ) | |
| def forward_predictor_step( | |
| self, xs: torch.Tensor, cache: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| assert len(cache) == 2 | |
| # fake padding | |
| padding = torch.zeros(1, 1) | |
| return self.predictor.forward_step(xs, padding, cache) | |
| def forward_joint_step( | |
| self, enc_out: torch.Tensor, pred_out: torch.Tensor | |
| ) -> torch.Tensor: | |
| return self.joint(enc_out, pred_out) | |
| def forward_predictor_init_state(self) -> List[torch.Tensor]: | |
| return self.predictor.init_state(1, device=torch.device("cpu")) | |