Spaces:
Runtime error
Runtime error
胥基
commited on
Commit
·
193b86e
1
Parent(s):
fefeca2
copy gaia-leaderboard
Browse files- .gitattributes +0 -1
- README.md +9 -7
- app.py +254 -0
- content.py +47 -0
- requirements.txt +5 -0
- scorer.py +101 -0
.gitattributes
CHANGED
@@ -25,7 +25,6 @@
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
license:
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: GAIA Leaderboard
|
3 |
+
emoji: 🦾
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.3.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: apache-2.0
|
11 |
+
tags:
|
12 |
+
- leaderboard
|
13 |
---
|
14 |
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import datetime
|
4 |
+
from email.utils import parseaddr
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from datasets import load_dataset
|
11 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
12 |
+
from huggingface_hub import HfApi
|
13 |
+
|
14 |
+
# InfoStrings
|
15 |
+
from scorer import question_scorer
|
16 |
+
from content import format_error, format_warning, format_log, TITLE, INTRODUCTION_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT, model_hyperlink
|
17 |
+
|
18 |
+
TOKEN = os.environ.get("TOKEN", None)
|
19 |
+
|
20 |
+
OWNER="gaia-benchmark"
|
21 |
+
DATA_DATASET = f"{OWNER}/GAIA"
|
22 |
+
INTERNAL_DATA_DATASET = f"{OWNER}/GAIA_internal"
|
23 |
+
SUBMISSION_DATASET = f"{OWNER}/submissions_internal"
|
24 |
+
CONTACT_DATASET = f"{OWNER}/contact_info"
|
25 |
+
RESULTS_DATASET = f"{OWNER}/results_public"
|
26 |
+
LEADERBOARD_PATH = f"{OWNER}/leaderboard"
|
27 |
+
api = HfApi()
|
28 |
+
|
29 |
+
YEAR_VERSION = "2023"
|
30 |
+
|
31 |
+
os.makedirs("scored", exist_ok=True)
|
32 |
+
|
33 |
+
# Display the results
|
34 |
+
eval_results = load_dataset(RESULTS_DATASET, YEAR_VERSION, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
|
35 |
+
contact_infos = load_dataset(CONTACT_DATASET, YEAR_VERSION, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
|
36 |
+
def get_dataframe_from_results(eval_results, split):
|
37 |
+
local_df = eval_results[split]
|
38 |
+
local_df = local_df.map(lambda row: {"model": model_hyperlink(row["url"], row["model"])})
|
39 |
+
local_df = local_df.remove_columns(["system_prompt", "url"])
|
40 |
+
local_df = local_df.rename_column("model", "Model name")
|
41 |
+
local_df = local_df.rename_column("model_family", "Model family")
|
42 |
+
local_df = local_df.rename_column("score", "Average score (%)")
|
43 |
+
for i in [1, 2, 3]:
|
44 |
+
local_df = local_df.rename_column(f"score_level{i}", f"Level {i} score (%)")
|
45 |
+
df = pd.DataFrame(local_df)
|
46 |
+
df = df.sort_values(by=["Average score (%)"], ascending=False)
|
47 |
+
|
48 |
+
numeric_cols = [c for c in local_df.column_names if "score" in c]
|
49 |
+
df[numeric_cols] = df[numeric_cols].multiply(100).round(decimals=2)
|
50 |
+
#df = df.style.format("{:.2%}", subset=numeric_cols)
|
51 |
+
|
52 |
+
return df
|
53 |
+
|
54 |
+
eval_dataframe_val = get_dataframe_from_results(eval_results=eval_results, split="validation")
|
55 |
+
eval_dataframe_test = get_dataframe_from_results(eval_results=eval_results, split="test")
|
56 |
+
|
57 |
+
# Gold answers
|
58 |
+
gold_results = {}
|
59 |
+
gold_dataset = load_dataset(INTERNAL_DATA_DATASET, f"{YEAR_VERSION}_all", token=TOKEN)
|
60 |
+
gold_results = {split: {row["task_id"]: row for row in gold_dataset[split]} for split in ["test", "validation"]}
|
61 |
+
|
62 |
+
|
63 |
+
def restart_space():
|
64 |
+
api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)
|
65 |
+
|
66 |
+
TYPES = ["markdown", "number", "number", "number", "number", "str", "str"]
|
67 |
+
|
68 |
+
def add_new_eval(
|
69 |
+
val_or_test: str,
|
70 |
+
model: str,
|
71 |
+
model_family: str,
|
72 |
+
system_prompt: str,
|
73 |
+
url: str,
|
74 |
+
path_to_file: str,
|
75 |
+
organisation: str,
|
76 |
+
mail: str,
|
77 |
+
):
|
78 |
+
# Very basic email parsing
|
79 |
+
_, parsed_mail = parseaddr(mail)
|
80 |
+
if not "@" in parsed_mail:
|
81 |
+
return format_warning("Please provide a valid email adress.")
|
82 |
+
|
83 |
+
print("Adding new eval")
|
84 |
+
|
85 |
+
# Check if the combination model/org already exists and prints a warning message if yes
|
86 |
+
if model.lower() in set([m.lower() for m in eval_results[val_or_test]["model"]]) and organisation.lower() in set([o.lower() for l in eval_results[val_or_test]["organisation"]]):
|
87 |
+
return format_warning("This model has been already submitted.")
|
88 |
+
|
89 |
+
if path_to_file is None:
|
90 |
+
return format_warning("Please attach a file.")
|
91 |
+
|
92 |
+
# Save submitted file
|
93 |
+
api.upload_file(
|
94 |
+
repo_id=SUBMISSION_DATASET,
|
95 |
+
path_or_fileobj=path_to_file.name,
|
96 |
+
path_in_repo=f"{organisation}/{model}/{YEAR_VERSION}_{val_or_test}_raw_{datetime.datetime.today()}.jsonl",
|
97 |
+
repo_type="dataset",
|
98 |
+
token=TOKEN
|
99 |
+
)
|
100 |
+
|
101 |
+
# Compute score
|
102 |
+
file_path = path_to_file.name
|
103 |
+
scores = {"all": 0, 1: 0, 2: 0, 3: 0}
|
104 |
+
num_questions = {"all": 0, 1: 0, 2: 0, 3: 0}
|
105 |
+
with open(f"scored/{organisation}_{model}.jsonl", "w") as scored_file:
|
106 |
+
with open(file_path, 'r') as f:
|
107 |
+
for ix, line in enumerate(f):
|
108 |
+
try:
|
109 |
+
task = json.loads(line)
|
110 |
+
except Exception:
|
111 |
+
return format_error(f"Line {ix} is incorrectly formatted. Please fix it and resubmit your file.")
|
112 |
+
|
113 |
+
if "model_answer" not in task:
|
114 |
+
raise format_error(f"Line {ix} contains no model_answer key. Please fix it and resubmit your file.")
|
115 |
+
answer = task["model_answer"]
|
116 |
+
task_id = task["task_id"]
|
117 |
+
try:
|
118 |
+
level = int(gold_results[val_or_test][task_id]["Level"])
|
119 |
+
except KeyError:
|
120 |
+
return format_error(f"{task_id} not found in split {val_or_test}. Are you sure you submitted the correct file?")
|
121 |
+
|
122 |
+
score = question_scorer(task['model_answer'], gold_results[val_or_test][task_id]["Final answer"])
|
123 |
+
|
124 |
+
scored_file.write(
|
125 |
+
json.dumps({
|
126 |
+
"id": task_id,
|
127 |
+
"model_answer": answer,
|
128 |
+
"score": score,
|
129 |
+
"level": level
|
130 |
+
}) + "\n"
|
131 |
+
)
|
132 |
+
|
133 |
+
scores["all"] += score
|
134 |
+
scores[level] += score
|
135 |
+
num_questions["all"] += 1
|
136 |
+
num_questions[level] += 1
|
137 |
+
|
138 |
+
# Save scored file
|
139 |
+
api.upload_file(
|
140 |
+
repo_id=SUBMISSION_DATASET,
|
141 |
+
path_or_fileobj=f"scored/{organisation}_{model}.jsonl",
|
142 |
+
path_in_repo=f"{organisation}/{model}/{YEAR_VERSION}_{val_or_test}_scored_{datetime.datetime.today()}.jsonl",
|
143 |
+
repo_type="dataset",
|
144 |
+
token=TOKEN
|
145 |
+
)
|
146 |
+
|
147 |
+
# Actual submission
|
148 |
+
eval_entry = {
|
149 |
+
"model": model,
|
150 |
+
"model_family": model_family,
|
151 |
+
"system_prompt": system_prompt,
|
152 |
+
"url": url,
|
153 |
+
"organisation": organisation,
|
154 |
+
"score": scores["all"]/num_questions["all"],
|
155 |
+
"score_level1": scores[1]/num_questions[1],
|
156 |
+
"score_level2": scores[2]/num_questions[2],
|
157 |
+
"score_level3": scores[3]/num_questions[3],
|
158 |
+
}
|
159 |
+
eval_results[val_or_test] = eval_results[val_or_test].add_item(eval_entry)
|
160 |
+
print(eval_results)
|
161 |
+
eval_results.push_to_hub(RESULTS_DATASET, config_name = YEAR_VERSION, token=TOKEN)
|
162 |
+
|
163 |
+
contact_info = {
|
164 |
+
"model": model,
|
165 |
+
"model_family": model_family,
|
166 |
+
"url": url,
|
167 |
+
"organisation": organisation,
|
168 |
+
"mail": mail,
|
169 |
+
}
|
170 |
+
contact_infos[val_or_test]= contact_infos[val_or_test].add_item(contact_info)
|
171 |
+
contact_infos.push_to_hub(CONTACT_DATASET, config_name = YEAR_VERSION, token=TOKEN)
|
172 |
+
|
173 |
+
return format_log(f"Model {model} submitted by {organisation} successfully. \nPlease refresh the leaderboard, and wait a bit to see the score displayed")
|
174 |
+
|
175 |
+
|
176 |
+
def refresh():
|
177 |
+
eval_results = load_dataset(RESULTS_DATASET, YEAR_VERSION, token=TOKEN, download_mode="force_redownload", ignore_verifications=True)
|
178 |
+
eval_dataframe_val = get_dataframe_from_results(eval_results=eval_results, split="validation")
|
179 |
+
eval_dataframe_test = get_dataframe_from_results(eval_results=eval_results, split="test")
|
180 |
+
return eval_dataframe_val, eval_dataframe_test
|
181 |
+
|
182 |
+
def upload_file(files):
|
183 |
+
file_paths = [file.name for file in files]
|
184 |
+
return file_paths
|
185 |
+
|
186 |
+
|
187 |
+
demo = gr.Blocks()
|
188 |
+
with demo:
|
189 |
+
gr.HTML(TITLE)
|
190 |
+
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
with gr.Accordion("📙 Citation", open=False):
|
194 |
+
citation_button = gr.Textbox(
|
195 |
+
value=CITATION_BUTTON_TEXT,
|
196 |
+
label=CITATION_BUTTON_LABEL,
|
197 |
+
elem_id="citation-button",
|
198 |
+
) #.style(show_copy_button=True)
|
199 |
+
|
200 |
+
with gr.Tab("Results: Test"):
|
201 |
+
leaderboard_table_test = gr.components.Dataframe(
|
202 |
+
value=eval_dataframe_test, datatype=TYPES, interactive=False,
|
203 |
+
column_widths=["20%"]
|
204 |
+
)
|
205 |
+
with gr.Tab("Results: Validation"):
|
206 |
+
leaderboard_table_val = gr.components.Dataframe(
|
207 |
+
value=eval_dataframe_val, datatype=TYPES, interactive=False,
|
208 |
+
column_widths=["20%"]
|
209 |
+
)
|
210 |
+
|
211 |
+
refresh_button = gr.Button("Refresh")
|
212 |
+
refresh_button.click(
|
213 |
+
refresh,
|
214 |
+
inputs=[],
|
215 |
+
outputs=[
|
216 |
+
leaderboard_table_val,
|
217 |
+
leaderboard_table_test,
|
218 |
+
],
|
219 |
+
)
|
220 |
+
with gr.Accordion("Submit a new model for evaluation"):
|
221 |
+
with gr.Row():
|
222 |
+
with gr.Column():
|
223 |
+
level_of_test = gr.Radio(["validation", "test"], value="validation", label="Split")
|
224 |
+
model_name_textbox = gr.Textbox(label="Model name")
|
225 |
+
model_family_textbox = gr.Textbox(label="Model family")
|
226 |
+
system_prompt_textbox = gr.Textbox(label="System prompt example")
|
227 |
+
url_textbox = gr.Textbox(label="Url to model information")
|
228 |
+
with gr.Column():
|
229 |
+
organisation = gr.Textbox(label="Organisation")
|
230 |
+
mail = gr.Textbox(label="Contact email (will be stored privately, & used if there is an issue with your submission)")
|
231 |
+
file_output = gr.File()
|
232 |
+
|
233 |
+
|
234 |
+
submit_button = gr.Button("Submit Eval")
|
235 |
+
submission_result = gr.Markdown()
|
236 |
+
submit_button.click(
|
237 |
+
add_new_eval,
|
238 |
+
[
|
239 |
+
level_of_test,
|
240 |
+
model_name_textbox,
|
241 |
+
model_family_textbox,
|
242 |
+
system_prompt_textbox,
|
243 |
+
url_textbox,
|
244 |
+
file_output,
|
245 |
+
organisation,
|
246 |
+
mail
|
247 |
+
],
|
248 |
+
submission_result,
|
249 |
+
)
|
250 |
+
|
251 |
+
scheduler = BackgroundScheduler()
|
252 |
+
scheduler.add_job(restart_space, "interval", seconds=3600)
|
253 |
+
scheduler.start()
|
254 |
+
demo.launch(debug=True)
|
content.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TITLE = """<h1 align="center" id="space-title">GAIA Leaderboard</h1>"""
|
2 |
+
|
3 |
+
INTRODUCTION_TEXT = """
|
4 |
+
GAIA is a benchmark which aims at evaluating next-generation LLMs (LLMs with augmented capabilities due to added tooling, efficient prompting, access to search, etc). (See our [paper](https://arxiv.org/abs/2311.12983) for more details.)
|
5 |
+
|
6 |
+
## Data
|
7 |
+
GAIA is made of more than 450 non-trivial question with an unambiguous answer, requiring different levels of tooling and autonomy to solve.
|
8 |
+
It is therefore divided in 3 levels, where level 1 should be breakable by very good LLMs, and level 3 indicate a strong jump in model capabilities. Each level is divided into a fully public dev set for validation, and a test set with private answers and metadata.
|
9 |
+
|
10 |
+
GAIA data can be found in [this dataset](https://huggingface.co/datasets/gaia-benchmark/GAIA). Questions are contained in `metadata.jsonl`. Some questions come with an additional file, that can be found in the same folder and whose id is given in the field `file_name`.
|
11 |
+
|
12 |
+
## Submissions
|
13 |
+
Results can be submitted for both validation and test. Scores are expressed as the percentage of correct answers for a given split.
|
14 |
+
|
15 |
+
We expect submissions to be json-line files with the following format. The first two fields are mandatory, `reasoning_trace` is optionnal:
|
16 |
+
```
|
17 |
+
{"task_id": "task_id_1", "model_answer": "Answer 1 from your model", "reasoning_trace": "The different steps by which your model reached answer 1"}
|
18 |
+
{"task_id": "task_id_2", "model_answer": "Answer 2 from your model", "reasoning_trace": "The different steps by which your model reached answer 2"}
|
19 |
+
```
|
20 |
+
Submission made by our team are labelled "GAIA authors". While we report average scores over different runs when possible in our paper, we only report the best run in the leaderboard.
|
21 |
+
|
22 |
+
**Please do not repost the public dev set, nor use it in training data for your models.**
|
23 |
+
"""
|
24 |
+
|
25 |
+
CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
|
26 |
+
CITATION_BUTTON_TEXT = r"""@misc{mialon2023gaia,
|
27 |
+
title={GAIA: a benchmark for General AI Assistants},
|
28 |
+
author={Grégoire Mialon and Clémentine Fourrier and Craig Swift and Thomas Wolf and Yann LeCun and Thomas Scialom},
|
29 |
+
year={2023},
|
30 |
+
eprint={2311.12983},
|
31 |
+
archivePrefix={arXiv},
|
32 |
+
primaryClass={cs.CL}
|
33 |
+
}"""
|
34 |
+
|
35 |
+
|
36 |
+
def format_error(msg):
|
37 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{msg}</p>"
|
38 |
+
|
39 |
+
def format_warning(msg):
|
40 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{msg}</p>"
|
41 |
+
|
42 |
+
def format_log(msg):
|
43 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{msg}</p>"
|
44 |
+
|
45 |
+
def model_hyperlink(link, model_name):
|
46 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
47 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.14.5
|
2 |
+
gradio==4.3.0
|
3 |
+
huggingface-hub==0.18.0
|
4 |
+
numpy==1.24.2
|
5 |
+
APScheduler==3.10.1
|
scorer.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import string
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def normalize_number_str(number_str: str) -> float:
|
10 |
+
# we replace these common units and commas to allow
|
11 |
+
# conversion to float
|
12 |
+
for char in ["$", "%", ","]:
|
13 |
+
number_str = number_str.replace(char, "")
|
14 |
+
try:
|
15 |
+
return float(number_str)
|
16 |
+
except ValueError:
|
17 |
+
print(f"String {number_str} cannot be normalized to number str.")
|
18 |
+
return float("inf")
|
19 |
+
|
20 |
+
|
21 |
+
def split_string(
|
22 |
+
s: str,
|
23 |
+
char_list: list[str] = [",", ";"],
|
24 |
+
) -> list[str]:
|
25 |
+
pattern = f"[{''.join(char_list)}]"
|
26 |
+
return re.split(pattern, s)
|
27 |
+
|
28 |
+
|
29 |
+
def question_scorer(
|
30 |
+
model_answer: str,
|
31 |
+
ground_truth: str,
|
32 |
+
) -> bool:
|
33 |
+
def is_float(element: any) -> bool:
|
34 |
+
try:
|
35 |
+
float(element)
|
36 |
+
return True
|
37 |
+
except ValueError:
|
38 |
+
return False
|
39 |
+
|
40 |
+
# if gt is a number
|
41 |
+
if is_float(ground_truth):
|
42 |
+
print(f"Evaluating {model_answer} as a number.")
|
43 |
+
normalized_answer = normalize_number_str(model_answer)
|
44 |
+
return normalized_answer == float(ground_truth)
|
45 |
+
|
46 |
+
# if gt is a list
|
47 |
+
elif any(char in ground_truth for char in [",", ";"]):
|
48 |
+
print(f"Evaluating {model_answer} as a comma separated list.")
|
49 |
+
# question with the fish: normalization removes punct
|
50 |
+
|
51 |
+
gt_elems = split_string(ground_truth)
|
52 |
+
ma_elems = split_string(model_answer)
|
53 |
+
|
54 |
+
# check length is the same
|
55 |
+
if len(gt_elems) != len(ma_elems):
|
56 |
+
warnings.warn(
|
57 |
+
"Answer lists have different lengths, returning False.", UserWarning
|
58 |
+
)
|
59 |
+
return False
|
60 |
+
|
61 |
+
# compare each element as float or str
|
62 |
+
comparisons = []
|
63 |
+
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
64 |
+
if is_float(gt_elem):
|
65 |
+
normalized_ma_elem = normalize_number_str(ma_elem)
|
66 |
+
comparisons.append(normalized_ma_elem == float(gt_elem))
|
67 |
+
else:
|
68 |
+
# we do not remove punct since comparisons can include punct
|
69 |
+
comparisons.append(
|
70 |
+
normalize_str(ma_elem, remove_punct=False)
|
71 |
+
== normalize_str(gt_elem, remove_punct=False)
|
72 |
+
)
|
73 |
+
return all(comparisons)
|
74 |
+
|
75 |
+
# if gt is a str
|
76 |
+
else:
|
77 |
+
print(f"Evaluating {model_answer} as a string.")
|
78 |
+
return normalize_str(model_answer) == normalize_str(ground_truth)
|
79 |
+
|
80 |
+
|
81 |
+
def normalize_str(input_str, remove_punct=True) -> str:
|
82 |
+
"""
|
83 |
+
Normalize a string by:
|
84 |
+
- Removing all white spaces
|
85 |
+
- Optionally removing punctuation (if remove_punct is True)
|
86 |
+
- Converting to lowercase
|
87 |
+
Parameters:
|
88 |
+
- input_str: str, the string to normalize
|
89 |
+
- remove_punct: bool, whether to remove punctuation (default: True)
|
90 |
+
Returns:
|
91 |
+
- str, the normalized string
|
92 |
+
"""
|
93 |
+
# Remove all white spaces. Required e.g for seagull vs. sea gull
|
94 |
+
no_spaces = re.sub(r"\s", "", input_str)
|
95 |
+
|
96 |
+
# Remove punctuation, if specified.
|
97 |
+
if remove_punct:
|
98 |
+
translator = str.maketrans("", "", string.punctuation)
|
99 |
+
return no_spaces.lower().translate(translator)
|
100 |
+
else:
|
101 |
+
return no_spaces.lower()
|