backend_demo / src /custom_tasks /winograd_task.py
Shaltiel's picture
Fixed whitespace for prediction
adf0b2e
raw
history blame
2.11 kB
import re
import string
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.metrics import Metrics, MetricCategory
from lighteval.metrics.utils import CorpusLevelMetric, MetricUseCase
from aenum import extend_enum
import numpy as np
from lighteval.tasks.requests import Doc
from Levenshtein import distance
import collections
from lighteval.utils import as_list
def winograd_eval_fn(golds: list[str], predictions: list[str], formatted_doc: Doc = None):
if len(predictions) > 1:
raise ValueError("Predictions should have one item")
# do some santizations, since some models produce more info
pred = re.sub('<[^>]+>', '', predictions[0]).strip() # remove xml tags
return 1 if pred == golds[0] else 0
winograd_acc_metric = CorpusLevelMetric(
metric="winograd_acc",
higher_is_better=True,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
sample_level_fn=winograd_eval_fn
)
extend_enum(Metrics, 'winograd_acc_metric', winograd_acc_metric)
def winograd_prompt_fn(line, task_name: str = None):
"""Defines how to go from a dataset line to a doc object.
Follow examples in src/lighteval/tasks/tasks_prompt_formatting.py, or get more info
about what this function should do in the README.
"""
return Doc(
task_name=task_name,
query=line["prompt"].strip(),
choices=[resp.strip() for resp in line["response"]],
gold_index=0,
instruction="",
)
# This is how you create a simple tasks (like hellaswag) which has one single subset
# attached to it, and one evaluation possible.
winograd_task = LightevalTaskConfig(
name="winograd-acc",
prompt_function="winograd_prompt_fn", # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py
suite=["custom"],
hf_repo="dicta-hebrew-llm-leaderboard/tests",
hf_subset="default",
hf_avail_splits=["winograd"],
evaluation_splits=["winograd"],
metric=['winograd_acc_metric'],
stop_sequence=['\n'],
generation_size=32
)