|
import datasets |
|
import re |
|
import string |
|
|
|
def normalize_answer(s): |
|
"""Lower text and remove punctuation, articles and extra whitespace.""" |
|
|
|
def remove_articles(text): |
|
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) |
|
return re.sub(regex, " ", text) |
|
|
|
def white_space_fix(text): |
|
return " ".join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
def compute_exact(a_gold, a_pred): |
|
return int(normalize_answer(a_gold) == normalize_answer(a_pred)) |
|
|
|
def compute_em(predictions, references): |
|
scores = [compute_exact(ref, pred) for pred, ref in zip(predictions, references)] |
|
return sum(scores)/len(scores) |
|
|
|
class ExactMatch(datasets.Metric): |
|
def _info(self): |
|
return datasets.MetricInfo( |
|
description="This will get effective exact match in text data", |
|
citation="", |
|
homepage="", |
|
inputs_description="", |
|
features=datasets.Features({ |
|
'predictions': datasets.Value('string'), |
|
'references': datasets.Value('string'), |
|
}), |
|
codebase_urls=["https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py"], |
|
reference_urls=["https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py"] |
|
) |
|
|
|
def _compute(self, predictions, references): |
|
return {"exact_match": compute_em(predictions, references)} |