| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from typing import List |
| |
|
| | import k2 |
| | import torch |
| |
|
| | from icefall.lexicon import Lexicon |
| |
|
| |
|
| | class CtcTrainingGraphCompiler(object): |
| | def __init__( |
| | self, |
| | lexicon: Lexicon, |
| | device: torch.device, |
| | oov: str = "<UNK>", |
| | need_repeat_flag: bool = False, |
| | ): |
| | """ |
| | Args: |
| | lexicon: |
| | It is built from `data/lang/lexicon.txt`. |
| | device: |
| | The device to use for operations compiling transcripts to FSAs. |
| | oov: |
| | Out of vocabulary word. When a word in the transcript |
| | does not exist in the lexicon, it is replaced with `oov`. |
| | need_repeat_flag: |
| | If True, will add an attribute named `_is_repeat_token_` to ctc_topo |
| | indicating whether this token is a repeat token in ctc graph. |
| | This attribute is needed to implement delay-penalty for phone-based |
| | ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more |
| | details. Note: The above change MUST be included in k2 to open this |
| | flag. |
| | """ |
| | L_inv = lexicon.L_inv.to(device) |
| | assert L_inv.requires_grad is False |
| |
|
| | assert oov in lexicon.word_table |
| |
|
| | self.L_inv = k2.arc_sort(L_inv) |
| | self.oov_id = lexicon.word_table[oov] |
| | self.word_table = lexicon.word_table |
| |
|
| | max_token_id = max(lexicon.tokens) |
| | ctc_topo = k2.ctc_topo(max_token_id, modified=False) |
| |
|
| | self.ctc_topo = ctc_topo.to(device) |
| |
|
| | if need_repeat_flag: |
| | self.ctc_topo._is_repeat_token_ = ( |
| | self.ctc_topo.labels != self.ctc_topo.aux_labels |
| | ) |
| |
|
| | self.device = device |
| |
|
| | def compile(self, texts: List[str]) -> k2.Fsa: |
| | """Build decoding graphs by composing ctc_topo with |
| | given transcripts. |
| | |
| | Args: |
| | texts: |
| | A list of strings. Each string contains a sentence for an utterance. |
| | A sentence consists of spaces separated words. An example `texts` |
| | looks like: |
| | |
| | ['hello icefall', 'CTC training with k2'] |
| | |
| | Returns: |
| | An FsaVec, the composition result of `self.ctc_topo` and the |
| | transcript FSA. |
| | """ |
| | transcript_fsa = self.convert_transcript_to_fsa(texts) |
| |
|
| | |
| | |
| | fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa) |
| |
|
| | fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops) |
| |
|
| | decoding_graph = k2.compose( |
| | self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False |
| | ) |
| |
|
| | assert decoding_graph.requires_grad is False |
| |
|
| | return decoding_graph |
| |
|
| | def texts_to_ids(self, texts: List[str]) -> List[List[int]]: |
| | """Convert a list of texts to a list-of-list of word IDs. |
| | |
| | Args: |
| | texts: |
| | It is a list of strings. Each string consists of space(s) |
| | separated words. An example containing two strings is given below: |
| | |
| | ['HELLO ICEFALL', 'HELLO k2'] |
| | Returns: |
| | Return a list-of-list of word IDs. |
| | """ |
| | word_ids_list = [] |
| | for text in texts: |
| | word_ids = [] |
| | for word in text.split(): |
| | if word in self.word_table: |
| | word_ids.append(self.word_table[word]) |
| | else: |
| | word_ids.append(self.oov_id) |
| | word_ids_list.append(word_ids) |
| | return word_ids_list |
| |
|
| | def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa: |
| | """Convert a list of transcript texts to an FsaVec. |
| | |
| | Args: |
| | texts: |
| | A list of strings. Each string contains a sentence for an utterance. |
| | A sentence consists of spaces separated words. An example `texts` |
| | looks like: |
| | |
| | ['hello icefall', 'CTC training with k2'] |
| | |
| | Returns: |
| | Return an FsaVec, whose `shape[0]` equals to `len(texts)`. |
| | """ |
| | word_ids_list = [] |
| | for text in texts: |
| | word_ids = [] |
| | for word in text.split(): |
| | if word in self.word_table: |
| | word_ids.append(self.word_table[word]) |
| | else: |
| | word_ids.append(self.oov_id) |
| | word_ids_list.append(word_ids) |
| |
|
| | word_fsa = k2.linear_fsa(word_ids_list, self.device) |
| |
|
| | word_fsa_with_self_loops = k2.add_epsilon_self_loops(word_fsa) |
| |
|
| | fsa = k2.intersect( |
| | self.L_inv, word_fsa_with_self_loops, treat_epsilons_specially=False |
| | ) |
| | |
| | |
| | ans_fsa = fsa.invert_() |
| | return k2.arc_sort(ans_fsa) |
| |
|