Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,190 Bytes
8b7a945 649e0fb 8b7a945 ec8e2d4 6f9f649 8b7a945 649e0fb ec8e2d4 649e0fb 8b7a945 9c49811 8b7a945 9c49811 8b7a945 a96f80a 3fcf957 6f9f649 7845083 6f9f649 7845083 6f9f649 7845083 f30cbcc 3fcf957 7845083 ec8e2d4 6f9f649 7845083 6f9f649 7845083 3fcf957 bf586e3 3fcf957 ec8e2d4 7845083 8b7a945 3fcf957 7845083 1a22df4 7845083 ec8e2d4 6f9f649 ec8e2d4 7845083 1a22df4 7845083 ec8e2d4 6f9f649 ec8e2d4 7845083 ec8e2d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from dataclasses import dataclass
from enum import Enum
from air_benchmark.tasks.tasks import BenchmarkTable
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
from src.models import TaskType
def get_safe_name(name: str):
"""Get RFC 1123 compatible safe name"""
name = name.replace("-", "_")
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
@dataclass
class Benchmark:
name: str # [domain]_[language]_[metric], task_key in the json file,
metric: str # ndcg_at_1 ,metric_key in the json file
col_name: str # [domain]_[language], name to display in the leaderboard
domain: str
lang: str
task: str
# create a function return an enum class containing all the benchmarks
def get_benchmarks_enum(benchmark_version: str, task_type: TaskType):
benchmark_dict = {}
if task_type == TaskType.qa:
for task, domain_dict in BenchmarkTable[benchmark_version].items():
if task != task_type.value:
continue
for domain, lang_dict in domain_dict.items():
for lang, dataset_list in lang_dict.items():
benchmark_name = get_safe_name(f"{domain}_{lang}")
col_name = benchmark_name
for metric in dataset_list:
if "test" not in dataset_list[metric]["splits"]:
continue
benchmark_dict[benchmark_name] = Benchmark(
benchmark_name, metric, col_name, domain, lang, task
)
elif task_type == TaskType.long_doc:
for task, domain_dict in BenchmarkTable[benchmark_version].items():
if task != task_type.value:
continue
for domain, lang_dict in domain_dict.items():
for lang, dataset_list in lang_dict.items():
for dataset in dataset_list:
benchmark_name = f"{domain}_{lang}_{dataset}"
benchmark_name = get_safe_name(benchmark_name)
col_name = benchmark_name
if "test" not in dataset_list[dataset]["splits"]:
continue
for metric in METRIC_LIST:
benchmark_dict[benchmark_name] = Benchmark(
benchmark_name, metric, col_name, domain, lang, task
)
return benchmark_dict
qa_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
safe_version_name = get_safe_name(version)[-4:]
qa_benchmark_dict[safe_version_name] = Enum(
f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.qa)
)
long_doc_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
safe_version_name = get_safe_name(version)[-4:]
long_doc_benchmark_dict[safe_version_name] = Enum(
f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.long_doc)
)
QABenchmarks = Enum("QABenchmarks", qa_benchmark_dict)
LongDocBenchmarks = Enum("LongDocBenchmarks", long_doc_benchmark_dict)
|