viklofg's picture
Upload 3 files
93d3903
raw
history blame
7.07 kB
import jiwer
_MODEL = None
_TOKENIZER = None
def bytelen(string):
"""Return the length of `string` in utf-8 bytes"""
return len(bytes(string, encoding='utf-8', errors='ignore'))
def take_bytes(words: list[str], n_bytes: int) -> tuple[list[str], list[str]]:
"""
Take `n_bytes` of words from a list of words `words`
Arguments:
words: A list of words
n_bytes: max size of words to take (in bytes)
Returns:
A tuple (head, tail) where `head` is max n_bytes,
`tail` is the remaining words
"""
current_n_bytes = 0
for i, word in enumerate(words):
if current_n_bytes + bytelen(word) > n_bytes:
return words[:i-1], words[i-1:]
current_n_bytes += bytelen(word) + 1 # add 1 to account for space between words
return words, []
def split(text, max_len, offset=0):
"""Split `text` in chunks of at most `max_len` UTF-8 bytes"""
words = text.split(' ')
chunks = []
if offset:
chunk, words = take_bytes(words, offset)
chunks.append(' '.join(chunk))
while words:
chunk, words = take_bytes(words, max_len)
chunks.append(' '.join(chunk))
return chunks
class TextMerger:
"""Class for merging texts."""
EMPTY = '🗌'
SPACE = '_'
def __init__(self, original: str, char_level=False):
"""
Arguments:
original: The original text.
"""
self.char_level = char_level
self.word_level = not char_level
self.original = self._process_incoming(original)
self.original_padded = self._pad_between_words(self.original)
self.candidates = [[] for _ in self.original_padded]
self.candidate_texts = []
self.alignments = []
def _pad_between_words(self, string: str) -> list[str]:
"""
Insert `EMPTY` constant between words in `string`.
Used for aligning suggested insertions.
Example:
'Hello world' -> [EMPTY, 'Hello', EMPTY, 'world', EMPTY]
"""
words = string.split(' ')
padded = [self.EMPTY]
for word in words:
padded.append(word)
padded.append(self.EMPTY)
return padded
def _process_incoming(self, text):
if self.char_level:
return ' '.join(text.replace(' ', self.SPACE))
return text.replace('\n', '\n ')
def _process_outgoing(self, words: list[str]):
if self.char_level:
return ''.join(words).replace(self.SPACE, ' ')
return ' '.join(words).replace('\n ', '\n')
def add_candidate_texts(self, texts: list[str]):
for text in texts:
self.add_candidate_text(text)
def add_candidate_text(self, text: str):
"""
Add `text` as a candidate correction of `original`
"""
# Bookkeeping
self.candidate_texts.append(text)
text = self._process_incoming(text)
jiwer_result = jiwer.process_words(
self.original,
text,
reference_transform = jiwer.Compose([jiwer.ReduceToListOfListOfWords()]),
hypothesis_transform = jiwer.Compose([jiwer.ReduceToListOfListOfWords()]))
self.alignments.append(jiwer_result)
# Work through the jiwer results and fill in the candidates list
text = jiwer_result.hypotheses[0] #text.split(' ')
for chunk in jiwer_result.alignments[0]:
x0, x1 = chunk.ref_start_idx, chunk.ref_end_idx
y0, y1 = chunk.hyp_start_idx, chunk.hyp_end_idx
if chunk.type == 'substitute':
# Append the suggested substitution to the candidate list
for i in range(x1-x0):
self.candidates[2*(x0+i)+1].extend(text[y0+i:y0+i+1])
# Insert the suggested insertion as a suggestion to
# the `EMPTY` item between words in the original
elif chunk.type == 'insert':
if self.char_level:
self.candidates[2*x0].append(''.join(text[y0:y1]))
else:
self.candidates[2*x0].append(' '.join(text[y0:y1]))
# This word is suggested to be deleted, append EMPTY as a candidate
elif chunk.type == 'delete':
for i in range(x1-x0):
self.candidates[2*(x0+i)+1].append(self.EMPTY)
def combine(self) -> str:
"""
Combine the current candidate texts
"""
out = []
for original, candidates in zip(self.original_padded, self.candidates):
correction_candidate = self._best_candidate(candidates, original)
out.append(correction_candidate)
out = [word for word in out if word != self.EMPTY]
return self._process_outgoing(out)
def _best_candidate(self, candidates, original):
"""
Return the best candidate out of `candidates`
Uses majority vote to determine the best candidate.
Example: best_candidate(['Hello', 'Hello', 'Hallå']) -> 'Hello'
"""
if len(candidates) < self._majority():
return original
if len(set(candidates)) == 1:
return candidates[0]
if self.word_level:
tm = TextMerger(original, char_level=True)
tm.add_candidate_texts(candidates)
return tm.combine()
else:
candidate, n_votes = max(((candidate, candidates.count(candidate)) for candidate in candidates), key=lambda x: x[1])
return candidate if n_votes >= self._majority() else original
def _majority(self):
return 1 + len(self.candidate_texts) // 2
def process(text: str, n_candidates: int = 1):
if n_candidates == 1:
splits = split(text, 127)
return ' '.join(generate(splits))
combiner = TextMerger(text)
splits = [split(text, 127, 127 * i // n_candidates) for i in range(n_candidates)]
outputs = [generate(lines) for lines in splits]
for output in outputs:
combiner.add_candidate_text(' '.join(output))
return combiner.combine()
def generate(texts):
inputs = _TOKENIZER(texts, padding=True, truncation=True, return_tensors='pt')
output_ids = _MODEL.generate(**inputs)
return _TOKENIZER.batch_decode(output_ids, skip_special_tokens=True)
def diff(old: str, new: str):
"""Display the difference between old and new"""
result = jiwer.process_characters(old, new)
output = ''
for chunk in result.alignments[0]:
old_chars = ''.join(old[chunk.ref_start_idx:chunk.ref_end_idx])
new_chars = ''.join(new[chunk.hyp_start_idx:chunk.hyp_end_idx])
if chunk.type == 'equal':
output += old_chars
continue
if old_chars and not old_chars.isspace():
output += f':red[~~{old_chars.strip()}~~]'
output += f':green[{new_chars}]'
return output
def set_model(model, tokenizer):
global _MODEL, _TOKENIZER
_MODEL = model
_TOKENIZER = tokenizer