danasone commited on
Commit
a0b78f4
1 Parent(s): b037807
Files changed (4) hide show
  1. classifier.py +147 -0
  2. merger.py +181 -0
  3. requirements.txt +4 -1
  4. ru_errant.py +117 -18
classifier.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from string import punctuation
5
+
6
+ import Levenshtein
7
+ from errant.edit import Edit
8
+
9
+
10
+ def edit_to_tuple(edit: Edit, idx: int = 0) -> tuple[int, int, str, str, int]:
11
+ cor_toks_str = " ".join([tok.text for tok in edit.c_toks])
12
+ return [edit.o_start, edit.o_end, edit.type, cor_toks_str, idx]
13
+
14
+
15
+ def classify(edit: Edit) -> list[Edit]:
16
+ """Classifies an Edit via updating its `type` attribute."""
17
+ # Insertion and deletion
18
+ if ((not edit.o_toks and edit.c_toks) or (edit.o_toks and not edit.c_toks)):
19
+ error_cats = get_one_sided_type(edit.o_toks, edit.c_toks)
20
+ elif edit.o_toks != edit.c_toks:
21
+ error_cats = get_two_sided_type(edit.o_toks, edit.c_toks)
22
+ else:
23
+ error_cats = {"NA": edit.c_toks[0].text}
24
+ new_edit_list = []
25
+ if error_cats:
26
+ for error_cat, correct_str in error_cats.items():
27
+ edit.type = error_cat
28
+ edit_tuple = edit_to_tuple(edit)
29
+ edit_tuple[3] = correct_str
30
+ new_edit_list.append(edit_tuple)
31
+ return new_edit_list
32
+
33
+
34
+ def get_edit_info(toks):
35
+ pos = []
36
+ dep = []
37
+ morph = dict()
38
+ for tok in toks:
39
+ pos.append(tok.tag_)
40
+ dep.append(tok.dep_)
41
+ morphs = str(tok.morph).split('|')
42
+ for m in morphs:
43
+ if len(m.strip()):
44
+ k, v = m.strip().split('=')
45
+ morph[k] = v
46
+ return pos, dep, morph
47
+
48
+
49
+ def get_one_sided_type(o_toks, c_toks):
50
+ """Classifies a zero-to-one or one-to-zero error based on a token list."""
51
+ pos_list, _, _ = get_edit_info(o_toks if o_toks else c_toks)
52
+ if "PUNCT" in pos_list or "SPACE" in pos_list:
53
+ return {"PUNCT": c_toks[0].text if c_toks else ""}
54
+ return {"SPELL": c_toks[0].text if c_toks else ""}
55
+
56
+
57
+ def get_two_sided_type(o_toks, c_toks) -> dict[str, str]:
58
+ """Classifies a one-to-one or one-to-many or many-to-one error based on token lists."""
59
+ # one-to-one cases
60
+ if len(o_toks) == len(c_toks) == 1:
61
+ if (
62
+ all(char in punctuation + " " for char in o_toks[0].text) and
63
+ all(char in punctuation + " " for char in c_toks[0].text)
64
+ ):
65
+ return {"PUNCT": c_toks[0].text}
66
+ source_w, correct_w = o_toks[0].text, c_toks[0].text
67
+ if source_w != correct_w:
68
+ # if both string are lowercase or both are uppercase,
69
+ # and there is no "ё" in both, then it may be only "SPELL" error type
70
+ if (((source_w.islower() and correct_w.islower()) or
71
+ (source_w.isupper() and correct_w.isupper())) and
72
+ "ё" not in source_w + correct_w):
73
+ return {"SPELL": correct_w}
74
+ # edits with multiple errors (e.g. SPELL + CASE)
75
+ # Step 1. Make char-level Levenstein table
76
+ char_edits = Levenshtein.editops(source_w, correct_w)
77
+ # Step 2. Classify operations (CASE, YO, SPELL)
78
+ edits_classified = classify_char_edits(char_edits, source_w, correct_w)
79
+ # Step 3. Combine the same-typed errors into minimal string pairs
80
+ separated_edits = get_edit_strings(source_w, correct_w, edits_classified)
81
+ return separated_edits
82
+ # one-to-many and many-to-one cases
83
+ if all(char in punctuation + " " for char in o_toks.text + c_toks.text):
84
+ return {"PUNCT": c_toks.text}
85
+ joint_corr_str = " ".join([tok.text for tok in c_toks])
86
+ joint_corr_str = joint_corr_str.replace("- ", "-").replace(" -", "-")
87
+ return {"SPELL": joint_corr_str}
88
+
89
+
90
+ def classify_char_edits(char_edits, source_w, correct_w):
91
+ """Classifies char-level Levenstein operations into SPELL, YO and CASE."""
92
+ edits_classified = []
93
+ for edit in char_edits:
94
+ if edit[0] == "replace":
95
+ if "ё" in [source_w[edit[1]], correct_w[edit[2]]]:
96
+ edits_classified.append((*edit, "YO"))
97
+ elif source_w[edit[1]].lower() == correct_w[edit[2]].lower():
98
+ edits_classified.append((*edit, "CASE"))
99
+ else:
100
+ if (
101
+ (source_w[edit[1]].islower() and correct_w[edit[2]].isupper()) or
102
+ (source_w[edit[1]].isupper() and correct_w[edit[2]].islower())
103
+ ):
104
+ edits_classified.append((*edit, "CASE"))
105
+ edits_classified.append((*edit, "SPELL"))
106
+ else:
107
+ edits_classified.append((*edit, "SPELL"))
108
+ return edits_classified
109
+
110
+
111
+ def get_edit_strings(source: str, correction: str,
112
+ edits_classified: list[tuple]) -> dict[str, str]:
113
+ """
114
+ Applies classified (SPELL, YO and CASE) char operations to source word separately.
115
+ Returns a dict mapping error type to source string with corrections of this type only.
116
+ """
117
+ separated_edits = defaultdict(lambda: source)
118
+ shift = 0 # char position shift to consider on deletions and insertions
119
+ for edit in edits_classified:
120
+ edit_type = edit[3]
121
+ curr_src = separated_edits[edit_type]
122
+ if edit_type == "CASE": # SOURCE letter spelled in CORRECTION case
123
+ if correction[edit[2]].isupper():
124
+ correction_char = source[edit[1]].upper()
125
+ else:
126
+ correction_char = source[edit[1]].lower()
127
+ else:
128
+ if edit[0] == "delete":
129
+ correction_char = ""
130
+ elif edit[0] == "insert":
131
+ correction_char = correction[edit[2]]
132
+ elif source[edit[1]].isupper():
133
+ correction_char = correction[edit[2]].upper()
134
+ else:
135
+ correction_char = correction[edit[2]].lower()
136
+ if edit[0] == "replace":
137
+ separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
138
+ curr_src[edit[1]+shift + 1:]
139
+ elif edit[0] == "delete":
140
+ separated_edits[edit_type] = curr_src[:edit[1] + shift] + \
141
+ curr_src[edit[1]+shift + 1:]
142
+ shift -= 1
143
+ elif edit[0] == "insert":
144
+ separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
145
+ curr_src[edit[1]+shift:]
146
+ shift += 1
147
+ return dict(separated_edits)
merger.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import re
5
+ from string import punctuation
6
+
7
+ import Levenshtein
8
+ from errant.alignment import Alignment
9
+ from errant.edit import Edit
10
+
11
+
12
+ def get_rule_edits(alignment: Alignment) -> list[Edit]:
13
+ """Groups word-level alignment according to merging rules."""
14
+ edits = []
15
+ # Split alignment into groups
16
+ alignment_groups = group_alignment(alignment, "new")
17
+ for op, group in alignment_groups:
18
+ group = list(group)
19
+ # Ignore M
20
+ if op == "M":
21
+ continue
22
+ # T is always split
23
+ if op == "T":
24
+ for seq in group:
25
+ edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
26
+ # Process D, I and S subsequence
27
+ else:
28
+ processed = process_seq(group, alignment)
29
+ # Turn the processed sequence into edits
30
+ for seq in processed:
31
+ edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
32
+ return edits
33
+
34
+
35
+ def group_alignment(alignment: Alignment, mode: str = "default") -> list[tuple[str, list[tuple]]]:
36
+ """
37
+ Does initial alignment grouping:
38
+ 1. Make groups of MDM, MIM od MSM.
39
+ 2. In remaining operations, make groups of Ms, groups of Ts, and D/I/Ss.
40
+ Do not group what was on the sides of M[DIS]M: SSMDMS -> [SS, MDM, S], not [MDM, SSS].
41
+ 3. Sort groups by the order in which they appear in the alignment.
42
+ """
43
+ if mode == "new":
44
+ op_groups = []
45
+ # Format operation types sequence as string to use regex sequence search
46
+ all_ops_seq = "".join([op[0][0] for op in alignment.align_seq])
47
+ # Find M[DIS]M groups and merge (need them to detect hyphen vs. space spelling)
48
+ ungrouped_ids = list(range(len(alignment.align_seq)))
49
+ for match in re.finditer("M[DIS]M", all_ops_seq):
50
+ start, end = match.start(), match.end()
51
+ op_groups.append(("MSM", alignment.align_seq[start:end]))
52
+ for idx in range(start, end):
53
+ ungrouped_ids.remove(idx)
54
+ # Group remaining operations by default rules (groups of M, T and rest)
55
+ if ungrouped_ids:
56
+ def get_group_type(operation):
57
+ return operation if operation in {"M", "T"} else "DIS"
58
+ curr_group = [alignment.align_seq[ungrouped_ids[0]]]
59
+ last_oper_type = get_group_type(curr_group[0][0][0])
60
+ for i, idx in enumerate(ungrouped_ids[1:], start=1):
61
+ operation = alignment.align_seq[idx]
62
+ oper_type = get_group_type(operation[0][0])
63
+ if (oper_type == last_oper_type and
64
+ (idx - ungrouped_ids[i-1] == 1 or oper_type in {"M", "T"})):
65
+ curr_group.append(operation)
66
+ else:
67
+ op_groups.append((last_oper_type, curr_group))
68
+ curr_group = [operation]
69
+ last_oper_type = oper_type
70
+ if curr_group:
71
+ op_groups.append((last_oper_type, curr_group))
72
+ # Sort groups by the start id of the first group entry
73
+ op_groups = sorted(op_groups, key=lambda x: x[1][0][1])
74
+ else:
75
+ grouped = itertools.groupby(alignment.align_seq,
76
+ lambda x: x[0][0] if x[0][0] in {"M", "T"} else False)
77
+ op_groups = [(op, list(group)) for op, group in grouped]
78
+ return op_groups
79
+
80
+
81
+ def process_seq(seq: list[tuple], alignment: Alignment) -> list[tuple]:
82
+ """Applies merging rules to previously formed alignment groups (`seq`)."""
83
+ # Return single alignments
84
+ if len(seq) <= 1:
85
+ return seq
86
+ # Get the ops for the whole sequence
87
+ ops = [op[0] for op in seq]
88
+
89
+ # Get indices of all start-end combinations in the seq: 012 = 01, 02, 12
90
+ combos = list(itertools.combinations(range(0, len(seq)), 2))
91
+ # Sort them starting with largest spans first
92
+ combos.sort(key=lambda x: x[1] - x[0], reverse=True)
93
+ # Loop through combos
94
+ for start, end in combos:
95
+ # Ignore ranges that do NOT contain a substitution, deletion or insertion.
96
+ if not any(type_ in ops[start:end + 1] for type_ in ["D", "I", "S"]):
97
+ continue
98
+ # Merge all D xor I ops. (95% of human multi-token edits contain S).
99
+ if set(ops[start:end + 1]) == {"D"} or set(ops[start:end + 1]) == {"I"}:
100
+ return (process_seq(seq[:start], alignment)
101
+ + merge_edits(seq[start:end + 1])
102
+ + process_seq(seq[end + 1:], alignment))
103
+ # Get the tokens in orig and cor.
104
+ o = alignment.orig[seq[start][1]:seq[end][2]]
105
+ c = alignment.cor[seq[start][3]:seq[end][4]]
106
+ if ops[start:end + 1] in [["M", "D", "M"], ["M", "I", "M"], ["M", "S", "M"]]:
107
+ # merge hyphens
108
+ if (o[start + 1].text == "-" or c[start + 1].text == "-") and len(o) != len(c):
109
+ return (process_seq(seq[:start], alignment)
110
+ + merge_edits(seq[start:end + 1])
111
+ + process_seq(seq[end + 1:], alignment))
112
+ # if it is not a hyphen-space edit, return only punct edit
113
+ return seq[start + 1: end]
114
+ # Merge possessive suffixes: [friends -> friend 's]
115
+ if o[-1].tag_ == "POS" or c[-1].tag_ == "POS":
116
+ return (process_seq(seq[:end - 1], alignment)
117
+ + merge_edits(seq[end - 1:end + 1])
118
+ + process_seq(seq[end + 1:], alignment))
119
+ # Case changes
120
+ if o[-1].lower == c[-1].lower:
121
+ # Merge first token I or D: [Cat -> The big cat]
122
+ if (start == 0 and
123
+ (len(o) == 1 and c[0].text[0].isupper()) or
124
+ (len(c) == 1 and o[0].text[0].isupper())):
125
+ return (merge_edits(seq[start:end + 1])
126
+ + process_seq(seq[end + 1:], alignment))
127
+ # Merge with previous punctuation: [, we -> . We], [we -> . We]
128
+ if (len(o) > 1 and is_punct(o[-2])) or \
129
+ (len(c) > 1 and is_punct(c[-2])):
130
+ return (process_seq(seq[:end - 1], alignment)
131
+ + merge_edits(seq[end - 1:end + 1])
132
+ + process_seq(seq[end + 1:], alignment))
133
+ # Merge whitespace/hyphens: [acat -> a cat], [sub - way -> subway]
134
+ s_str = re.sub("['-]", "", "".join([tok.lower_ for tok in o]))
135
+ t_str = re.sub("['-]", "", "".join([tok.lower_ for tok in c]))
136
+ if s_str == t_str or s_str.replace(" ", "") == t_str.replace(" ", ""):
137
+ return (process_seq(seq[:start], alignment)
138
+ + merge_edits(seq[start:end + 1])
139
+ + process_seq(seq[end + 1:], alignment))
140
+ # Merge same POS or auxiliary/infinitive/phrasal verbs:
141
+ # [to eat -> eating], [watch -> look at]
142
+ pos_set = set([tok.pos for tok in o] + [tok.pos for tok in c])
143
+ if len(o) != len(c) and (len(pos_set) == 1 or pos_set.issubset({"AUX", "PART", "VERB"})):
144
+ return (process_seq(seq[:start], alignment)
145
+ + merge_edits(seq[start:end + 1])
146
+ + process_seq(seq[end + 1:], alignment))
147
+ # Split rules take effect when we get to smallest chunks
148
+ if end - start < 2:
149
+ # Split adjacent substitutions
150
+ if len(o) == len(c) == 2:
151
+ return (process_seq(seq[:start + 1], alignment)
152
+ + process_seq(seq[start + 1:], alignment))
153
+ # Split similar substitutions at sequence boundaries
154
+ if ((ops[start] == "S" and char_cost(o[0].text, c[0].text) > 0.75) or
155
+ (ops[end] == "S" and char_cost(o[-1].text, c[-1].text) > 0.75)):
156
+ return (process_seq(seq[:start + 1], alignment)
157
+ + process_seq(seq[start + 1:], alignment))
158
+ # Split final determiners
159
+ if (end == len(seq) - 1 and
160
+ ((ops[-1] in {"D", "S"} and o[-1].pos == "DET") or
161
+ (ops[-1] in {"I", "S"} and c[-1].pos == "DET"))):
162
+ return process_seq(seq[:-1], alignment) + [seq[-1]]
163
+ return seq
164
+
165
+
166
+ def is_punct(token) -> bool:
167
+ return token.text in punctuation
168
+
169
+
170
+ def char_cost(a: str, b: str) -> float:
171
+ """Calculate the cost of character alignment; i.e. char similarity."""
172
+
173
+ return Levenshtein.ratio(a, b)
174
+
175
+
176
+ def merge_edits(seq: list[tuple]) -> list[tuple]:
177
+ """Merge the input alignment sequence to a single edit span."""
178
+
179
+ if seq:
180
+ return [("X", seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
181
+ return seq
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- git+https://github.com/huggingface/evaluate@main
 
 
 
 
1
+ git+https://github.com/huggingface/evaluate@main
2
+ git+https://github.com/Askinkaty/errant/@4183e57
3
+ Levenshtein
4
+ ru-core-news-lg @ https://huggingface.co/spacy/ru_core_news_lg/resolve/main/ru_core_news_lg-any-py3-none-any.whl
ru_errant.py CHANGED
@@ -12,11 +12,26 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  """TODO: Add a description here."""
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import evaluate
17
  import datasets
18
 
19
-
20
  # TODO: Add BibTeX citation
21
  _CITATION = """\
22
  @InProceedings{huggingface:module,
@@ -31,7 +46,6 @@ _DESCRIPTION = """\
31
  This new module is designed to solve this great ML task and is crafted with a lot of care.
32
  """
33
 
34
-
35
  # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
  Calculates how good are predictions given some references, using certain scores
@@ -57,6 +71,40 @@ Examples:
57
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class RuErrant(evaluate.Metric):
62
  """TODO: Short description of my evaluation module."""
@@ -70,26 +118,77 @@ class RuErrant(evaluate.Metric):
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
 
 
 
77
  # Homepage of the module for documentation
78
  homepage="http://module.homepage",
79
  # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
  reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
  def _download_and_prepare(self, dl_manager):
85
- """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
88
-
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
+ from __future__ import annotations
16
 
17
+ import re
18
+ from collections import Counter, namedtuple
19
+ from typing import Iterable
20
+ from tqdm.auto import tqdm
21
+
22
+ from errant.annotator import Annotator
23
+ from errant.commands.compare_m2 import process_edits
24
+ from errant.commands.compare_m2 import evaluate_edits
25
+ from errant.commands.compare_m2 import merge_dict
26
+ from errant.edit import Edit
27
+ import spacy
28
+ from spacy.tokenizer import Tokenizer
29
+ from spacy.util import compile_prefix_regex, compile_infix_regex, compile_suffix_regex
30
+ import classifier
31
+ import merger
32
  import evaluate
33
  import datasets
34
 
 
35
  # TODO: Add BibTeX citation
36
  _CITATION = """\
37
  @InProceedings{huggingface:module,
 
46
  This new module is designed to solve this great ML task and is crafted with a lot of care.
47
  """
48
 
 
49
  # TODO: Add description of the arguments of the module here
50
  _KWARGS_DESCRIPTION = """
51
  Calculates how good are predictions given some references, using certain scores
 
71
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
72
 
73
 
74
+ def update_spacy_tokenizer(nlp):
75
+ """
76
+ Changes Spacy tokenizer to parse additional patterns.
77
+ """
78
+ infix_re = compile_infix_regex(nlp.Defaults.infixes[:-1] + ["\]\("])
79
+ simple_url_re = re.compile(r'''^https?://''')
80
+ nlp.tokenizer = Tokenizer(
81
+ nlp.vocab,
82
+ prefix_search=compile_prefix_regex(nlp.Defaults.prefixes + ['\\\\\"']).search,
83
+ suffix_search=compile_suffix_regex(nlp.Defaults.suffixes + ['\\\\']).search,
84
+ infix_finditer=infix_re.finditer,
85
+ token_match=None,
86
+ url_match=simple_url_re.match
87
+ )
88
+ return nlp
89
+
90
+
91
+ def annotate_errors(self, orig: str, cor: str, merging: str = "rules") -> list[Edit]:
92
+ """
93
+ Overrides `Annotator.annotate()` function to allow multiple errors per token.
94
+ This is nesessary to parse combined errors, e.g.:
95
+ ["werd", "Word"] >>> Errors: ["SPELL", "CASE"]
96
+ The `classify()` method called inside is implemented in ruerrant_classifier.py
97
+ (also overrides the original classifier).
98
+ """
99
+
100
+ alignment = self.annotator.align(orig, cor, False)
101
+ edits = self.annotator.merge(alignment, merging)
102
+ classified_edits = []
103
+ for edit in edits:
104
+ classified_edits.extend(self.annotator.classify(edit))
105
+ return sorted(classified_edits, key=lambda x: (x[0], x[2]))
106
+
107
+
108
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
109
  class RuErrant(evaluate.Metric):
110
  """TODO: Short description of my evaluation module."""
 
118
  citation=_CITATION,
119
  inputs_description=_KWARGS_DESCRIPTION,
120
  # This defines the format of each prediction and reference
121
+ features=datasets.Features(
122
+ {
123
+ "sources": datasets.Value("string", id="sequence"),
124
+ "corrections": datasets.Value("string", id="sequence"),
125
+ "answers": datasets.Value("string", id="sequence"),
126
+ }
127
+ ),
128
  # Homepage of the module for documentation
129
  homepage="http://module.homepage",
130
  # Additional links to the codebase or references
131
+ codebase_urls=["https://github.com/ai-forever/sage"],
132
  reference_urls=["http://path.to.reference.url/new_module"]
133
  )
134
 
135
  def _download_and_prepare(self, dl_manager):
136
+ self.annotator = Annotator("ru",
137
+ nlp=update_spacy_tokenizer(spacy.load("ru_core_news_lg")),
138
+ merger=merger,
139
+ classifier=classifier)
140
+
141
+ def _compute(self, sources, corrections, answers):
142
+ """
143
+ Evaluates iterables of sources, hyp and ref corrections with ERRANT metric.
144
+
145
+ Args:
146
+ sources (Iterable[str]): an iterable of source texts;
147
+ corrections (Iterable[str]): an iterable of gold corrections for the source texts;
148
+ answers (Iterable[str]): an iterable of evaluated corrections for the source texts;
149
+
150
+ Returns:
151
+ dict[str, tuple[float, ...]]: a dict mapping error categories to the corresponding
152
+ P, R, F1 metric values.
153
+ """
154
+ best_dict = Counter({"tp": 0, "fp": 0, "fn": 0})
155
+ best_cats = {}
156
+ sents = zip(sources, corrections, answers)
157
+ pb = tqdm(sents, desc="Calculating errant metric", total=len(sources))
158
+ for sent_id, sent in enumerate(pb):
159
+ src = self.annotator.parse(sent[0])
160
+ ref = self.annotator.parse(sent[1])
161
+ hyp = self.annotator.parse(sent[2])
162
+ # Align hyp and ref corrections and annotate errors
163
+ hyp_edits = self.annotate_errors(src, hyp)
164
+ ref_edits = self.annotate_errors(src, ref)
165
+ # Process the edits for detection/correction based on args
166
+ ProcessingArgs = namedtuple("ProcessingArgs",
167
+ ["dt", "ds", "single", "multi", "filt", "cse"],
168
+ defaults=[False, False, False, False, [], True])
169
+ processing_args = ProcessingArgs()
170
+ hyp_dict = process_edits(hyp_edits, processing_args)
171
+ ref_dict = process_edits(ref_edits, processing_args)
172
+ # Evaluate edits and get best TP, FP, FN hyp+ref combo.
173
+ EvaluationArgs = namedtuple("EvaluationArgs",
174
+ ["beta", "verbose"],
175
+ defaults=[1.0, False])
176
+ evaluation_args = EvaluationArgs()
177
+ count_dict, cat_dict = evaluate_edits(
178
+ hyp_dict, ref_dict, best_dict, sent_id, evaluation_args)
179
+ # Merge these dicts with best_dict and best_cats
180
+ best_dict += Counter(count_dict) # corpus-level TP, FP, FN
181
+ best_cats = merge_dict(best_cats, cat_dict) # corpus-level errortype-wise TP, FP, FN
182
+ cat_prf = {}
183
+ for cat, values in best_cats.items():
184
+ tp, fp, fn = values # fp - extra corrections, fn - missed corrections
185
+ p = float(tp) / (tp + fp) if tp + fp else 1.0
186
+ r = float(tp) / (tp + fn) if tp + fn else 1.0
187
+ f = (2 * p * r) / (p + r) if p + r else 0.0
188
+ cat_prf[cat] = (p, r, f)
189
+
190
+ for error_category in ["CASE", "PUNCT", "SPELL", "YO"]:
191
+ if error_category not in cat_prf:
192
+ cat_prf[error_category] = (1.0, 1.0, 1.0)
193
+
194
+ return cat_prf