from dataclasses import dataclass, field from datasets import load_dataset, Dataset from functools import cached_property from tqdm.auto import tqdm from typing import Any, Optional, Protocol, Iterable, Callable import logging import pandas as pd from functools import partial from .utils import * from evaluate import load def fake_pipeline(prompts: Iterable[str]) -> list[str]: return [prompt for prompt in tqdm(prompts)] @dataclass class Task: dataset_name: str | tuple[str, str] = ("gsm8k", "main") split: str = "test" # metrics: list[str] = field(default_factory=list) metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k") input_column: str = "question" label_column: str = "answer" prompt: Optional[Callable | str] = None few_shot: int = 0 few_shot_from: Optional[str] = None # results: dict[str, Any] = field(default_factory=dict) def __post_init__(self): names = ( [self.dataset_name] if isinstance(self.dataset_name, str) else list(self.dataset_name) ) names[0] = names[0].split("/")[-1] self.name = "-".join(names) + f"-{self.split}" if isinstance(self.prompt, str): self.prompt = lambda example: { self.input_column: self.prompt.format( input_column=example[self.input_column] ) } @cached_property def samples(self): return self.dataset[self.input_column] @cached_property def dataset(self): ds = load_dataset( *self.dataset_name if isinstance(self.dataset_name, tuple) else self.dataset_name, # split=self.split, ) test_ds = ds[self.split] if self.prompt is not None: test_ds = test_ds.map(self.prompt) if self.few_shot: if self.few_shot_from is None: for name in ["train", "validation", "val", "dev"]: if name in ds: self.few_shot_from = name break shots = ds[self.few_shot_from].select(range(self.few_shot)) if self.prompt is not None: shots = shots.map(self.prompt) shots = shots.map( lambda example: { self.input_column: example[self.input_column] + example[self.label_column], } )[self.input_column] few_shot_prompts = "\n".join(shots) test_ds = test_ds.map( lambda example: { self.input_column: few_shot_prompts + "\n" + example[self.input_column], } ) return test_ds @cached_property def metric(self): metric = ( load(self.metric_name) if isinstance(self.metric_name, str) else load(*self.metric_name) ) return metric def run( self, pipeline, ): if (outputs := pipeline(self.samples)) is None: logging.warning("pipeline returns None") return self.outputs = outputs try: result = self.metric._compute( responses=outputs, references=self.dataset[self.label_column] ) except Exception as e: result = self.metric.compute( responses=outputs, references=self.dataset[self.label_column] ) # if log: # name = name or pipeline.__name__ # self.results[name] = result return result def multichoice(responses: Any, references: list[str]): if isinstance(responses[0], str): responses = [extract_choice(response) for response in responses] else: responses = decode_choice(responses) # return [ # int(response == reference) for reference, response in zip(references, responses) # ] return responses, references class Metrics: cmmlu = multichoice mmlu = multichoice def gsm8k(responses: list[str], answers: list[str | int]): # scores = [] # for response, answer in zip(responses, answers): # pred = extract_numeric(response) # gold = extract_numeric(answer) if isinstance(answer, str) else str(answer) # scores.append(1.0 * (pred == gold)) responses = [extract_numeric(response) for response in responses] answers = [ extract_numeric(answer) if isinstance(answer, str) else str(answer) for answer in answers ] return responses, answers def MATH(responses: list[str], answers: list[str]): scores = [] for response, answer in zip(responses, answers): indices = [pos for pos, char in enumerate(response) if char == "$"] if len(indices) <= 2: scores.append(0) continue else: result = response[indices[-2] + 1 : indices[-1]] gold = get_answer(answer) scores.append(1.0 * is_equiv(result, gold)) return scores def math23k(responses: list[str], answers: list[str]): scores = [] for response, answer in zip(responses, answers): pred = extract_numeric(response, pattern=NUMERIC_IN_ZH) gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH) scores.append(1.0 * (pred == gold)) return scores class CMMLU: def prompt_cmmlu(example, chat=False): prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:" prompt = prefix + example["Question"] for choice in list("ABCD"): prompt += f"\n{choice}. {example[choice]}" prompt += "\n答案:" return {"prompt": prompt} subcategories = { "agronomy": ["other"], "anatomy": ["biology"], "ancient_chinese": ["linguistics", "china specific"], "arts": ["arts"], "astronomy": ["physics"], "business_ethics": ["business"], "chinese_civil_service_exam": ["politics", "china specific"], "chinese_driving_rule": ["other", "china specific"], "chinese_food_culture": ["culture", "china specific"], "chinese_foreign_policy": ["politics", "china specific"], "chinese_history": ["history", "china specific"], "chinese_literature": ["literature", "china specific"], "chinese_teacher_qualification": ["education", "china specific"], "college_actuarial_science": ["math"], "college_education": ["education"], "college_engineering_hydrology": ["engineering"], "college_law": ["law"], "college_mathematics": ["math"], "college_medical_statistics": ["statistics"], "clinical_knowledge": ["other"], "college_medicine": ["other"], "computer_science": ["computer science"], "computer_security": ["other"], "conceptual_physics": ["physics"], "construction_project_management": ["other", "china specific"], "economics": ["economics"], "education": ["education"], "elementary_chinese": ["linguistics", "china specific"], "elementary_commonsense": ["other", "china specific"], "elementary_information_and_technology": ["other"], "electrical_engineering": ["engineering"], "elementary_mathematics": ["math"], "ethnology": ["culture", "china specific"], "food_science": ["other"], "genetics": ["biology"], "global_facts": ["global"], "high_school_biology": ["biology"], "high_school_chemistry": ["chemistry"], "high_school_geography": ["geography"], "high_school_mathematics": ["math"], "high_school_physics": ["physics"], "high_school_politics": ["politics", "china specific"], "human_sexuality": ["other"], "international_law": ["law"], "journalism": ["sociology"], "jurisprudence": ["law"], "legal_and_moral_basis": ["other"], "logical": ["philosophy"], "machine_learning": ["computer science"], "management": ["business"], "marketing": ["business"], "marxist_theory": ["philosophy"], "modern_chinese": ["linguistics", "china specific"], "nutrition": ["other"], "philosophy": ["philosophy"], "professional_accounting": ["business"], "professional_law": ["law"], "professional_medicine": ["other"], "professional_psychology": ["psychology"], "public_relations": ["politics"], "security_study": ["politics"], "sociology": ["culture"], "sports_science": ["other"], "traditional_chinese_medicine": ["other", "china specific"], "virology": ["biology"], "world_history": ["history"], "world_religions": ["global"], } categories = { "STEM": [ "physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics", ], "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"], "Social Science": [ "linguistics", "business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology", ], "Other": ["other"], "China specific": ["china specific"], "Test": ["computer science"], } finer_categories = ( pd.Series(subcategories) # noqa # type: ignore .explode() .reset_index() .set_index(0) .groupby(0) .agg(list)["index"] .to_dict() ) @classmethod def suite(cls, chat=False): suite = {} for k, v in cls.categories.items(): for subject in v: suite[k] = [ Task( ("haonan-li/cmmlu", subcategories), metric_name=("sustech/tlem", "cmmlu"), input_column="prompt", label_column="Answer", prompt=partial(cls.prompt_cmmlu, chat=chat), ) for subcategories in cls.finer_categories[subject] ] return suite class MMLU: input_column = "prompt" label_column = "target" @classmethod def prompt_mmlu(cls, example, chat=False): prefix = ( "The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n" if chat else "Question: " ) prompt = prefix + example["input"] for choice in list("ABCD"): prompt += f"\n{choice}. {example[choice]}" prompt += "\nAnswer:" return {"prompt": prompt} subcategories = { "abstract_algebra": ["math"], "anatomy": ["health"], "astronomy": ["physics"], "business_ethics": ["business"], "clinical_knowledge": ["health"], "college_biology": ["biology"], "college_chemistry": ["chemistry"], "college_computer_science": ["computer science"], "college_mathematics": ["math"], "college_medicine": ["health"], "college_physics": ["physics"], "computer_security": ["computer science"], "conceptual_physics": ["physics"], "econometrics": ["economics"], "electrical_engineering": ["engineering"], "elementary_mathematics": ["math"], "formal_logic": ["philosophy"], "global_facts": ["other"], "high_school_biology": ["biology"], "high_school_chemistry": ["chemistry"], "high_school_computer_science": ["computer science"], "high_school_european_history": ["history"], "high_school_geography": ["geography"], "high_school_government_and_politics": ["politics"], "high_school_macroeconomics": ["economics"], "high_school_mathematics": ["math"], "high_school_microeconomics": ["economics"], "high_school_physics": ["physics"], "high_school_psychology": ["psychology"], "high_school_statistics": ["math"], "high_school_us_history": ["history"], "high_school_world_history": ["history"], "human_aging": ["health"], "human_sexuality": ["culture"], "international_law": ["law"], "jurisprudence": ["law"], "logical_fallacies": ["philosophy"], "machine_learning": ["computer science"], "management": ["business"], "marketing": ["business"], "medical_genetics": ["health"], "miscellaneous": ["other"], "moral_disputes": ["philosophy"], "moral_scenarios": ["philosophy"], "nutrition": ["health"], "philosophy": ["philosophy"], "prehistory": ["history"], "professional_accounting": ["other"], "professional_law": ["law"], "professional_medicine": ["health"], "professional_psychology": ["psychology"], "public_relations": ["politics"], "security_studies": ["politics"], "sociology": ["culture"], "us_foreign_policy": ["politics"], "virology": ["health"], "world_religions": ["philosophy"], } categories = { "Math": [ "math", ], "STEM": [ "physics", "chemistry", "biology", "computer science", "math", "engineering", ], "humanities": ["history", "philosophy", "law"], "social sciences": [ "politics", "culture", "economics", "geography", "psychology", ], "Other": ["other", "business", "health"], "All": [ "physics", "chemistry", "biology", "computer science", "math", "engineering", "history", "philosophy", "law", "politics", "culture", "economics", "geography", "psychology", "other", "business", "health", ], "Test": ["culture"], } @classmethod def suite(cls, chat=False): finer_categories = ( pd.Series(cls.subcategories) # noqa # type: ignore .explode() .reset_index() .set_index(0) .groupby(0) .agg(list)["index"] .to_dict() ) suite = {} for k, v in cls.categories.items(): for subject in v: suite[k] = [ Task( ("lukaemon/mmlu", subcategories), metric_name=("sustech/tlem", "mmlu"), input_column=cls.input_column, label_column=cls.label_column, prompt=partial(cls.prompt_mmlu, chat=chat), few_shot=0 if chat else 5, few_shot_from="validation", ) for subcategories in finer_categories[subject] ] return suite