ru_errant / merger.py
danasone's picture
test code
a0b78f4
from __future__ import annotations
import itertools
import re
from string import punctuation
import Levenshtein
from errant.alignment import Alignment
from errant.edit import Edit
def get_rule_edits(alignment: Alignment) -> list[Edit]:
"""Groups word-level alignment according to merging rules."""
edits = []
# Split alignment into groups
alignment_groups = group_alignment(alignment, "new")
for op, group in alignment_groups:
group = list(group)
# Ignore M
if op == "M":
continue
# T is always split
if op == "T":
for seq in group:
edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
# Process D, I and S subsequence
else:
processed = process_seq(group, alignment)
# Turn the processed sequence into edits
for seq in processed:
edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
return edits
def group_alignment(alignment: Alignment, mode: str = "default") -> list[tuple[str, list[tuple]]]:
"""
Does initial alignment grouping:
1. Make groups of MDM, MIM od MSM.
2. In remaining operations, make groups of Ms, groups of Ts, and D/I/Ss.
Do not group what was on the sides of M[DIS]M: SSMDMS -> [SS, MDM, S], not [MDM, SSS].
3. Sort groups by the order in which they appear in the alignment.
"""
if mode == "new":
op_groups = []
# Format operation types sequence as string to use regex sequence search
all_ops_seq = "".join([op[0][0] for op in alignment.align_seq])
# Find M[DIS]M groups and merge (need them to detect hyphen vs. space spelling)
ungrouped_ids = list(range(len(alignment.align_seq)))
for match in re.finditer("M[DIS]M", all_ops_seq):
start, end = match.start(), match.end()
op_groups.append(("MSM", alignment.align_seq[start:end]))
for idx in range(start, end):
ungrouped_ids.remove(idx)
# Group remaining operations by default rules (groups of M, T and rest)
if ungrouped_ids:
def get_group_type(operation):
return operation if operation in {"M", "T"} else "DIS"
curr_group = [alignment.align_seq[ungrouped_ids[0]]]
last_oper_type = get_group_type(curr_group[0][0][0])
for i, idx in enumerate(ungrouped_ids[1:], start=1):
operation = alignment.align_seq[idx]
oper_type = get_group_type(operation[0][0])
if (oper_type == last_oper_type and
(idx - ungrouped_ids[i-1] == 1 or oper_type in {"M", "T"})):
curr_group.append(operation)
else:
op_groups.append((last_oper_type, curr_group))
curr_group = [operation]
last_oper_type = oper_type
if curr_group:
op_groups.append((last_oper_type, curr_group))
# Sort groups by the start id of the first group entry
op_groups = sorted(op_groups, key=lambda x: x[1][0][1])
else:
grouped = itertools.groupby(alignment.align_seq,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False)
op_groups = [(op, list(group)) for op, group in grouped]
return op_groups
def process_seq(seq: list[tuple], alignment: Alignment) -> list[tuple]:
"""Applies merging rules to previously formed alignment groups (`seq`)."""
# Return single alignments
if len(seq) <= 1:
return seq
# Get the ops for the whole sequence
ops = [op[0] for op in seq]
# Get indices of all start-end combinations in the seq: 012 = 01, 02, 12
combos = list(itertools.combinations(range(0, len(seq)), 2))
# Sort them starting with largest spans first
combos.sort(key=lambda x: x[1] - x[0], reverse=True)
# Loop through combos
for start, end in combos:
# Ignore ranges that do NOT contain a substitution, deletion or insertion.
if not any(type_ in ops[start:end + 1] for type_ in ["D", "I", "S"]):
continue
# Merge all D xor I ops. (95% of human multi-token edits contain S).
if set(ops[start:end + 1]) == {"D"} or set(ops[start:end + 1]) == {"I"}:
return (process_seq(seq[:start], alignment)
+ merge_edits(seq[start:end + 1])
+ process_seq(seq[end + 1:], alignment))
# Get the tokens in orig and cor.
o = alignment.orig[seq[start][1]:seq[end][2]]
c = alignment.cor[seq[start][3]:seq[end][4]]
if ops[start:end + 1] in [["M", "D", "M"], ["M", "I", "M"], ["M", "S", "M"]]:
# merge hyphens
if (o[start + 1].text == "-" or c[start + 1].text == "-") and len(o) != len(c):
return (process_seq(seq[:start], alignment)
+ merge_edits(seq[start:end + 1])
+ process_seq(seq[end + 1:], alignment))
# if it is not a hyphen-space edit, return only punct edit
return seq[start + 1: end]
# Merge possessive suffixes: [friends -> friend 's]
if o[-1].tag_ == "POS" or c[-1].tag_ == "POS":
return (process_seq(seq[:end - 1], alignment)
+ merge_edits(seq[end - 1:end + 1])
+ process_seq(seq[end + 1:], alignment))
# Case changes
if o[-1].lower == c[-1].lower:
# Merge first token I or D: [Cat -> The big cat]
if (start == 0 and
(len(o) == 1 and c[0].text[0].isupper()) or
(len(c) == 1 and o[0].text[0].isupper())):
return (merge_edits(seq[start:end + 1])
+ process_seq(seq[end + 1:], alignment))
# Merge with previous punctuation: [, we -> . We], [we -> . We]
if (len(o) > 1 and is_punct(o[-2])) or \
(len(c) > 1 and is_punct(c[-2])):
return (process_seq(seq[:end - 1], alignment)
+ merge_edits(seq[end - 1:end + 1])
+ process_seq(seq[end + 1:], alignment))
# Merge whitespace/hyphens: [acat -> a cat], [sub - way -> subway]
s_str = re.sub("['-]", "", "".join([tok.lower_ for tok in o]))
t_str = re.sub("['-]", "", "".join([tok.lower_ for tok in c]))
if s_str == t_str or s_str.replace(" ", "") == t_str.replace(" ", ""):
return (process_seq(seq[:start], alignment)
+ merge_edits(seq[start:end + 1])
+ process_seq(seq[end + 1:], alignment))
# Merge same POS or auxiliary/infinitive/phrasal verbs:
# [to eat -> eating], [watch -> look at]
pos_set = set([tok.pos for tok in o] + [tok.pos for tok in c])
if len(o) != len(c) and (len(pos_set) == 1 or pos_set.issubset({"AUX", "PART", "VERB"})):
return (process_seq(seq[:start], alignment)
+ merge_edits(seq[start:end + 1])
+ process_seq(seq[end + 1:], alignment))
# Split rules take effect when we get to smallest chunks
if end - start < 2:
# Split adjacent substitutions
if len(o) == len(c) == 2:
return (process_seq(seq[:start + 1], alignment)
+ process_seq(seq[start + 1:], alignment))
# Split similar substitutions at sequence boundaries
if ((ops[start] == "S" and char_cost(o[0].text, c[0].text) > 0.75) or
(ops[end] == "S" and char_cost(o[-1].text, c[-1].text) > 0.75)):
return (process_seq(seq[:start + 1], alignment)
+ process_seq(seq[start + 1:], alignment))
# Split final determiners
if (end == len(seq) - 1 and
((ops[-1] in {"D", "S"} and o[-1].pos == "DET") or
(ops[-1] in {"I", "S"} and c[-1].pos == "DET"))):
return process_seq(seq[:-1], alignment) + [seq[-1]]
return seq
def is_punct(token) -> bool:
return token.text in punctuation
def char_cost(a: str, b: str) -> float:
"""Calculate the cost of character alignment; i.e. char similarity."""
return Levenshtein.ratio(a, b)
def merge_edits(seq: list[tuple]) -> list[tuple]:
"""Merge the input alignment sequence to a single edit span."""
if seq:
return [("X", seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
return seq