# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TODO: Add a description here.""" from __future__ import annotations import re from collections import Counter, namedtuple from typing import Iterable from tqdm.auto import tqdm from sage.evaluation.ruerrant_wrapper import classifier from sage.evaluation.ruerrant_wrapper import merger from errant.annotator import Annotator from errant.commands.compare_m2 import process_edits from errant.commands.compare_m2 import evaluate_edits from errant.commands.compare_m2 import merge_dict from errant.edit import Edit import spacy from spacy.tokenizer import Tokenizer from spacy.util import compile_prefix_regex, compile_infix_regex, compile_suffix_regex import evaluate import datasets # TODO: Add BibTeX citation _CITATION = """\ @InProceedings{huggingface:module, title = {A great new module}, authors={huggingface, Inc.}, year={2020} } """ # TODO: Add description of the module here _DESCRIPTION = """\ This new module is designed to solve this great ML task and is crafted with a lot of care. """ # TODO: Add description of the arguments of the module here _KWARGS_DESCRIPTION = """ Calculates how good are predictions given some references, using certain scores Args: predictions: list of predictions to score. Each predictions should be a string with tokens separated by spaces. references: list of reference for each prediction. Each reference should be a string with tokens separated by spaces. Returns: accuracy: description of the first score, another_score: description of the second score, Examples: Examples should be written in doctest format, and should illustrate how to use the function. >>> my_new_module = evaluate.load("my_new_module") >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1]) >>> print(results) {'accuracy': 1.0} """ # TODO: Define external resources urls if needed BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt" def update_spacy_tokenizer(nlp): """ Changes Spacy tokenizer to parse additional patterns. """ infix_re = compile_infix_regex(nlp.Defaults.infixes[:-1] + ["\]\("]) simple_url_re = re.compile(r'''^https?://''') nlp.tokenizer = Tokenizer( nlp.vocab, prefix_search=compile_prefix_regex(nlp.Defaults.prefixes + ['\\\\\"']).search, suffix_search=compile_suffix_regex(nlp.Defaults.suffixes + ['\\\\']).search, infix_finditer=infix_re.finditer, token_match=None, url_match=simple_url_re.match ) return nlp @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class RuErrant(evaluate.Metric): """TODO: Short description of my evaluation module.""" def _info(self): # TODO: Specifies the evaluate.EvaluationModuleInfo object return evaluate.MetricInfo( # This is the description that will appear on the modules page. module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, # This defines the format of each prediction and reference features=datasets.Features( { "sources": datasets.Value("string", id="sequence"), "corrections": datasets.Value("string", id="sequence"), "answers": datasets.Value("string", id="sequence"), } ), # Homepage of the module for documentation homepage="http://module.homepage", # Additional links to the codebase or references codebase_urls=["https://github.com/ai-forever/sage"], reference_urls=["http://path.to.reference.url/new_module"] ) def _download_and_prepare(self, dl_manager): self.annotator = Annotator("ru", nlp=update_spacy_tokenizer(spacy.load("ru_core_news_lg")), merger=merger, classifier=classifier) def annotate_errors(self, orig: str, cor: str, merging: str = "rules") -> list[Edit]: """ Overrides `Annotator.annotate()` function to allow multiple errors per token. This is nesessary to parse combined errors, e.g.: ["werd", "Word"] >>> Errors: ["SPELL", "CASE"] The `classify()` method called inside is implemented in ruerrant_classifier.py (also overrides the original classifier). """ alignment = self.annotator.align(orig, cor, False) edits = self.annotator.merge(alignment, merging) classified_edits = [] for edit in edits: classified_edits.extend(self.annotator.classify(edit)) return sorted(classified_edits, key=lambda x: (x[0], x[2])) def _compute(self, sources, corrections, answers): """ Evaluates iterables of sources, hyp and ref corrections with ERRANT metric. Args: sources (Iterable[str]): an iterable of source texts; corrections (Iterable[str]): an iterable of gold corrections for the source texts; answers (Iterable[str]): an iterable of evaluated corrections for the source texts; Returns: dict[str, tuple[float, ...]]: a dict mapping error categories to the corresponding P, R, F1 metric values. """ best_dict = Counter({"tp": 0, "fp": 0, "fn": 0}) best_cats = {} sents = zip(sources, corrections, answers) for sent_id, sent in enumerate(sents): src = self.annotator.parse(sent[0]) ref = self.annotator.parse(sent[1]) hyp = self.annotator.parse(sent[2]) # Align hyp and ref corrections and annotate errors hyp_edits = self.annotate_errors(src, hyp) ref_edits = self.annotate_errors(src, ref) # Process the edits for detection/correction based on args ProcessingArgs = namedtuple("ProcessingArgs", ["dt", "ds", "single", "multi", "filt", "cse"], defaults=[False, False, False, False, [], True]) processing_args = ProcessingArgs() hyp_dict = process_edits(hyp_edits, processing_args) ref_dict = process_edits(ref_edits, processing_args) # Evaluate edits and get best TP, FP, FN hyp+ref combo. EvaluationArgs = namedtuple("EvaluationArgs", ["beta", "verbose"], defaults=[1.0, False]) evaluation_args = EvaluationArgs() count_dict, cat_dict = evaluate_edits( hyp_dict, ref_dict, best_dict, sent_id, evaluation_args) # Merge these dicts with best_dict and best_cats best_dict += Counter(count_dict) # corpus-level TP, FP, FN best_cats = merge_dict(best_cats, cat_dict) # corpus-level errortype-wise TP, FP, FN cat_prf = {} for cat, values in best_cats.items(): tp, fp, fn = values # fp - extra corrections, fn - missed corrections p = float(tp) / (tp + fp) if tp + fp else 1.0 r = float(tp) / (tp + fn) if tp + fn else 1.0 f = (2 * p * r) / (p + r) if p + r else 0.0 cat_prf[cat] = (p, r, f) for error_category in ["CASE", "PUNCT", "SPELL", "YO"]: if error_category not in cat_prf: cat_prf[error_category] = (1.0, 1.0, 1.0) return cat_prf