Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,746 Bytes
8b7a945 3478401 8b7a945 3478401 8b7a945 9c49811 3478401 8b7a945 9c49811 8b7a945 a96f80a 3478401 f30cbcc 8b7a945 3478401 8b7a945 f30cbcc 3478401 32ebf18 3478401 8b7a945 9c49811 3478401 e8879cc 3478401 f30cbcc 32ebf18 3478401 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from dataclasses import dataclass
from enum import Enum
from air_benchmark.tasks.tasks import BenchmarkTable
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
from src.models import TaskType, get_safe_name
@dataclass
class Benchmark:
name: str # [domain]_[language]_[metric], task_key in the json file,
metric: str # metric_key in the json file
col_name: str # [domain]_[language], name to display in the leaderboard
domain: str
lang: str
task: str
# create a function return an enum class containing all the benchmarks
def get_qa_benchmarks_dict(version: str):
benchmark_dict = {}
for task, domain_dict in BenchmarkTable[version].items():
if task != TaskType.qa.value:
continue
for domain, lang_dict in domain_dict.items():
for lang, dataset_list in lang_dict.items():
benchmark_name = get_safe_name(f"{domain}_{lang}")
col_name = benchmark_name
for metric in dataset_list:
if "test" not in dataset_list[metric]["splits"]:
continue
benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
return benchmark_dict
def get_doc_benchmarks_dict(version: str):
benchmark_dict = {}
for task, domain_dict in BenchmarkTable[version].items():
if task != TaskType.long_doc.value:
continue
for domain, lang_dict in domain_dict.items():
for lang, dataset_list in lang_dict.items():
for dataset in dataset_list:
benchmark_name = f"{domain}_{lang}_{dataset}"
benchmark_name = get_safe_name(benchmark_name)
col_name = benchmark_name
if "test" not in dataset_list[dataset]["splits"]:
continue
for metric in METRIC_LIST:
benchmark_dict[benchmark_name] = Benchmark(
benchmark_name, metric, col_name, domain, lang, task
)
return benchmark_dict
_qa_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
safe_version_name = get_safe_name(version)
_qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_qa_benchmarks_dict(version))
_doc_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
safe_version_name = get_safe_name(version)
_doc_benchmark_dict[safe_version_name] = Enum(
f"LongDocBenchmarks_{safe_version_name}", get_doc_benchmarks_dict(version)
)
QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
LongDocBenchmarks = Enum("LongDocBenchmarks", _doc_benchmark_dict)
|