# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # from collections import defaultdict from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from modules.wenet_extractor.transformer.ctc import CTC from modules.wenet_extractor.transformer.decoder import TransformerDecoder from modules.wenet_extractor.transformer.encoder import TransformerEncoder from modules.wenet_extractor.transformer.label_smoothing_loss import LabelSmoothingLoss from modules.wenet_extractor.utils.common import ( IGNORE_ID, add_sos_eos, log_add, remove_duplicates_and_blank, th_accuracy, reverse_pad_list, ) from modules.wenet_extractor.utils.mask import ( make_pad_mask, mask_finished_preds, mask_finished_scores, subsequent_mask, ) class ASRModel(torch.nn.Module): """CTC-attention hybrid Encoder-Decoder model""" def __init__( self, vocab_size: int, encoder: TransformerEncoder, decoder: TransformerDecoder, ctc: CTC, ctc_weight: float = 0.5, ignore_id: int = IGNORE_ID, reverse_weight: float = 0.0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, lfmmi_dir: str = "", ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.reverse_weight = reverse_weight self.encoder = encoder self.decoder = decoder self.ctc = ctc self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) self.lfmmi_dir = lfmmi_dir if self.lfmmi_dir != "": self.load_lfmmi_resource() def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Dict[str, Optional[torch.Tensor]]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) # 1. Encoder encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out_lens = encoder_mask.squeeze(1).sum(1) # 2a. Attention-decoder branch if self.ctc_weight != 1.0: loss_att, acc_att = self._calc_att_loss( encoder_out, encoder_mask, text, text_lengths ) else: loss_att = None # 2b. CTC branch or LF-MMI loss if self.ctc_weight != 0.0: if self.lfmmi_dir != "": loss_ctc = self._calc_lfmmi_loss(encoder_out, encoder_mask, text) else: loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) else: loss_ctc = None if loss_ctc is None: loss = loss_att elif loss_att is None: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc} def _calc_att_loss( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ) -> Tuple[torch.Tensor, float]: ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # reverse the seq, used for right to left decoder r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id)) r_ys_in_pad, r_ys_out_pad = add_sos_eos( r_ys_pad, self.sos, self.eos, self.ignore_id ) # 1. Forward decoder decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, self.reverse_weight, ) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) r_loss_att = torch.tensor(0.0) if self.reverse_weight > 0.0: r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) loss_att = ( loss_att * (1 - self.reverse_weight) + r_loss_att * self.reverse_weight ) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) return loss_att, acc_att def _forward_encoder( self, speech: torch.Tensor, speech_lengths: torch.Tensor, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: # Let's assume B = batch_size # 1. Encoder if simulate_streaming and decoding_chunk_size > 0: encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( speech, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, ) # (B, maxlen, encoder_dim) else: encoder_out, encoder_mask = self.encoder( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, ) # (B, maxlen, encoder_dim) return encoder_out, encoder_mask def encoder_extractor( self, speech: torch.Tensor, speech_lengths: torch.Tensor, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: # assert speech.shape[0] == speech_lengths[0] assert decoding_chunk_size != 0 batch_size = speech.shape[0] encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) # (B, maxlen, encoder_dim) return encoder_out def recognize( self, speech: torch.Tensor, speech_lengths: torch.Tensor, beam_size: int = 10, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, ) -> torch.Tensor: """Apply beam search on attention decoder Args: speech (torch.Tensor): (batch, max_len, feat_dim) speech_length (torch.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion Returns: torch.Tensor: decoding result, (batch, max_result_len) """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 device = speech.device batch_size = speech.shape[0] # Let's assume B = batch_size and N = beam_size # 1. Encoder encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) # (B, maxlen, encoder_dim) maxlen = encoder_out.size(1) encoder_dim = encoder_out.size(2) running_size = batch_size * beam_size encoder_out = ( encoder_out.unsqueeze(1) .repeat(1, beam_size, 1, 1) .view(running_size, maxlen, encoder_dim) ) # (B*N, maxlen, encoder_dim) encoder_mask = ( encoder_mask.unsqueeze(1) .repeat(1, beam_size, 1, 1) .view(running_size, 1, maxlen) ) # (B*N, 1, max_len) hyps = torch.ones([running_size, 1], dtype=torch.long, device=device).fill_( self.sos ) # (B*N, 1) scores = torch.tensor( [0.0] + [-float("inf")] * (beam_size - 1), dtype=torch.float ) scores = ( scores.to(device).repeat([batch_size]).unsqueeze(1).to(device) ) # (B*N, 1) end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) cache: Optional[List[torch.Tensor]] = None # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos if end_flag.sum() == running_size: break # 2.1 Forward decoder step hyps_mask = ( subsequent_mask(i).unsqueeze(0).repeat(running_size, 1, 1).to(device) ) # (B*N, i, i) # logp: (B*N, vocab) logp, cache = self.decoder.forward_one_step( encoder_out, encoder_mask, hyps, hyps_mask, cache ) # 2.2 First beam prune: select topk best prob at current time top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) # 2.3 Second beam prune: select topk score with history scores = scores + top_k_logp # (B*N, N), broadcast add scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) scores, offset_k_index = scores.topk(k=beam_size) # (B, N) # Update cache to be consistent with new topk scores / hyps cache_index = (offset_k_index // beam_size).view(-1) # (B*N) base_cache_index = ( torch.arange(batch_size, device=device) .view(-1, 1) .repeat([1, beam_size]) * beam_size ).view( -1 ) # (B*N) cache_index = base_cache_index + cache_index cache = [torch.index_select(c, dim=0, index=cache_index) for c in cache] scores = scores.view(-1, 1) # (B*N, 1) # 2.4. Compute base index in top_k_index, # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), # then find offset_k_index in top_k_index base_k_index = ( torch.arange(batch_size, device=device) .view(-1, 1) .repeat([1, beam_size]) ) # (B, N) base_k_index = base_k_index * beam_size * beam_size best_k_index = base_k_index.view(-1) + offset_k_index.view(-1) # (B*N) # 2.5 Update best hyps best_k_pred = torch.index_select( top_k_index.view(-1), dim=-1, index=best_k_index ) # (B*N) best_hyps_index = best_k_index // beam_size last_best_k_hyps = torch.index_select( hyps, dim=0, index=best_hyps_index ) # (B*N, i) hyps = torch.cat( (last_best_k_hyps, best_k_pred.view(-1, 1)), dim=1 ) # (B*N, i+1) # 2.6 Update end flag end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1) # 3. Select best of best scores = scores.view(batch_size, beam_size) # TODO: length normalization best_scores, best_index = scores.max(dim=-1) best_hyps_index = ( best_index + torch.arange(batch_size, dtype=torch.long, device=device) * beam_size ) best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) best_hyps = best_hyps[:, 1:] return best_hyps, best_scores def ctc_greedy_search( self, speech: torch.Tensor, speech_lengths: torch.Tensor, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, ) -> List[List[int]]: """Apply CTC greedy search Args: speech (torch.Tensor): (batch, max_len, feat_dim) speech_length (torch.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion Returns: List[List[int]]: best path result """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 batch_size = speech.shape[0] # Let's assume B = batch_size encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) # (B, maxlen, encoder_dim) maxlen = encoder_out.size(1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen) topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen) hyps = [hyp.tolist() for hyp in topk_index] scores = topk_prob.max(1) hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] return hyps, scores def _ctc_prefix_beam_search( self, speech: torch.Tensor, speech_lengths: torch.Tensor, beam_size: int, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, ) -> Tuple[List[List[int]], torch.Tensor]: """CTC prefix beam search inner implementation Args: speech (torch.Tensor): (batch, max_len, feat_dim) speech_length (torch.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion Returns: List[List[int]]: nbest results torch.Tensor: encoder output, (1, max_len, encoder_dim), it will be used for rescoring in attention rescoring mode """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 batch_size = speech.shape[0] # For CTC prefix beam search, we only support batch_size=1 assert batch_size == 1 # Let's assume B = batch_size and N = beam_size # 1. Encoder forward and get CTC score encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) # (B, maxlen, encoder_dim) maxlen = encoder_out.size(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) cur_hyps = [(tuple(), (0.0, -float("inf")))] # 2. CTC beam search step by step for t in range(0, maxlen): logp = ctc_probs[t] # (vocab_size,) # key: prefix, value (pb, pnb), default value(-inf, -inf) next_hyps = defaultdict(lambda: (-float("inf"), -float("inf"))) # 2.1 First beam prune: select topk best top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) for s in top_k_index: s = s.item() ps = logp[s].item() for prefix, (pb, pnb) in cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == 0: # blank n_pb, n_pnb = next_hyps[prefix] n_pb = log_add([n_pb, pb + ps, pnb + ps]) next_hyps[prefix] = (n_pb, n_pnb) elif s == last: # Update *ss -> *s; n_pb, n_pnb = next_hyps[prefix] n_pnb = log_add([n_pnb, pnb + ps]) next_hyps[prefix] = (n_pb, n_pnb) # Update *s-s -> *ss, - is for blank n_prefix = prefix + (s,) n_pb, n_pnb = next_hyps[n_prefix] n_pnb = log_add([n_pnb, pb + ps]) next_hyps[n_prefix] = (n_pb, n_pnb) else: n_prefix = prefix + (s,) n_pb, n_pnb = next_hyps[n_prefix] n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) next_hyps[n_prefix] = (n_pb, n_pnb) # 2.2 Second beam prune next_hyps = sorted( next_hyps.items(), key=lambda x: log_add(list(x[1])), reverse=True ) cur_hyps = next_hyps[:beam_size] hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] return hyps, encoder_out def ctc_prefix_beam_search( self, speech: torch.Tensor, speech_lengths: torch.Tensor, beam_size: int, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, ) -> List[int]: """Apply CTC prefix beam search Args: speech (torch.Tensor): (batch, max_len, feat_dim) speech_length (torch.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion Returns: List[int]: CTC prefix beam search nbest results """ hyps, _ = self._ctc_prefix_beam_search( speech, speech_lengths, beam_size, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) return hyps[0] def attention_rescoring( self, speech: torch.Tensor, speech_lengths: torch.Tensor, beam_size: int, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, ctc_weight: float = 0.0, simulate_streaming: bool = False, reverse_weight: float = 0.0, ) -> List[int]: """Apply attention rescoring decoding, CTC prefix beam search is applied first to get nbest, then we resoring the nbest on attention decoder with corresponding encoder out Args: speech (torch.Tensor): (batch, max_len, feat_dim) speech_length (torch.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion reverse_weight (float): right to left decoder weight ctc_weight (float): ctc score weight Returns: List[int]: Attention rescoring result """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 if reverse_weight > 0.0: # decoder should be a bitransformer decoder if reverse_weight > 0.0 assert hasattr(self.decoder, "right_decoder") device = speech.device batch_size = speech.shape[0] # For attention rescoring we only support batch_size=1 assert batch_size == 1 # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size hyps, encoder_out = self._ctc_prefix_beam_search( speech, speech_lengths, beam_size, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) assert len(hyps) == beam_size hyps_pad = pad_sequence( [torch.tensor(hyp[0], device=device, dtype=torch.long) for hyp in hyps], True, self.ignore_id, ) # (beam_size, max_hyps_len) ori_hyps_pad = hyps_pad hyps_lens = torch.tensor( [len(hyp[0]) for hyp in hyps], device=device, dtype=torch.long ) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = torch.ones( beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device ) # used for right to left decoder r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id) decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, reverse_weight ) # (beam_size, max_hyps_len, vocab_size) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) decoder_out = decoder_out.cpu().numpy() # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # conventional transformer decoder. r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) r_decoder_out = r_decoder_out.cpu().numpy() # Only use decoder score for rescoring best_score = -float("inf") best_index = 0 for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] score += decoder_out[i][len(hyp[0])][self.eos] # add right to left decoder score if reverse_weight > 0: r_score = 0.0 for j, w in enumerate(hyp[0]): r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] r_score += r_decoder_out[i][len(hyp[0])][self.eos] score = score * (1 - reverse_weight) + r_score * reverse_weight # add ctc score score += hyp[1] * ctc_weight if score > best_score: best_score = score best_index = i return hyps[best_index][0], best_score @torch.jit.unused def load_lfmmi_resource(self): with open("{}/tokens.txt".format(self.lfmmi_dir), "r") as fin: for line in fin: arr = line.strip().split() if arr[0] == "": self.sos_eos_id = int(arr[1]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.graph_compiler = MmiTrainingGraphCompiler( self.lfmmi_dir, device=device, oov="", sos_id=self.sos_eos_id, eos_id=self.sos_eos_id, ) self.lfmmi = LFMMILoss( graph_compiler=self.graph_compiler, den_scale=1, use_pruned_intersect=False, ) self.word_table = {} with open("{}/words.txt".format(self.lfmmi_dir), "r") as fin: for line in fin: arr = line.strip().split() assert len(arr) == 2 self.word_table[int(arr[1])] = arr[0] @torch.jit.unused def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text): ctc_probs = self.ctc.log_softmax(encoder_out) supervision_segments = torch.stack( ( torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), encoder_mask.squeeze(dim=1).sum(dim=1).to("cpu"), ), 1, ).to(torch.int32) dense_fsa_vec = k2.DenseFsaVec( ctc_probs, supervision_segments, allow_truncate=3, ) text = [ " ".join([self.word_table[j.item()] for j in i if j != -1]) for i in text ] loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text) return loss def load_hlg_resource_if_necessary(self, hlg, word): if not hasattr(self, "hlg"): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device)) if not hasattr(self.hlg, "lm_scores"): self.hlg.lm_scores = self.hlg.scores.clone() if not hasattr(self, "word_table"): self.word_table = {} with open(word, "r") as fin: for line in fin: arr = line.strip().split() assert len(arr) == 2 self.word_table[int(arr[1])] = arr[0] @torch.no_grad() def hlg_onebest( self, speech: torch.Tensor, speech_lengths: torch.Tensor, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, hlg: str = "", word: str = "", symbol_table: Dict[str, int] = None, ) -> List[int]: self.load_hlg_resource_if_necessary(hlg, word) encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) # (B, maxlen, encoder_dim) ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) supervision_segments = torch.stack( ( torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), encoder_mask.squeeze(dim=1).sum(dim=1).cpu(), ), 1, ).to(torch.int32) lattice = get_lattice( nnet_output=ctc_probs, decoding_graph=self.hlg, supervision_segments=supervision_segments, search_beam=20, output_beam=7, min_active_states=30, max_active_states=10000, subsampling_factor=4, ) best_path = one_best_decoding(lattice=lattice, use_double_scores=True) hyps = get_texts(best_path) hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps] return hyps @torch.no_grad() def hlg_rescore( self, speech: torch.Tensor, speech_lengths: torch.Tensor, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, lm_scale: float = 0, decoder_scale: float = 0, r_decoder_scale: float = 0, hlg: str = "", word: str = "", symbol_table: Dict[str, int] = None, ) -> List[int]: self.load_hlg_resource_if_necessary(hlg, word) device = speech.device encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming, ) # (B, maxlen, encoder_dim) ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) supervision_segments = torch.stack( ( torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), encoder_mask.squeeze(dim=1).sum(dim=1).cpu(), ), 1, ).to(torch.int32) lattice = get_lattice( nnet_output=ctc_probs, decoding_graph=self.hlg, supervision_segments=supervision_segments, search_beam=20, output_beam=7, min_active_states=30, max_active_states=10000, subsampling_factor=4, ) nbest = Nbest.from_lattice( lattice=lattice, num_paths=100, use_double_scores=True, nbest_scale=0.5, ) nbest = nbest.intersect(lattice) assert hasattr(nbest.fsa, "lm_scores") assert hasattr(nbest.fsa, "tokens") assert isinstance(nbest.fsa.tokens, torch.Tensor) tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) tokens = tokens.remove_values_leq(0) hyps = tokens.tolist() # cal attention_score hyps_pad = pad_sequence( [torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps], True, self.ignore_id, ) # (beam_size, max_hyps_len) ori_hyps_pad = hyps_pad hyps_lens = torch.tensor( [len(hyp) for hyp in hyps], device=device, dtype=torch.long ) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining encoder_out_repeat = [] tot_scores = nbest.tot_scores() repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)] for i in range(len(encoder_out)): encoder_out_repeat.append(encoder_out[i : i + 1].repeat(repeats[i], 1, 1)) encoder_out = torch.concat(encoder_out_repeat, dim=0) encoder_mask = torch.ones( encoder_out.size(0), 1, encoder_out.size(1), dtype=torch.bool, device=device ) # used for right to left decoder r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id) reverse_weight = 0.5 decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, reverse_weight ) # (beam_size, max_hyps_len, vocab_size) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) decoder_out = decoder_out # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # conventional transformer decoder. r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) r_decoder_out = r_decoder_out decoder_scores = torch.tensor( [ sum([decoder_out[i, j, hyps[i][j]] for j in range(len(hyps[i]))]) for i in range(len(hyps)) ], device=device, ) r_decoder_scores = [] for i in range(len(hyps)): score = 0 for j in range(len(hyps[i])): score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]] score += r_decoder_out[i, len(hyps[i]), self.eos] r_decoder_scores.append(score) r_decoder_scores = torch.tensor(r_decoder_scores, device=device) am_scores = nbest.compute_am_scores() ngram_lm_scores = nbest.compute_lm_scores() tot_scores = ( am_scores.values + lm_scale * ngram_lm_scores.values + decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) hyps = get_texts(best_path) hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps] return hyps @torch.jit.export def subsampling_rate(self) -> int: """Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate @torch.jit.export def right_context(self) -> int: """Export interface for c++ call, return right_context of the model""" return self.encoder.embed.right_context @torch.jit.export def sos_symbol(self) -> int: """Export interface for c++ call, return sos symbol id of the model""" return self.sos @torch.jit.export def eos_symbol(self) -> int: """Export interface for c++ call, return eos symbol id of the model""" return self.eos @torch.jit.export def forward_encoder_chunk( self, xs: torch.Tensor, offset: int, required_cache_size: int, att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. Args: xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), where `time == (chunk_size - 1) * subsample_rate + \ subsample.right_context + 1` offset (int): current offset in encoder output time stamp required_cache_size (int): cache size required for next chunk compuation >=0: actual cache size <0: means all history cache is required att_cache (torch.Tensor): cache tensor for KEY & VALUE in transformer/conformer attention, with shape (elayers, head, cache_t1, d_k * 2), where `head * d_k == hidden-dim` and `cache_t1 == chunk_size * num_decoding_left_chunks`. cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, (elayers, b=1, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1` Returns: torch.Tensor: output of current input xs, with shape (b=1, chunk_size, hidden-dim). torch.Tensor: new attention cache required for next chunk, with dynamic shape (elayers, head, ?, d_k * 2) depending on required_cache_size. torch.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. """ return self.encoder.forward_chunk( xs, offset, required_cache_size, att_cache, cnn_cache ) @torch.jit.export def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor: """Export interface for c++ call, apply linear transform and log softmax before ctc Args: xs (torch.Tensor): encoder output Returns: torch.Tensor: activation before ctc """ return self.ctc.log_softmax(xs) @torch.jit.export def is_bidirectional_decoder(self) -> bool: """ Returns: torch.Tensor: decoder output """ if hasattr(self.decoder, "right_decoder"): return True else: return False @torch.jit.export def forward_attention_decoder( self, hyps: torch.Tensor, hyps_lens: torch.Tensor, encoder_out: torch.Tensor, reverse_weight: float = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: hyps (torch.Tensor): hyps from ctc prefix beam search, already pad sos at the begining hyps_lens (torch.Tensor): length of each hyp in hyps encoder_out (torch.Tensor): corresponding encoder output r_hyps (torch.Tensor): hyps from ctc prefix beam search, already pad eos at the begining which is used fo right to left decoder reverse_weight: used for verfing whether used right to left decoder, > 0 will use. Returns: torch.Tensor: decoder output """ assert encoder_out.size(0) == 1 num_hyps = hyps.size(0) assert hyps_lens.size(0) == num_hyps encoder_out = encoder_out.repeat(num_hyps, 1, 1) encoder_mask = torch.ones( num_hyps, 1, encoder_out.size(1), dtype=torch.bool, device=encoder_out.device, ) # input for right to left decoder # this hyps_lens has count token, we need minus it. r_hyps_lens = hyps_lens - 1 # this hyps has included token, so it should be # convert the original hyps. r_hyps = hyps[:, 1:] # >>> r_hyps # >>> tensor([[ 1, 2, 3], # >>> [ 9, 8, 4], # >>> [ 2, -1, -1]]) # >>> r_hyps_lens # >>> tensor([3, 3, 1]) # NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used # in `reverse_pad_list` thus we have to refine the below code. # Issue: https://github.com/wenet-e2e/wenet/issues/1113 # Equal to: # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) max_len = torch.max(r_hyps_lens) index_range = torch.arange(0, max_len, 1).to(encoder_out.device) seq_len_expand = r_hyps_lens.unsqueeze(1) seq_mask = seq_len_expand > index_range # (beam, max_len) # >>> seq_mask # >>> tensor([[ True, True, True], # >>> [ True, True, True], # >>> [ True, False, False]]) index = (seq_len_expand - 1) - index_range # (beam, max_len) # >>> index # >>> tensor([[ 2, 1, 0], # >>> [ 2, 1, 0], # >>> [ 0, -1, -2]]) index = index * seq_mask # >>> index # >>> tensor([[2, 1, 0], # >>> [2, 1, 0], # >>> [0, 0, 0]]) r_hyps = torch.gather(r_hyps, 1, index) # >>> r_hyps # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, 2, 2]]) r_hyps = torch.where(seq_mask, r_hyps, self.eos) # >>> r_hyps # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, eos, eos]]) r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) # >>> r_hyps # >>> tensor([[sos, 3, 2, 1], # >>> [sos, 4, 8, 9], # >>> [sos, 2, eos, eos]]) decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight ) # (num_hyps, max_hyps_len, vocab_size) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) # right to left decoder may be not used during decoding process, # which depends on reverse_weight param. # r_dccoder_out will be 0.0, if reverse_weight is 0.0 r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) return decoder_out, r_decoder_out