tlem / tasks.py
facat's picture
fix fewshot
075ef98
raw
history blame
No virus
15.6 kB
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from functools import cached_property
from tqdm.auto import tqdm
from typing import Any, Optional, Protocol, Iterable, Callable
import logging
import pandas as pd
from functools import partial
from .utils import *
from evaluate import load
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
return [prompt for prompt in tqdm(prompts)]
@dataclass
class Task:
dataset_name: str | tuple[str, str] = ("gsm8k", "main")
split: str = "test"
# metrics: list[str] = field(default_factory=list)
metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
input_column: str = "question"
label_column: str = "answer"
prompt: Optional[Callable | str] = None
few_shot: int = 0
few_shot_from: Optional[str] = None
# results: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
names = (
[self.dataset_name]
if isinstance(self.dataset_name, str)
else list(self.dataset_name)
)
names[0] = names[0].split("/")[-1]
self.name = "-".join(names) + f"-{self.split}"
if isinstance(self.prompt, str):
self.prompt = lambda example: {
self.input_column: self.prompt.format(
input_column=example[self.input_column]
)
}
@cached_property
def samples(self):
return self.dataset[self.input_column]
@cached_property
def dataset(self):
ds = load_dataset(
*self.dataset_name
if isinstance(self.dataset_name, tuple)
else self.dataset_name,
# split=self.split,
)
test_ds = ds[self.split]
if self.prompt is not None:
test_ds = test_ds.map(self.prompt)
if self.few_shot:
if self.few_shot_from is None:
for name in ["train", "validation", "val", "dev"]:
if name in ds:
self.few_shot_from = name
break
shots = ds[self.few_shot_from].select(range(self.few_shot))
if self.prompt is not None:
shots = shots.map(self.prompt)
shots = shots.map(
lambda example: {
self.input_column: example[self.input_column]
+ example[self.label_column],
}
)[self.input_column]
few_shot_prompts = "\n".join(shots)
test_ds = test_ds.map(
lambda example: {
self.input_column: few_shot_prompts
+ "\n"
+ example[self.input_column],
}
)
return test_ds
@cached_property
def metric(self):
metric = (
load(self.metric_name)
if isinstance(self.metric_name, str)
else load(*self.metric_name)
)
return metric
def run(
self,
pipeline,
):
if (outputs := pipeline(self.samples)) is None:
logging.warning("pipeline returns None")
return
self.outputs = outputs
try:
result = self.metric._compute(
responses=outputs, references=self.dataset[self.label_column]
)
except Exception as e:
result = self.metric.compute(
responses=outputs, references=self.dataset[self.label_column]
)
# if log:
# name = name or pipeline.__name__
# self.results[name] = result
return result
def multichoice(responses: Any, references: list[str]):
if isinstance(responses[0], str):
responses = [extract_choice(response) for response in responses]
else:
responses = decode_choice(responses)
# return [
# int(response == reference) for reference, response in zip(references, responses)
# ]
return responses, references
class Metrics:
cmmlu = multichoice
mmlu = multichoice
def gsm8k(responses: list[str], answers: list[str | int]):
# scores = []
# for response, answer in zip(responses, answers):
# pred = extract_numeric(response)
# gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
# scores.append(1.0 * (pred == gold))
responses = [extract_numeric(response) for response in responses]
answers = [
extract_numeric(answer) if isinstance(answer, str) else str(answer)
for answer in answers
]
return responses, answers
def MATH(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
indices = [pos for pos, char in enumerate(response) if char == "$"]
if len(indices) <= 2:
scores.append(0)
continue
else:
result = response[indices[-2] + 1 : indices[-1]]
gold = get_answer(answer)
scores.append(1.0 * is_equiv(result, gold))
return scores
def math23k(responses: list[str], answers: list[str]):
scores = []
for response, answer in zip(responses, answers):
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
scores.append(1.0 * (pred == gold))
return scores
class CMMLU:
def prompt_cmmlu(example, chat=False):
prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:"
prompt = prefix + example["Question"]
for choice in list("ABCD"):
prompt += f"\n{choice}. {example[choice]}"
prompt += "\n答案:"
return {"prompt": prompt}
subcategories = {
"agronomy": ["other"],
"anatomy": ["biology"],
"ancient_chinese": ["linguistics", "china specific"],
"arts": ["arts"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"chinese_civil_service_exam": ["politics", "china specific"],
"chinese_driving_rule": ["other", "china specific"],
"chinese_food_culture": ["culture", "china specific"],
"chinese_foreign_policy": ["politics", "china specific"],
"chinese_history": ["history", "china specific"],
"chinese_literature": ["literature", "china specific"],
"chinese_teacher_qualification": ["education", "china specific"],
"college_actuarial_science": ["math"],
"college_education": ["education"],
"college_engineering_hydrology": ["engineering"],
"college_law": ["law"],
"college_mathematics": ["math"],
"college_medical_statistics": ["statistics"],
"clinical_knowledge": ["other"],
"college_medicine": ["other"],
"computer_science": ["computer science"],
"computer_security": ["other"],
"conceptual_physics": ["physics"],
"construction_project_management": ["other", "china specific"],
"economics": ["economics"],
"education": ["education"],
"elementary_chinese": ["linguistics", "china specific"],
"elementary_commonsense": ["other", "china specific"],
"elementary_information_and_technology": ["other"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"ethnology": ["culture", "china specific"],
"food_science": ["other"],
"genetics": ["biology"],
"global_facts": ["global"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_geography": ["geography"],
"high_school_mathematics": ["math"],
"high_school_physics": ["physics"],
"high_school_politics": ["politics", "china specific"],
"human_sexuality": ["other"],
"international_law": ["law"],
"journalism": ["sociology"],
"jurisprudence": ["law"],
"legal_and_moral_basis": ["other"],
"logical": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"marxist_theory": ["philosophy"],
"modern_chinese": ["linguistics", "china specific"],
"nutrition": ["other"],
"philosophy": ["philosophy"],
"professional_accounting": ["business"],
"professional_law": ["law"],
"professional_medicine": ["other"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_study": ["politics"],
"sociology": ["culture"],
"sports_science": ["other"],
"traditional_chinese_medicine": ["other", "china specific"],
"virology": ["biology"],
"world_history": ["history"],
"world_religions": ["global"],
}
categories = {
"STEM": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
"statistics",
],
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
"Social Science": [
"linguistics",
"business",
"politics",
"culture",
"economics",
"geography",
"psychology",
"education",
"sociology",
],
"Other": ["other"],
"China specific": ["china specific"],
"Test": ["computer science"],
}
finer_categories = (
pd.Series(subcategories) # noqa # type: ignore
.explode()
.reset_index()
.set_index(0)
.groupby(0)
.agg(list)["index"]
.to_dict()
)
@classmethod
def suite(cls, chat=False):
suite = {}
for k, v in cls.categories.items():
for subject in v:
suite[k] = [
Task(
("haonan-li/cmmlu", subcategories),
metric_name=("sustech/tlem", "cmmlu"),
input_column="prompt",
label_column="Answer",
prompt=partial(cls.prompt_cmmlu, chat=chat),
)
for subcategories in cls.finer_categories[subject]
]
return suite
class MMLU:
input_column = "prompt"
label_column = "target"
@classmethod
def prompt_mmlu(cls, example, chat=False):
prefix = (
"The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n"
if chat
else "Question: "
)
prompt = prefix + example["input"]
for choice in list("ABCD"):
prompt += f"\n{choice}. {example[choice]}"
prompt += "\nAnswer:"
return {"prompt": prompt}
subcategories = {
"abstract_algebra": ["math"],
"anatomy": ["health"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"clinical_knowledge": ["health"],
"college_biology": ["biology"],
"college_chemistry": ["chemistry"],
"college_computer_science": ["computer science"],
"college_mathematics": ["math"],
"college_medicine": ["health"],
"college_physics": ["physics"],
"computer_security": ["computer science"],
"conceptual_physics": ["physics"],
"econometrics": ["economics"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"formal_logic": ["philosophy"],
"global_facts": ["other"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_computer_science": ["computer science"],
"high_school_european_history": ["history"],
"high_school_geography": ["geography"],
"high_school_government_and_politics": ["politics"],
"high_school_macroeconomics": ["economics"],
"high_school_mathematics": ["math"],
"high_school_microeconomics": ["economics"],
"high_school_physics": ["physics"],
"high_school_psychology": ["psychology"],
"high_school_statistics": ["math"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"human_aging": ["health"],
"human_sexuality": ["culture"],
"international_law": ["law"],
"jurisprudence": ["law"],
"logical_fallacies": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"medical_genetics": ["health"],
"miscellaneous": ["other"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"nutrition": ["health"],
"philosophy": ["philosophy"],
"prehistory": ["history"],
"professional_accounting": ["other"],
"professional_law": ["law"],
"professional_medicine": ["health"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"sociology": ["culture"],
"us_foreign_policy": ["politics"],
"virology": ["health"],
"world_religions": ["philosophy"],
}
categories = {
"Math": [
"math",
],
"STEM": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
],
"humanities": ["history", "philosophy", "law"],
"social sciences": [
"politics",
"culture",
"economics",
"geography",
"psychology",
],
"Other": ["other", "business", "health"],
"All": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
"history",
"philosophy",
"law",
"politics",
"culture",
"economics",
"geography",
"psychology",
"other",
"business",
"health",
],
"Test": ["culture"],
}
@classmethod
def suite(cls, chat=False):
finer_categories = (
pd.Series(cls.subcategories) # noqa # type: ignore
.explode()
.reset_index()
.set_index(0)
.groupby(0)
.agg(list)["index"]
.to_dict()
)
suite = {}
for k, v in cls.categories.items():
for subject in v:
suite[k] = [
Task(
("lukaemon/mmlu", subcategories),
metric_name=("sustech/tlem", "mmlu"),
input_column=cls.input_column,
label_column=cls.label_column,
prompt=partial(cls.prompt_mmlu, chat=chat),
few_shot=0 if chat else 5,
few_shot_from="validation",
)
for subcategories in finer_categories[subject]
]
return suite