Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
refactor: reformat with black
Browse files- app.py +143 -125
- src/about.py +1 -1
- src/benchmarks.py +17 -14
- src/display/columns.py +33 -16
- src/display/components.py +13 -16
- src/envs.py +4 -2
- src/loaders.py +28 -26
- src/models.py +21 -11
- src/utils.py +136 -99
- tests/src/display/test_utils.py +14 -7
- tests/src/test_benchmarks.py +1 -2
- tests/src/test_read_evals.py +15 -7
- tests/test_utils.py +64 -34
app.py
CHANGED
@@ -4,39 +4,38 @@ import gradio as gr
|
|
4 |
from apscheduler.schedulers.background import BackgroundScheduler
|
5 |
from huggingface_hub import snapshot_download
|
6 |
|
7 |
-
from src.about import
|
8 |
-
|
9 |
-
)
|
10 |
-
from src.benchmarks import (
|
11 |
-
QABenchmarks,
|
12 |
-
LongDocBenchmarks
|
13 |
-
)
|
14 |
-
from src.display.css_html_js import custom_css
|
15 |
from src.display.components import (
|
16 |
-
|
17 |
-
get_search_bar,
|
18 |
-
get_reranking_dropdown,
|
19 |
-
get_noreranking_dropdown,
|
20 |
-
get_metric_dropdown,
|
21 |
get_domain_dropdown,
|
22 |
get_language_dropdown,
|
23 |
-
|
|
|
|
|
|
|
24 |
get_revision_and_ts_checkbox,
|
25 |
-
|
|
|
26 |
)
|
|
|
27 |
from src.envs import (
|
28 |
API,
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
EVAL_RESULTS_PATH,
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
32 |
)
|
33 |
from src.loaders import load_eval_results
|
34 |
-
from src.utils import
|
35 |
-
update_metric,
|
36 |
-
set_listeners,
|
37 |
-
reset_rank,
|
38 |
-
remove_html, upload_file, submit_results
|
39 |
-
)
|
40 |
|
41 |
|
42 |
def restart_space():
|
@@ -47,11 +46,15 @@ try:
|
|
47 |
if not os.environ.get("LOCAL_MODE", False):
|
48 |
print("Running in local mode")
|
49 |
snapshot_download(
|
50 |
-
repo_id=RESULTS_REPO,
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
-
except Exception
|
54 |
-
print(
|
55 |
restart_space()
|
56 |
|
57 |
global data
|
@@ -61,29 +64,39 @@ datastore = data[LATEST_BENCHMARK_VERSION]
|
|
61 |
|
62 |
|
63 |
def update_metric_qa(
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
):
|
72 |
-
return update_metric(
|
73 |
-
|
|
|
74 |
|
75 |
|
76 |
def update_metric_long_doc(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
):
|
85 |
-
return update_metric(
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
|
89 |
def update_datastore(version):
|
@@ -93,10 +106,8 @@ def update_datastore(version):
|
|
93 |
selected_domains = get_domain_dropdown(QABenchmarks[datastore.slug])
|
94 |
selected_langs = get_language_dropdown(QABenchmarks[datastore.slug])
|
95 |
selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
|
96 |
-
leaderboard_table = get_leaderboard_table(
|
97 |
-
|
98 |
-
hidden_leaderboard_table = get_leaderboard_table(
|
99 |
-
datastore.raw_df_qa, datastore.types_qa, visible=False)
|
100 |
return selected_domains, selected_langs, selected_rerankings, leaderboard_table, hidden_leaderboard_table
|
101 |
|
102 |
|
@@ -107,10 +118,10 @@ def update_datastore_long_doc(version):
|
|
107 |
selected_domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
|
108 |
selected_langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
|
109 |
selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
|
110 |
-
leaderboard_table = get_leaderboard_table(
|
111 |
-
datastore.leaderboard_df_long_doc, datastore.types_long_doc)
|
112 |
hidden_leaderboard_table = get_leaderboard_table(
|
113 |
-
datastore.raw_df_long_doc, datastore.types_long_doc, visible=False
|
|
|
114 |
return selected_domains, selected_langs, selected_rerankings, leaderboard_table, hidden_leaderboard_table
|
115 |
|
116 |
|
@@ -151,16 +162,16 @@ with demo:
|
|
151 |
with gr.Column():
|
152 |
selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
|
153 |
# shown_table
|
154 |
-
lb_table = get_leaderboard_table(
|
155 |
-
datastore.leaderboard_df_qa, datastore.types_qa)
|
156 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
157 |
-
hidden_lb_table = get_leaderboard_table(
|
158 |
-
datastore.raw_df_qa, datastore.types_qa, visible=False)
|
159 |
|
160 |
selected_version.change(
|
161 |
update_datastore,
|
162 |
-
[
|
163 |
-
|
|
|
|
|
164 |
)
|
165 |
|
166 |
set_listeners(
|
@@ -189,7 +200,7 @@ with demo:
|
|
189 |
show_revision_and_timestamp,
|
190 |
],
|
191 |
lb_table,
|
192 |
-
queue=True
|
193 |
)
|
194 |
|
195 |
with gr.TabItem("Retrieval Only", id=11):
|
@@ -200,28 +211,32 @@ with demo:
|
|
200 |
selected_noreranker = get_noreranking_dropdown()
|
201 |
|
202 |
lb_df_retriever = datastore.leaderboard_df_qa[
|
203 |
-
datastore.leaderboard_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
|
|
204 |
lb_df_retriever = reset_rank(lb_df_retriever)
|
205 |
-
lb_table_retriever = get_leaderboard_table(
|
206 |
-
lb_df_retriever, datastore.types_qa)
|
207 |
|
208 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
209 |
hidden_lb_df_retriever = datastore.raw_df_qa[
|
210 |
-
datastore.raw_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
|
|
211 |
hidden_lb_df_retriever = reset_rank(hidden_lb_df_retriever)
|
212 |
-
hidden_lb_table_retriever = get_leaderboard_table(
|
213 |
-
|
|
|
214 |
|
215 |
selected_version.change(
|
216 |
update_datastore,
|
217 |
-
[
|
|
|
|
|
218 |
[
|
219 |
selected_domains,
|
220 |
selected_langs,
|
221 |
selected_noreranker,
|
222 |
lb_table_retriever,
|
223 |
-
hidden_lb_table_retriever
|
224 |
-
]
|
225 |
)
|
226 |
|
227 |
set_listeners(
|
@@ -250,44 +265,43 @@ with demo:
|
|
250 |
show_revision_and_timestamp,
|
251 |
],
|
252 |
lb_table_retriever,
|
253 |
-
queue=True
|
254 |
)
|
255 |
with gr.TabItem("Reranking Only", id=12):
|
256 |
-
lb_df_reranker =
|
257 |
-
datastore.leaderboard_df_qa[
|
258 |
-
|
259 |
-
COL_NAME_RETRIEVAL_MODEL
|
260 |
-
] == BM25_LINK
|
261 |
-
]
|
262 |
lb_df_reranker = reset_rank(lb_df_reranker)
|
263 |
-
reranking_models_reranker =
|
264 |
-
remove_html).unique().tolist()
|
|
|
265 |
with gr.Row():
|
266 |
with gr.Column(scale=1):
|
267 |
selected_rerankings_reranker = get_reranking_dropdown(reranking_models_reranker)
|
268 |
with gr.Column(scale=1):
|
269 |
search_bar_reranker = gr.Textbox(show_label=False, visible=False)
|
270 |
-
lb_table_reranker = get_leaderboard_table(
|
271 |
-
lb_df_reranker, datastore.types_qa)
|
272 |
|
273 |
hidden_lb_df_reranker = datastore.raw_df_qa[
|
274 |
-
datastore.raw_df_qa[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
|
|
275 |
hidden_lb_df_reranker = reset_rank(hidden_lb_df_reranker)
|
276 |
hidden_lb_table_reranker = get_leaderboard_table(
|
277 |
-
hidden_lb_df_reranker,
|
278 |
-
datastore.types_qa, visible=False
|
279 |
)
|
280 |
|
281 |
selected_version.change(
|
282 |
update_datastore,
|
283 |
-
[
|
|
|
|
|
284 |
[
|
285 |
selected_domains,
|
286 |
selected_langs,
|
287 |
selected_rerankings_reranker,
|
288 |
lb_table_reranker,
|
289 |
-
hidden_lb_table_reranker
|
290 |
-
]
|
291 |
)
|
292 |
|
293 |
set_listeners(
|
@@ -315,7 +329,7 @@ with demo:
|
|
315 |
show_revision_and_timestamp,
|
316 |
],
|
317 |
lb_table_reranker,
|
318 |
-
queue=True
|
319 |
)
|
320 |
with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
|
321 |
with gr.Row():
|
@@ -353,14 +367,16 @@ with demo:
|
|
353 |
|
354 |
selected_version.change(
|
355 |
update_datastore_long_doc,
|
356 |
-
[
|
|
|
|
|
357 |
[
|
358 |
selected_domains,
|
359 |
selected_langs,
|
360 |
selected_rerankings,
|
361 |
lb_table_long_doc,
|
362 |
-
hidden_lb_table_long_doc
|
363 |
-
]
|
364 |
)
|
365 |
|
366 |
set_listeners(
|
@@ -386,10 +402,10 @@ with demo:
|
|
386 |
selected_rerankings,
|
387 |
search_bar,
|
388 |
show_anonymous,
|
389 |
-
show_revision_and_timestamp
|
390 |
],
|
391 |
lb_table_long_doc,
|
392 |
-
queue=True
|
393 |
)
|
394 |
with gr.TabItem("Retrieval Only", id=21):
|
395 |
with gr.Row():
|
@@ -399,14 +415,15 @@ with demo:
|
|
399 |
selected_noreranker = get_noreranking_dropdown()
|
400 |
lb_df_retriever_long_doc = datastore.leaderboard_df_long_doc[
|
401 |
datastore.leaderboard_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
402 |
-
|
403 |
lb_df_retriever_long_doc = reset_rank(lb_df_retriever_long_doc)
|
404 |
lb_table_retriever_long_doc = get_leaderboard_table(
|
405 |
-
lb_df_retriever_long_doc, datastore.types_long_doc
|
|
|
406 |
|
407 |
hidden_lb_df_retriever_long_doc = datastore.raw_df_long_doc[
|
408 |
datastore.raw_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
409 |
-
|
410 |
hidden_lb_df_retriever_long_doc = reset_rank(hidden_lb_df_retriever_long_doc)
|
411 |
hidden_lb_table_retriever_long_doc = get_leaderboard_table(
|
412 |
hidden_lb_df_retriever_long_doc, datastore.types_long_doc, visible=False
|
@@ -414,14 +431,16 @@ with demo:
|
|
414 |
|
415 |
selected_version.change(
|
416 |
update_datastore_long_doc,
|
417 |
-
[
|
|
|
|
|
418 |
[
|
419 |
selected_domains,
|
420 |
selected_langs,
|
421 |
selected_noreranker,
|
422 |
lb_table_retriever_long_doc,
|
423 |
-
hidden_lb_table_retriever_long_doc
|
424 |
-
]
|
425 |
)
|
426 |
|
427 |
set_listeners(
|
@@ -449,27 +468,27 @@ with demo:
|
|
449 |
show_revision_and_timestamp,
|
450 |
],
|
451 |
lb_table_retriever_long_doc,
|
452 |
-
queue=True
|
453 |
)
|
454 |
with gr.TabItem("Reranking Only", id=22):
|
455 |
-
lb_df_reranker_ldoc =
|
456 |
-
datastore.leaderboard_df_long_doc[
|
457 |
-
|
458 |
-
COL_NAME_RETRIEVAL_MODEL
|
459 |
-
] == BM25_LINK
|
460 |
-
]
|
461 |
lb_df_reranker_ldoc = reset_rank(lb_df_reranker_ldoc)
|
462 |
-
reranking_models_reranker_ldoc =
|
463 |
-
remove_html).unique().tolist()
|
|
|
464 |
with gr.Row():
|
465 |
with gr.Column(scale=1):
|
466 |
selected_rerankings_reranker_ldoc = get_reranking_dropdown(
|
467 |
-
reranking_models_reranker_ldoc
|
|
|
468 |
with gr.Column(scale=1):
|
469 |
search_bar_reranker_ldoc = gr.Textbox(show_label=False, visible=False)
|
470 |
lb_table_reranker_ldoc = get_leaderboard_table(lb_df_reranker_ldoc, datastore.types_long_doc)
|
471 |
hidden_lb_df_reranker_ldoc = datastore.raw_df_long_doc[
|
472 |
-
datastore.raw_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
|
|
473 |
hidden_lb_df_reranker_ldoc = reset_rank(hidden_lb_df_reranker_ldoc)
|
474 |
hidden_lb_table_reranker_ldoc = get_leaderboard_table(
|
475 |
hidden_lb_df_reranker_ldoc, datastore.types_long_doc, visible=False
|
@@ -477,14 +496,16 @@ with demo:
|
|
477 |
|
478 |
selected_version.change(
|
479 |
update_datastore_long_doc,
|
480 |
-
[
|
|
|
|
|
481 |
[
|
482 |
selected_domains,
|
483 |
selected_langs,
|
484 |
selected_rerankings_reranker_ldoc,
|
485 |
lb_table_reranker_ldoc,
|
486 |
-
hidden_lb_table_reranker_ldoc
|
487 |
-
]
|
488 |
)
|
489 |
|
490 |
set_listeners(
|
@@ -511,7 +532,7 @@ with demo:
|
|
511 |
show_revision_and_timestamp,
|
512 |
],
|
513 |
lb_table_reranker_ldoc,
|
514 |
-
queue=True
|
515 |
)
|
516 |
|
517 |
with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
|
@@ -528,23 +549,18 @@ with demo:
|
|
528 |
with gr.Row():
|
529 |
with gr.Column():
|
530 |
reranking_model_name = gr.Textbox(
|
531 |
-
label="Reranking Model name",
|
532 |
-
info="Optional",
|
533 |
-
value="NoReranker"
|
534 |
)
|
535 |
with gr.Column():
|
536 |
-
reranking_model_url = gr.Textbox(
|
537 |
-
label="Reranking Model URL",
|
538 |
-
info="Optional",
|
539 |
-
value=""
|
540 |
-
)
|
541 |
with gr.Row():
|
542 |
with gr.Column():
|
543 |
benchmark_version = gr.Dropdown(
|
544 |
BENCHMARK_VERSION_LIST,
|
545 |
value=LATEST_BENCHMARK_VERSION,
|
546 |
interactive=True,
|
547 |
-
label="AIR-Bench Version"
|
|
|
548 |
with gr.Row():
|
549 |
upload_button = gr.UploadButton("Click to upload search results", file_count="single")
|
550 |
with gr.Row():
|
@@ -553,7 +569,8 @@ with demo:
|
|
553 |
is_anonymous = gr.Checkbox(
|
554 |
label="Nope. I want to submit anonymously 🥷",
|
555 |
value=False,
|
556 |
-
info="Do you want to shown on the leaderboard by default?"
|
|
|
557 |
with gr.Row():
|
558 |
submit_button = gr.Button("Submit")
|
559 |
with gr.Row():
|
@@ -563,7 +580,8 @@ with demo:
|
|
563 |
[
|
564 |
upload_button,
|
565 |
],
|
566 |
-
file_output
|
|
|
567 |
submit_button.click(
|
568 |
submit_results,
|
569 |
[
|
@@ -573,10 +591,10 @@ with demo:
|
|
573 |
reranking_model_name,
|
574 |
reranking_model_url,
|
575 |
benchmark_version,
|
576 |
-
is_anonymous
|
577 |
],
|
578 |
submission_result,
|
579 |
-
show_progress="hidden"
|
580 |
)
|
581 |
|
582 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
|
|
|
4 |
from apscheduler.schedulers.background import BackgroundScheduler
|
5 |
from huggingface_hub import snapshot_download
|
6 |
|
7 |
+
from src.about import BENCHMARKS_TEXT, EVALUATION_QUEUE_TEXT, INTRODUCTION_TEXT, TITLE
|
8 |
+
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from src.display.components import (
|
10 |
+
get_anonymous_checkbox,
|
|
|
|
|
|
|
|
|
11 |
get_domain_dropdown,
|
12 |
get_language_dropdown,
|
13 |
+
get_leaderboard_table,
|
14 |
+
get_metric_dropdown,
|
15 |
+
get_noreranking_dropdown,
|
16 |
+
get_reranking_dropdown,
|
17 |
get_revision_and_ts_checkbox,
|
18 |
+
get_search_bar,
|
19 |
+
get_version_dropdown,
|
20 |
)
|
21 |
+
from src.display.css_html_js import custom_css
|
22 |
from src.envs import (
|
23 |
API,
|
24 |
+
BENCHMARK_VERSION_LIST,
|
25 |
+
BM25_LINK,
|
26 |
+
COL_NAME_RERANKING_MODEL,
|
27 |
+
COL_NAME_RETRIEVAL_MODEL,
|
28 |
+
DEFAULT_METRIC_LONG_DOC,
|
29 |
+
DEFAULT_METRIC_QA,
|
30 |
EVAL_RESULTS_PATH,
|
31 |
+
LATEST_BENCHMARK_VERSION,
|
32 |
+
METRIC_LIST,
|
33 |
+
REPO_ID,
|
34 |
+
RESULTS_REPO,
|
35 |
+
TOKEN,
|
36 |
)
|
37 |
from src.loaders import load_eval_results
|
38 |
+
from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def restart_space():
|
|
|
46 |
if not os.environ.get("LOCAL_MODE", False):
|
47 |
print("Running in local mode")
|
48 |
snapshot_download(
|
49 |
+
repo_id=RESULTS_REPO,
|
50 |
+
local_dir=EVAL_RESULTS_PATH,
|
51 |
+
repo_type="dataset",
|
52 |
+
tqdm_class=None,
|
53 |
+
etag_timeout=30,
|
54 |
+
token=TOKEN,
|
55 |
)
|
56 |
+
except Exception:
|
57 |
+
print("failed to download")
|
58 |
restart_space()
|
59 |
|
60 |
global data
|
|
|
64 |
|
65 |
|
66 |
def update_metric_qa(
|
67 |
+
metric: str,
|
68 |
+
domains: list,
|
69 |
+
langs: list,
|
70 |
+
reranking_model: list,
|
71 |
+
query: str,
|
72 |
+
show_anonymous: bool,
|
73 |
+
show_revision_and_timestamp: bool,
|
74 |
):
|
75 |
+
return update_metric(
|
76 |
+
datastore, "qa", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp
|
77 |
+
)
|
78 |
|
79 |
|
80 |
def update_metric_long_doc(
|
81 |
+
metric: str,
|
82 |
+
domains: list,
|
83 |
+
langs: list,
|
84 |
+
reranking_model: list,
|
85 |
+
query: str,
|
86 |
+
show_anonymous: bool,
|
87 |
+
show_revision_and_timestamp,
|
88 |
):
|
89 |
+
return update_metric(
|
90 |
+
datastore,
|
91 |
+
"long-doc",
|
92 |
+
metric,
|
93 |
+
domains,
|
94 |
+
langs,
|
95 |
+
reranking_model,
|
96 |
+
query,
|
97 |
+
show_anonymous,
|
98 |
+
show_revision_and_timestamp,
|
99 |
+
)
|
100 |
|
101 |
|
102 |
def update_datastore(version):
|
|
|
106 |
selected_domains = get_domain_dropdown(QABenchmarks[datastore.slug])
|
107 |
selected_langs = get_language_dropdown(QABenchmarks[datastore.slug])
|
108 |
selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
|
109 |
+
leaderboard_table = get_leaderboard_table(datastore.leaderboard_df_qa, datastore.types_qa)
|
110 |
+
hidden_leaderboard_table = get_leaderboard_table(datastore.raw_df_qa, datastore.types_qa, visible=False)
|
|
|
|
|
111 |
return selected_domains, selected_langs, selected_rerankings, leaderboard_table, hidden_leaderboard_table
|
112 |
|
113 |
|
|
|
118 |
selected_domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
|
119 |
selected_langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
|
120 |
selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
|
121 |
+
leaderboard_table = get_leaderboard_table(datastore.leaderboard_df_long_doc, datastore.types_long_doc)
|
|
|
122 |
hidden_leaderboard_table = get_leaderboard_table(
|
123 |
+
datastore.raw_df_long_doc, datastore.types_long_doc, visible=False
|
124 |
+
)
|
125 |
return selected_domains, selected_langs, selected_rerankings, leaderboard_table, hidden_leaderboard_table
|
126 |
|
127 |
|
|
|
162 |
with gr.Column():
|
163 |
selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
|
164 |
# shown_table
|
165 |
+
lb_table = get_leaderboard_table(datastore.leaderboard_df_qa, datastore.types_qa)
|
|
|
166 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
167 |
+
hidden_lb_table = get_leaderboard_table(datastore.raw_df_qa, datastore.types_qa, visible=False)
|
|
|
168 |
|
169 |
selected_version.change(
|
170 |
update_datastore,
|
171 |
+
[
|
172 |
+
selected_version,
|
173 |
+
],
|
174 |
+
[selected_domains, selected_langs, selected_rerankings, lb_table, hidden_lb_table],
|
175 |
)
|
176 |
|
177 |
set_listeners(
|
|
|
200 |
show_revision_and_timestamp,
|
201 |
],
|
202 |
lb_table,
|
203 |
+
queue=True,
|
204 |
)
|
205 |
|
206 |
with gr.TabItem("Retrieval Only", id=11):
|
|
|
211 |
selected_noreranker = get_noreranking_dropdown()
|
212 |
|
213 |
lb_df_retriever = datastore.leaderboard_df_qa[
|
214 |
+
datastore.leaderboard_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
215 |
+
]
|
216 |
lb_df_retriever = reset_rank(lb_df_retriever)
|
217 |
+
lb_table_retriever = get_leaderboard_table(lb_df_retriever, datastore.types_qa)
|
|
|
218 |
|
219 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
220 |
hidden_lb_df_retriever = datastore.raw_df_qa[
|
221 |
+
datastore.raw_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
222 |
+
]
|
223 |
hidden_lb_df_retriever = reset_rank(hidden_lb_df_retriever)
|
224 |
+
hidden_lb_table_retriever = get_leaderboard_table(
|
225 |
+
hidden_lb_df_retriever, datastore.types_qa, visible=False
|
226 |
+
)
|
227 |
|
228 |
selected_version.change(
|
229 |
update_datastore,
|
230 |
+
[
|
231 |
+
selected_version,
|
232 |
+
],
|
233 |
[
|
234 |
selected_domains,
|
235 |
selected_langs,
|
236 |
selected_noreranker,
|
237 |
lb_table_retriever,
|
238 |
+
hidden_lb_table_retriever,
|
239 |
+
],
|
240 |
)
|
241 |
|
242 |
set_listeners(
|
|
|
265 |
show_revision_and_timestamp,
|
266 |
],
|
267 |
lb_table_retriever,
|
268 |
+
queue=True,
|
269 |
)
|
270 |
with gr.TabItem("Reranking Only", id=12):
|
271 |
+
lb_df_reranker = datastore.leaderboard_df_qa[
|
272 |
+
datastore.leaderboard_df_qa[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
273 |
+
]
|
|
|
|
|
|
|
274 |
lb_df_reranker = reset_rank(lb_df_reranker)
|
275 |
+
reranking_models_reranker = (
|
276 |
+
lb_df_reranker[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
|
277 |
+
)
|
278 |
with gr.Row():
|
279 |
with gr.Column(scale=1):
|
280 |
selected_rerankings_reranker = get_reranking_dropdown(reranking_models_reranker)
|
281 |
with gr.Column(scale=1):
|
282 |
search_bar_reranker = gr.Textbox(show_label=False, visible=False)
|
283 |
+
lb_table_reranker = get_leaderboard_table(lb_df_reranker, datastore.types_qa)
|
|
|
284 |
|
285 |
hidden_lb_df_reranker = datastore.raw_df_qa[
|
286 |
+
datastore.raw_df_qa[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
287 |
+
]
|
288 |
hidden_lb_df_reranker = reset_rank(hidden_lb_df_reranker)
|
289 |
hidden_lb_table_reranker = get_leaderboard_table(
|
290 |
+
hidden_lb_df_reranker, datastore.types_qa, visible=False
|
|
|
291 |
)
|
292 |
|
293 |
selected_version.change(
|
294 |
update_datastore,
|
295 |
+
[
|
296 |
+
selected_version,
|
297 |
+
],
|
298 |
[
|
299 |
selected_domains,
|
300 |
selected_langs,
|
301 |
selected_rerankings_reranker,
|
302 |
lb_table_reranker,
|
303 |
+
hidden_lb_table_reranker,
|
304 |
+
],
|
305 |
)
|
306 |
|
307 |
set_listeners(
|
|
|
329 |
show_revision_and_timestamp,
|
330 |
],
|
331 |
lb_table_reranker,
|
332 |
+
queue=True,
|
333 |
)
|
334 |
with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
|
335 |
with gr.Row():
|
|
|
367 |
|
368 |
selected_version.change(
|
369 |
update_datastore_long_doc,
|
370 |
+
[
|
371 |
+
selected_version,
|
372 |
+
],
|
373 |
[
|
374 |
selected_domains,
|
375 |
selected_langs,
|
376 |
selected_rerankings,
|
377 |
lb_table_long_doc,
|
378 |
+
hidden_lb_table_long_doc,
|
379 |
+
],
|
380 |
)
|
381 |
|
382 |
set_listeners(
|
|
|
402 |
selected_rerankings,
|
403 |
search_bar,
|
404 |
show_anonymous,
|
405 |
+
show_revision_and_timestamp,
|
406 |
],
|
407 |
lb_table_long_doc,
|
408 |
+
queue=True,
|
409 |
)
|
410 |
with gr.TabItem("Retrieval Only", id=21):
|
411 |
with gr.Row():
|
|
|
415 |
selected_noreranker = get_noreranking_dropdown()
|
416 |
lb_df_retriever_long_doc = datastore.leaderboard_df_long_doc[
|
417 |
datastore.leaderboard_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
418 |
+
]
|
419 |
lb_df_retriever_long_doc = reset_rank(lb_df_retriever_long_doc)
|
420 |
lb_table_retriever_long_doc = get_leaderboard_table(
|
421 |
+
lb_df_retriever_long_doc, datastore.types_long_doc
|
422 |
+
)
|
423 |
|
424 |
hidden_lb_df_retriever_long_doc = datastore.raw_df_long_doc[
|
425 |
datastore.raw_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
426 |
+
]
|
427 |
hidden_lb_df_retriever_long_doc = reset_rank(hidden_lb_df_retriever_long_doc)
|
428 |
hidden_lb_table_retriever_long_doc = get_leaderboard_table(
|
429 |
hidden_lb_df_retriever_long_doc, datastore.types_long_doc, visible=False
|
|
|
431 |
|
432 |
selected_version.change(
|
433 |
update_datastore_long_doc,
|
434 |
+
[
|
435 |
+
selected_version,
|
436 |
+
],
|
437 |
[
|
438 |
selected_domains,
|
439 |
selected_langs,
|
440 |
selected_noreranker,
|
441 |
lb_table_retriever_long_doc,
|
442 |
+
hidden_lb_table_retriever_long_doc,
|
443 |
+
],
|
444 |
)
|
445 |
|
446 |
set_listeners(
|
|
|
468 |
show_revision_and_timestamp,
|
469 |
],
|
470 |
lb_table_retriever_long_doc,
|
471 |
+
queue=True,
|
472 |
)
|
473 |
with gr.TabItem("Reranking Only", id=22):
|
474 |
+
lb_df_reranker_ldoc = datastore.leaderboard_df_long_doc[
|
475 |
+
datastore.leaderboard_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
476 |
+
]
|
|
|
|
|
|
|
477 |
lb_df_reranker_ldoc = reset_rank(lb_df_reranker_ldoc)
|
478 |
+
reranking_models_reranker_ldoc = (
|
479 |
+
lb_df_reranker_ldoc[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
|
480 |
+
)
|
481 |
with gr.Row():
|
482 |
with gr.Column(scale=1):
|
483 |
selected_rerankings_reranker_ldoc = get_reranking_dropdown(
|
484 |
+
reranking_models_reranker_ldoc
|
485 |
+
)
|
486 |
with gr.Column(scale=1):
|
487 |
search_bar_reranker_ldoc = gr.Textbox(show_label=False, visible=False)
|
488 |
lb_table_reranker_ldoc = get_leaderboard_table(lb_df_reranker_ldoc, datastore.types_long_doc)
|
489 |
hidden_lb_df_reranker_ldoc = datastore.raw_df_long_doc[
|
490 |
+
datastore.raw_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
491 |
+
]
|
492 |
hidden_lb_df_reranker_ldoc = reset_rank(hidden_lb_df_reranker_ldoc)
|
493 |
hidden_lb_table_reranker_ldoc = get_leaderboard_table(
|
494 |
hidden_lb_df_reranker_ldoc, datastore.types_long_doc, visible=False
|
|
|
496 |
|
497 |
selected_version.change(
|
498 |
update_datastore_long_doc,
|
499 |
+
[
|
500 |
+
selected_version,
|
501 |
+
],
|
502 |
[
|
503 |
selected_domains,
|
504 |
selected_langs,
|
505 |
selected_rerankings_reranker_ldoc,
|
506 |
lb_table_reranker_ldoc,
|
507 |
+
hidden_lb_table_reranker_ldoc,
|
508 |
+
],
|
509 |
)
|
510 |
|
511 |
set_listeners(
|
|
|
532 |
show_revision_and_timestamp,
|
533 |
],
|
534 |
lb_table_reranker_ldoc,
|
535 |
+
queue=True,
|
536 |
)
|
537 |
|
538 |
with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
|
|
|
549 |
with gr.Row():
|
550 |
with gr.Column():
|
551 |
reranking_model_name = gr.Textbox(
|
552 |
+
label="Reranking Model name", info="Optional", value="NoReranker"
|
|
|
|
|
553 |
)
|
554 |
with gr.Column():
|
555 |
+
reranking_model_url = gr.Textbox(label="Reranking Model URL", info="Optional", value="")
|
|
|
|
|
|
|
|
|
556 |
with gr.Row():
|
557 |
with gr.Column():
|
558 |
benchmark_version = gr.Dropdown(
|
559 |
BENCHMARK_VERSION_LIST,
|
560 |
value=LATEST_BENCHMARK_VERSION,
|
561 |
interactive=True,
|
562 |
+
label="AIR-Bench Version",
|
563 |
+
)
|
564 |
with gr.Row():
|
565 |
upload_button = gr.UploadButton("Click to upload search results", file_count="single")
|
566 |
with gr.Row():
|
|
|
569 |
is_anonymous = gr.Checkbox(
|
570 |
label="Nope. I want to submit anonymously 🥷",
|
571 |
value=False,
|
572 |
+
info="Do you want to shown on the leaderboard by default?",
|
573 |
+
)
|
574 |
with gr.Row():
|
575 |
submit_button = gr.Button("Submit")
|
576 |
with gr.Row():
|
|
|
580 |
[
|
581 |
upload_button,
|
582 |
],
|
583 |
+
file_output,
|
584 |
+
)
|
585 |
submit_button.click(
|
586 |
submit_results,
|
587 |
[
|
|
|
591 |
reranking_model_name,
|
592 |
reranking_model_url,
|
593 |
benchmark_version,
|
594 |
+
is_anonymous,
|
595 |
],
|
596 |
submission_result,
|
597 |
+
show_progress="hidden",
|
598 |
)
|
599 |
|
600 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
|
src/about.py
CHANGED
@@ -8,7 +8,7 @@ INTRODUCTION_TEXT = """
|
|
8 |
"""
|
9 |
|
10 |
# Which evaluations are you running? how can people reproduce what you have?
|
11 |
-
BENCHMARKS_TEXT =
|
12 |
## How the test data are generated?
|
13 |
### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
|
14 |
|
|
|
8 |
"""
|
9 |
|
10 |
# Which evaluations are you running? how can people reproduce what you have?
|
11 |
+
BENCHMARKS_TEXT = """
|
12 |
## How the test data are generated?
|
13 |
### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
|
14 |
|
src/benchmarks.py
CHANGED
@@ -3,16 +3,13 @@ from enum import Enum
|
|
3 |
|
4 |
from air_benchmark.tasks.tasks import BenchmarkTable
|
5 |
|
6 |
-
from src.envs import
|
7 |
|
8 |
|
9 |
def get_safe_name(name: str):
|
10 |
"""Get RFC 1123 compatible safe name"""
|
11 |
-
name = name.replace(
|
12 |
-
return
|
13 |
-
character.lower()
|
14 |
-
for character in name
|
15 |
-
if (character.isalnum() or character == '_'))
|
16 |
|
17 |
|
18 |
@dataclass
|
@@ -39,8 +36,9 @@ def get_benchmarks_enum(benchmark_version, task_type):
|
|
39 |
for metric in dataset_list:
|
40 |
if "test" not in dataset_list[metric]["splits"]:
|
41 |
continue
|
42 |
-
benchmark_dict[benchmark_name] =
|
43 |
-
|
|
|
44 |
elif task_type == "long-doc":
|
45 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
46 |
if task != task_type:
|
@@ -54,21 +52,26 @@ def get_benchmarks_enum(benchmark_version, task_type):
|
|
54 |
if "test" not in dataset_list[dataset]["splits"]:
|
55 |
continue
|
56 |
for metric in METRIC_LIST:
|
57 |
-
benchmark_dict[benchmark_name] =
|
58 |
-
|
|
|
59 |
return benchmark_dict
|
60 |
|
61 |
|
62 |
qa_benchmark_dict = {}
|
63 |
for version in BENCHMARK_VERSION_LIST:
|
64 |
safe_version_name = get_safe_name(version)[-4:]
|
65 |
-
qa_benchmark_dict[safe_version_name] = Enum(
|
|
|
|
|
66 |
|
67 |
long_doc_benchmark_dict = {}
|
68 |
for version in BENCHMARK_VERSION_LIST:
|
69 |
safe_version_name = get_safe_name(version)[-4:]
|
70 |
-
long_doc_benchmark_dict[safe_version_name] = Enum(
|
|
|
|
|
71 |
|
72 |
|
73 |
-
QABenchmarks = Enum(
|
74 |
-
LongDocBenchmarks = Enum(
|
|
|
3 |
|
4 |
from air_benchmark.tasks.tasks import BenchmarkTable
|
5 |
|
6 |
+
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
|
7 |
|
8 |
|
9 |
def get_safe_name(name: str):
|
10 |
"""Get RFC 1123 compatible safe name"""
|
11 |
+
name = name.replace("-", "_")
|
12 |
+
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
@dataclass
|
|
|
36 |
for metric in dataset_list:
|
37 |
if "test" not in dataset_list[metric]["splits"]:
|
38 |
continue
|
39 |
+
benchmark_dict[benchmark_name] = Benchmark(
|
40 |
+
benchmark_name, metric, col_name, domain, lang, task
|
41 |
+
)
|
42 |
elif task_type == "long-doc":
|
43 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
44 |
if task != task_type:
|
|
|
52 |
if "test" not in dataset_list[dataset]["splits"]:
|
53 |
continue
|
54 |
for metric in METRIC_LIST:
|
55 |
+
benchmark_dict[benchmark_name] = Benchmark(
|
56 |
+
benchmark_name, metric, col_name, domain, lang, task
|
57 |
+
)
|
58 |
return benchmark_dict
|
59 |
|
60 |
|
61 |
qa_benchmark_dict = {}
|
62 |
for version in BENCHMARK_VERSION_LIST:
|
63 |
safe_version_name = get_safe_name(version)[-4:]
|
64 |
+
qa_benchmark_dict[safe_version_name] = Enum(
|
65 |
+
f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa")
|
66 |
+
)
|
67 |
|
68 |
long_doc_benchmark_dict = {}
|
69 |
for version in BENCHMARK_VERSION_LIST:
|
70 |
safe_version_name = get_safe_name(version)[-4:]
|
71 |
+
long_doc_benchmark_dict[safe_version_name] = Enum(
|
72 |
+
f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc")
|
73 |
+
)
|
74 |
|
75 |
|
76 |
+
QABenchmarks = Enum("QABenchmarks", qa_benchmark_dict)
|
77 |
+
LongDocBenchmarks = Enum("LongDocBenchmarks", long_doc_benchmark_dict)
|
src/display/columns.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1 |
from dataclasses import dataclass, make_dataclass
|
2 |
|
3 |
-
from src.envs import
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
def fields(raw_class):
|
@@ -23,16 +32,20 @@ class ColumnContent:
|
|
23 |
def get_default_auto_eval_column_dict():
|
24 |
auto_eval_column_dict = []
|
25 |
# Init
|
|
|
26 |
auto_eval_column_dict.append(
|
27 |
-
[
|
|
|
|
|
|
|
|
|
28 |
)
|
29 |
auto_eval_column_dict.append(
|
30 |
-
[
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, hidden=False, never_hidden=True)]
|
36 |
)
|
37 |
auto_eval_column_dict.append(
|
38 |
["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
|
@@ -40,16 +53,20 @@ def get_default_auto_eval_column_dict():
|
|
40 |
auto_eval_column_dict.append(
|
41 |
["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
|
42 |
)
|
|
|
43 |
auto_eval_column_dict.append(
|
44 |
-
[
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
)
|
50 |
auto_eval_column_dict.append(
|
51 |
-
[
|
52 |
-
|
|
|
|
|
|
|
53 |
)
|
54 |
auto_eval_column_dict.append(
|
55 |
["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
|
|
|
1 |
from dataclasses import dataclass, make_dataclass
|
2 |
|
3 |
+
from src.envs import (
|
4 |
+
COL_NAME_AVG,
|
5 |
+
COL_NAME_IS_ANONYMOUS,
|
6 |
+
COL_NAME_RANK,
|
7 |
+
COL_NAME_RERANKING_MODEL,
|
8 |
+
COL_NAME_RERANKING_MODEL_LINK,
|
9 |
+
COL_NAME_RETRIEVAL_MODEL,
|
10 |
+
COL_NAME_RETRIEVAL_MODEL_LINK,
|
11 |
+
COL_NAME_REVISION,
|
12 |
+
COL_NAME_TIMESTAMP,
|
13 |
+
)
|
14 |
|
15 |
|
16 |
def fields(raw_class):
|
|
|
32 |
def get_default_auto_eval_column_dict():
|
33 |
auto_eval_column_dict = []
|
34 |
# Init
|
35 |
+
auto_eval_column_dict.append(["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)])
|
36 |
auto_eval_column_dict.append(
|
37 |
+
[
|
38 |
+
"retrieval_model",
|
39 |
+
ColumnContent,
|
40 |
+
ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, hidden=False, never_hidden=True),
|
41 |
+
]
|
42 |
)
|
43 |
auto_eval_column_dict.append(
|
44 |
+
[
|
45 |
+
"reranking_model",
|
46 |
+
ColumnContent,
|
47 |
+
ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, hidden=False, never_hidden=True),
|
48 |
+
]
|
|
|
49 |
)
|
50 |
auto_eval_column_dict.append(
|
51 |
["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
|
|
|
53 |
auto_eval_column_dict.append(
|
54 |
["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
|
55 |
)
|
56 |
+
auto_eval_column_dict.append(["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)])
|
57 |
auto_eval_column_dict.append(
|
58 |
+
[
|
59 |
+
"retrieval_model_link",
|
60 |
+
ColumnContent,
|
61 |
+
ColumnContent(COL_NAME_RETRIEVAL_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False),
|
62 |
+
]
|
63 |
)
|
64 |
auto_eval_column_dict.append(
|
65 |
+
[
|
66 |
+
"reranking_model_link",
|
67 |
+
ColumnContent,
|
68 |
+
ColumnContent(COL_NAME_RERANKING_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False),
|
69 |
+
]
|
70 |
)
|
71 |
auto_eval_column_dict.append(
|
72 |
["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
|
src/display/components.py
CHANGED
@@ -8,7 +8,7 @@ def get_version_dropdown():
|
|
8 |
choices=BENCHMARK_VERSION_LIST,
|
9 |
value=LATEST_BENCHMARK_VERSION,
|
10 |
label="Select the version of AIR-Bench",
|
11 |
-
interactive=True
|
12 |
)
|
13 |
|
14 |
|
@@ -16,26 +16,25 @@ def get_search_bar():
|
|
16 |
return gr.Textbox(
|
17 |
placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
|
18 |
show_label=False,
|
19 |
-
info="Search the retrieval methods"
|
20 |
)
|
21 |
|
22 |
|
23 |
def get_reranking_dropdown(model_list):
|
24 |
-
return gr.Dropdown(
|
25 |
-
choices=model_list,
|
26 |
-
label="Select the reranking models",
|
27 |
-
interactive=True,
|
28 |
-
multiselect=True
|
29 |
-
)
|
30 |
|
31 |
|
32 |
def get_noreranking_dropdown():
|
33 |
return gr.Dropdown(
|
34 |
-
choices=[
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
interactive=False,
|
37 |
multiselect=True,
|
38 |
-
visible=False
|
39 |
)
|
40 |
|
41 |
|
@@ -75,7 +74,7 @@ def get_language_dropdown(benchmarks, default_languages=None):
|
|
75 |
value=default_languages,
|
76 |
label="Select the languages",
|
77 |
multiselect=True,
|
78 |
-
interactive=True
|
79 |
)
|
80 |
|
81 |
|
@@ -83,15 +82,13 @@ def get_anonymous_checkbox():
|
|
83 |
return gr.Checkbox(
|
84 |
label="Show anonymous submissions",
|
85 |
value=False,
|
86 |
-
info="The anonymous submissions might have invalid model information."
|
87 |
)
|
88 |
|
89 |
|
90 |
def get_revision_and_ts_checkbox():
|
91 |
return gr.Checkbox(
|
92 |
-
label="Show submission details",
|
93 |
-
value=False,
|
94 |
-
info="Show the revision and timestamp information of submissions"
|
95 |
)
|
96 |
|
97 |
|
|
|
8 |
choices=BENCHMARK_VERSION_LIST,
|
9 |
value=LATEST_BENCHMARK_VERSION,
|
10 |
label="Select the version of AIR-Bench",
|
11 |
+
interactive=True,
|
12 |
)
|
13 |
|
14 |
|
|
|
16 |
return gr.Textbox(
|
17 |
placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
|
18 |
show_label=False,
|
19 |
+
info="Search the retrieval methods",
|
20 |
)
|
21 |
|
22 |
|
23 |
def get_reranking_dropdown(model_list):
|
24 |
+
return gr.Dropdown(choices=model_list, label="Select the reranking models", interactive=True, multiselect=True)
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
def get_noreranking_dropdown():
|
28 |
return gr.Dropdown(
|
29 |
+
choices=[
|
30 |
+
"NoReranker",
|
31 |
+
],
|
32 |
+
value=[
|
33 |
+
"NoReranker",
|
34 |
+
],
|
35 |
interactive=False,
|
36 |
multiselect=True,
|
37 |
+
visible=False,
|
38 |
)
|
39 |
|
40 |
|
|
|
74 |
value=default_languages,
|
75 |
label="Select the languages",
|
76 |
multiselect=True,
|
77 |
+
interactive=True,
|
78 |
)
|
79 |
|
80 |
|
|
|
82 |
return gr.Checkbox(
|
83 |
label="Show anonymous submissions",
|
84 |
value=False,
|
85 |
+
info="The anonymous submissions might have invalid model information.",
|
86 |
)
|
87 |
|
88 |
|
89 |
def get_revision_and_ts_checkbox():
|
90 |
return gr.Checkbox(
|
91 |
+
label="Show submission details", value=False, info="Show the revision and timestamp information of submissions"
|
|
|
|
|
92 |
)
|
93 |
|
94 |
|
src/envs.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import os
|
2 |
-
|
3 |
from huggingface_hub import HfApi
|
4 |
|
|
|
|
|
5 |
# Info to change for your repository
|
6 |
# ----------------------------------
|
7 |
TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
|
@@ -63,7 +65,7 @@ METRIC_LIST = [
|
|
63 |
"mrr_at_5",
|
64 |
"mrr_at_10",
|
65 |
"mrr_at_100",
|
66 |
-
"mrr_at_1000"
|
67 |
]
|
68 |
COL_NAME_AVG = "Average ⬆️"
|
69 |
COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
|
|
|
1 |
import os
|
2 |
+
|
3 |
from huggingface_hub import HfApi
|
4 |
|
5 |
+
from src.display.formatting import model_hyperlink
|
6 |
+
|
7 |
# Info to change for your repository
|
8 |
# ----------------------------------
|
9 |
TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
|
|
|
65 |
"mrr_at_5",
|
66 |
"mrr_at_10",
|
67 |
"mrr_at_100",
|
68 |
+
"mrr_at_1000",
|
69 |
]
|
70 |
COL_NAME_AVG = "Average ⬆️"
|
71 |
COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
|
src/loaders.py
CHANGED
@@ -3,8 +3,14 @@ from typing import List
|
|
3 |
|
4 |
import pandas as pd
|
5 |
|
6 |
-
from src.envs import
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from src.models import FullEvalResult, LeaderboardDataStore
|
9 |
from src.utils import get_default_cols, get_leaderboard_df
|
10 |
|
@@ -23,7 +29,7 @@ def load_raw_eval_results(results_path: str) -> List[FullEvalResult]:
|
|
23 |
# select the latest results
|
24 |
for file in files:
|
25 |
if not (file.startswith("results") and file.endswith(".json")):
|
26 |
-
print(f
|
27 |
continue
|
28 |
model_result_filepaths.append(os.path.join(root, file))
|
29 |
|
@@ -32,10 +38,10 @@ def load_raw_eval_results(results_path: str) -> List[FullEvalResult]:
|
|
32 |
# create evaluation results
|
33 |
try:
|
34 |
eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
|
35 |
-
except UnicodeDecodeError
|
36 |
print(f"loading file failed. {model_result_filepath}")
|
37 |
continue
|
38 |
-
print(f
|
39 |
timestamp = eval_result.timestamp
|
40 |
eval_results[timestamp] = eval_result
|
41 |
|
@@ -52,43 +58,39 @@ def load_raw_eval_results(results_path: str) -> List[FullEvalResult]:
|
|
52 |
|
53 |
def get_safe_name(name: str):
|
54 |
"""Get RFC 1123 compatible safe name"""
|
55 |
-
name = name.replace(
|
56 |
-
return
|
57 |
-
character.lower()
|
58 |
-
for character in name
|
59 |
-
if (character.isalnum() or character == '_'))
|
60 |
|
61 |
|
62 |
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
63 |
slug = get_safe_name(version)[-4:]
|
64 |
lb_data_store = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
|
65 |
lb_data_store.raw_data = load_raw_eval_results(file_path)
|
66 |
-
print(f
|
67 |
|
68 |
-
lb_data_store.raw_df_qa = get_leaderboard_df(
|
69 |
-
|
70 |
-
print(f'QA data loaded: {lb_data_store.raw_df_qa.shape}')
|
71 |
lb_data_store.leaderboard_df_qa = lb_data_store.raw_df_qa.copy()
|
72 |
-
shown_columns_qa, types_qa = get_default_cols(
|
73 |
lb_data_store.types_qa = types_qa
|
74 |
-
lb_data_store.leaderboard_df_qa =
|
75 |
-
|
|
|
76 |
lb_data_store.leaderboard_df_qa.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
77 |
|
78 |
-
lb_data_store.raw_df_long_doc = get_leaderboard_df(
|
79 |
-
|
80 |
-
print(f'Long-Doc data loaded: {len(lb_data_store.raw_df_long_doc)}')
|
81 |
lb_data_store.leaderboard_df_long_doc = lb_data_store.raw_df_long_doc.copy()
|
82 |
-
shown_columns_long_doc, types_long_doc = get_default_cols(
|
83 |
-
'long-doc', lb_data_store.slug, add_fix_cols=True)
|
84 |
lb_data_store.types_long_doc = types_long_doc
|
85 |
-
lb_data_store.leaderboard_df_long_doc =
|
86 |
-
lb_data_store.leaderboard_df_long_doc[
|
87 |
-
|
88 |
lb_data_store.leaderboard_df_long_doc.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
89 |
|
90 |
lb_data_store.reranking_models = sorted(
|
91 |
-
list(frozenset([eval_result.reranking_model for eval_result in lb_data_store.raw_data]))
|
|
|
92 |
return lb_data_store
|
93 |
|
94 |
|
|
|
3 |
|
4 |
import pandas as pd
|
5 |
|
6 |
+
from src.envs import (
|
7 |
+
BENCHMARK_VERSION_LIST,
|
8 |
+
COL_NAME_IS_ANONYMOUS,
|
9 |
+
COL_NAME_REVISION,
|
10 |
+
COL_NAME_TIMESTAMP,
|
11 |
+
DEFAULT_METRIC_LONG_DOC,
|
12 |
+
DEFAULT_METRIC_QA,
|
13 |
+
)
|
14 |
from src.models import FullEvalResult, LeaderboardDataStore
|
15 |
from src.utils import get_default_cols, get_leaderboard_df
|
16 |
|
|
|
29 |
# select the latest results
|
30 |
for file in files:
|
31 |
if not (file.startswith("results") and file.endswith(".json")):
|
32 |
+
print(f"skip {file}")
|
33 |
continue
|
34 |
model_result_filepaths.append(os.path.join(root, file))
|
35 |
|
|
|
38 |
# create evaluation results
|
39 |
try:
|
40 |
eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
|
41 |
+
except UnicodeDecodeError:
|
42 |
print(f"loading file failed. {model_result_filepath}")
|
43 |
continue
|
44 |
+
print(f"file loaded: {model_result_filepath}")
|
45 |
timestamp = eval_result.timestamp
|
46 |
eval_results[timestamp] = eval_result
|
47 |
|
|
|
58 |
|
59 |
def get_safe_name(name: str):
|
60 |
"""Get RFC 1123 compatible safe name"""
|
61 |
+
name = name.replace("-", "_")
|
62 |
+
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
|
|
|
|
|
|
63 |
|
64 |
|
65 |
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
66 |
slug = get_safe_name(version)[-4:]
|
67 |
lb_data_store = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
|
68 |
lb_data_store.raw_data = load_raw_eval_results(file_path)
|
69 |
+
print(f"raw data: {len(lb_data_store.raw_data)}")
|
70 |
|
71 |
+
lb_data_store.raw_df_qa = get_leaderboard_df(lb_data_store, task="qa", metric=DEFAULT_METRIC_QA)
|
72 |
+
print(f"QA data loaded: {lb_data_store.raw_df_qa.shape}")
|
|
|
73 |
lb_data_store.leaderboard_df_qa = lb_data_store.raw_df_qa.copy()
|
74 |
+
shown_columns_qa, types_qa = get_default_cols("qa", lb_data_store.slug, add_fix_cols=True)
|
75 |
lb_data_store.types_qa = types_qa
|
76 |
+
lb_data_store.leaderboard_df_qa = lb_data_store.leaderboard_df_qa[
|
77 |
+
~lb_data_store.leaderboard_df_qa[COL_NAME_IS_ANONYMOUS]
|
78 |
+
][shown_columns_qa]
|
79 |
lb_data_store.leaderboard_df_qa.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
80 |
|
81 |
+
lb_data_store.raw_df_long_doc = get_leaderboard_df(lb_data_store, task="long-doc", metric=DEFAULT_METRIC_LONG_DOC)
|
82 |
+
print(f"Long-Doc data loaded: {len(lb_data_store.raw_df_long_doc)}")
|
|
|
83 |
lb_data_store.leaderboard_df_long_doc = lb_data_store.raw_df_long_doc.copy()
|
84 |
+
shown_columns_long_doc, types_long_doc = get_default_cols("long-doc", lb_data_store.slug, add_fix_cols=True)
|
|
|
85 |
lb_data_store.types_long_doc = types_long_doc
|
86 |
+
lb_data_store.leaderboard_df_long_doc = lb_data_store.leaderboard_df_long_doc[
|
87 |
+
~lb_data_store.leaderboard_df_long_doc[COL_NAME_IS_ANONYMOUS]
|
88 |
+
][shown_columns_long_doc]
|
89 |
lb_data_store.leaderboard_df_long_doc.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
90 |
|
91 |
lb_data_store.reranking_models = sorted(
|
92 |
+
list(frozenset([eval_result.reranking_model for eval_result in lb_data_store.raw_data]))
|
93 |
+
)
|
94 |
return lb_data_store
|
95 |
|
96 |
|
src/models.py
CHANGED
@@ -7,8 +7,15 @@ import pandas as pd
|
|
7 |
|
8 |
from src.benchmarks import get_safe_name
|
9 |
from src.display.formatting import make_clickable_model
|
10 |
-
from src.envs import
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
@dataclass
|
@@ -17,6 +24,7 @@ class EvalResult:
|
|
17 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
|
18 |
domains, languages, and datasets
|
19 |
"""
|
|
|
20 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
|
21 |
retrieval_model: str
|
22 |
reranking_model: str
|
@@ -33,6 +41,7 @@ class FullEvalResult:
|
|
33 |
"""
|
34 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
|
35 |
"""
|
|
|
36 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
|
37 |
retrieval_model: str
|
38 |
reranking_model: str
|
@@ -56,7 +65,6 @@ class FullEvalResult:
|
|
56 |
result_list = []
|
57 |
retrieval_model_link = ""
|
58 |
reranking_model_link = ""
|
59 |
-
revision = ""
|
60 |
for item in model_data:
|
61 |
config = item.get("config", {})
|
62 |
# eval results for different metrics
|
@@ -75,7 +83,7 @@ class FullEvalResult:
|
|
75 |
metric=config["metric"],
|
76 |
timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
|
77 |
revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
|
78 |
-
is_anonymous=config.get("is_anonymous", False)
|
79 |
)
|
80 |
result_list.append(eval_result)
|
81 |
return cls(
|
@@ -87,10 +95,10 @@ class FullEvalResult:
|
|
87 |
results=result_list,
|
88 |
timestamp=result_list[0].timestamp,
|
89 |
revision=result_list[0].revision,
|
90 |
-
is_anonymous=result_list[0].is_anonymous
|
91 |
)
|
92 |
|
93 |
-
def to_dict(self, task=
|
94 |
"""
|
95 |
Convert the results in all the EvalResults over different tasks and metrics.
|
96 |
The output is a list of dict compatible with the dataframe UI
|
@@ -102,10 +110,12 @@ class FullEvalResult:
|
|
102 |
if eval_result.task != task:
|
103 |
continue
|
104 |
results[eval_result.eval_name]["eval_name"] = eval_result.eval_name
|
105 |
-
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL] = (
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
|
110 |
results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
|
111 |
results[eval_result.eval_name][COL_NAME_REVISION] = self.revision
|
@@ -118,7 +128,7 @@ class FullEvalResult:
|
|
118 |
lang = result["lang"]
|
119 |
dataset = result["dataset"]
|
120 |
value = result["value"] * 100
|
121 |
-
if dataset ==
|
122 |
benchmark_name = f"{domain}_{lang}"
|
123 |
else:
|
124 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
|
|
7 |
|
8 |
from src.benchmarks import get_safe_name
|
9 |
from src.display.formatting import make_clickable_model
|
10 |
+
from src.envs import (
|
11 |
+
COL_NAME_IS_ANONYMOUS,
|
12 |
+
COL_NAME_RERANKING_MODEL,
|
13 |
+
COL_NAME_RERANKING_MODEL_LINK,
|
14 |
+
COL_NAME_RETRIEVAL_MODEL,
|
15 |
+
COL_NAME_RETRIEVAL_MODEL_LINK,
|
16 |
+
COL_NAME_REVISION,
|
17 |
+
COL_NAME_TIMESTAMP,
|
18 |
+
)
|
19 |
|
20 |
|
21 |
@dataclass
|
|
|
24 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
|
25 |
domains, languages, and datasets
|
26 |
"""
|
27 |
+
|
28 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
|
29 |
retrieval_model: str
|
30 |
reranking_model: str
|
|
|
41 |
"""
|
42 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
|
43 |
"""
|
44 |
+
|
45 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
|
46 |
retrieval_model: str
|
47 |
reranking_model: str
|
|
|
65 |
result_list = []
|
66 |
retrieval_model_link = ""
|
67 |
reranking_model_link = ""
|
|
|
68 |
for item in model_data:
|
69 |
config = item.get("config", {})
|
70 |
# eval results for different metrics
|
|
|
83 |
metric=config["metric"],
|
84 |
timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
|
85 |
revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
|
86 |
+
is_anonymous=config.get("is_anonymous", False),
|
87 |
)
|
88 |
result_list.append(eval_result)
|
89 |
return cls(
|
|
|
95 |
results=result_list,
|
96 |
timestamp=result_list[0].timestamp,
|
97 |
revision=result_list[0].revision,
|
98 |
+
is_anonymous=result_list[0].is_anonymous,
|
99 |
)
|
100 |
|
101 |
+
def to_dict(self, task="qa", metric="ndcg_at_3") -> List:
|
102 |
"""
|
103 |
Convert the results in all the EvalResults over different tasks and metrics.
|
104 |
The output is a list of dict compatible with the dataframe UI
|
|
|
110 |
if eval_result.task != task:
|
111 |
continue
|
112 |
results[eval_result.eval_name]["eval_name"] = eval_result.eval_name
|
113 |
+
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL] = make_clickable_model(
|
114 |
+
self.retrieval_model, self.retrieval_model_link
|
115 |
+
)
|
116 |
+
results[eval_result.eval_name][COL_NAME_RERANKING_MODEL] = make_clickable_model(
|
117 |
+
self.reranking_model, self.reranking_model_link
|
118 |
+
)
|
119 |
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
|
120 |
results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
|
121 |
results[eval_result.eval_name][COL_NAME_REVISION] = self.revision
|
|
|
128 |
lang = result["lang"]
|
129 |
dataset = result["dataset"]
|
130 |
value = result["value"] * 100
|
131 |
+
if dataset == "default":
|
132 |
benchmark_name = f"{domain}_{lang}"
|
133 |
else:
|
134 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
src/utils.py
CHANGED
@@ -6,11 +6,21 @@ from pathlib import Path
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from src.benchmarks import
|
10 |
from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
|
11 |
-
from src.display.formatting import
|
12 |
-
from src.envs import
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def calculate_mean(row):
|
@@ -22,7 +32,7 @@ def calculate_mean(row):
|
|
22 |
|
23 |
def remove_html(input_str):
|
24 |
# Regular expression for finding HTML tags
|
25 |
-
clean = re.sub(r
|
26 |
return clean
|
27 |
|
28 |
|
@@ -67,7 +77,7 @@ def get_default_cols(task: str, version_slug, add_fix_cols: bool = True) -> tupl
|
|
67 |
elif task == "long-doc":
|
68 |
benchmarks = LongDocBenchmarks[version_slug]
|
69 |
else:
|
70 |
-
raise
|
71 |
cols_list, types_list = get_default_col_names_and_types(benchmarks)
|
72 |
benchmark_list = [c.value.col_name for c in list(benchmarks.value)]
|
73 |
for col_name, col_type in zip(cols_list, types_list):
|
@@ -91,12 +101,12 @@ def get_default_cols(task: str, version_slug, add_fix_cols: bool = True) -> tupl
|
|
91 |
|
92 |
|
93 |
def select_columns(
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
) -> pd.DataFrame:
|
101 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
102 |
selected_cols = []
|
@@ -106,7 +116,7 @@ def select_columns(
|
|
106 |
elif task == "long-doc":
|
107 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
108 |
else:
|
109 |
-
raise
|
110 |
if eval_col.domain not in domain_query:
|
111 |
continue
|
112 |
if eval_col.lang not in language_query:
|
@@ -127,24 +137,21 @@ def select_columns(
|
|
127 |
|
128 |
def get_safe_name(name: str):
|
129 |
"""Get RFC 1123 compatible safe name"""
|
130 |
-
name = name.replace(
|
131 |
-
return
|
132 |
-
character.lower()
|
133 |
-
for character in name
|
134 |
-
if (character.isalnum() or character == '_'))
|
135 |
|
136 |
|
137 |
def _update_table(
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
):
|
149 |
version_slug = get_safe_name(version)[-4:]
|
150 |
filtered_df = hidden_df.copy()
|
@@ -159,36 +166,43 @@ def _update_table(
|
|
159 |
|
160 |
|
161 |
def update_table_long_doc(
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
):
|
173 |
return _update_table(
|
174 |
"long-doc",
|
175 |
version,
|
176 |
-
hidden_df,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
|
179 |
def update_metric(
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
) -> pd.DataFrame:
|
190 |
# raw_data = datastore.raw_data
|
191 |
-
if task ==
|
192 |
leaderboard_df = get_leaderboard_df(datastore, task=task, metric=metric)
|
193 |
version = datastore.version
|
194 |
return update_table(
|
@@ -199,7 +213,7 @@ def update_metric(
|
|
199 |
reranking_model,
|
200 |
query,
|
201 |
show_anonymous,
|
202 |
-
show_revision_and_timestamp
|
203 |
)
|
204 |
elif task == "long-doc":
|
205 |
leaderboard_df = get_leaderboard_df(datastore, task=task, metric=metric)
|
@@ -212,7 +226,7 @@ def update_metric(
|
|
212 |
reranking_model,
|
213 |
query,
|
214 |
show_anonymous,
|
215 |
-
show_revision_and_timestamp
|
216 |
)
|
217 |
|
218 |
|
@@ -231,15 +245,15 @@ def get_iso_format_timestamp():
|
|
231 |
current_timestamp = current_timestamp.replace(microsecond=0)
|
232 |
|
233 |
# Convert to ISO 8601 format and replace the offset with 'Z'
|
234 |
-
iso_format_timestamp = current_timestamp.isoformat().replace(
|
235 |
-
filename_friendly_timestamp = current_timestamp.strftime(
|
236 |
return iso_format_timestamp, filename_friendly_timestamp
|
237 |
|
238 |
|
239 |
def calculate_file_md5(file_path):
|
240 |
md5 = hashlib.md5()
|
241 |
|
242 |
-
with open(file_path,
|
243 |
while True:
|
244 |
data = f.read(4096)
|
245 |
if not data:
|
@@ -250,13 +264,14 @@ def calculate_file_md5(file_path):
|
|
250 |
|
251 |
|
252 |
def submit_results(
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
260 |
if not filepath.endswith(".zip"):
|
261 |
return styled_error(f"file uploading aborted. wrong file type: {filepath}")
|
262 |
|
@@ -269,11 +284,13 @@ def submit_results(
|
|
269 |
if not model_url.startswith("https://") and not model_url.startswith("http://"):
|
270 |
# TODO: retrieve the model page and find the model name on the page
|
271 |
return styled_error(
|
272 |
-
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
|
|
273 |
if reranking_model != "NoReranker":
|
274 |
if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
|
275 |
return styled_error(
|
276 |
-
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
|
|
277 |
|
278 |
# rename the uploaded file
|
279 |
input_fp = Path(filepath)
|
@@ -283,14 +300,15 @@ def submit_results(
|
|
283 |
input_folder_path = input_fp.parent
|
284 |
|
285 |
if not reranking_model:
|
286 |
-
reranking_model =
|
287 |
|
288 |
API.upload_file(
|
289 |
path_or_fileobj=filepath,
|
290 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
|
291 |
repo_id=SEARCH_RESULTS_REPO,
|
292 |
repo_type="dataset",
|
293 |
-
commit_message=f"feat: submit {model} to evaluate"
|
|
|
294 |
|
295 |
output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
|
296 |
output_config = {
|
@@ -301,7 +319,7 @@ def submit_results(
|
|
301 |
"version": f"{version}",
|
302 |
"is_anonymous": is_anonymous,
|
303 |
"revision": f"{revision}",
|
304 |
-
"timestamp": f"{timestamp_config}"
|
305 |
}
|
306 |
with open(input_folder_path / output_config_fn, "w") as f:
|
307 |
json.dump(output_config, f, indent=4, ensure_ascii=False)
|
@@ -310,7 +328,8 @@ def submit_results(
|
|
310 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
|
311 |
repo_id=SEARCH_RESULTS_REPO,
|
312 |
repo_type="dataset",
|
313 |
-
commit_message=f"feat: submit {model} + {reranking_model} config"
|
|
|
314 |
return styled_message(
|
315 |
f"Thanks for submission!\n"
|
316 |
f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
|
@@ -327,13 +346,15 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
327 |
Creates a dataframe from all the individual experiment results
|
328 |
"""
|
329 |
raw_data = datastore.raw_data
|
330 |
-
cols = [
|
|
|
|
|
331 |
if task == "qa":
|
332 |
benchmarks = QABenchmarks[datastore.slug]
|
333 |
elif task == "long-doc":
|
334 |
benchmarks = LongDocBenchmarks[datastore.slug]
|
335 |
else:
|
336 |
-
raise
|
337 |
cols_qa, _ = get_default_col_names_and_types(benchmarks)
|
338 |
cols += cols_qa
|
339 |
benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
|
@@ -364,16 +385,16 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
364 |
|
365 |
|
366 |
def set_listeners(
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
):
|
378 |
if task == "qa":
|
379 |
update_table_func = update_table
|
@@ -381,35 +402,51 @@ def set_listeners(
|
|
381 |
update_table_func = update_table_long_doc
|
382 |
else:
|
383 |
raise NotImplementedError
|
384 |
-
selector_list = [
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
393 |
# Set search_bar listener
|
394 |
search_bar.submit(update_table_func, search_bar_args, target_df)
|
395 |
|
396 |
# Set column-wise listener
|
397 |
for selector in selector_list:
|
398 |
-
selector.change(
|
|
|
|
|
|
|
|
|
|
|
399 |
|
400 |
|
401 |
def update_table(
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
):
|
412 |
return _update_table(
|
413 |
"qa",
|
414 |
version,
|
415 |
-
hidden_df,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
10 |
from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
|
11 |
+
from src.display.formatting import styled_error, styled_message
|
12 |
+
from src.envs import (
|
13 |
+
API,
|
14 |
+
COL_NAME_AVG,
|
15 |
+
COL_NAME_IS_ANONYMOUS,
|
16 |
+
COL_NAME_RANK,
|
17 |
+
COL_NAME_RERANKING_MODEL,
|
18 |
+
COL_NAME_RETRIEVAL_MODEL,
|
19 |
+
COL_NAME_REVISION,
|
20 |
+
COL_NAME_TIMESTAMP,
|
21 |
+
LATEST_BENCHMARK_VERSION,
|
22 |
+
SEARCH_RESULTS_REPO,
|
23 |
+
)
|
24 |
|
25 |
|
26 |
def calculate_mean(row):
|
|
|
32 |
|
33 |
def remove_html(input_str):
|
34 |
# Regular expression for finding HTML tags
|
35 |
+
clean = re.sub(r"<.*?>", "", input_str)
|
36 |
return clean
|
37 |
|
38 |
|
|
|
77 |
elif task == "long-doc":
|
78 |
benchmarks = LongDocBenchmarks[version_slug]
|
79 |
else:
|
80 |
+
raise NotImplementedError
|
81 |
cols_list, types_list = get_default_col_names_and_types(benchmarks)
|
82 |
benchmark_list = [c.value.col_name for c in list(benchmarks.value)]
|
83 |
for col_name, col_type in zip(cols_list, types_list):
|
|
|
101 |
|
102 |
|
103 |
def select_columns(
|
104 |
+
df: pd.DataFrame,
|
105 |
+
domain_query: list,
|
106 |
+
language_query: list,
|
107 |
+
task: str = "qa",
|
108 |
+
reset_ranking: bool = True,
|
109 |
+
version_slug: str = None,
|
110 |
) -> pd.DataFrame:
|
111 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
112 |
selected_cols = []
|
|
|
116 |
elif task == "long-doc":
|
117 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
118 |
else:
|
119 |
+
raise NotImplementedError
|
120 |
if eval_col.domain not in domain_query:
|
121 |
continue
|
122 |
if eval_col.lang not in language_query:
|
|
|
137 |
|
138 |
def get_safe_name(name: str):
|
139 |
"""Get RFC 1123 compatible safe name"""
|
140 |
+
name = name.replace("-", "_")
|
141 |
+
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
|
|
|
|
|
|
142 |
|
143 |
|
144 |
def _update_table(
|
145 |
+
task: str,
|
146 |
+
version: str,
|
147 |
+
hidden_df: pd.DataFrame,
|
148 |
+
domains: list,
|
149 |
+
langs: list,
|
150 |
+
reranking_query: list,
|
151 |
+
query: str,
|
152 |
+
show_anonymous: bool,
|
153 |
+
reset_ranking: bool = True,
|
154 |
+
show_revision_and_timestamp: bool = False,
|
155 |
):
|
156 |
version_slug = get_safe_name(version)[-4:]
|
157 |
filtered_df = hidden_df.copy()
|
|
|
166 |
|
167 |
|
168 |
def update_table_long_doc(
|
169 |
+
version: str,
|
170 |
+
hidden_df: pd.DataFrame,
|
171 |
+
domains: list,
|
172 |
+
langs: list,
|
173 |
+
reranking_query: list,
|
174 |
+
query: str,
|
175 |
+
show_anonymous: bool,
|
176 |
+
show_revision_and_timestamp: bool = False,
|
177 |
+
reset_ranking: bool = True,
|
|
|
178 |
):
|
179 |
return _update_table(
|
180 |
"long-doc",
|
181 |
version,
|
182 |
+
hidden_df,
|
183 |
+
domains,
|
184 |
+
langs,
|
185 |
+
reranking_query,
|
186 |
+
query,
|
187 |
+
show_anonymous,
|
188 |
+
reset_ranking,
|
189 |
+
show_revision_and_timestamp,
|
190 |
+
)
|
191 |
|
192 |
|
193 |
def update_metric(
|
194 |
+
datastore,
|
195 |
+
task: str,
|
196 |
+
metric: str,
|
197 |
+
domains: list,
|
198 |
+
langs: list,
|
199 |
+
reranking_model: list,
|
200 |
+
query: str,
|
201 |
+
show_anonymous: bool = False,
|
202 |
+
show_revision_and_timestamp: bool = False,
|
203 |
) -> pd.DataFrame:
|
204 |
# raw_data = datastore.raw_data
|
205 |
+
if task == "qa":
|
206 |
leaderboard_df = get_leaderboard_df(datastore, task=task, metric=metric)
|
207 |
version = datastore.version
|
208 |
return update_table(
|
|
|
213 |
reranking_model,
|
214 |
query,
|
215 |
show_anonymous,
|
216 |
+
show_revision_and_timestamp,
|
217 |
)
|
218 |
elif task == "long-doc":
|
219 |
leaderboard_df = get_leaderboard_df(datastore, task=task, metric=metric)
|
|
|
226 |
reranking_model,
|
227 |
query,
|
228 |
show_anonymous,
|
229 |
+
show_revision_and_timestamp,
|
230 |
)
|
231 |
|
232 |
|
|
|
245 |
current_timestamp = current_timestamp.replace(microsecond=0)
|
246 |
|
247 |
# Convert to ISO 8601 format and replace the offset with 'Z'
|
248 |
+
iso_format_timestamp = current_timestamp.isoformat().replace("+00:00", "Z")
|
249 |
+
filename_friendly_timestamp = current_timestamp.strftime("%Y%m%d%H%M%S")
|
250 |
return iso_format_timestamp, filename_friendly_timestamp
|
251 |
|
252 |
|
253 |
def calculate_file_md5(file_path):
|
254 |
md5 = hashlib.md5()
|
255 |
|
256 |
+
with open(file_path, "rb") as f:
|
257 |
while True:
|
258 |
data = f.read(4096)
|
259 |
if not data:
|
|
|
264 |
|
265 |
|
266 |
def submit_results(
|
267 |
+
filepath: str,
|
268 |
+
model: str,
|
269 |
+
model_url: str,
|
270 |
+
reranking_model: str = "",
|
271 |
+
reranking_model_url: str = "",
|
272 |
+
version: str = LATEST_BENCHMARK_VERSION,
|
273 |
+
is_anonymous=False,
|
274 |
+
):
|
275 |
if not filepath.endswith(".zip"):
|
276 |
return styled_error(f"file uploading aborted. wrong file type: {filepath}")
|
277 |
|
|
|
284 |
if not model_url.startswith("https://") and not model_url.startswith("http://"):
|
285 |
# TODO: retrieve the model page and find the model name on the page
|
286 |
return styled_error(
|
287 |
+
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
288 |
+
)
|
289 |
if reranking_model != "NoReranker":
|
290 |
if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
|
291 |
return styled_error(
|
292 |
+
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
293 |
+
)
|
294 |
|
295 |
# rename the uploaded file
|
296 |
input_fp = Path(filepath)
|
|
|
300 |
input_folder_path = input_fp.parent
|
301 |
|
302 |
if not reranking_model:
|
303 |
+
reranking_model = "NoReranker"
|
304 |
|
305 |
API.upload_file(
|
306 |
path_or_fileobj=filepath,
|
307 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
|
308 |
repo_id=SEARCH_RESULTS_REPO,
|
309 |
repo_type="dataset",
|
310 |
+
commit_message=f"feat: submit {model} to evaluate",
|
311 |
+
)
|
312 |
|
313 |
output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
|
314 |
output_config = {
|
|
|
319 |
"version": f"{version}",
|
320 |
"is_anonymous": is_anonymous,
|
321 |
"revision": f"{revision}",
|
322 |
+
"timestamp": f"{timestamp_config}",
|
323 |
}
|
324 |
with open(input_folder_path / output_config_fn, "w") as f:
|
325 |
json.dump(output_config, f, indent=4, ensure_ascii=False)
|
|
|
328 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
|
329 |
repo_id=SEARCH_RESULTS_REPO,
|
330 |
repo_type="dataset",
|
331 |
+
commit_message=f"feat: submit {model} + {reranking_model} config",
|
332 |
+
)
|
333 |
return styled_message(
|
334 |
f"Thanks for submission!\n"
|
335 |
f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
|
|
|
346 |
Creates a dataframe from all the individual experiment results
|
347 |
"""
|
348 |
raw_data = datastore.raw_data
|
349 |
+
cols = [
|
350 |
+
COL_NAME_IS_ANONYMOUS,
|
351 |
+
]
|
352 |
if task == "qa":
|
353 |
benchmarks = QABenchmarks[datastore.slug]
|
354 |
elif task == "long-doc":
|
355 |
benchmarks = LongDocBenchmarks[datastore.slug]
|
356 |
else:
|
357 |
+
raise NotImplementedError
|
358 |
cols_qa, _ = get_default_col_names_and_types(benchmarks)
|
359 |
cols += cols_qa
|
360 |
benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
|
|
|
385 |
|
386 |
|
387 |
def set_listeners(
|
388 |
+
task,
|
389 |
+
target_df,
|
390 |
+
source_df,
|
391 |
+
search_bar,
|
392 |
+
version,
|
393 |
+
selected_domains,
|
394 |
+
selected_langs,
|
395 |
+
selected_rerankings,
|
396 |
+
show_anonymous,
|
397 |
+
show_revision_and_timestamp,
|
398 |
):
|
399 |
if task == "qa":
|
400 |
update_table_func = update_table
|
|
|
402 |
update_table_func = update_table_long_doc
|
403 |
else:
|
404 |
raise NotImplementedError
|
405 |
+
selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
|
406 |
+
search_bar_args = [
|
407 |
+
source_df,
|
408 |
+
version,
|
409 |
+
] + selector_list
|
410 |
+
selector_args = (
|
411 |
+
[version, source_df]
|
412 |
+
+ selector_list
|
413 |
+
+ [
|
414 |
+
show_revision_and_timestamp,
|
415 |
+
]
|
416 |
+
)
|
417 |
# Set search_bar listener
|
418 |
search_bar.submit(update_table_func, search_bar_args, target_df)
|
419 |
|
420 |
# Set column-wise listener
|
421 |
for selector in selector_list:
|
422 |
+
selector.change(
|
423 |
+
update_table_func,
|
424 |
+
selector_args,
|
425 |
+
target_df,
|
426 |
+
queue=True,
|
427 |
+
)
|
428 |
|
429 |
|
430 |
def update_table(
|
431 |
+
version: str,
|
432 |
+
hidden_df: pd.DataFrame,
|
433 |
+
domains: list,
|
434 |
+
langs: list,
|
435 |
+
reranking_query: list,
|
436 |
+
query: str,
|
437 |
+
show_anonymous: bool,
|
438 |
+
show_revision_and_timestamp: bool = False,
|
439 |
+
reset_ranking: bool = True,
|
440 |
):
|
441 |
return _update_table(
|
442 |
"qa",
|
443 |
version,
|
444 |
+
hidden_df,
|
445 |
+
domains,
|
446 |
+
langs,
|
447 |
+
reranking_query,
|
448 |
+
query,
|
449 |
+
show_anonymous,
|
450 |
+
reset_ranking,
|
451 |
+
show_revision_and_timestamp,
|
452 |
+
)
|
tests/src/display/test_utils.py
CHANGED
@@ -1,5 +1,13 @@
|
|
1 |
-
|
2 |
-
from src.display.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
def test_fields():
|
@@ -8,13 +16,12 @@ def test_fields():
|
|
8 |
|
9 |
|
10 |
def test_macro_variables():
|
11 |
-
print(f
|
12 |
-
print(f
|
13 |
-
print(f
|
14 |
-
print(f
|
15 |
|
16 |
|
17 |
def test_get_default_auto_eval_column_dict():
|
18 |
auto_eval_column_dict_list = get_default_auto_eval_column_dict()
|
19 |
assert len(auto_eval_column_dict_list) == 9
|
20 |
-
|
|
|
1 |
+
|
2 |
+
from src.display.utils import (
|
3 |
+
COLS_LONG_DOC,
|
4 |
+
COLS_QA,
|
5 |
+
TYPES_LONG_DOC,
|
6 |
+
TYPES_QA,
|
7 |
+
AutoEvalColumnQA,
|
8 |
+
fields,
|
9 |
+
get_default_auto_eval_column_dict,
|
10 |
+
)
|
11 |
|
12 |
|
13 |
def test_fields():
|
|
|
16 |
|
17 |
|
18 |
def test_macro_variables():
|
19 |
+
print(f"COLS_QA: {COLS_QA}")
|
20 |
+
print(f"COLS_LONG_DOC: {COLS_LONG_DOC}")
|
21 |
+
print(f"TYPES_QA: {TYPES_QA}")
|
22 |
+
print(f"TYPES_LONG_DOC: {TYPES_LONG_DOC}")
|
23 |
|
24 |
|
25 |
def test_get_default_auto_eval_column_dict():
|
26 |
auto_eval_column_dict_list = get_default_auto_eval_column_dict()
|
27 |
assert len(auto_eval_column_dict_list) == 9
|
|
tests/src/test_benchmarks.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from src.benchmarks import
|
2 |
|
3 |
|
4 |
def test_qabenchmarks():
|
@@ -11,6 +11,5 @@ def test_qabenchmarks():
|
|
11 |
print(l)
|
12 |
|
13 |
|
14 |
-
|
15 |
def test_longdocbenchmarks():
|
16 |
print(list(LongDocBenchmarks))
|
|
|
1 |
+
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
2 |
|
3 |
|
4 |
def test_qabenchmarks():
|
|
|
11 |
print(l)
|
12 |
|
13 |
|
|
|
14 |
def test_longdocbenchmarks():
|
15 |
print(list(LongDocBenchmarks))
|
tests/src/test_read_evals.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
from pathlib import Path
|
2 |
|
|
|
3 |
from src.read_evals import load_raw_eval_results
|
4 |
from src.utils import get_leaderboard_df
|
5 |
-
from src.models import FullEvalResult
|
6 |
|
7 |
cur_fp = Path(__file__)
|
8 |
|
@@ -11,8 +11,7 @@ def test_init_from_json_file():
|
|
11 |
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
12 |
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
13 |
num_different_task_domain_lang_metric_dataset_combination = 6
|
14 |
-
assert len(full_eval_result.results) ==
|
15 |
-
num_different_task_domain_lang_metric_dataset_combination
|
16 |
assert full_eval_result.retrieval_model == "bge-m3"
|
17 |
assert full_eval_result.reranking_model == "bge-reranker-v2-m3"
|
18 |
|
@@ -20,7 +19,7 @@ def test_init_from_json_file():
|
|
20 |
def test_to_dict():
|
21 |
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
22 |
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
23 |
-
result_list = full_eval_result.to_dict(task=
|
24 |
assert len(result_list) == 1
|
25 |
result_dict = result_list[0]
|
26 |
assert result_dict["Retrieval Model"] == "bge-m3"
|
@@ -43,7 +42,7 @@ def test_get_raw_eval_results():
|
|
43 |
def test_get_leaderboard_df():
|
44 |
results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
|
45 |
raw_data = load_raw_eval_results(results_path)
|
46 |
-
df = get_leaderboard_df(raw_data,
|
47 |
assert df.shape[0] == 4
|
48 |
# the results contain only one embedding model
|
49 |
# for i in range(4):
|
@@ -58,7 +57,7 @@ def test_get_leaderboard_df():
|
|
58 |
def test_get_leaderboard_df_long_doc():
|
59 |
results_path = cur_fp.parents[2] / "toydata" / "test_results"
|
60 |
raw_data = load_raw_eval_results(results_path)
|
61 |
-
df = get_leaderboard_df(raw_data,
|
62 |
assert df.shape[0] == 2
|
63 |
# the results contain only one embedding model
|
64 |
for i in range(2):
|
@@ -67,4 +66,13 @@ def test_get_leaderboard_df_long_doc():
|
|
67 |
assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
|
68 |
assert df["Reranking Model"][1] == "NoReranker"
|
69 |
assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
|
70 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
+
from src.models import FullEvalResult
|
4 |
from src.read_evals import load_raw_eval_results
|
5 |
from src.utils import get_leaderboard_df
|
|
|
6 |
|
7 |
cur_fp = Path(__file__)
|
8 |
|
|
|
11 |
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
12 |
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
13 |
num_different_task_domain_lang_metric_dataset_combination = 6
|
14 |
+
assert len(full_eval_result.results) == num_different_task_domain_lang_metric_dataset_combination
|
|
|
15 |
assert full_eval_result.retrieval_model == "bge-m3"
|
16 |
assert full_eval_result.reranking_model == "bge-reranker-v2-m3"
|
17 |
|
|
|
19 |
def test_to_dict():
|
20 |
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
21 |
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
22 |
+
result_list = full_eval_result.to_dict(task="qa", metric="ndcg_at_1")
|
23 |
assert len(result_list) == 1
|
24 |
result_dict = result_list[0]
|
25 |
assert result_dict["Retrieval Model"] == "bge-m3"
|
|
|
42 |
def test_get_leaderboard_df():
|
43 |
results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
|
44 |
raw_data = load_raw_eval_results(results_path)
|
45 |
+
df = get_leaderboard_df(raw_data, "qa", "ndcg_at_10")
|
46 |
assert df.shape[0] == 4
|
47 |
# the results contain only one embedding model
|
48 |
# for i in range(4):
|
|
|
57 |
def test_get_leaderboard_df_long_doc():
|
58 |
results_path = cur_fp.parents[2] / "toydata" / "test_results"
|
59 |
raw_data = load_raw_eval_results(results_path)
|
60 |
+
df = get_leaderboard_df(raw_data, "long-doc", "ndcg_at_1")
|
61 |
assert df.shape[0] == 2
|
62 |
# the results contain only one embedding model
|
63 |
for i in range(2):
|
|
|
66 |
assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
|
67 |
assert df["Reranking Model"][1] == "NoReranker"
|
68 |
assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
|
69 |
+
assert (
|
70 |
+
not df[
|
71 |
+
[
|
72 |
+
"Average ⬆️",
|
73 |
+
"law_en_lex_files_500k_600k",
|
74 |
+
]
|
75 |
+
]
|
76 |
+
.isnull()
|
77 |
+
.values.any()
|
78 |
+
)
|
tests/test_utils.py
CHANGED
@@ -1,28 +1,33 @@
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
-
from src.utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc, get_iso_format_timestamp, get_default_cols
|
5 |
from app import update_table
|
6 |
-
from src.envs import
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
@pytest.fixture
|
11 |
def toy_df():
|
12 |
return pd.DataFrame(
|
13 |
{
|
14 |
-
"Retrieval Model": [
|
15 |
-
|
16 |
-
"bge-m3",
|
17 |
-
"jina-embeddings-v2-base",
|
18 |
-
"jina-embeddings-v2-base"
|
19 |
-
],
|
20 |
-
"Reranking Model": [
|
21 |
-
"bge-reranker-v2-m3",
|
22 |
-
"NoReranker",
|
23 |
-
"bge-reranker-v2-m3",
|
24 |
-
"NoReranker"
|
25 |
-
],
|
26 |
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
27 |
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
28 |
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
@@ -36,18 +41,8 @@ def toy_df():
|
|
36 |
def toy_df_long_doc():
|
37 |
return pd.DataFrame(
|
38 |
{
|
39 |
-
"Retrieval Model": [
|
40 |
-
|
41 |
-
"bge-m3",
|
42 |
-
"jina-embeddings-v2-base",
|
43 |
-
"jina-embeddings-v2-base"
|
44 |
-
],
|
45 |
-
"Reranking Model": [
|
46 |
-
"bge-reranker-v2-m3",
|
47 |
-
"NoReranker",
|
48 |
-
"bge-reranker-v2-m3",
|
49 |
-
"NoReranker"
|
50 |
-
],
|
51 |
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
52 |
"law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
|
53 |
"law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
|
@@ -55,8 +50,15 @@ def toy_df_long_doc():
|
|
55 |
"law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
|
56 |
}
|
57 |
)
|
|
|
|
|
58 |
def test_filter_models(toy_df):
|
59 |
-
df_result = filter_models(
|
|
|
|
|
|
|
|
|
|
|
60 |
assert len(df_result) == 2
|
61 |
assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
|
62 |
|
@@ -74,13 +76,33 @@ def test_filter_queries(toy_df):
|
|
74 |
|
75 |
|
76 |
def test_select_columns(toy_df):
|
77 |
-
df_result = select_columns(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
assert len(df_result.columns) == 4
|
79 |
-
assert df_result[
|
80 |
|
81 |
|
82 |
def test_update_table_long_doc(toy_df_long_doc):
|
83 |
-
df_result = update_table_long_doc(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
print(df_result)
|
85 |
|
86 |
|
@@ -108,10 +130,18 @@ def test_update_table():
|
|
108 |
COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
|
109 |
COL_NAME_RANK: [1, 2, 3],
|
110 |
COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
|
111 |
-
"wiki_en": [0.1, 0.2, 0.3]
|
112 |
}
|
113 |
)
|
114 |
-
results = update_table(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
# keep the RANK as the same regardless of the unsorted averages
|
116 |
assert results[COL_NAME_RANK].to_list() == [1, 2, 3]
|
117 |
-
|
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
|
|
4 |
from app import update_table
|
5 |
+
from src.envs import (
|
6 |
+
COL_NAME_AVG,
|
7 |
+
COL_NAME_IS_ANONYMOUS,
|
8 |
+
COL_NAME_RANK,
|
9 |
+
COL_NAME_RERANKING_MODEL,
|
10 |
+
COL_NAME_RETRIEVAL_MODEL,
|
11 |
+
COL_NAME_REVISION,
|
12 |
+
COL_NAME_TIMESTAMP,
|
13 |
+
)
|
14 |
+
from src.utils import (
|
15 |
+
filter_models,
|
16 |
+
filter_queries,
|
17 |
+
get_default_cols,
|
18 |
+
get_iso_format_timestamp,
|
19 |
+
search_table,
|
20 |
+
select_columns,
|
21 |
+
update_table_long_doc,
|
22 |
+
)
|
23 |
|
24 |
|
25 |
@pytest.fixture
|
26 |
def toy_df():
|
27 |
return pd.DataFrame(
|
28 |
{
|
29 |
+
"Retrieval Model": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
|
30 |
+
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
32 |
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
33 |
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
|
|
41 |
def toy_df_long_doc():
|
42 |
return pd.DataFrame(
|
43 |
{
|
44 |
+
"Retrieval Model": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
|
45 |
+
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
47 |
"law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
|
48 |
"law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
|
|
|
50 |
"law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
|
51 |
}
|
52 |
)
|
53 |
+
|
54 |
+
|
55 |
def test_filter_models(toy_df):
|
56 |
+
df_result = filter_models(
|
57 |
+
toy_df,
|
58 |
+
[
|
59 |
+
"bge-reranker-v2-m3",
|
60 |
+
],
|
61 |
+
)
|
62 |
assert len(df_result) == 2
|
63 |
assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
|
64 |
|
|
|
76 |
|
77 |
|
78 |
def test_select_columns(toy_df):
|
79 |
+
df_result = select_columns(
|
80 |
+
toy_df,
|
81 |
+
[
|
82 |
+
"news",
|
83 |
+
],
|
84 |
+
[
|
85 |
+
"zh",
|
86 |
+
],
|
87 |
+
)
|
88 |
assert len(df_result.columns) == 4
|
89 |
+
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
|
90 |
|
91 |
|
92 |
def test_update_table_long_doc(toy_df_long_doc):
|
93 |
+
df_result = update_table_long_doc(
|
94 |
+
toy_df_long_doc,
|
95 |
+
[
|
96 |
+
"law",
|
97 |
+
],
|
98 |
+
[
|
99 |
+
"en",
|
100 |
+
],
|
101 |
+
[
|
102 |
+
"bge-reranker-v2-m3",
|
103 |
+
],
|
104 |
+
"jina",
|
105 |
+
)
|
106 |
print(df_result)
|
107 |
|
108 |
|
|
|
130 |
COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
|
131 |
COL_NAME_RANK: [1, 2, 3],
|
132 |
COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
|
133 |
+
"wiki_en": [0.1, 0.2, 0.3],
|
134 |
}
|
135 |
)
|
136 |
+
results = update_table(
|
137 |
+
df,
|
138 |
+
"wiki",
|
139 |
+
"en",
|
140 |
+
["NoReranker"],
|
141 |
+
"",
|
142 |
+
show_anonymous=False,
|
143 |
+
reset_ranking=False,
|
144 |
+
show_revision_and_timestamp=False,
|
145 |
+
)
|
146 |
# keep the RANK as the same regardless of the unsorted averages
|
147 |
assert results[COL_NAME_RANK].to_list() == [1, 2, 3]
|
|