saxenarohit
added cnn
a0d8a50
raw
history blame
2.71 kB
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean
import datasets
from src.backend.tasks.cnndm import utils
@register_task("cnndm")
class CnnDm(Task):
VERSION = 0
DATASET_PATH = "cnn_dailymail"
DATASET_NAME = "3.0.0"
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, config=None):
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
print('XXX CNNDM!')
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
return self.dataset["test"]
def doc_to_text(self, doc):
return f'Document: {doc["article"]}\nSummary:'
@staticmethod
def should_decontaminate():
return True
def doc_to_decontamination_query(self, doc):
return doc["article"]
def doc_to_target(self, doc):
return doc["highlights"]
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return [
Instance(
request_type="generate_until",
doc=doc,
arguments=(ctx, {"until": ["\n", "."]}),
idx=0,
**kwargs
)
]
def process_results(self, doc, results):
return utils.process_results(doc, results)
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {k: mean for k in ["rouge1", "rouge2", "rougeL"]}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {k: True for k in ["rouge1", "rouge2", "rougeL"]}