bhadresh-savani's picture
added training files
2303827
raw
history blame
No virus
1.68 kB
import datasets
import re
import string
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_em(predictions, references):
scores = [compute_exact(ref, pred) for pred, ref in zip(predictions, references)]
return sum(scores)/len(scores)
class ExactMatch(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description="This will get effective exact match in text data",
citation="",
homepage="",
inputs_description="",
features=datasets.Features({
'predictions': datasets.Value('string'),
'references': datasets.Value('string'),
}),
codebase_urls=["https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py"],
reference_urls=["https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py"]
)
def _compute(self, predictions, references):
return {"exact_match": compute_em(predictions, references)}