edbeeching commited on
Commit
9346f1c
β€’
1 Parent(s): 4ff62ee

creates leaderboard

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +110 -0
  3. requirements.txt +66 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ evals/
2
+ venv/
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import numpy as np
4
+ import gradio as gr
5
+ from huggingface_hub import Repository
6
+ import json
7
+ from apscheduler.schedulers.background import BackgroundScheduler
8
+ import pandas as pd
9
+ # clone / pull the lmeh eval data
10
+ H4_TOKEN = os.environ.get("H4_TOKEN", None)
11
+ repo=None
12
+ if H4_TOKEN:
13
+ # try:
14
+ # shutil.rmtree("./evals/")
15
+ # except:
16
+ # pass
17
+
18
+ repo = Repository(
19
+ local_dir="./evals/", clone_from="HuggingFaceH4/lmeh_evaluations", use_auth_token=H4_TOKEN, repo_type="dataset"
20
+ )
21
+ repo.git_pull()
22
+
23
+
24
+ # parse the results
25
+ BENCHMARKS = ["arc_challenge", "hellaswag", "hendrycks", "truthfulqa_mc"]
26
+ BENCH_TO_NAME = {
27
+ "arc_challenge":"ARC",
28
+ "hellaswag":"HellaSwag",
29
+ "hendrycks":"MMLU",
30
+ "truthfulqa_mc":"TruthQA",
31
+ }
32
+ METRICS = ["acc_norm", "acc_norm", "acc_norm", "mc2"]
33
+
34
+ entries = [entry for entry in os.listdir("evals") if not entry.startswith('.')]
35
+ model_directories = [entry for entry in entries if os.path.isdir(os.path.join("evals", entry))]
36
+
37
+
38
+ def make_clickable_model(model_name):
39
+ # remove user from model name
40
+ #model_name_show = ' '.join(model_name.split('/')[1:])
41
+
42
+ link = "https://huggingface.co/" + model_name
43
+ return f'<a target="_blank" href="{link}" style="color: blue; text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
44
+
45
+ def load_results(model, benchmark, metric):
46
+ file_path = os.path.join("evals", model, f"{model}-eval_{benchmark}.json")
47
+ if not os.path.exists(file_path):
48
+ return 0.0, None
49
+
50
+ with open(file_path) as fp:
51
+ data = json.load(fp)
52
+ accs = np.array([v[metric] for k, v in data["results"].items()])
53
+ mean_acc = np.mean(accs)
54
+ return mean_acc, data["config"]["model_args"]
55
+
56
+ COLS = ["eval_name", "total", "ARC", "HellaSwag", "MMLU", "TruthQA", "base_model"]
57
+ TYPES = ["str", "number", "number", "number", "number", "number","markdown", ]
58
+ def get_leaderboard():
59
+ if repo:
60
+ repo.git_pull()
61
+ all_data = []
62
+ for model in model_directories:
63
+ model_data = {"base_model": None}
64
+ model_data = {"eval_name": model}
65
+
66
+ for benchmark, metric in zip(BENCHMARKS, METRICS):
67
+ value, base_model = load_results(model, benchmark, metric)
68
+ model_data[BENCH_TO_NAME[benchmark]] = value
69
+ if base_model is not None: # in case the last benchmark failed
70
+ model_data["base_model"] = base_model
71
+
72
+ model_data["total"] = sum(model_data[benchmark] for benchmark in BENCH_TO_NAME.values())
73
+
74
+ if model_data["base_model"] is not None:
75
+ model_data["base_model"] = make_clickable_model(model_data["base_model"])
76
+ all_data.append(model_data)
77
+
78
+ dataframe = pd.DataFrame.from_records(all_data)
79
+ dataframe = dataframe.sort_values(by=['total'], ascending=False)
80
+
81
+ dataframe = dataframe[COLS]
82
+ return dataframe
83
+
84
+ leaderboard = get_leaderboard()
85
+
86
+ block = gr.Blocks()
87
+ with block:
88
+ gr.Markdown(f"""
89
+ # H4 Model Evaluation leaderboard using the <a href="https://github.com/EleutherAI/lm-evaluation-harness" target="_blank"> LMEH benchmark suite </a>.
90
+ Evaluation is performed against 4 popular benchmarks AI2 Reasoning Challenge, HellaSwag, MMLU, and TruthFul QC MC. To run your own benchmarks, refer to the README in the H4 repo.
91
+ """)
92
+
93
+ with gr.Row():
94
+ leaderboard_table = gr.components.Dataframe(value=leaderboard, headers=COLS,
95
+ datatype=TYPES, max_rows=5)
96
+ with gr.Row():
97
+ refresh_button = gr.Button("Refresh")
98
+ refresh_button.click(get_leaderboard, inputs=[], outputs=leaderboard_table)
99
+
100
+
101
+
102
+ block.launch()
103
+
104
+ def refresh_leaderboard():
105
+ leaderboard_table = get_leaderboard()
106
+ print("leaderboard updated")
107
+
108
+ scheduler = BackgroundScheduler()
109
+ scheduler.add_job(func=refresh_leaderboard, trigger="interval", seconds=300) # refresh every 5 mins
110
+ scheduler.start()
requirements.txt ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ APScheduler==3.10.1
7
+ async-timeout==4.0.2
8
+ attrs==23.1.0
9
+ certifi==2022.12.7
10
+ charset-normalizer==3.1.0
11
+ click==8.1.3
12
+ contourpy==1.0.7
13
+ cycler==0.11.0
14
+ entrypoints==0.4
15
+ fastapi==0.95.1
16
+ ffmpy==0.3.0
17
+ filelock==3.11.0
18
+ fonttools==4.39.3
19
+ frozenlist==1.3.3
20
+ fsspec==2023.4.0
21
+ gradio==3.27.0
22
+ gradio_client==0.1.3
23
+ h11==0.14.0
24
+ httpcore==0.17.0
25
+ httpx==0.24.0
26
+ huggingface-hub==0.13.4
27
+ idna==3.4
28
+ Jinja2==3.1.2
29
+ jsonschema==4.17.3
30
+ kiwisolver==1.4.4
31
+ linkify-it-py==2.0.0
32
+ markdown-it-py==2.2.0
33
+ MarkupSafe==2.1.2
34
+ matplotlib==3.7.1
35
+ mdit-py-plugins==0.3.3
36
+ mdurl==0.1.2
37
+ multidict==6.0.4
38
+ numpy==1.24.2
39
+ orjson==3.8.10
40
+ packaging==23.1
41
+ pandas==2.0.0
42
+ Pillow==9.5.0
43
+ pydantic==1.10.7
44
+ pydub==0.25.1
45
+ pyparsing==3.0.9
46
+ pyrsistent==0.19.3
47
+ python-dateutil==2.8.2
48
+ python-multipart==0.0.6
49
+ pytz==2023.3
50
+ pytz-deprecation-shim==0.1.0.post0
51
+ PyYAML==6.0
52
+ requests==2.28.2
53
+ semantic-version==2.10.0
54
+ six==1.16.0
55
+ sniffio==1.3.0
56
+ starlette==0.26.1
57
+ toolz==0.12.0
58
+ tqdm==4.65.0
59
+ typing_extensions==4.5.0
60
+ tzdata==2023.3
61
+ tzlocal==4.3
62
+ uc-micro-py==1.0.1
63
+ urllib3==1.26.15
64
+ uvicorn==0.21.1
65
+ websockets==11.0.1
66
+ yarl==1.8.2