tlem / tasks.py
facat's picture
upd
044ed98
raw
history blame
No virus
4.34 kB
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from functools import cached_property
from tqdm.auto import tqdm
from typing import Any, Optional, Protocol, Iterable, Callable
from .utils import (
NUMERIC_IN_ZH,
extract_choice_ans,
extract_numeric,
get_answer,
is_equiv,
)
from evaluate import load
TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
return [prompt for prompt in tqdm(prompts)]
@dataclass
class Task:
dataset_name: str | tuple[str, str] = ("gsm8k", "main")
split: str = "test"
# metrics: list[str] = field(default_factory=list)
metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
input_column: str = "question"
label_column: str = "answer"
prompt: Optional[Callable | str] = None
@cached_property
def name(self):
return (
self.dataset_name
if isinstance(self.dataset_name, str)
else self.dataset_name[0]
) + f"-{self.split}"
@cached_property
def samples(self):
return self.dataset[self.input_column]
@cached_property
def dataset(self):
ds = load_dataset(
*self.dataset_name
if isinstance(self.dataset_name, tuple)
else self.dataset_name,
split=self.split,
)
if self.prompt is not None:
ds = ds.map(
lambda example: {
self.input_column: self.prompt.format(
input_column=example[self.input_column]
)
}
if isinstance(self.prompt, str)
else self.prompt(example),
)
return ds
@cached_property
def metric(self):
metric = (
load(self.metric_name)
if isinstance(self.metric_name, str)
else load(*self.metric_name)
)
return metric
def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
outputs = pipeline(self.samples)
return self.metric.compute(
responses=outputs, references=self.dataset[self.label_column]
)
class Metrics:
def gsm8k(responses: list[str], answers: list[str | int]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response)
gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
scores.append(1.0 * (pred == gold))
return scores
def MATH(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
indices = [pos for pos, char in enumerate(response) if char == "$"]
if len(indices) <= 2:
scores.append(0)
continue
else:
result = response[indices[-2] + 1 : indices[-1]]
gold = get_answer(answer)
scores.append(1.0 * is_equiv(result, gold))
return scores
def math23k(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
scores.append(1.0 * (pred == gold))
return scores
def gsm8k_zh(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = extract_numeric(answer)
scores.append(1.0 * (pred == gold))
return scores
def svamp(responses: list[float], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = answer
scores.append(1.0 * (float(pred) == gold))
return scores
def mmlu(responses, answers):
scores = []
for response, answer in zip(responses, answers):
pred = extract_choice_ans(response)
gold = answer.lower()
scores.append(1.0 * (pred == gold))
return scores