meteor / meteor.py
lvwerra's picture
lvwerra HF staff
Update Space (evaluate main: e4a27243)
e955d89
raw
history blame
7.37 kB
# Copyright 2020 The HuggingFace Evaluate Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" METEOR metric. """
from dataclasses import dataclass
import datasets
import numpy as np
from datasets.config import importlib_metadata, version
from nltk.translate import meteor_score
import evaluate
NLTK_VERSION = version.parse(importlib_metadata.version("nltk"))
if NLTK_VERSION >= version.Version("3.6.4"):
from nltk import word_tokenize
_CITATION = """\
@inproceedings{banarjee2005,
title = {{METEOR}: An Automatic Metric for {MT} Evaluation with Improved Correlation with Human Judgments},
author = {Banerjee, Satanjeev and Lavie, Alon},
booktitle = {Proceedings of the {ACL} Workshop on Intrinsic and Extrinsic Evaluation Measures for Machine Translation and/or Summarization},
month = jun,
year = {2005},
address = {Ann Arbor, Michigan},
publisher = {Association for Computational Linguistics},
url = {https://www.aclweb.org/anthology/W05-0909},
pages = {65--72},
}
"""
_DESCRIPTION = """\
METEOR, an automatic metric for machine translation evaluation
that is based on a generalized concept of unigram matching between the
machine-produced translation and human-produced reference translations.
Unigrams can be matched based on their surface forms, stemmed forms,
and meanings; furthermore, METEOR can be easily extended to include more
advanced matching strategies. Once all generalized unigram matches
between the two strings have been found, METEOR computes a score for
this matching using a combination of unigram-precision, unigram-recall, and
a measure of fragmentation that is designed to directly capture how
well-ordered the matched words in the machine translation are in relation
to the reference.
METEOR gets an R correlation value of 0.347 with human evaluation on the Arabic
data and 0.331 on the Chinese data. This is shown to be an improvement on
using simply unigram-precision, unigram-recall and their harmonic F1
combination.
"""
_KWARGS_DESCRIPTION = """
Computes METEOR score of translated segments against one or more references.
Args:
predictions: list of predictions to score. Each prediction
should be a string with tokens separated by spaces.
references: list of reference for each prediction. Each
reference should be a string with tokens separated by spaces.
alpha: Parameter for controlling relative weights of precision and recall. default: 0.9
beta: Parameter for controlling shape of penalty as a function of fragmentation. default: 3
gamma: Relative weight assigned to fragmentation penalty. default: 0.5
Returns:
'meteor': meteor score.
Examples:
>>> meteor = evaluate.load('meteor')
>>> predictions = ["It is a guide to action which ensures that the military always obeys the commands of the party"]
>>> references = ["It is a guide to action that ensures that the military will forever heed Party commands"]
>>> results = meteor.compute(predictions=predictions, references=references)
>>> print(round(results["meteor"], 4))
0.6944
"""
@dataclass
class MeteorConfig(evaluate.info.Config):
name: str = "default"
alpha: float = 0.9
beta: float = 3.0
gamma: float = 0.5
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Meteor(evaluate.Metric):
CONFIG_CLASS = MeteorConfig
ALLOWED_CONFIG_NAMES = ["default", "multilabel"]
def _info(self, config):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
config=config,
features=[
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
}
),
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
],
codebase_urls=["https://github.com/nltk/nltk/blob/develop/nltk/translate/meteor_score.py"],
reference_urls=[
"https://www.nltk.org/api/nltk.translate.html#module-nltk.translate.meteor_score",
"https://en.wikipedia.org/wiki/METEOR",
],
)
def _download_and_prepare(self, dl_manager):
import nltk
nltk.download("wordnet")
if NLTK_VERSION >= version.Version("3.6.5"):
nltk.download("punkt")
if NLTK_VERSION >= version.Version("3.6.6"):
nltk.download("omw-1.4")
def _compute(self, predictions, references):
multiple_refs = isinstance(references[0], list)
if NLTK_VERSION >= version.Version("3.6.5"):
# the version of METEOR in NLTK version 3.6.5 and earlier expect tokenized inputs
if multiple_refs:
scores = [
meteor_score.meteor_score(
[word_tokenize(ref) for ref in refs],
word_tokenize(pred),
alpha=self.config.alpha,
beta=self.config.beta,
gamma=self.config.gamma,
)
for refs, pred in zip(references, predictions)
]
else:
scores = [
meteor_score.single_meteor_score(
word_tokenize(ref),
word_tokenize(pred),
alpha=self.config.alpha,
beta=self.config.beta,
gamma=self.config.gamma,
)
for ref, pred in zip(references, predictions)
]
else:
if multiple_refs:
scores = [
meteor_score.meteor_score(
[[word_tokenize(ref) for ref in group] for group in references][0],
word_tokenize(pred),
alpha=self.config.alpha,
beta=self.config.beta,
gamma=self.config.gamma,
)
for ref, pred in zip(references, predictions)
]
else:
scores = [
meteor_score.single_meteor_score(
ref, pred, alpha=self.config.alpha, beta=self.config.beta, gamma=self.config.gamma
)
for ref, pred in zip(references, predictions)
]
return {"meteor": np.mean(scores)}