Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- encoding: utf-8 -*- | |
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import torch | |
| import numpy as np | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from funasr_detach.models.transducer.joint_network import JointNetwork | |
| class Hypothesis: | |
| """Default hypothesis definition for Transducer search algorithms. | |
| Args: | |
| score: Total log-probability. | |
| yseq: Label sequence as integer ID sequence. | |
| dec_state: RNNDecoder or StatelessDecoder state. | |
| ((N, 1, D_dec), (N, 1, D_dec) or None) or None | |
| lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None | |
| """ | |
| score: float | |
| yseq: List[int] | |
| dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None | |
| lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None | |
| class ExtendedHypothesis(Hypothesis): | |
| """Extended hypothesis definition for NSC beam search and mAES. | |
| Args: | |
| : Hypothesis dataclass arguments. | |
| dec_out: Decoder output sequence. (B, D_dec) | |
| lm_score: Log-probabilities of the LM for given label. (vocab_size) | |
| """ | |
| dec_out: torch.Tensor = None | |
| lm_score: torch.Tensor = None | |
| class BeamSearchTransducer: | |
| """Beam search implementation for Transducer. | |
| Args: | |
| decoder: Decoder module. | |
| joint_network: Joint network module. | |
| beam_size: Size of the beam. | |
| lm: LM class. | |
| lm_weight: LM weight for soft fusion. | |
| search_type: Search algorithm to use during inference. | |
| max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) | |
| u_max: Maximum expected target sequence length. (ALSD) | |
| nstep: Number of maximum expansion steps at each time step. (mAES) | |
| expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) | |
| expansion_beta: | |
| Number of additional candidates for expanded hypotheses selection. (mAES) | |
| score_norm: Normalize final scores by length. | |
| nbest: Number of final hypothesis. | |
| streaming: Whether to perform chunk-by-chunk beam search. | |
| """ | |
| def __init__( | |
| self, | |
| decoder, | |
| joint_network: JointNetwork, | |
| beam_size: int, | |
| lm: Optional[torch.nn.Module] = None, | |
| lm_weight: float = 0.1, | |
| search_type: str = "default", | |
| max_sym_exp: int = 3, | |
| u_max: int = 50, | |
| nstep: int = 2, | |
| expansion_gamma: float = 2.3, | |
| expansion_beta: int = 2, | |
| score_norm: bool = False, | |
| nbest: int = 1, | |
| streaming: bool = False, | |
| ) -> None: | |
| """Construct a BeamSearchTransducer object.""" | |
| super().__init__() | |
| self.decoder = decoder | |
| self.joint_network = joint_network | |
| self.vocab_size = decoder.vocab_size | |
| assert beam_size <= self.vocab_size, ( | |
| "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." | |
| % ( | |
| beam_size, | |
| self.vocab_size, | |
| ) | |
| ) | |
| self.beam_size = beam_size | |
| if search_type == "default": | |
| self.search_algorithm = self.default_beam_search | |
| elif search_type == "tsd": | |
| assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % ( | |
| max_sym_exp | |
| ) | |
| self.max_sym_exp = max_sym_exp | |
| self.search_algorithm = self.time_sync_decoding | |
| elif search_type == "alsd": | |
| assert not streaming, "ALSD is not available in streaming mode." | |
| assert u_max >= 0, "u_max should be a positive integer, a portion of max_T." | |
| self.u_max = u_max | |
| self.search_algorithm = self.align_length_sync_decoding | |
| elif search_type == "maes": | |
| assert self.vocab_size >= beam_size + expansion_beta, ( | |
| "beam_size (%d) + expansion_beta (%d) " | |
| " should be smaller than or equal to vocab size (%d)." | |
| % (beam_size, expansion_beta, self.vocab_size) | |
| ) | |
| self.max_candidates = beam_size + expansion_beta | |
| self.nstep = nstep | |
| self.expansion_gamma = expansion_gamma | |
| self.search_algorithm = self.modified_adaptive_expansion_search | |
| else: | |
| raise NotImplementedError( | |
| "Specified search type (%s) is not supported." % search_type | |
| ) | |
| self.use_lm = lm is not None | |
| if self.use_lm: | |
| assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported." | |
| self.sos = self.vocab_size - 1 | |
| self.lm = lm | |
| self.lm_weight = lm_weight | |
| self.score_norm = score_norm | |
| self.nbest = nbest | |
| self.reset_inference_cache() | |
| def __call__( | |
| self, | |
| enc_out: torch.Tensor, | |
| is_final: bool = True, | |
| ) -> List[Hypothesis]: | |
| """Perform beam search. | |
| Args: | |
| enc_out: Encoder output sequence. (T, D_enc) | |
| is_final: Whether enc_out is the final chunk of data. | |
| Returns: | |
| nbest_hyps: N-best decoding results | |
| """ | |
| self.decoder.set_device(enc_out.device) | |
| hyps = self.search_algorithm(enc_out) | |
| if is_final: | |
| self.reset_inference_cache() | |
| return self.sort_nbest(hyps) | |
| self.search_cache = hyps | |
| return hyps | |
| def reset_inference_cache(self) -> None: | |
| """Reset cache for decoder scoring and streaming.""" | |
| self.decoder.score_cache = {} | |
| self.search_cache = None | |
| def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: | |
| """Sort in-place hypotheses by score or score given sequence length. | |
| Args: | |
| hyps: Hypothesis. | |
| Return: | |
| hyps: Sorted hypothesis. | |
| """ | |
| if self.score_norm: | |
| hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) | |
| else: | |
| hyps.sort(key=lambda x: x.score, reverse=True) | |
| return hyps[: self.nbest] | |
| def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]: | |
| """Recombine hypotheses with same label ID sequence. | |
| Args: | |
| hyps: Hypotheses. | |
| Returns: | |
| final: Recombined hypotheses. | |
| """ | |
| final = {} | |
| for hyp in hyps: | |
| str_yseq = "_".join(map(str, hyp.yseq)) | |
| if str_yseq in final: | |
| final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score) | |
| else: | |
| final[str_yseq] = hyp | |
| return [*final.values()] | |
| def select_k_expansions( | |
| self, | |
| hyps: List[ExtendedHypothesis], | |
| topk_idx: torch.Tensor, | |
| topk_logp: torch.Tensor, | |
| ) -> List[ExtendedHypothesis]: | |
| """Return K hypotheses candidates for expansion from a list of hypothesis. | |
| K candidates are selected according to the extended hypotheses probabilities | |
| and a prune-by-value method. Where K is equal to beam_size + beta. | |
| Args: | |
| hyps: Hypotheses. | |
| topk_idx: Indices of candidates hypothesis. | |
| topk_logp: Log-probabilities of candidates hypothesis. | |
| Returns: | |
| k_expansions: Best K expansion hypotheses candidates. | |
| """ | |
| k_expansions = [] | |
| for i, hyp in enumerate(hyps): | |
| hyp_i = [ | |
| (int(k), hyp.score + float(v)) | |
| for k, v in zip(topk_idx[i], topk_logp[i]) | |
| ] | |
| k_best_exp = max(hyp_i, key=lambda x: x[1])[1] | |
| k_expansions.append( | |
| sorted( | |
| filter( | |
| lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i | |
| ), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| ) | |
| return k_expansions | |
| def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor: | |
| """Make batch of inputs with left padding for LM scoring. | |
| Args: | |
| hyps_seq: Hypothesis sequences. | |
| Returns: | |
| : Padded batch of sequences. | |
| """ | |
| max_len = max([len(h) for h in hyps_seq]) | |
| return torch.LongTensor( | |
| [[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq], | |
| device=self.decoder.device, | |
| ) | |
| def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]: | |
| """Beam search implementation without prefix search. | |
| Modified from https://arxiv.org/pdf/1211.3711.pdf | |
| Args: | |
| enc_out: Encoder output sequence. (T, D) | |
| Returns: | |
| nbest_hyps: N-best hypothesis. | |
| """ | |
| beam_k = min(self.beam_size, (self.vocab_size - 1)) | |
| max_t = len(enc_out) | |
| if self.search_cache is not None: | |
| kept_hyps = self.search_cache | |
| else: | |
| kept_hyps = [ | |
| Hypothesis( | |
| score=0.0, | |
| yseq=[0], | |
| dec_state=self.decoder.init_state(1), | |
| ) | |
| ] | |
| for t in range(max_t): | |
| hyps = kept_hyps | |
| kept_hyps = [] | |
| while True: | |
| max_hyp = max(hyps, key=lambda x: x.score) | |
| hyps.remove(max_hyp) | |
| label = torch.full( | |
| (1, 1), | |
| max_hyp.yseq[-1], | |
| dtype=torch.long, | |
| device=self.decoder.device, | |
| ) | |
| dec_out, state = self.decoder.score( | |
| label, | |
| max_hyp.yseq, | |
| max_hyp.dec_state, | |
| ) | |
| logp = torch.log_softmax( | |
| self.joint_network(enc_out[t : t + 1, :], dec_out), | |
| dim=-1, | |
| ).squeeze(0) | |
| top_k = logp[1:].topk(beam_k, dim=-1) | |
| kept_hyps.append( | |
| Hypothesis( | |
| score=(max_hyp.score + float(logp[0:1])), | |
| yseq=max_hyp.yseq, | |
| dec_state=max_hyp.dec_state, | |
| lm_state=max_hyp.lm_state, | |
| ) | |
| ) | |
| if self.use_lm: | |
| lm_scores, lm_state = self.lm.score( | |
| torch.LongTensor( | |
| [self.sos] + max_hyp.yseq[1:], device=self.decoder.device | |
| ), | |
| max_hyp.lm_state, | |
| None, | |
| ) | |
| else: | |
| lm_state = max_hyp.lm_state | |
| for logp, k in zip(*top_k): | |
| score = max_hyp.score + float(logp) | |
| if self.use_lm: | |
| score += self.lm_weight * lm_scores[k + 1] | |
| hyps.append( | |
| Hypothesis( | |
| score=score, | |
| yseq=max_hyp.yseq + [int(k + 1)], | |
| dec_state=state, | |
| lm_state=lm_state, | |
| ) | |
| ) | |
| hyps_max = float(max(hyps, key=lambda x: x.score).score) | |
| kept_most_prob = sorted( | |
| [hyp for hyp in kept_hyps if hyp.score > hyps_max], | |
| key=lambda x: x.score, | |
| ) | |
| if len(kept_most_prob) >= self.beam_size: | |
| kept_hyps = kept_most_prob | |
| break | |
| return kept_hyps | |
| def align_length_sync_decoding( | |
| self, | |
| enc_out: torch.Tensor, | |
| ) -> List[Hypothesis]: | |
| """Alignment-length synchronous beam search implementation. | |
| Based on https://ieeexplore.ieee.org/document/9053040 | |
| Args: | |
| h: Encoder output sequences. (T, D) | |
| Returns: | |
| nbest_hyps: N-best hypothesis. | |
| """ | |
| t_max = int(enc_out.size(0)) | |
| u_max = min(self.u_max, (t_max - 1)) | |
| B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))] | |
| final = [] | |
| if self.use_lm: | |
| B[0].lm_state = self.lm.zero_state() | |
| for i in range(t_max + u_max): | |
| A = [] | |
| B_ = [] | |
| B_enc_out = [] | |
| for hyp in B: | |
| u = len(hyp.yseq) - 1 | |
| t = i - u | |
| if t > (t_max - 1): | |
| continue | |
| B_.append(hyp) | |
| B_enc_out.append((t, enc_out[t])) | |
| if B_: | |
| beam_enc_out = torch.stack([b[1] for b in B_enc_out]) | |
| beam_dec_out, beam_state = self.decoder.batch_score(B_) | |
| beam_logp = torch.log_softmax( | |
| self.joint_network(beam_enc_out, beam_dec_out), | |
| dim=-1, | |
| ) | |
| beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) | |
| if self.use_lm: | |
| beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
| self.create_lm_batch_inputs([b.yseq for b in B_]), | |
| [b.lm_state for b in B_], | |
| None, | |
| ) | |
| for i, hyp in enumerate(B_): | |
| new_hyp = Hypothesis( | |
| score=(hyp.score + float(beam_logp[i, 0])), | |
| yseq=hyp.yseq[:], | |
| dec_state=hyp.dec_state, | |
| lm_state=hyp.lm_state, | |
| ) | |
| A.append(new_hyp) | |
| if B_enc_out[i][0] == (t_max - 1): | |
| final.append(new_hyp) | |
| for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): | |
| new_hyp = Hypothesis( | |
| score=(hyp.score + float(logp)), | |
| yseq=(hyp.yseq[:] + [int(k)]), | |
| dec_state=self.decoder.select_state(beam_state, i), | |
| lm_state=hyp.lm_state, | |
| ) | |
| if self.use_lm: | |
| new_hyp.score += self.lm_weight * beam_lm_scores[i, k] | |
| new_hyp.lm_state = beam_lm_states[i] | |
| A.append(new_hyp) | |
| B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] | |
| B = self.recombine_hyps(B) | |
| if final: | |
| return final | |
| return B | |
| def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: | |
| """Time synchronous beam search implementation. | |
| Based on https://ieeexplore.ieee.org/document/9053040 | |
| Args: | |
| enc_out: Encoder output sequence. (T, D) | |
| Returns: | |
| nbest_hyps: N-best hypothesis. | |
| """ | |
| if self.search_cache is not None: | |
| B = self.search_cache | |
| else: | |
| B = [ | |
| Hypothesis( | |
| yseq=[0], | |
| score=0.0, | |
| dec_state=self.decoder.init_state(1), | |
| ) | |
| ] | |
| if self.use_lm: | |
| B[0].lm_state = self.lm.zero_state() | |
| for enc_out_t in enc_out: | |
| A = [] | |
| C = B | |
| enc_out_t = enc_out_t.unsqueeze(0) | |
| for v in range(self.max_sym_exp): | |
| D = [] | |
| beam_dec_out, beam_state = self.decoder.batch_score(C) | |
| beam_logp = torch.log_softmax( | |
| self.joint_network(enc_out_t, beam_dec_out), | |
| dim=-1, | |
| ) | |
| beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) | |
| seq_A = [h.yseq for h in A] | |
| for i, hyp in enumerate(C): | |
| if hyp.yseq not in seq_A: | |
| A.append( | |
| Hypothesis( | |
| score=(hyp.score + float(beam_logp[i, 0])), | |
| yseq=hyp.yseq[:], | |
| dec_state=hyp.dec_state, | |
| lm_state=hyp.lm_state, | |
| ) | |
| ) | |
| else: | |
| dict_pos = seq_A.index(hyp.yseq) | |
| A[dict_pos].score = np.logaddexp( | |
| A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) | |
| ) | |
| if v < (self.max_sym_exp - 1): | |
| if self.use_lm: | |
| beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
| self.create_lm_batch_inputs([c.yseq for c in C]), | |
| [c.lm_state for c in C], | |
| None, | |
| ) | |
| for i, hyp in enumerate(C): | |
| for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): | |
| new_hyp = Hypothesis( | |
| score=(hyp.score + float(logp)), | |
| yseq=(hyp.yseq + [int(k)]), | |
| dec_state=self.decoder.select_state(beam_state, i), | |
| lm_state=hyp.lm_state, | |
| ) | |
| if self.use_lm: | |
| new_hyp.score += self.lm_weight * beam_lm_scores[i, k] | |
| new_hyp.lm_state = beam_lm_states[i] | |
| D.append(new_hyp) | |
| C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size] | |
| B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] | |
| return B | |
| def modified_adaptive_expansion_search( | |
| self, | |
| enc_out: torch.Tensor, | |
| ) -> List[ExtendedHypothesis]: | |
| """Modified version of Adaptive Expansion Search (mAES). | |
| Based on AES (https://ieeexplore.ieee.org/document/9250505) and | |
| NSC (https://arxiv.org/abs/2201.05420). | |
| Args: | |
| enc_out: Encoder output sequence. (T, D_enc) | |
| Returns: | |
| nbest_hyps: N-best hypothesis. | |
| """ | |
| if self.search_cache is not None: | |
| kept_hyps = self.search_cache | |
| else: | |
| init_tokens = [ | |
| ExtendedHypothesis( | |
| yseq=[0], | |
| score=0.0, | |
| dec_state=self.decoder.init_state(1), | |
| ) | |
| ] | |
| beam_dec_out, beam_state = self.decoder.batch_score( | |
| init_tokens, | |
| ) | |
| if self.use_lm: | |
| beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
| self.create_lm_batch_inputs([h.yseq for h in init_tokens]), | |
| [h.lm_state for h in init_tokens], | |
| None, | |
| ) | |
| lm_state = beam_lm_states[0] | |
| lm_score = beam_lm_scores[0] | |
| else: | |
| lm_state = None | |
| lm_score = None | |
| kept_hyps = [ | |
| ExtendedHypothesis( | |
| yseq=[0], | |
| score=0.0, | |
| dec_state=self.decoder.select_state(beam_state, 0), | |
| dec_out=beam_dec_out[0], | |
| lm_state=lm_state, | |
| lm_score=lm_score, | |
| ) | |
| ] | |
| for enc_out_t in enc_out: | |
| hyps = kept_hyps | |
| kept_hyps = [] | |
| beam_enc_out = enc_out_t.unsqueeze(0) | |
| list_b = [] | |
| for n in range(self.nstep): | |
| beam_dec_out = torch.stack([h.dec_out for h in hyps]) | |
| beam_logp, beam_idx = torch.log_softmax( | |
| self.joint_network(beam_enc_out, beam_dec_out), | |
| dim=-1, | |
| ).topk(self.max_candidates, dim=-1) | |
| k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp) | |
| list_exp = [] | |
| for i, hyp in enumerate(hyps): | |
| for k, new_score in k_expansions[i]: | |
| new_hyp = ExtendedHypothesis( | |
| yseq=hyp.yseq[:], | |
| score=new_score, | |
| dec_out=hyp.dec_out, | |
| dec_state=hyp.dec_state, | |
| lm_state=hyp.lm_state, | |
| lm_score=hyp.lm_score, | |
| ) | |
| if k == 0: | |
| list_b.append(new_hyp) | |
| else: | |
| new_hyp.yseq.append(int(k)) | |
| if self.use_lm: | |
| new_hyp.score += self.lm_weight * float(hyp.lm_score[k]) | |
| list_exp.append(new_hyp) | |
| if not list_exp: | |
| kept_hyps = sorted( | |
| self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True | |
| )[: self.beam_size] | |
| break | |
| else: | |
| beam_dec_out, beam_state = self.decoder.batch_score( | |
| list_exp, | |
| ) | |
| if self.use_lm: | |
| beam_lm_scores, beam_lm_states = self.lm.batch_score( | |
| self.create_lm_batch_inputs([h.yseq for h in list_exp]), | |
| [h.lm_state for h in list_exp], | |
| None, | |
| ) | |
| if n < (self.nstep - 1): | |
| for i, hyp in enumerate(list_exp): | |
| hyp.dec_out = beam_dec_out[i] | |
| hyp.dec_state = self.decoder.select_state(beam_state, i) | |
| if self.use_lm: | |
| hyp.lm_state = beam_lm_states[i] | |
| hyp.lm_score = beam_lm_scores[i] | |
| hyps = list_exp[:] | |
| else: | |
| beam_logp = torch.log_softmax( | |
| self.joint_network(beam_enc_out, beam_dec_out), | |
| dim=-1, | |
| ) | |
| for i, hyp in enumerate(list_exp): | |
| hyp.score += float(beam_logp[i, 0]) | |
| hyp.dec_out = beam_dec_out[i] | |
| hyp.dec_state = self.decoder.select_state(beam_state, i) | |
| if self.use_lm: | |
| hyp.lm_state = beam_lm_states[i] | |
| hyp.lm_score = beam_lm_scores[i] | |
| kept_hyps = sorted( | |
| self.recombine_hyps(list_b + list_exp), | |
| key=lambda x: x.score, | |
| reverse=True, | |
| )[: self.beam_size] | |
| return kept_hyps | |