Spaces:
Runtime error
Runtime error
import os | |
import json | |
import glob | |
from collections import defaultdict | |
import gradio as gr | |
from content import * | |
import glob | |
ARC = "arc" | |
HELLASWAG = "hellaswag" | |
MMLU = "mmlu" | |
TRUTHFULQA = "truthfulqa" | |
BENCHMARKS = [ARC, HELLASWAG, MMLU, TRUTHFULQA] | |
METRICS = ["acc_norm", "acc_norm", "acc_norm", "mc2"] | |
def collect_results(): | |
performance_dict = defaultdict(dict) | |
pretrained_models = set() | |
for file in glob.glob('evals/*/*.json'): | |
with open(file, 'r') as f: | |
data = json.load(f) | |
if 'results' not in data: | |
continue | |
if 'config' not in data: | |
continue | |
results = data['results'] | |
config = data['config'] | |
if 'model_args' not in config: | |
continue | |
model_args = config['model_args'].split(',') | |
pretrained = [x for x in model_args if x.startswith('pretrained=')] | |
if len(pretrained) != 1: | |
continue | |
pretrained = pretrained[0].split('=')[1] | |
pretrained = pretrained.split('/')[-1] | |
pretrained_models.add(pretrained) | |
for lang_task, perfs in results.items(): | |
task, lang = lang_task.split('_') | |
assert task in BENCHMARKS | |
if lang and task: | |
metric = METRICS[BENCHMARKS.index(task)] | |
p = round(perfs[metric] * 100, 1) | |
performance_dict[(pretrained, lang)][task] = p | |
return performance_dict, pretrained_models | |
def get_leaderboard_df(performance_dict, pretrained_models): | |
df = list() | |
for (pretrained, lang), perfs in performance_dict.items(): | |
arc_perf = perfs.get(ARC, 0.0) | |
hellaswag_perf = perfs.get(HELLASWAG, 0.0) | |
mmlu_perf = perfs.get(MMLU, 0.0) | |
truthfulqa_perf = perfs.get(TRUTHFULQA, 0.0) | |
if arc_perf * hellaswag_perf * mmlu_perf * truthfulqa_perf == 0: | |
continue | |
avg = round((arc_perf + hellaswag_perf + mmlu_perf + truthfulqa_perf) / 4, 1) | |
row = [pretrained, lang, avg, arc_perf, hellaswag_perf, mmlu_perf, truthfulqa_perf] | |
df.append(row) | |
return df | |
MODEL_COL = "Model" | |
LANG_COL = "Language" | |
AVERAGE_COL = "Average" | |
ARC_COL = "ARC (25-shot)" | |
HELLASWAG_COL = "HellaSwag (10-shot)️" | |
MMLU_COL = "MMLU (5-shot)" | |
TRUTHFULQA_COL = "TruthfulQA (0-shot)" | |
COLS = [MODEL_COL, LANG_COL, AVERAGE_COL, ARC_COL, HELLASWAG_COL, MMLU_COL, TRUTHFULQA_COL] | |
TYPES = ["str", "str", "number", "number", "number", "number", "number"] | |
args = collect_results() | |
leaderboard_df = get_leaderboard_df(*args) | |
demo = gr.Blocks() | |
with demo: | |
gr.HTML(TITLE) | |
gr.Markdown(INTRO_TEXT, elem_classes="markdown-text") | |
gr.Markdown(HOW_TO, elem_classes="markdown-text") | |
with gr.Box(): | |
search_bar = gr.Textbox( | |
placeholder="Search models...", show_label=False, elem_id="search-bar" | |
) | |
leaderboard_table = gr.components.Dataframe( | |
value=leaderboard_df, | |
headers=COLS, | |
datatype=TYPES, | |
max_rows=5, | |
elem_id="leaderboard-table", | |
) | |
gr.Markdown(CITATION, elem_classes="markdown-text") | |
demo.launch() | |