import jiwer _MODEL = None _TOKENIZER = None def bytelen(string): """Return the length of `string` in utf-8 bytes""" return len(bytes(string, encoding='utf-8', errors='ignore')) def take_bytes(words: list[str], n_bytes: int) -> tuple[list[str], list[str]]: """ Take `n_bytes` of words from a list of words `words` Arguments: words: A list of words n_bytes: max size of words to take (in bytes) Returns: A tuple (head, tail) where `head` is max n_bytes, `tail` is the remaining words """ current_n_bytes = 0 for i, word in enumerate(words): if current_n_bytes + bytelen(word) > n_bytes: return words[:i-1], words[i-1:] current_n_bytes += bytelen(word) + 1 # add 1 to account for space between words return words, [] def split(text, max_len, offset=0): """Split `text` in chunks of at most `max_len` UTF-8 bytes""" words = text.split(' ') chunks = [] if offset: chunk, words = take_bytes(words, offset) chunks.append(' '.join(chunk)) while words: chunk, words = take_bytes(words, max_len) chunks.append(' '.join(chunk)) return chunks class TextMerger: """Class for merging texts.""" EMPTY = '🗌' SPACE = '_' def __init__(self, original: str, char_level=False): """ Arguments: original: The original text. """ self.char_level = char_level self.word_level = not char_level self.original = self._process_incoming(original) self.original_padded = self._pad_between_words(self.original) self.candidates = [[] for _ in self.original_padded] self.candidate_texts = [] self.alignments = [] def _pad_between_words(self, string: str) -> list[str]: """ Insert `EMPTY` constant between words in `string`. Used for aligning suggested insertions. Example: 'Hello world' -> [EMPTY, 'Hello', EMPTY, 'world', EMPTY] """ words = string.split(' ') padded = [self.EMPTY] for word in words: padded.append(word) padded.append(self.EMPTY) return padded def _process_incoming(self, text): if self.char_level: return ' '.join(text.replace(' ', self.SPACE)) return text.replace('\n', '\n ') def _process_outgoing(self, words: list[str]): if self.char_level: return ''.join(words).replace(self.SPACE, ' ') return ' '.join(words).replace('\n ', '\n') def add_candidate_texts(self, texts: list[str]): for text in texts: self.add_candidate_text(text) def add_candidate_text(self, text: str): """ Add `text` as a candidate correction of `original` """ # Bookkeeping self.candidate_texts.append(text) text = self._process_incoming(text) jiwer_result = jiwer.process_words( self.original, text, reference_transform = jiwer.Compose([jiwer.ReduceToListOfListOfWords()]), hypothesis_transform = jiwer.Compose([jiwer.ReduceToListOfListOfWords()])) self.alignments.append(jiwer_result) # Work through the jiwer results and fill in the candidates list text = jiwer_result.hypotheses[0] #text.split(' ') for chunk in jiwer_result.alignments[0]: x0, x1 = chunk.ref_start_idx, chunk.ref_end_idx y0, y1 = chunk.hyp_start_idx, chunk.hyp_end_idx if chunk.type == 'substitute': # Append the suggested substitution to the candidate list for i in range(x1-x0): self.candidates[2*(x0+i)+1].extend(text[y0+i:y0+i+1]) # Insert the suggested insertion as a suggestion to # the `EMPTY` item between words in the original elif chunk.type == 'insert': if self.char_level: self.candidates[2*x0].append(''.join(text[y0:y1])) else: self.candidates[2*x0].append(' '.join(text[y0:y1])) # This word is suggested to be deleted, append EMPTY as a candidate elif chunk.type == 'delete': for i in range(x1-x0): self.candidates[2*(x0+i)+1].append(self.EMPTY) def combine(self) -> str: """ Combine the current candidate texts """ out = [] for original, candidates in zip(self.original_padded, self.candidates): correction_candidate = self._best_candidate(candidates, original) out.append(correction_candidate) out = [word for word in out if word != self.EMPTY] return self._process_outgoing(out) def _best_candidate(self, candidates, original): """ Return the best candidate out of `candidates` Uses majority vote to determine the best candidate. Example: best_candidate(['Hello', 'Hello', 'Hallå']) -> 'Hello' """ if len(candidates) < self._majority(): return original if len(set(candidates)) == 1: return candidates[0] if self.word_level: tm = TextMerger(original, char_level=True) tm.add_candidate_texts(candidates) return tm.combine() else: candidate, n_votes = max(((candidate, candidates.count(candidate)) for candidate in candidates), key=lambda x: x[1]) return candidate if n_votes >= self._majority() else original def _majority(self): return 1 + len(self.candidate_texts) // 2 def process(text: str, n_candidates: int = 1): if n_candidates == 1: splits = split(text, 127) return ' '.join(generate(splits)) combiner = TextMerger(text) splits = [split(text, 127, 127 * i // n_candidates) for i in range(n_candidates)] outputs = [generate(lines) for lines in splits] for output in outputs: combiner.add_candidate_text(' '.join(output)) return combiner.combine() def generate(texts): inputs = _TOKENIZER(texts, padding=True, truncation=True, return_tensors='pt') output_ids = _MODEL.generate(**inputs) return _TOKENIZER.batch_decode(output_ids, skip_special_tokens=True) def diff(old: str, new: str): """Display the difference between old and new""" result = jiwer.process_characters(old, new) output = '' for chunk in result.alignments[0]: old_chars = ''.join(old[chunk.ref_start_idx:chunk.ref_end_idx]) new_chars = ''.join(new[chunk.hyp_start_idx:chunk.hyp_end_idx]) if chunk.type == 'equal': output += old_chars continue if old_chars and not old_chars.isspace(): output += f':red[~~{old_chars.strip()}~~]' output += f':green[{new_chars}]' return output def set_model(model, tokenizer): global _MODEL, _TOKENIZER _MODEL = model _TOKENIZER = tokenizer