pminervini's picture
update
fd975b0
raw
history blame
No virus
5.99 kB
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean
import torch
import sacrebleu
from rouge_score import rouge_scorer, scoring
def bleu(refs, preds):
"""
Returns `t5` style BLEU scores. See the related implementation:
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L41
:param refs:
A `list` of `list` of reference `str`s.
:param preds:
A `list` of predicted `str`s.
"""
score = sacrebleu.corpus_bleu(
preds,
refs,
smooth_method="exp",
smooth_value=0.0,
force=False,
lowercase=False,
tokenize="intl",
use_effective_order=False,
).score
return score
def rouge(refs, preds):
"""
Returns `t5` style ROUGE scores. See the related implementation:
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L68
:param refs:
A `list` of reference `strs`.
:param preds:
A `list` of predicted `strs`.
"""
rouge_types = ["rouge1", "rouge2", "rougeLsum"]
scorer = rouge_scorer.RougeScorer(rouge_types)
# Add newlines between sentences to correctly compute `rougeLsum`.
def _prepare_summary(summary):
summary = summary.replace(" . ", ".\n")
return summary
# Accumulate confidence intervals.
aggregator = scoring.BootstrapAggregator()
for ref, pred in zip(refs, preds):
ref = _prepare_summary(ref)
pred = _prepare_summary(pred)
aggregator.add_scores(scorer.score(ref, pred))
result = aggregator.aggregate()
return {type: result[type].mid.fmeasure * 100 for type in rouge_types}
@register_task("cnndm")
class CnnDm(Task):
VERSION = 0
DATASET_PATH = "cnn_dailymail"
DATASET_NAME = "3.0.0"
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, config=None):
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
self.factkb_tokenizer = None
self.factkb_model = None
def maybe_init_factkb(self):
if self.factkb_tokenizer is None or self.factkb_model is None:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
return f'Document: {doc["article"]}\nSummary:'
@staticmethod
def should_decontaminate():
return True
def doc_to_decontamination_query(self, doc):
return doc["article"]
def doc_to_target(self, doc):
return doc["highlights"]
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return [
Instance(
request_type="generate_until",
doc=doc,
arguments=(ctx, {"until": ["\n"]}),
idx=0,
**kwargs
)
]
def process_results(self, doc, results):
completion = results[0]
# true_refs, false_refs = doc["correct_answers"], doc["incorrect_answers"]
# all_refs = true_refs + false_refs
document = doc["article"]
true_refs = [doc["highlights"]]
all_refs = true_refs
# ROUGE-N
rouge_scores = [rouge([ref], [completion]) for ref in all_refs]
# ROUGE-1
rouge1_scores = [score["rouge1"] for score in rouge_scores]
# ROUGE-2
rouge2_scores = [score["rouge2"] for score in rouge_scores]
# ROUGE-L
rougeL_scores = [score["rougeLsum"] for score in rouge_scores]
self.maybe_init_factkb()
input_factkb = [[completion, document]]
factkb_tokens = self.factkb_tokenizer(input_factkb, return_tensors="pt", padding="max_length", truncation=True).to(self.factkb_model.device)
factkb_logits = self.factkb_model(**factkb_tokens).logits
factkb_res = torch.softmax(factkb_logits, dim=1)
res = {
"rouge1": rouge1_scores[0],
"rouge2": rouge2_scores[0],
"rougeL": rougeL_scores[0],
"factKB": float(factkb_res[0][1])
}
return res
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {k: mean for k in ["rouge1", "rouge2", "rougeL"]}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {k: True for k in ["rouge1", "rouge2", "rougeL"]}