|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" SARI metric.""" |
|
|
|
|
|
from collections import Counter |
|
|
|
|
|
import datasets |
|
|
import sacrebleu |
|
|
import sacremoses |
|
|
from packaging import version |
|
|
|
|
|
import evaluate |
|
|
|
|
|
|
|
|
_CITATION = """\ |
|
|
@inproceedings{xu-etal-2016-optimizing, |
|
|
title = {Optimizing Statistical Machine Translation for Text Simplification}, |
|
|
authors={Xu, Wei and Napoles, Courtney and Pavlick, Ellie and Chen, Quanze and Callison-Burch, Chris}, |
|
|
journal = {Transactions of the Association for Computational Linguistics}, |
|
|
volume = {4}, |
|
|
year={2016}, |
|
|
url = {https://www.aclweb.org/anthology/Q16-1029}, |
|
|
pages = {401--415}, |
|
|
} |
|
|
""" |
|
|
|
|
|
_DESCRIPTION = """\ |
|
|
SARI is a metric used for evaluating automatic text simplification systems. |
|
|
The metric compares the predicted simplified sentences against the reference |
|
|
and the source sentences. It explicitly measures the goodness of words that are |
|
|
added, deleted and kept by the system. |
|
|
Sari = (F1_add + F1_keep + P_del) / 3 |
|
|
where |
|
|
F1_add: n-gram F1 score for add operation |
|
|
F1_keep: n-gram F1 score for keep operation |
|
|
P_del: n-gram precision score for delete operation |
|
|
n = 4, as in the original paper. |
|
|
|
|
|
This implementation is adapted from Tensorflow's tensor2tensor implementation [3]. |
|
|
It has two differences with the original GitHub [1] implementation: |
|
|
(1) Defines 0/0=1 instead of 0 to give higher scores for predictions that match |
|
|
a target exactly. |
|
|
(2) Fixes an alleged bug [2] in the keep score computation. |
|
|
[1] https://github.com/cocoxu/simplification/blob/master/SARI.py |
|
|
(commit 0210f15) |
|
|
[2] https://github.com/cocoxu/simplification/issues/6 |
|
|
[3] https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/sari_hook.py |
|
|
""" |
|
|
|
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
|
Calculates sari score (between 0 and 100) given a list of source and predicted |
|
|
sentences, and a list of lists of reference sentences. |
|
|
Args: |
|
|
sources: list of source sentences where each sentence should be a string. |
|
|
predictions: list of predicted sentences where each sentence should be a string. |
|
|
references: list of lists of reference sentences where each sentence should be a string. |
|
|
Returns: |
|
|
sari: sari score |
|
|
Examples: |
|
|
>>> sources=["About 95 species are currently accepted ."] |
|
|
>>> predictions=["About 95 you now get in ."] |
|
|
>>> references=[["About 95 species are currently known .","About 95 species are now accepted .","95 species are now accepted ."]] |
|
|
>>> sari = evaluate.load("sari") |
|
|
>>> results = sari.compute(sources=sources, predictions=predictions, references=references) |
|
|
>>> print(results) |
|
|
{'sari': 26.953601953601954} |
|
|
""" |
|
|
|
|
|
|
|
|
def SARIngram(sgrams, cgrams, rgramslist, numref): |
|
|
rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams] |
|
|
rgramcounter = Counter(rgramsall) |
|
|
|
|
|
sgramcounter = Counter(sgrams) |
|
|
sgramcounter_rep = Counter() |
|
|
for sgram, scount in sgramcounter.items(): |
|
|
sgramcounter_rep[sgram] = scount * numref |
|
|
|
|
|
cgramcounter = Counter(cgrams) |
|
|
cgramcounter_rep = Counter() |
|
|
for cgram, ccount in cgramcounter.items(): |
|
|
cgramcounter_rep[cgram] = ccount * numref |
|
|
|
|
|
|
|
|
keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep |
|
|
keepgramcountergood_rep = keepgramcounter_rep & rgramcounter |
|
|
keepgramcounterall_rep = sgramcounter_rep & rgramcounter |
|
|
|
|
|
keeptmpscore1 = 0 |
|
|
keeptmpscore2 = 0 |
|
|
for keepgram in keepgramcountergood_rep: |
|
|
keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram] |
|
|
|
|
|
|
|
|
keeptmpscore2 += keepgramcountergood_rep[keepgram] |
|
|
|
|
|
|
|
|
keepscore_precision = 1 |
|
|
keepscore_recall = 1 |
|
|
if len(keepgramcounter_rep) > 0: |
|
|
keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep) |
|
|
if len(keepgramcounterall_rep) > 0: |
|
|
|
|
|
|
|
|
keepscore_recall = keeptmpscore2 / sum(keepgramcounterall_rep.values()) |
|
|
keepscore = 0 |
|
|
if keepscore_precision > 0 or keepscore_recall > 0: |
|
|
keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall) |
|
|
|
|
|
|
|
|
delgramcounter_rep = sgramcounter_rep - cgramcounter_rep |
|
|
delgramcountergood_rep = delgramcounter_rep - rgramcounter |
|
|
delgramcounterall_rep = sgramcounter_rep - rgramcounter |
|
|
deltmpscore1 = 0 |
|
|
deltmpscore2 = 0 |
|
|
for delgram in delgramcountergood_rep: |
|
|
deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram] |
|
|
deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram] |
|
|
|
|
|
|
|
|
delscore_precision = 1 |
|
|
if len(delgramcounter_rep) > 0: |
|
|
delscore_precision = deltmpscore1 / len(delgramcounter_rep) |
|
|
|
|
|
|
|
|
addgramcounter = set(cgramcounter) - set(sgramcounter) |
|
|
addgramcountergood = set(addgramcounter) & set(rgramcounter) |
|
|
addgramcounterall = set(rgramcounter) - set(sgramcounter) |
|
|
|
|
|
addtmpscore = 0 |
|
|
for addgram in addgramcountergood: |
|
|
addtmpscore += 1 |
|
|
|
|
|
|
|
|
|
|
|
addscore_precision = 1 |
|
|
addscore_recall = 1 |
|
|
if len(addgramcounter) > 0: |
|
|
addscore_precision = addtmpscore / len(addgramcounter) |
|
|
if len(addgramcounterall) > 0: |
|
|
addscore_recall = addtmpscore / len(addgramcounterall) |
|
|
addscore = 0 |
|
|
if addscore_precision > 0 or addscore_recall > 0: |
|
|
addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall) |
|
|
|
|
|
return (keepscore, delscore_precision, addscore) |
|
|
|
|
|
|
|
|
def SARIsent(ssent, csent, rsents): |
|
|
numref = len(rsents) |
|
|
|
|
|
s1grams = ssent.split(" ") |
|
|
c1grams = csent.split(" ") |
|
|
s2grams = [] |
|
|
c2grams = [] |
|
|
s3grams = [] |
|
|
c3grams = [] |
|
|
s4grams = [] |
|
|
c4grams = [] |
|
|
|
|
|
r1gramslist = [] |
|
|
r2gramslist = [] |
|
|
r3gramslist = [] |
|
|
r4gramslist = [] |
|
|
for rsent in rsents: |
|
|
r1grams = rsent.split(" ") |
|
|
r2grams = [] |
|
|
r3grams = [] |
|
|
r4grams = [] |
|
|
r1gramslist.append(r1grams) |
|
|
for i in range(0, len(r1grams) - 1): |
|
|
if i < len(r1grams) - 1: |
|
|
r2gram = r1grams[i] + " " + r1grams[i + 1] |
|
|
r2grams.append(r2gram) |
|
|
if i < len(r1grams) - 2: |
|
|
r3gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] |
|
|
r3grams.append(r3gram) |
|
|
if i < len(r1grams) - 3: |
|
|
r4gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] + " " + r1grams[i + 3] |
|
|
r4grams.append(r4gram) |
|
|
r2gramslist.append(r2grams) |
|
|
r3gramslist.append(r3grams) |
|
|
r4gramslist.append(r4grams) |
|
|
|
|
|
for i in range(0, len(s1grams) - 1): |
|
|
if i < len(s1grams) - 1: |
|
|
s2gram = s1grams[i] + " " + s1grams[i + 1] |
|
|
s2grams.append(s2gram) |
|
|
if i < len(s1grams) - 2: |
|
|
s3gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] |
|
|
s3grams.append(s3gram) |
|
|
if i < len(s1grams) - 3: |
|
|
s4gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] + " " + s1grams[i + 3] |
|
|
s4grams.append(s4gram) |
|
|
|
|
|
for i in range(0, len(c1grams) - 1): |
|
|
if i < len(c1grams) - 1: |
|
|
c2gram = c1grams[i] + " " + c1grams[i + 1] |
|
|
c2grams.append(c2gram) |
|
|
if i < len(c1grams) - 2: |
|
|
c3gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] |
|
|
c3grams.append(c3gram) |
|
|
if i < len(c1grams) - 3: |
|
|
c4gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] + " " + c1grams[i + 3] |
|
|
c4grams.append(c4gram) |
|
|
|
|
|
(keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref) |
|
|
(keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref) |
|
|
(keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref) |
|
|
(keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref) |
|
|
avgkeepscore = sum([keep1score, keep2score, keep3score, keep4score]) / 4 |
|
|
avgdelscore = sum([del1score, del2score, del3score, del4score]) / 4 |
|
|
avgaddscore = sum([add1score, add2score, add3score, add4score]) / 4 |
|
|
finalscore = (avgkeepscore + avgdelscore + avgaddscore) / 3 |
|
|
return finalscore |
|
|
|
|
|
|
|
|
def normalize(sentence, lowercase: bool = True, tokenizer: str = "13a", return_str: bool = True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if lowercase: |
|
|
sentence = sentence.lower() |
|
|
|
|
|
if tokenizer in ["13a", "intl"]: |
|
|
if version.parse(sacrebleu.__version__).major >= 2: |
|
|
normalized_sent = sacrebleu.metrics.bleu._get_tokenizer(tokenizer)()(sentence) |
|
|
else: |
|
|
normalized_sent = sacrebleu.TOKENIZERS[tokenizer]()(sentence) |
|
|
elif tokenizer == "moses": |
|
|
normalized_sent = sacremoses.MosesTokenizer().tokenize(sentence, return_str=True, escape=False) |
|
|
elif tokenizer == "penn": |
|
|
normalized_sent = sacremoses.MosesTokenizer().penn_tokenize(sentence, return_str=True) |
|
|
else: |
|
|
normalized_sent = sentence |
|
|
|
|
|
if not return_str: |
|
|
normalized_sent = normalized_sent.split() |
|
|
|
|
|
return normalized_sent |
|
|
|
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
|
class Sari(evaluate.Metric): |
|
|
def _info(self): |
|
|
return evaluate.MetricInfo( |
|
|
description=_DESCRIPTION, |
|
|
citation=_CITATION, |
|
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
features=datasets.Features( |
|
|
{ |
|
|
"sources": datasets.Value("string", id="sequence"), |
|
|
"predictions": datasets.Value("string", id="sequence"), |
|
|
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), |
|
|
} |
|
|
), |
|
|
codebase_urls=[ |
|
|
"https://github.com/cocoxu/simplification/blob/master/SARI.py", |
|
|
"https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/sari_hook.py", |
|
|
], |
|
|
reference_urls=["https://www.aclweb.org/anthology/Q16-1029.pdf"], |
|
|
) |
|
|
|
|
|
def _compute(self, sources, predictions, references): |
|
|
|
|
|
if not (len(sources) == len(predictions) == len(references)): |
|
|
raise ValueError("Sources length must match predictions and references lengths.") |
|
|
sari_score = 0 |
|
|
for src, pred, refs in zip(sources, predictions, references): |
|
|
sari_score += SARIsent(normalize(src), normalize(pred), [normalize(sent) for sent in refs]) |
|
|
sari_score = sari_score / len(predictions) |
|
|
return {"sari": 100 * sari_score} |
|
|
|