Spaces:
AIR-Bench
/
Running on CPU Upgrade

File size: 3,190 Bytes
8b7a945
 
 
649e0fb
8b7a945
ec8e2d4
6f9f649
8b7a945
 
649e0fb
 
ec8e2d4
 
649e0fb
 
8b7a945
 
9c49811
8b7a945
 
9c49811
 
 
8b7a945
a96f80a
3fcf957
6f9f649
7845083
6f9f649
7845083
6f9f649
7845083
 
 
 
f30cbcc
3fcf957
7845083
 
ec8e2d4
 
 
6f9f649
7845083
6f9f649
7845083
 
 
3fcf957
 
 
 
bf586e3
 
3fcf957
ec8e2d4
 
 
7845083
8b7a945
3fcf957
7845083
1a22df4
7845083
ec8e2d4
6f9f649
ec8e2d4
7845083
 
1a22df4
7845083
ec8e2d4
6f9f649
ec8e2d4
7845083
 
ec8e2d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from dataclasses import dataclass
from enum import Enum

from air_benchmark.tasks.tasks import BenchmarkTable

from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
from src.models import TaskType


def get_safe_name(name: str):
    """Get RFC 1123 compatible safe name"""
    name = name.replace("-", "_")
    return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))


@dataclass
class Benchmark:
    name: str  # [domain]_[language]_[metric], task_key in the json file,
    metric: str  # ndcg_at_1 ,metric_key in the json file
    col_name: str  # [domain]_[language], name to display in the leaderboard
    domain: str
    lang: str
    task: str


# create a function return an enum class containing all the benchmarks
def get_benchmarks_enum(benchmark_version: str, task_type: TaskType):
    benchmark_dict = {}
    if task_type == TaskType.qa:
        for task, domain_dict in BenchmarkTable[benchmark_version].items():
            if task != task_type.value:
                continue
            for domain, lang_dict in domain_dict.items():
                for lang, dataset_list in lang_dict.items():
                    benchmark_name = get_safe_name(f"{domain}_{lang}")
                    col_name = benchmark_name
                    for metric in dataset_list:
                        if "test" not in dataset_list[metric]["splits"]:
                            continue
                        benchmark_dict[benchmark_name] = Benchmark(
                            benchmark_name, metric, col_name, domain, lang, task
                        )
    elif task_type == TaskType.long_doc:
        for task, domain_dict in BenchmarkTable[benchmark_version].items():
            if task != task_type.value:
                continue
            for domain, lang_dict in domain_dict.items():
                for lang, dataset_list in lang_dict.items():
                    for dataset in dataset_list:
                        benchmark_name = f"{domain}_{lang}_{dataset}"
                        benchmark_name = get_safe_name(benchmark_name)
                        col_name = benchmark_name
                        if "test" not in dataset_list[dataset]["splits"]:
                            continue
                        for metric in METRIC_LIST:
                            benchmark_dict[benchmark_name] = Benchmark(
                                benchmark_name, metric, col_name, domain, lang, task
                            )
    return benchmark_dict


qa_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
    safe_version_name = get_safe_name(version)[-4:]
    qa_benchmark_dict[safe_version_name] = Enum(
        f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.qa)
    )

long_doc_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
    safe_version_name = get_safe_name(version)[-4:]
    long_doc_benchmark_dict[safe_version_name] = Enum(
        f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.long_doc)
    )


QABenchmarks = Enum("QABenchmarks", qa_benchmark_dict)
LongDocBenchmarks = Enum("LongDocBenchmarks", long_doc_benchmark_dict)