|
import jiwer |
|
|
|
|
|
_MODEL = None |
|
_TOKENIZER = None |
|
|
|
|
|
def bytelen(string): |
|
"""Return the length of `string` in utf-8 bytes""" |
|
return len(bytes(string, encoding='utf-8', errors='ignore')) |
|
|
|
|
|
def take_bytes(words: list[str], n_bytes: int) -> tuple[list[str], list[str]]: |
|
""" |
|
Take `n_bytes` of words from a list of words `words` |
|
|
|
Arguments: |
|
words: A list of words |
|
n_bytes: max size of words to take (in bytes) |
|
|
|
Returns: |
|
A tuple (head, tail) where `head` is max n_bytes, |
|
`tail` is the remaining words |
|
""" |
|
current_n_bytes = 0 |
|
for i, word in enumerate(words): |
|
if current_n_bytes + bytelen(word) > n_bytes: |
|
return words[:i-1], words[i-1:] |
|
current_n_bytes += bytelen(word) + 1 |
|
return words, [] |
|
|
|
|
|
def split(text, max_len, offset=0): |
|
"""Split `text` in chunks of at most `max_len` UTF-8 bytes""" |
|
words = text.split(' ') |
|
chunks = [] |
|
|
|
if offset: |
|
chunk, words = take_bytes(words, offset) |
|
chunks.append(' '.join(chunk)) |
|
|
|
while words: |
|
chunk, words = take_bytes(words, max_len) |
|
chunks.append(' '.join(chunk)) |
|
|
|
return chunks |
|
|
|
|
|
class TextMerger: |
|
"""Class for merging texts.""" |
|
|
|
EMPTY = '🗌' |
|
SPACE = '_' |
|
|
|
def __init__(self, original: str, char_level=False): |
|
""" |
|
Arguments: |
|
original: The original text. |
|
""" |
|
|
|
self.char_level = char_level |
|
self.word_level = not char_level |
|
self.original = self._process_incoming(original) |
|
self.original_padded = self._pad_between_words(self.original) |
|
self.candidates = [[] for _ in self.original_padded] |
|
self.candidate_texts = [] |
|
self.alignments = [] |
|
|
|
def _pad_between_words(self, string: str) -> list[str]: |
|
""" |
|
Insert `EMPTY` constant between words in `string`. |
|
Used for aligning suggested insertions. |
|
|
|
Example: |
|
'Hello world' -> [EMPTY, 'Hello', EMPTY, 'world', EMPTY] |
|
""" |
|
words = string.split(' ') |
|
padded = [self.EMPTY] |
|
for word in words: |
|
padded.append(word) |
|
padded.append(self.EMPTY) |
|
return padded |
|
|
|
def _process_incoming(self, text): |
|
if self.char_level: |
|
return ' '.join(text.replace(' ', self.SPACE)) |
|
return text.replace('\n', '\n ') |
|
|
|
def _process_outgoing(self, words: list[str]): |
|
if self.char_level: |
|
return ''.join(words).replace(self.SPACE, ' ') |
|
return ' '.join(words).replace('\n ', '\n') |
|
|
|
def add_candidate_texts(self, texts: list[str]): |
|
for text in texts: |
|
self.add_candidate_text(text) |
|
|
|
def add_candidate_text(self, text: str): |
|
""" |
|
Add `text` as a candidate correction of `original` |
|
""" |
|
|
|
self.candidate_texts.append(text) |
|
text = self._process_incoming(text) |
|
jiwer_result = jiwer.process_words( |
|
self.original, |
|
text, |
|
reference_transform = jiwer.Compose([jiwer.ReduceToListOfListOfWords()]), |
|
hypothesis_transform = jiwer.Compose([jiwer.ReduceToListOfListOfWords()])) |
|
|
|
self.alignments.append(jiwer_result) |
|
|
|
|
|
text = jiwer_result.hypotheses[0] |
|
for chunk in jiwer_result.alignments[0]: |
|
x0, x1 = chunk.ref_start_idx, chunk.ref_end_idx |
|
y0, y1 = chunk.hyp_start_idx, chunk.hyp_end_idx |
|
|
|
if chunk.type == 'substitute': |
|
|
|
for i in range(x1-x0): |
|
self.candidates[2*(x0+i)+1].extend(text[y0+i:y0+i+1]) |
|
|
|
|
|
|
|
elif chunk.type == 'insert': |
|
if self.char_level: |
|
self.candidates[2*x0].append(''.join(text[y0:y1])) |
|
else: |
|
self.candidates[2*x0].append(' '.join(text[y0:y1])) |
|
|
|
|
|
elif chunk.type == 'delete': |
|
for i in range(x1-x0): |
|
self.candidates[2*(x0+i)+1].append(self.EMPTY) |
|
|
|
|
|
def combine(self) -> str: |
|
""" |
|
Combine the current candidate texts |
|
""" |
|
out = [] |
|
for original, candidates in zip(self.original_padded, self.candidates): |
|
correction_candidate = self._best_candidate(candidates, original) |
|
out.append(correction_candidate) |
|
out = [word for word in out if word != self.EMPTY] |
|
return self._process_outgoing(out) |
|
|
|
|
|
def _best_candidate(self, candidates, original): |
|
""" |
|
Return the best candidate out of `candidates` |
|
|
|
Uses majority vote to determine the best candidate. |
|
Example: best_candidate(['Hello', 'Hello', 'Hallå']) -> 'Hello' |
|
""" |
|
if len(candidates) < self._majority(): |
|
return original |
|
|
|
if len(set(candidates)) == 1: |
|
return candidates[0] |
|
|
|
if self.word_level: |
|
tm = TextMerger(original, char_level=True) |
|
tm.add_candidate_texts(candidates) |
|
return tm.combine() |
|
|
|
else: |
|
candidate, n_votes = max(((candidate, candidates.count(candidate)) for candidate in candidates), key=lambda x: x[1]) |
|
return candidate if n_votes >= self._majority() else original |
|
|
|
def _majority(self): |
|
return 1 + len(self.candidate_texts) // 2 |
|
|
|
|
|
def process(text: str, n_candidates: int = 1): |
|
|
|
if n_candidates == 1: |
|
splits = split(text, 127) |
|
return ' '.join(generate(splits)) |
|
|
|
combiner = TextMerger(text) |
|
splits = [split(text, 127, 127 * i // n_candidates) for i in range(n_candidates)] |
|
outputs = [generate(lines) for lines in splits] |
|
for output in outputs: |
|
combiner.add_candidate_text(' '.join(output)) |
|
return combiner.combine() |
|
|
|
|
|
def generate(texts): |
|
inputs = _TOKENIZER(texts, padding=True, truncation=True, return_tensors='pt') |
|
output_ids = _MODEL.generate(**inputs) |
|
return _TOKENIZER.batch_decode(output_ids, skip_special_tokens=True) |
|
|
|
|
|
def diff(old: str, new: str): |
|
"""Display the difference between old and new""" |
|
result = jiwer.process_characters(old, new) |
|
output = '' |
|
for chunk in result.alignments[0]: |
|
old_chars = ''.join(old[chunk.ref_start_idx:chunk.ref_end_idx]) |
|
new_chars = ''.join(new[chunk.hyp_start_idx:chunk.hyp_end_idx]) |
|
|
|
if chunk.type == 'equal': |
|
output += old_chars |
|
continue |
|
|
|
if old_chars and not old_chars.isspace(): |
|
output += f':red[~~{old_chars.strip()}~~]' |
|
|
|
output += f':green[{new_chars}]' |
|
return output |
|
|
|
|
|
def set_model(model, tokenizer): |
|
global _MODEL, _TOKENIZER |
|
_MODEL = model |
|
_TOKENIZER = tokenizer |
|
|