ru_errant / classifier.py
danasone's picture
test code
a0b78f4
from __future__ import annotations
from collections import defaultdict
from string import punctuation
import Levenshtein
from errant.edit import Edit
def edit_to_tuple(edit: Edit, idx: int = 0) -> tuple[int, int, str, str, int]:
cor_toks_str = " ".join([tok.text for tok in edit.c_toks])
return [edit.o_start, edit.o_end, edit.type, cor_toks_str, idx]
def classify(edit: Edit) -> list[Edit]:
"""Classifies an Edit via updating its `type` attribute."""
# Insertion and deletion
if ((not edit.o_toks and edit.c_toks) or (edit.o_toks and not edit.c_toks)):
error_cats = get_one_sided_type(edit.o_toks, edit.c_toks)
elif edit.o_toks != edit.c_toks:
error_cats = get_two_sided_type(edit.o_toks, edit.c_toks)
else:
error_cats = {"NA": edit.c_toks[0].text}
new_edit_list = []
if error_cats:
for error_cat, correct_str in error_cats.items():
edit.type = error_cat
edit_tuple = edit_to_tuple(edit)
edit_tuple[3] = correct_str
new_edit_list.append(edit_tuple)
return new_edit_list
def get_edit_info(toks):
pos = []
dep = []
morph = dict()
for tok in toks:
pos.append(tok.tag_)
dep.append(tok.dep_)
morphs = str(tok.morph).split('|')
for m in morphs:
if len(m.strip()):
k, v = m.strip().split('=')
morph[k] = v
return pos, dep, morph
def get_one_sided_type(o_toks, c_toks):
"""Classifies a zero-to-one or one-to-zero error based on a token list."""
pos_list, _, _ = get_edit_info(o_toks if o_toks else c_toks)
if "PUNCT" in pos_list or "SPACE" in pos_list:
return {"PUNCT": c_toks[0].text if c_toks else ""}
return {"SPELL": c_toks[0].text if c_toks else ""}
def get_two_sided_type(o_toks, c_toks) -> dict[str, str]:
"""Classifies a one-to-one or one-to-many or many-to-one error based on token lists."""
# one-to-one cases
if len(o_toks) == len(c_toks) == 1:
if (
all(char in punctuation + " " for char in o_toks[0].text) and
all(char in punctuation + " " for char in c_toks[0].text)
):
return {"PUNCT": c_toks[0].text}
source_w, correct_w = o_toks[0].text, c_toks[0].text
if source_w != correct_w:
# if both string are lowercase or both are uppercase,
# and there is no "ั‘" in both, then it may be only "SPELL" error type
if (((source_w.islower() and correct_w.islower()) or
(source_w.isupper() and correct_w.isupper())) and
"ั‘" not in source_w + correct_w):
return {"SPELL": correct_w}
# edits with multiple errors (e.g. SPELL + CASE)
# Step 1. Make char-level Levenstein table
char_edits = Levenshtein.editops(source_w, correct_w)
# Step 2. Classify operations (CASE, YO, SPELL)
edits_classified = classify_char_edits(char_edits, source_w, correct_w)
# Step 3. Combine the same-typed errors into minimal string pairs
separated_edits = get_edit_strings(source_w, correct_w, edits_classified)
return separated_edits
# one-to-many and many-to-one cases
if all(char in punctuation + " " for char in o_toks.text + c_toks.text):
return {"PUNCT": c_toks.text}
joint_corr_str = " ".join([tok.text for tok in c_toks])
joint_corr_str = joint_corr_str.replace("- ", "-").replace(" -", "-")
return {"SPELL": joint_corr_str}
def classify_char_edits(char_edits, source_w, correct_w):
"""Classifies char-level Levenstein operations into SPELL, YO and CASE."""
edits_classified = []
for edit in char_edits:
if edit[0] == "replace":
if "ั‘" in [source_w[edit[1]], correct_w[edit[2]]]:
edits_classified.append((*edit, "YO"))
elif source_w[edit[1]].lower() == correct_w[edit[2]].lower():
edits_classified.append((*edit, "CASE"))
else:
if (
(source_w[edit[1]].islower() and correct_w[edit[2]].isupper()) or
(source_w[edit[1]].isupper() and correct_w[edit[2]].islower())
):
edits_classified.append((*edit, "CASE"))
edits_classified.append((*edit, "SPELL"))
else:
edits_classified.append((*edit, "SPELL"))
return edits_classified
def get_edit_strings(source: str, correction: str,
edits_classified: list[tuple]) -> dict[str, str]:
"""
Applies classified (SPELL, YO and CASE) char operations to source word separately.
Returns a dict mapping error type to source string with corrections of this type only.
"""
separated_edits = defaultdict(lambda: source)
shift = 0 # char position shift to consider on deletions and insertions
for edit in edits_classified:
edit_type = edit[3]
curr_src = separated_edits[edit_type]
if edit_type == "CASE": # SOURCE letter spelled in CORRECTION case
if correction[edit[2]].isupper():
correction_char = source[edit[1]].upper()
else:
correction_char = source[edit[1]].lower()
else:
if edit[0] == "delete":
correction_char = ""
elif edit[0] == "insert":
correction_char = correction[edit[2]]
elif source[edit[1]].isupper():
correction_char = correction[edit[2]].upper()
else:
correction_char = correction[edit[2]].lower()
if edit[0] == "replace":
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
curr_src[edit[1]+shift + 1:]
elif edit[0] == "delete":
separated_edits[edit_type] = curr_src[:edit[1] + shift] + \
curr_src[edit[1]+shift + 1:]
shift -= 1
elif edit[0] == "insert":
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
curr_src[edit[1]+shift:]
shift += 1
return dict(separated_edits)