| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import itertools |
| | import math |
| | from typing import Iterable, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
| | from nemo.core.classes import NeuralModule, typecheck |
| | from nemo.core.neural_types import LengthsType, LogprobsType, NeuralType, PredictionsType |
| |
|
| |
|
| | class _TokensWrapper: |
| | def __init__(self, vocabulary: List[str], tokenizer: TokenizerSpec): |
| | self.vocabulary = vocabulary |
| | self.tokenizer = tokenizer |
| |
|
| | if tokenizer is None: |
| | self.reverse_map = {vocabulary[i]: i for i in range(len(vocabulary))} |
| |
|
| | @property |
| | def blank(self): |
| | return len(self.vocabulary) |
| |
|
| | @property |
| | def unk_id(self): |
| | if (self.tokenizer is not None) and hasattr(self.tokenizer, 'unk_id') and self.tokenizer.unk_id is not None: |
| | return self.tokenizer.unk_id |
| |
|
| | if '<unk>' in self.vocabulary: |
| | return self.token_to_id('<unk>') |
| | else: |
| | return -1 |
| |
|
| | @property |
| | def vocab(self): |
| | return self.vocabulary |
| |
|
| | @property |
| | def vocab_size(self): |
| | |
| | return len(self.vocabulary) + 1 |
| |
|
| | def token_to_id(self, token: str): |
| | if token == self.blank: |
| | return -1 |
| |
|
| | if self.tokenizer is not None: |
| | return self.tokenizer.token_to_id(token) |
| | else: |
| | return self.reverse_map[token] |
| |
|
| |
|
| | class FlashLightKenLMBeamSearchDecoder(NeuralModule): |
| | ''' |
| | @property |
| | def input_types(self): |
| | """Returns definitions of module input ports. |
| | """ |
| | return { |
| | "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), |
| | } |
| | |
| | @property |
| | def output_types(self): |
| | """Returns definitions of module output ports. |
| | """ |
| | return {"hypos": NeuralType(('B'), PredictionsType())} |
| | ''' |
| |
|
| | def __init__( |
| | self, |
| | lm_path: str, |
| | vocabulary: List[str], |
| | tokenizer: Optional[TokenizerSpec] = None, |
| | lexicon_path: Optional[str] = None, |
| | beam_size: int = 32, |
| | beam_size_token: int = 32, |
| | beam_threshold: float = 25.0, |
| | lm_weight: float = 2.0, |
| | word_score: float = -1.0, |
| | unk_weight: float = -math.inf, |
| | sil_weight: float = 0.0, |
| | unit_lm: bool = False, |
| | ): |
| |
|
| | try: |
| | from flashlight.lib.text.decoder import ( |
| | LM, |
| | CriterionType, |
| | KenLM, |
| | LexiconDecoder, |
| | LexiconDecoderOptions, |
| | SmearingMode, |
| | Trie, |
| | ) |
| | from flashlight.lib.text.dictionary import create_word_dict, load_words |
| | except ModuleNotFoundError: |
| | raise ModuleNotFoundError( |
| | "FlashLightKenLMBeamSearchDecoder requires the installation of flashlight python bindings " |
| | "from https://github.com/flashlight/text. Please follow the build instructions there." |
| | ) |
| |
|
| | super().__init__() |
| |
|
| | self.criterion_type = CriterionType.CTC |
| | self.tokenizer_wrapper = _TokensWrapper(vocabulary, tokenizer) |
| | self.vocab_size = self.tokenizer_wrapper.vocab_size |
| | self.blank = self.tokenizer_wrapper.blank |
| | self.silence = self.tokenizer_wrapper.unk_id |
| | self.unit_lm = unit_lm |
| |
|
| | if lexicon_path is not None: |
| | self.lexicon = load_words(lexicon_path) |
| | self.word_dict = create_word_dict(self.lexicon) |
| | self.unk_word = self.word_dict.get_index("<unk>") |
| |
|
| | |
| | |
| | |
| | |
| | self.lm = KenLM(lm_path, self.word_dict) |
| | self.trie = Trie(self.vocab_size, self.silence) |
| |
|
| | start_state = self.lm.start(False) |
| | for i, (word, spellings) in enumerate(self.lexicon.items()): |
| | word_idx = self.word_dict.get_index(word) |
| | _, score = self.lm.score(start_state, word_idx) |
| | for spelling in spellings: |
| | spelling_idxs = [self.tokenizer_wrapper.token_to_id(token) for token in spelling] |
| | if self.tokenizer_wrapper.unk_id in spelling_idxs: |
| | print(f'tokenizer has unknown id for word[ {word} ] {spelling} {spelling_idxs}', flush=True) |
| | continue |
| | self.trie.insert(spelling_idxs, word_idx, score) |
| | self.trie.smear(SmearingMode.MAX) |
| |
|
| | self.decoder_opts = LexiconDecoderOptions( |
| | beam_size=beam_size, |
| | beam_size_token=int(beam_size_token), |
| | beam_threshold=beam_threshold, |
| | lm_weight=lm_weight, |
| | word_score=word_score, |
| | unk_score=unk_weight, |
| | sil_score=sil_weight, |
| | log_add=False, |
| | criterion_type=self.criterion_type, |
| | ) |
| |
|
| | self.decoder = LexiconDecoder( |
| | self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, |
| | ) |
| | else: |
| | assert self.unit_lm, "lexicon free decoding can only be done with a unit language model" |
| | from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions |
| |
|
| | d = { |
| | w: [[w]] |
| | for w in self.tokenizer_wrapper.vocab + ([] if '<unk>' in self.tokenizer_wrapper.vocab else ['<unk>']) |
| | } |
| | self.word_dict = create_word_dict(d) |
| | self.lm = KenLM(lm_path, self.word_dict) |
| | self.decoder_opts = LexiconFreeDecoderOptions( |
| | beam_size=beam_size, |
| | beam_size_token=int(beam_size_token), |
| | beam_threshold=beam_threshold, |
| | lm_weight=lm_weight, |
| | sil_score=sil_weight, |
| | log_add=False, |
| | criterion_type=self.criterion_type, |
| | ) |
| | self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, []) |
| |
|
| | def _get_tokens(self, idxs: List[int]): |
| | """Normalize tokens by handling CTC blank, ASG replabels, etc.""" |
| |
|
| | idxs = (g[0] for g in itertools.groupby(idxs)) |
| | idxs = filter(lambda x: x != self.blank and x != self.silence, idxs) |
| |
|
| | return torch.LongTensor(list(idxs)) |
| |
|
| | def _get_timesteps(self, token_idxs: List[int]): |
| | """Returns frame numbers corresponding to every non-blank token. |
| | Parameters |
| | ---------- |
| | token_idxs : List[int] |
| | IDs of decoded tokens. |
| | Returns |
| | ------- |
| | List[int] |
| | Frame numbers corresponding to every non-blank token. |
| | """ |
| |
|
| | timesteps = [] |
| | for i, token_idx in enumerate(token_idxs): |
| | if token_idx == self.blank: |
| | continue |
| | if i == 0 or token_idx != token_idxs[i - 1]: |
| | timesteps.append(i) |
| |
|
| | return timesteps |
| |
|
| | |
| | @torch.no_grad() |
| | def forward(self, log_probs: Union[np.ndarray, torch.Tensor]): |
| | if isinstance(log_probs, np.ndarray): |
| | log_probs = torch.from_numpy(log_probs).float() |
| | if log_probs.dim() == 2: |
| | log_probs = log_probs.unsqueeze(0) |
| |
|
| | emissions = log_probs.cpu().contiguous() |
| |
|
| | B, T, N = emissions.size() |
| | hypos = [] |
| | |
| | for b in range(B): |
| | |
| | |
| | |
| | emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) |
| | results = self.decoder.decode(emissions_ptr, T, N) |
| |
|
| | hypos.append( |
| | [ |
| | { |
| | "tokens": self._get_tokens(result.tokens), |
| | "score": result.score, |
| | "timesteps": self._get_timesteps(result.tokens), |
| | "words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], |
| | } |
| | for result in results |
| | ] |
| | ) |
| |
|
| | return hypos |
| |
|