import logging from typing import Any, Optional, Protocol, Iterable, Callable from tqdm.auto import tqdm from evaluate.evaluation_suite import EvaluationSuite import evaluate import numpy as np import datasets import pandas as pd from .tasks import * from .utils import * from itertools import chain from copy import deepcopy from . import utils class ReasoningMetric(evaluate.Metric): """TODO: Short description of my evaluation module.""" def _info(self): # if self.config_name in ["cmmlu"]: features = datasets.Features( { "responses": datasets.Value("string"), # "responses": datasets.Sequence(datasets.Value("float")), "references": datasets.Value("string"), } ) # TODO: Specifies the evaluate.EvaluationModuleInfo object return evaluate.EvaluationModuleInfo( # This is the description that will appear on the modules page. # module_type="measurement", description="", citation="", inputs_description="", # This defines the format of each prediction and reference features=features, # Homepage of the module for documentation homepage="http://module.homepage", # Additional links to the codebase or references codebase_urls=["http://github.com/path/to/codebase/of/new_module"], reference_urls=["http://path.to.reference.url/new_module"], ) def _compute(self, responses, references): return_value = getattr(Metrics, self.config_name)(responses, references) match return_value: case extract_responses, extract_references: results = { self.config_name: np.mean( sync_pipe(lambda x, y: x == y)( zip(extract_responses, extract_references) ) ) } case dict(): results = return_value case list(): results = {self.config_name: np.mean(return_value)} case _: raise NotImplementedError return results class Suite(EvaluationSuite): task_class = Task utils = utils supported_datasets = [ "arc", "hellaswag", "mmlu-chat", "winogrande", "gsm8k", "cmmlu-chat", "ceval-chat", "bbh", "drop", "MATH", ] def __getitem__(self, key) -> Task: match key: case str(): return self.suite[key] case slice() | int(): return self.tasks[key] def agg(self, suite): for cate, tasks in suite.items(): if isinstance(tasks, dict): suite[cate] = self.agg(tasks) else: suite[cate] = np.mean([pd.Series(task.result).mean() for task in tasks]) return suite def run( self, model_or_pipeline: Any, ) -> dict[str, float]: self.assert_suite_nonempty() self.suite: dict[str, list[Task]] for task in (bar := tqdm(self.tasks)): bar.desc = f"complete {task.name}." _ = task.run(model_or_pipeline) logging.info(f"{task.name} {task.result=}") return self.agg(deepcopy(self.suite)) def arun(self, model_or_pipeline): async def sync_function(): return await tqdm.gather( *[task.arun(model_or_pipeline) for task in self.tasks], leave=False ) asyncio.run(sync_function()) return self.agg(deepcopy(self.suite)) def get_suite(self, name) -> dict[str, Task]: chat = False suite={} match name: case _ if "chat" in name: chat = True match name: case _ if name.startswith("mmlu"): suite = MMLU.suite(chat=chat) case _ if name.startswith("cmmlu"): suite = CMMLU.suite(chat=chat) case _ if name.startswith("ceval"): suite = CEVAL.suite(chat=chat) case "gsm8k": suite = Task( dataset_name=("gsm8k", "main"), metric_name=("sustech/tlem", "gsm8k"), input_column="question", label_column="answer", ) case "bbh": suite = BBH.suite() case "arc": suite = ARC.suite() case "hellaswag": suite = HellaSwag.suite() case "drop": suite = DROP.suite() case "winogrande": suite = Winogrande.suite() case "truthfulqa_mc1": suite = TruthfulQAMC1.suite() case _ if name.startswith("boolq"): suite = BoolQ.suite(chat=chat) case "mt_bench": suite = Task( dataset_name="SUSTech/mt_bench_judge", split="train", prompt=mt_bench_prompt # metric_name=("sustech/tlem", "gsm8k"), ) case "MATH" | "competition_math": suite = Task( dataset_name="hendrycks/competition_math", prompt="This is a math problem, please think step by step and slove it: {input_column}. Simplify your final answer as much as possible and surround them with '$' in TeX form.", metric_name=("sustech/tlem", "MATH"), input_column="problem", label_column="solution", ) case "open-leaderboard": for name in [ "arc", "hellaswag", "mmlu-chat", "winogrande", "gsm8k", # "truthful_qa", "drop", ]: suite.update(self.get_suite(name)) case "tlem": for name in [ "arc", "hellaswag", "mmlu-chat", "winogrande", "gsm8k", # "truthful_qa", "cmmlu-chat", "ceval-chat", "bbh", ]: suite.update(self.get_suite(name)) case "all": for name in self.supported_datasets: suite.update(self.get_suite(name)) case _: raise NotImplementedError( f"{name} is not supported in {self.supported_datasets}" ) if isinstance(suite, Task): suite = [suite] suite = {name: suite} return suite def singleton(self, task): try: return self.tasks[self.tasks.index(task)] except ValueError: logging.debug(f"add {task.name} to suite.") self.tasks.append(task) logging.debug(self.tasks) return self.tasks[-1] def drop_duplicates(self, suite): for category, tasks in suite.items(): match tasks: case list(): suite[category] = [self.singleton(task) for task in tasks] case dict(): suite[category] = self.drop_duplicates(tasks) case _: raise NotImplementedError return suite def load(self, name): sub_suite = self.get_suite(name) self.suite.update(sub_suite) self.suite = self.drop_duplicates(self.suite) # return self def __init__(self, name="tlem"): super().__init__(name) self.tasks = [] self.suite = {}