Spaces:
AIR-Bench
/
Running on CPU Upgrade

feat-add-versions-to-benchmarks

#28
by nan - opened
Makefile CHANGED
@@ -3,11 +3,21 @@
3
 
4
  style:
5
  python -m black --line-length 119 .
 
6
  python -m isort .
 
7
  ruff check --fix .
 
8
 
9
 
10
  quality:
11
  python -m black --check --line-length 119 .
 
12
  python -m isort --check-only .
 
13
  ruff check .
 
 
 
 
 
 
3
 
4
  style:
5
  python -m black --line-length 119 .
6
+ python -m black --line-length 119 src
7
  python -m isort .
8
+ python -m isort src
9
  ruff check --fix .
10
+ ruff check --fix src
11
 
12
 
13
  quality:
14
  python -m black --check --line-length 119 .
15
+ python -m black --check --line-length 119 src
16
  python -m isort --check-only .
17
+ python -m isort --check-only src
18
  ruff check .
19
+ ruff check src
20
+
21
+
22
+ test:
23
+ python -m pytest tests
app.py CHANGED
@@ -1,131 +1,141 @@
 
 
1
  import gradio as gr
2
  from apscheduler.schedulers.background import BackgroundScheduler
3
  from huggingface_hub import snapshot_download
4
 
5
- from src.about import (
6
- INTRODUCTION_TEXT,
7
- BENCHMARKS_TEXT,
8
- TITLE,
9
- EVALUATION_QUEUE_TEXT
10
- )
11
- from src.benchmarks import (
12
- DOMAIN_COLS_QA,
13
- LANG_COLS_QA,
14
- DOMAIN_COLS_LONG_DOC,
15
- LANG_COLS_LONG_DOC,
16
- METRIC_LIST,
17
- DEFAULT_METRIC_QA,
18
- DEFAULT_METRIC_LONG_DOC
19
- )
20
- from src.display.css_html_js import custom_css
21
- from src.display.utils import (
22
- COL_NAME_IS_ANONYMOUS,
23
- COL_NAME_REVISION,
24
- COL_NAME_TIMESTAMP,
25
- COL_NAME_RERANKING_MODEL,
26
- COL_NAME_RETRIEVAL_MODEL
27
  )
 
28
  from src.envs import (
29
  API,
 
 
 
30
  EVAL_RESULTS_PATH,
 
 
31
  REPO_ID,
32
  RESULTS_REPO,
33
  TOKEN,
34
- BM25_LINK,
35
- BENCHMARK_VERSION_LIST,
36
- LATEST_BENCHMARK_VERSION
37
- )
38
- from src.read_evals import (
39
- get_raw_eval_results,
40
- get_leaderboard_df
41
- )
42
- from src.utils import (
43
- update_metric,
44
- upload_file,
45
- get_default_cols,
46
- submit_results,
47
- reset_rank,
48
- remove_html
49
- )
50
- from src.display.gradio_formatting import (
51
- get_version_dropdown,
52
- get_search_bar,
53
- get_reranking_dropdown,
54
- get_metric_dropdown,
55
- get_domain_dropdown,
56
- get_language_dropdown,
57
- get_anonymous_checkbox,
58
- get_revision_and_ts_checkbox,
59
- get_leaderboard_table,
60
- get_noreranking_dropdown
61
  )
62
- from src.display.gradio_listener import set_listeners
 
 
 
63
 
64
  def restart_space():
65
  API.restart_space(repo_id=REPO_ID)
66
 
67
 
68
  try:
69
- snapshot_download(
70
- repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30,
71
- token=TOKEN
72
- )
73
- except Exception as e:
74
- print(f'failed to download')
 
 
 
 
 
 
75
  restart_space()
76
 
77
- raw_data = get_raw_eval_results(f"{EVAL_RESULTS_PATH}/{LATEST_BENCHMARK_VERSION}")
78
-
79
- original_df_qa = get_leaderboard_df(
80
- raw_data, task='qa', metric=DEFAULT_METRIC_QA)
81
- original_df_long_doc = get_leaderboard_df(
82
- raw_data, task='long-doc', metric=DEFAULT_METRIC_LONG_DOC)
83
- print(f'raw data: {len(raw_data)}')
84
- print(f'QA data loaded: {original_df_qa.shape}')
85
- print(f'Long-Doc data loaded: {len(original_df_long_doc)}')
86
-
87
- leaderboard_df_qa = original_df_qa.copy()
88
- # leaderboard_df_qa = leaderboard_df_qa[has_no_nan_values(df, _benchmark_cols)]
89
- shown_columns_qa, types_qa = get_default_cols(
90
- 'qa', leaderboard_df_qa.columns, add_fix_cols=True)
91
- leaderboard_df_qa = leaderboard_df_qa[~leaderboard_df_qa[COL_NAME_IS_ANONYMOUS]][shown_columns_qa]
92
- leaderboard_df_qa.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
93
-
94
- leaderboard_df_long_doc = original_df_long_doc.copy()
95
- shown_columns_long_doc, types_long_doc = get_default_cols(
96
- 'long-doc', leaderboard_df_long_doc.columns, add_fix_cols=True)
97
- leaderboard_df_long_doc = leaderboard_df_long_doc[~leaderboard_df_long_doc[COL_NAME_IS_ANONYMOUS]][shown_columns_long_doc]
98
- leaderboard_df_long_doc.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
99
-
100
- # select reranking model
101
- reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in raw_data])))
102
-
103
-
104
- def update_metric_qa(
105
- metric: str,
106
- domains: list,
107
- langs: list,
108
- reranking_model: list,
109
- query: str,
110
- show_anonymous: bool,
111
- show_revision_and_timestamp,
112
  ):
113
- return update_metric(raw_data, 'qa', metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
114
-
115
- def update_metric_long_doc(
116
- metric: str,
117
- domains: list,
118
- langs: list,
119
- reranking_model: list,
120
- query: str,
121
- show_anonymous: bool,
 
122
  show_revision_and_timestamp,
 
 
 
 
 
 
 
 
 
 
 
123
  ):
124
- return update_metric(raw_data, "long-doc", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
  demo = gr.Blocks(css=custom_css)
128
 
 
 
129
  with demo:
130
  gr.HTML(TITLE)
131
  gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
@@ -133,25 +143,24 @@ with demo:
133
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
134
  with gr.TabItem("Results", elem_id="results-tab-table"):
135
  with gr.Row():
136
- selected_version = get_version_dropdown()
137
 
138
  with gr.TabItem("QA", elem_id="qa-benchmark-tab-table", id=0):
139
  with gr.Row():
140
  with gr.Column(min_width=320):
141
  # select domain
142
  with gr.Row():
143
- selected_domains = get_domain_dropdown(DOMAIN_COLS_QA, DOMAIN_COLS_QA)
144
  # select language
145
  with gr.Row():
146
- selected_langs = get_language_dropdown(LANG_COLS_QA, LANG_COLS_QA)
147
-
148
  with gr.Column():
149
  # select the metric
150
- selected_metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_QA)
151
  with gr.Row():
152
  show_anonymous = get_anonymous_checkbox()
153
  with gr.Row():
154
- show_revision_and_timestamp = get_revision_and_ts_checkbox()
155
  with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
156
  with gr.TabItem("Retrieval + Reranking", id=10):
157
  with gr.Row():
@@ -160,273 +169,327 @@ with demo:
160
  search_bar = get_search_bar()
161
  # select reranking models
162
  with gr.Column():
163
- selected_rerankings = get_reranking_dropdown(reranking_models)
164
- leaderboard_table = get_leaderboard_table(leaderboard_df_qa, types_qa)
 
165
  # Dummy leaderboard for handling the case when the user uses backspace key
166
- hidden_leaderboard_table_for_search = get_leaderboard_table(original_df_qa, types_qa, visible=False)
 
 
 
 
 
 
 
 
167
 
168
  set_listeners(
169
- "qa",
170
- leaderboard_table,
171
- hidden_leaderboard_table_for_search,
172
  search_bar,
173
- selected_domains,
174
- selected_langs,
175
- selected_rerankings,
 
176
  show_anonymous,
177
- show_revision_and_timestamp,
178
  )
179
 
180
  # set metric listener
181
- selected_metric.change(
182
- update_metric_qa,
183
- [
184
- selected_metric,
185
- selected_domains,
186
- selected_langs,
187
- selected_rerankings,
188
- search_bar,
189
- show_anonymous,
190
- show_revision_and_timestamp,
191
- ],
192
- leaderboard_table,
193
- queue=True
194
  )
 
195
  with gr.TabItem("Retrieval Only", id=11):
196
  with gr.Row():
197
  with gr.Column(scale=1):
198
- search_bar_retriever = get_search_bar()
199
  with gr.Column(scale=1):
200
- selected_noreranker = get_noreranking_dropdown()
201
- lb_df_retriever = leaderboard_df_qa[leaderboard_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"]
202
- lb_df_retriever = reset_rank(lb_df_retriever)
203
- lb_table_retriever = get_leaderboard_table(lb_df_retriever, types_qa)
 
 
204
  # Dummy leaderboard for handling the case when the user uses backspace key
205
- hidden_lb_df_retriever = original_df_qa[original_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"]
206
- hidden_lb_df_retriever = reset_rank(hidden_lb_df_retriever)
207
- hidden_lb_table_retriever = get_leaderboard_table(hidden_lb_df_retriever, types_qa, visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  set_listeners(
210
- "qa",
211
- lb_table_retriever,
212
- hidden_lb_table_retriever,
213
- search_bar_retriever,
214
- selected_domains,
215
- selected_langs,
216
- selected_noreranker,
 
217
  show_anonymous,
218
- show_revision_and_timestamp,
219
  )
220
 
221
- # set metric listener
222
- selected_metric.change(
223
- update_metric_qa,
224
  [
225
- selected_metric,
226
- selected_domains,
227
- selected_langs,
228
- selected_noreranker,
229
- search_bar_retriever,
230
  show_anonymous,
231
- show_revision_and_timestamp,
232
  ],
233
- lb_table_retriever,
234
- queue=True
235
  )
 
236
  with gr.TabItem("Reranking Only", id=12):
237
- lb_df_reranker = leaderboard_df_qa[leaderboard_df_qa[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
238
- lb_df_reranker = reset_rank(lb_df_reranker)
239
- reranking_models_reranker = lb_df_reranker[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
240
  with gr.Row():
241
  with gr.Column(scale=1):
242
- selected_rerankings_reranker = get_reranking_dropdown(reranking_models_reranker)
243
  with gr.Column(scale=1):
244
- search_bar_reranker = gr.Textbox(show_label=False, visible=False)
245
- lb_table_reranker = get_leaderboard_table(lb_df_reranker, types_qa)
246
- hidden_lb_df_reranker = original_df_qa[original_df_qa[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
247
- hidden_lb_df_reranker = reset_rank(hidden_lb_df_reranker)
248
- hidden_lb_table_reranker = get_leaderboard_table(
249
- hidden_lb_df_reranker, types_qa, visible=False
 
 
 
 
 
 
 
 
 
250
  )
251
 
252
  set_listeners(
253
- "qa",
254
- lb_table_reranker,
255
- hidden_lb_table_reranker,
256
- search_bar_reranker,
257
- selected_domains,
258
- selected_langs,
259
- selected_rerankings_reranker,
 
260
  show_anonymous,
261
- show_revision_and_timestamp,
262
  )
263
- # set metric listener
264
- selected_metric.change(
265
- update_metric_qa,
266
  [
267
- selected_metric,
268
- selected_domains,
269
- selected_langs,
270
- selected_rerankings_reranker,
271
- search_bar_reranker,
272
  show_anonymous,
273
- show_revision_and_timestamp,
274
  ],
275
- lb_table_reranker,
276
- queue=True
277
  )
278
  with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
279
  with gr.Row():
280
  with gr.Column(min_width=320):
281
  # select domain
282
  with gr.Row():
283
- selected_domains = get_domain_dropdown(DOMAIN_COLS_LONG_DOC, DOMAIN_COLS_LONG_DOC)
284
  # select language
285
  with gr.Row():
286
- selected_langs = get_language_dropdown(
287
- LANG_COLS_LONG_DOC, LANG_COLS_LONG_DOC
288
- )
289
  with gr.Column():
290
  # select the metric
291
  with gr.Row():
292
- selected_metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_LONG_DOC)
293
  with gr.Row():
294
  show_anonymous = get_anonymous_checkbox()
295
  with gr.Row():
296
- show_revision_and_timestamp = get_revision_and_ts_checkbox()
297
- with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
298
  with gr.TabItem("Retrieval + Reranking", id=20):
299
  with gr.Row():
300
  with gr.Column():
301
  search_bar = get_search_bar()
302
- # select reranking model
303
  with gr.Column():
304
- selected_rerankings = get_reranking_dropdown(reranking_models)
305
 
306
- lb_table = get_leaderboard_table(
307
- leaderboard_df_long_doc, types_long_doc
308
- )
309
 
310
  # Dummy leaderboard for handling the case when the user uses backspace key
311
- hidden_lb_table_for_search = get_leaderboard_table(
312
- original_df_long_doc, types_long_doc, visible=False
 
 
 
 
 
 
313
  )
314
 
315
  set_listeners(
316
- "long-doc",
317
- lb_table,
318
- hidden_lb_table_for_search,
319
  search_bar,
320
- selected_domains,
321
- selected_langs,
322
- selected_rerankings,
 
323
  show_anonymous,
324
- show_revision_and_timestamp,
325
  )
326
 
327
  # set metric listener
328
- selected_metric.change(
329
- update_metric_long_doc,
330
  [
331
- selected_metric,
332
- selected_domains,
333
- selected_langs,
334
- selected_rerankings,
335
  search_bar,
336
  show_anonymous,
337
- show_revision_and_timestamp
338
  ],
339
- lb_table,
340
- queue=True
341
  )
342
  with gr.TabItem("Retrieval Only", id=21):
343
  with gr.Row():
344
  with gr.Column(scale=1):
345
- search_bar_retriever = get_search_bar()
346
  with gr.Column(scale=1):
347
- selected_noreranker = get_noreranking_dropdown()
348
- lb_df_retriever_long_doc = leaderboard_df_long_doc[
349
- leaderboard_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
 
350
  ]
351
- lb_df_retriever_long_doc = reset_rank(lb_df_retriever_long_doc)
352
- hidden_lb_db_retriever_long_doc = original_df_long_doc[
353
- original_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
 
 
354
  ]
355
- hidden_lb_db_retriever_long_doc = reset_rank(hidden_lb_db_retriever_long_doc)
356
- lb_table_retriever_long_doc = get_leaderboard_table(
357
- lb_df_retriever_long_doc, types_long_doc)
358
- hidden_lb_table_retriever_long_doc = get_leaderboard_table(
359
- hidden_lb_db_retriever_long_doc, types_long_doc, visible=False
 
 
 
 
360
  )
361
 
362
  set_listeners(
363
- "long-doc",
364
- lb_table_retriever_long_doc,
365
- hidden_lb_table_retriever_long_doc,
366
- search_bar_retriever,
367
- selected_domains,
368
- selected_langs,
369
- selected_noreranker,
 
370
  show_anonymous,
371
- show_revision_and_timestamp,
372
  )
373
 
374
- selected_metric.change(
375
- update_metric_long_doc,
376
  [
377
- selected_metric,
378
- selected_domains,
379
- selected_langs,
380
- selected_noreranker,
381
- search_bar_retriever,
382
  show_anonymous,
383
- show_revision_and_timestamp,
384
  ],
385
- lb_table_retriever_long_doc,
386
- queue=True
387
  )
388
  with gr.TabItem("Reranking Only", id=22):
389
- lb_df_reranker_ldoc = leaderboard_df_long_doc[
390
- leaderboard_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
391
- ]
392
- lb_df_reranker_ldoc = reset_rank(lb_df_reranker_ldoc)
393
- reranking_models_reranker_ldoc = lb_df_reranker_ldoc[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
 
 
394
  with gr.Row():
395
  with gr.Column(scale=1):
396
- selected_rerankings_reranker_ldoc = get_reranking_dropdown(reranking_models_reranker_ldoc)
397
  with gr.Column(scale=1):
398
- search_bar_reranker_ldoc = gr.Textbox(show_label=False, visible=False)
399
- lb_table_reranker_ldoc = get_leaderboard_table(lb_df_reranker_ldoc, types_long_doc)
400
- hidden_lb_df_reranker_ldoc = original_df_long_doc[original_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
401
- hidden_lb_df_reranker_ldoc = reset_rank(hidden_lb_df_reranker_ldoc)
402
- hidden_lb_table_reranker_ldoc = get_leaderboard_table(
403
- hidden_lb_df_reranker_ldoc, types_long_doc, visible=False
 
 
 
 
 
 
 
 
404
  )
405
 
406
  set_listeners(
407
- "long-doc",
408
- lb_table_reranker_ldoc,
409
- hidden_lb_table_reranker_ldoc,
410
- search_bar_reranker_ldoc,
411
- selected_domains,
412
- selected_langs,
413
- selected_rerankings_reranker_ldoc,
 
414
  show_anonymous,
415
- show_revision_and_timestamp,
416
  )
417
- selected_metric.change(
418
- update_metric_long_doc,
 
419
  [
420
- selected_metric,
421
- selected_domains,
422
- selected_langs,
423
- selected_rerankings_reranker_ldoc,
424
- search_bar_reranker_ldoc,
425
  show_anonymous,
426
- show_revision_and_timestamp,
427
  ],
428
- lb_table_reranker_ldoc,
429
- queue=True
430
  )
431
 
432
  with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
@@ -443,23 +506,18 @@ with demo:
443
  with gr.Row():
444
  with gr.Column():
445
  reranking_model_name = gr.Textbox(
446
- label="Reranking Model name",
447
- info="Optional",
448
- value="NoReranker"
449
  )
450
  with gr.Column():
451
- reranking_model_url = gr.Textbox(
452
- label="Reranking Model URL",
453
- info="Optional",
454
- value=""
455
- )
456
  with gr.Row():
457
  with gr.Column():
458
  benchmark_version = gr.Dropdown(
459
  BENCHMARK_VERSION_LIST,
460
  value=LATEST_BENCHMARK_VERSION,
461
  interactive=True,
462
- label="AIR-Bench Version")
 
463
  with gr.Row():
464
  upload_button = gr.UploadButton("Click to upload search results", file_count="single")
465
  with gr.Row():
@@ -468,7 +526,8 @@ with demo:
468
  is_anonymous = gr.Checkbox(
469
  label="Nope. I want to submit anonymously 🥷",
470
  value=False,
471
- info="Do you want to shown on the leaderboard by default?")
 
472
  with gr.Row():
473
  submit_button = gr.Button("Submit")
474
  with gr.Row():
@@ -478,7 +537,8 @@ with demo:
478
  [
479
  upload_button,
480
  ],
481
- file_output)
 
482
  submit_button.click(
483
  submit_results,
484
  [
@@ -488,10 +548,10 @@ with demo:
488
  reranking_model_name,
489
  reranking_model_url,
490
  benchmark_version,
491
- is_anonymous
492
  ],
493
  submission_result,
494
- show_progress="hidden"
495
  )
496
 
497
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
 
1
+ import os
2
+
3
  import gradio as gr
4
  from apscheduler.schedulers.background import BackgroundScheduler
5
  from huggingface_hub import snapshot_download
6
 
7
+ from src.about import BENCHMARKS_TEXT, EVALUATION_QUEUE_TEXT, INTRODUCTION_TEXT, TITLE
8
+ from src.benchmarks import LongDocBenchmarks, QABenchmarks
9
+ from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
10
+ from src.components import (
11
+ get_anonymous_checkbox,
12
+ get_domain_dropdown,
13
+ get_language_dropdown,
14
+ get_leaderboard_table,
15
+ get_metric_dropdown,
16
+ get_noreranking_dropdown,
17
+ get_reranking_dropdown,
18
+ get_revision_and_ts_checkbox,
19
+ get_search_bar,
20
+ get_version_dropdown,
 
 
 
 
 
 
 
 
21
  )
22
+ from src.css_html_js import custom_css
23
  from src.envs import (
24
  API,
25
+ BENCHMARK_VERSION_LIST,
26
+ DEFAULT_METRIC_LONG_DOC,
27
+ DEFAULT_METRIC_QA,
28
  EVAL_RESULTS_PATH,
29
+ LATEST_BENCHMARK_VERSION,
30
+ METRIC_LIST,
31
  REPO_ID,
32
  RESULTS_REPO,
33
  TOKEN,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
+ from src.loaders import load_eval_results
36
+ from src.models import TaskType, model_hyperlink
37
+ from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
38
+
39
 
40
  def restart_space():
41
  API.restart_space(repo_id=REPO_ID)
42
 
43
 
44
  try:
45
+ if not os.environ.get("LOCAL_MODE", False):
46
+ print("Running in local mode")
47
+ snapshot_download(
48
+ repo_id=RESULTS_REPO,
49
+ local_dir=EVAL_RESULTS_PATH,
50
+ repo_type="dataset",
51
+ tqdm_class=None,
52
+ etag_timeout=30,
53
+ token=TOKEN,
54
+ )
55
+ except Exception:
56
+ print("failed to download")
57
  restart_space()
58
 
59
+ global ds_dict
60
+ ds_dict = load_eval_results(EVAL_RESULTS_PATH)
61
+ global datastore
62
+ datastore = ds_dict[LATEST_BENCHMARK_VERSION]
63
+
64
+
65
+ def update_qa_metric(
66
+ metric: str,
67
+ domains: list,
68
+ langs: list,
69
+ reranking_model: list,
70
+ query: str,
71
+ show_anonymous: bool,
72
+ show_revision_and_timestamp: bool,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  ):
74
+ global datastore
75
+ return update_metric(
76
+ datastore,
77
+ TaskType.qa,
78
+ metric,
79
+ domains,
80
+ langs,
81
+ reranking_model,
82
+ query,
83
+ show_anonymous,
84
  show_revision_and_timestamp,
85
+ )
86
+
87
+
88
+ def update_doc_metric(
89
+ metric: str,
90
+ domains: list,
91
+ langs: list,
92
+ reranking_model: list,
93
+ query: str,
94
+ show_anonymous: bool,
95
+ show_revision_and_timestamp,
96
  ):
97
+ global datastore
98
+ return update_metric(
99
+ datastore,
100
+ TaskType.long_doc,
101
+ metric,
102
+ domains,
103
+ langs,
104
+ reranking_model,
105
+ query,
106
+ show_anonymous,
107
+ show_revision_and_timestamp,
108
+ )
109
+
110
+
111
+ def update_qa_version(version):
112
+ global datastore
113
+ global ds_dict
114
+ datastore = ds_dict[version]
115
+ domain_elem = get_domain_dropdown(QABenchmarks[datastore.slug])
116
+ lang_elem = get_language_dropdown(QABenchmarks[datastore.slug])
117
+ model_elem = get_reranking_dropdown(datastore.reranking_models)
118
+ df_elem = get_leaderboard_table(datastore.qa_fmt_df, datastore.qa_types)
119
+ hidden_df_elem = get_leaderboard_table(datastore.qa_raw_df, datastore.qa_types, visible=False)
120
+ return domain_elem, lang_elem, model_elem, df_elem, hidden_df_elem
121
+
122
+
123
+ def update_doc_version(version):
124
+ global datastore
125
+ global ds_dict
126
+ datastore = ds_dict[version]
127
+ domain_elem = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
128
+ lang_elem = get_language_dropdown(LongDocBenchmarks[datastore.slug])
129
+ model_elem = get_reranking_dropdown(datastore.reranking_models)
130
+ df_elem = get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
131
+ hidden_df_elem = get_leaderboard_table(datastore.doc_raw_df, datastore.doc_types, visible=False)
132
+ return domain_elem, lang_elem, model_elem, df_elem, hidden_df_elem
133
 
134
 
135
  demo = gr.Blocks(css=custom_css)
136
 
137
+ BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
138
+
139
  with demo:
140
  gr.HTML(TITLE)
141
  gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
 
143
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
144
  with gr.TabItem("Results", elem_id="results-tab-table"):
145
  with gr.Row():
146
+ version = get_version_dropdown()
147
 
148
  with gr.TabItem("QA", elem_id="qa-benchmark-tab-table", id=0):
149
  with gr.Row():
150
  with gr.Column(min_width=320):
151
  # select domain
152
  with gr.Row():
153
+ domains = get_domain_dropdown(QABenchmarks[datastore.slug])
154
  # select language
155
  with gr.Row():
156
+ langs = get_language_dropdown(QABenchmarks[datastore.slug])
 
157
  with gr.Column():
158
  # select the metric
159
+ metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_QA)
160
  with gr.Row():
161
  show_anonymous = get_anonymous_checkbox()
162
  with gr.Row():
163
+ show_rev_ts = get_revision_and_ts_checkbox()
164
  with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
165
  with gr.TabItem("Retrieval + Reranking", id=10):
166
  with gr.Row():
 
169
  search_bar = get_search_bar()
170
  # select reranking models
171
  with gr.Column():
172
+ models = get_reranking_dropdown(datastore.reranking_models)
173
+ # shown_table
174
+ qa_df_elem_ret_rerank = get_leaderboard_table(datastore.qa_fmt_df, datastore.qa_types)
175
  # Dummy leaderboard for handling the case when the user uses backspace key
176
+ qa_df_elem_ret_rerank_hidden = get_leaderboard_table(
177
+ datastore.qa_raw_df, datastore.qa_types, visible=False
178
+ )
179
+
180
+ version.change(
181
+ update_qa_version,
182
+ version,
183
+ [domains, langs, models, qa_df_elem_ret_rerank, qa_df_elem_ret_rerank_hidden],
184
+ )
185
 
186
  set_listeners(
187
+ TaskType.qa,
188
+ qa_df_elem_ret_rerank,
189
+ qa_df_elem_ret_rerank_hidden,
190
  search_bar,
191
+ version,
192
+ domains,
193
+ langs,
194
+ models,
195
  show_anonymous,
196
+ show_rev_ts,
197
  )
198
 
199
  # set metric listener
200
+ metric.change(
201
+ update_qa_metric,
202
+ [metric, domains, langs, models, search_bar, show_anonymous, show_rev_ts],
203
+ qa_df_elem_ret_rerank,
204
+ queue=True,
 
 
 
 
 
 
 
 
205
  )
206
+
207
  with gr.TabItem("Retrieval Only", id=11):
208
  with gr.Row():
209
  with gr.Column(scale=1):
210
+ search_bar_ret = get_search_bar()
211
  with gr.Column(scale=1):
212
+ models_ret = get_noreranking_dropdown()
213
+
214
+ _qa_df_ret = datastore.qa_fmt_df[datastore.qa_fmt_df[COL_NAME_RERANKING_MODEL] == "NoReranker"]
215
+ _qa_df_ret = reset_rank(_qa_df_ret)
216
+ qa_df_elem_ret = get_leaderboard_table(_qa_df_ret, datastore.qa_types)
217
+
218
  # Dummy leaderboard for handling the case when the user uses backspace key
219
+ _qa_df_ret_hidden = datastore.qa_raw_df[
220
+ datastore.qa_raw_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
221
+ ]
222
+ _qa_df_ret_hidden = reset_rank(_qa_df_ret_hidden)
223
+ qa_df_elem_ret_hidden = get_leaderboard_table(
224
+ _qa_df_ret_hidden, datastore.qa_types, visible=False
225
+ )
226
+
227
+ version.change(
228
+ update_qa_version,
229
+ version,
230
+ [
231
+ domains,
232
+ langs,
233
+ models_ret,
234
+ qa_df_elem_ret,
235
+ qa_df_elem_ret_hidden,
236
+ ],
237
+ )
238
 
239
  set_listeners(
240
+ TaskType.qa,
241
+ qa_df_elem_ret,
242
+ qa_df_elem_ret_hidden,
243
+ search_bar_ret,
244
+ version,
245
+ domains,
246
+ langs,
247
+ models_ret,
248
  show_anonymous,
249
+ show_rev_ts,
250
  )
251
 
252
+ metric.change(
253
+ update_qa_metric,
 
254
  [
255
+ metric,
256
+ domains,
257
+ langs,
258
+ models_ret,
259
+ search_bar_ret,
260
  show_anonymous,
261
+ show_rev_ts,
262
  ],
263
+ qa_df_elem_ret,
264
+ queue=True,
265
  )
266
+
267
  with gr.TabItem("Reranking Only", id=12):
268
+ _qa_df_rerank = datastore.qa_fmt_df[datastore.qa_fmt_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
269
+ _qa_df_rerank = reset_rank(_qa_df_rerank)
270
+ qa_rerank_models = _qa_df_rerank[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
271
  with gr.Row():
272
  with gr.Column(scale=1):
273
+ qa_models_rerank = get_reranking_dropdown(qa_rerank_models)
274
  with gr.Column(scale=1):
275
+ qa_search_bar_rerank = gr.Textbox(show_label=False, visible=False)
276
+ qa_df_elem_rerank = get_leaderboard_table(_qa_df_rerank, datastore.qa_types)
277
+
278
+ _qa_df_rerank_hidden = datastore.qa_raw_df[
279
+ datastore.qa_raw_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
280
+ ]
281
+ _qa_df_rerank_hidden = reset_rank(_qa_df_rerank_hidden)
282
+ qa_df_elem_rerank_hidden = get_leaderboard_table(
283
+ _qa_df_rerank_hidden, datastore.qa_types, visible=False
284
+ )
285
+
286
+ version.change(
287
+ update_qa_version,
288
+ version,
289
+ [domains, langs, qa_models_rerank, qa_df_elem_rerank, qa_df_elem_rerank_hidden],
290
  )
291
 
292
  set_listeners(
293
+ TaskType.qa,
294
+ qa_df_elem_rerank,
295
+ qa_df_elem_rerank_hidden,
296
+ qa_search_bar_rerank,
297
+ version,
298
+ domains,
299
+ langs,
300
+ qa_models_rerank,
301
  show_anonymous,
302
+ show_rev_ts,
303
  )
304
+
305
+ metric.change(
306
+ update_qa_metric,
307
  [
308
+ metric,
309
+ domains,
310
+ langs,
311
+ qa_models_rerank,
312
+ qa_search_bar_rerank,
313
  show_anonymous,
314
+ show_rev_ts,
315
  ],
316
+ qa_df_elem_rerank,
317
+ queue=True,
318
  )
319
  with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
320
  with gr.Row():
321
  with gr.Column(min_width=320):
322
  # select domain
323
  with gr.Row():
324
+ domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
325
  # select language
326
  with gr.Row():
327
+ langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
 
 
328
  with gr.Column():
329
  # select the metric
330
  with gr.Row():
331
+ metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_LONG_DOC)
332
  with gr.Row():
333
  show_anonymous = get_anonymous_checkbox()
334
  with gr.Row():
335
+ show_rev_ts = get_revision_and_ts_checkbox()
336
+ with gr.Tabs(elem_classes="tab-buttons"):
337
  with gr.TabItem("Retrieval + Reranking", id=20):
338
  with gr.Row():
339
  with gr.Column():
340
  search_bar = get_search_bar()
 
341
  with gr.Column():
342
+ models = get_reranking_dropdown(datastore.reranking_models)
343
 
344
+ doc_df_elem_ret_rerank = get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
 
 
345
 
346
  # Dummy leaderboard for handling the case when the user uses backspace key
347
+ doc_df_elem_ret_rerank_hidden = get_leaderboard_table(
348
+ datastore.doc_raw_df, datastore.doc_types, visible=False
349
+ )
350
+
351
+ version.change(
352
+ update_doc_version,
353
+ version,
354
+ [domains, langs, models, doc_df_elem_ret_rerank, doc_df_elem_ret_rerank_hidden],
355
  )
356
 
357
  set_listeners(
358
+ TaskType.long_doc,
359
+ doc_df_elem_ret_rerank,
360
+ doc_df_elem_ret_rerank_hidden,
361
  search_bar,
362
+ version,
363
+ domains,
364
+ langs,
365
+ models,
366
  show_anonymous,
367
+ show_rev_ts,
368
  )
369
 
370
  # set metric listener
371
+ metric.change(
372
+ update_doc_metric,
373
  [
374
+ metric,
375
+ domains,
376
+ langs,
377
+ models,
378
  search_bar,
379
  show_anonymous,
380
+ show_rev_ts,
381
  ],
382
+ doc_df_elem_ret_rerank,
383
+ queue=True,
384
  )
385
  with gr.TabItem("Retrieval Only", id=21):
386
  with gr.Row():
387
  with gr.Column(scale=1):
388
+ search_bar_ret = get_search_bar()
389
  with gr.Column(scale=1):
390
+ models_ret = get_noreranking_dropdown()
391
+
392
+ _doc_df_ret = datastore.doc_fmt_df[
393
+ datastore.doc_fmt_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
394
  ]
395
+ _doc_df_ret = reset_rank(_doc_df_ret)
396
+ doc_df_elem_ret = get_leaderboard_table(_doc_df_ret, datastore.doc_types)
397
+
398
+ _doc_df_ret_hidden = datastore.doc_raw_df[
399
+ datastore.doc_raw_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
400
  ]
401
+ _doc_df_ret_hidden = reset_rank(_doc_df_ret_hidden)
402
+ doc_df_elem_ret_hidden = get_leaderboard_table(
403
+ _doc_df_ret_hidden, datastore.doc_types, visible=False
404
+ )
405
+
406
+ version.change(
407
+ update_doc_version,
408
+ version,
409
+ [domains, langs, models_ret, doc_df_elem_ret, doc_df_elem_ret_hidden],
410
  )
411
 
412
  set_listeners(
413
+ TaskType.long_doc,
414
+ doc_df_elem_ret,
415
+ doc_df_elem_ret_hidden,
416
+ search_bar_ret,
417
+ version,
418
+ domains,
419
+ langs,
420
+ models_ret,
421
  show_anonymous,
422
+ show_rev_ts,
423
  )
424
 
425
+ metric.change(
426
+ update_doc_metric,
427
  [
428
+ metric,
429
+ domains,
430
+ langs,
431
+ models_ret,
432
+ search_bar_ret,
433
  show_anonymous,
434
+ show_rev_ts,
435
  ],
436
+ doc_df_elem_ret,
437
+ queue=True,
438
  )
439
  with gr.TabItem("Reranking Only", id=22):
440
+ _doc_df_rerank = datastore.doc_fmt_df[
441
+ datastore.doc_fmt_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
442
+ ]
443
+ _doc_df_rerank = reset_rank(_doc_df_rerank)
444
+ doc_rerank_models = (
445
+ _doc_df_rerank[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
446
+ )
447
  with gr.Row():
448
  with gr.Column(scale=1):
449
+ doc_models_rerank = get_reranking_dropdown(doc_rerank_models)
450
  with gr.Column(scale=1):
451
+ doc_search_bar_rerank = gr.Textbox(show_label=False, visible=False)
452
+ doc_df_elem_rerank = get_leaderboard_table(_doc_df_rerank, datastore.doc_types)
453
+ _doc_df_rerank_hidden = datastore.doc_raw_df[
454
+ datastore.doc_raw_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
455
+ ]
456
+ _doc_df_rerank_hidden = reset_rank(_doc_df_rerank_hidden)
457
+ doc_df_elem_rerank_hidden = get_leaderboard_table(
458
+ _doc_df_rerank_hidden, datastore.doc_types, visible=False
459
+ )
460
+
461
+ version.change(
462
+ update_doc_version,
463
+ version,
464
+ [domains, langs, doc_models_rerank, doc_df_elem_rerank, doc_df_elem_rerank_hidden],
465
  )
466
 
467
  set_listeners(
468
+ TaskType.long_doc,
469
+ doc_df_elem_rerank,
470
+ doc_df_elem_rerank_hidden,
471
+ doc_search_bar_rerank,
472
+ version,
473
+ domains,
474
+ langs,
475
+ doc_models_rerank,
476
  show_anonymous,
477
+ show_rev_ts,
478
  )
479
+
480
+ metric.change(
481
+ update_doc_metric,
482
  [
483
+ metric,
484
+ domains,
485
+ langs,
486
+ doc_models_rerank,
487
+ doc_search_bar_rerank,
488
  show_anonymous,
489
+ show_rev_ts,
490
  ],
491
+ doc_df_elem_rerank,
492
+ queue=True,
493
  )
494
 
495
  with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
 
506
  with gr.Row():
507
  with gr.Column():
508
  reranking_model_name = gr.Textbox(
509
+ label="Reranking Model name", info="Optional", value="NoReranker"
 
 
510
  )
511
  with gr.Column():
512
+ reranking_model_url = gr.Textbox(label="Reranking Model URL", info="Optional", value="")
 
 
 
 
513
  with gr.Row():
514
  with gr.Column():
515
  benchmark_version = gr.Dropdown(
516
  BENCHMARK_VERSION_LIST,
517
  value=LATEST_BENCHMARK_VERSION,
518
  interactive=True,
519
+ label="AIR-Bench Version",
520
+ )
521
  with gr.Row():
522
  upload_button = gr.UploadButton("Click to upload search results", file_count="single")
523
  with gr.Row():
 
526
  is_anonymous = gr.Checkbox(
527
  label="Nope. I want to submit anonymously 🥷",
528
  value=False,
529
+ info="Do you want to shown on the leaderboard by default?",
530
+ )
531
  with gr.Row():
532
  submit_button = gr.Button("Submit")
533
  with gr.Row():
 
537
  [
538
  upload_button,
539
  ],
540
+ file_output,
541
+ )
542
  submit_button.click(
543
  submit_results,
544
  [
 
548
  reranking_model_name,
549
  reranking_model_url,
550
  benchmark_version,
551
+ is_anonymous,
552
  ],
553
  submission_result,
554
+ show_progress="hidden",
555
  )
556
 
557
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
pyproject.toml CHANGED
@@ -1,9 +1,9 @@
1
  [tool.ruff]
2
  # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
3
- select = ["E", "F"]
4
- ignore = ["E501"] # line too long (black is taking care of this)
5
  line-length = 119
6
- fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
7
 
8
  [tool.isort]
9
  profile = "black"
 
1
  [tool.ruff]
2
  # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
3
+ lint.select = ["E", "F"]
4
+ lint.ignore = ["E501"] # line too long (black is taking care of this)
5
  line-length = 119
6
+ lint.fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
7
 
8
  [tool.isort]
9
  profile = "black"
requirements.txt CHANGED
@@ -2,7 +2,7 @@ APScheduler>=3.10.1
2
  black>=23.11.0
3
  click>=8.1.3
4
  datasets>=2.14.5
5
- gradio>=4.29.0
6
  gradio_client>=0.16.1
7
  huggingface-hub>=0.18.0
8
  numpy>=1.24.2
@@ -12,4 +12,4 @@ requests>=2.31.0
12
  tqdm>=4.65.0
13
  accelerate>=0.24.1
14
  socksio>=1.0.0
15
- air-benchmark>=0.0.4
 
2
  black>=23.11.0
3
  click>=8.1.3
4
  datasets>=2.14.5
5
+ gradio<5.0.0
6
  gradio_client>=0.16.1
7
  huggingface-hub>=0.18.0
8
  numpy>=1.24.2
 
12
  tqdm>=4.65.0
13
  accelerate>=0.24.1
14
  socksio>=1.0.0
15
+ air-benchmark>=0.1.0
src/about.py CHANGED
@@ -8,7 +8,7 @@ INTRODUCTION_TEXT = """
8
  """
9
 
10
  # Which evaluations are you running? how can people reproduce what you have?
11
- BENCHMARKS_TEXT = f"""
12
  ## How the test data are generated?
13
  ### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
14
 
 
8
  """
9
 
10
  # Which evaluations are you running? how can people reproduce what you have?
11
+ BENCHMARKS_TEXT = """
12
  ## How the test data are generated?
13
  ### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
14
 
src/benchmarks.py CHANGED
@@ -1,92 +1,71 @@
1
  from dataclasses import dataclass
2
  from enum import Enum
3
- from air_benchmark.tasks.tasks import BenchmarkTable
4
-
5
-
6
- def get_safe_name(name: str):
7
- """Get RFC 1123 compatible safe name"""
8
- name = name.replace('-', '_')
9
- return ''.join(
10
- character.lower()
11
- for character in name
12
- if (character.isalnum() or character == '_'))
13
 
 
14
 
15
- METRIC_LIST = [
16
- "ndcg_at_1",
17
- "ndcg_at_3",
18
- "ndcg_at_5",
19
- "ndcg_at_10",
20
- "ndcg_at_100",
21
- "ndcg_at_1000",
22
- "map_at_1",
23
- "map_at_3",
24
- "map_at_5",
25
- "map_at_10",
26
- "map_at_100",
27
- "map_at_1000",
28
- "recall_at_1",
29
- "recall_at_3",
30
- "recall_at_5",
31
- "recall_at_10",
32
- "recall_at_100",
33
- "recall_at_1000",
34
- "precision_at_1",
35
- "precision_at_3",
36
- "precision_at_5",
37
- "precision_at_10",
38
- "precision_at_100",
39
- "precision_at_1000",
40
- "mrr_at_1",
41
- "mrr_at_3",
42
- "mrr_at_5",
43
- "mrr_at_10",
44
- "mrr_at_100",
45
- "mrr_at_1000"
46
- ]
47
 
48
 
49
  @dataclass
50
  class Benchmark:
51
  name: str # [domain]_[language]_[metric], task_key in the json file,
52
- metric: str # ndcg_at_1 ,metric_key in the json file
53
  col_name: str # [domain]_[language], name to display in the leaderboard
54
  domain: str
55
  lang: str
56
  task: str
57
 
58
 
59
- qa_benchmark_dict = {}
60
- long_doc_benchmark_dict = {}
61
- for task, domain_dict in BenchmarkTable['AIR-Bench_24.04'].items():
62
- for domain, lang_dict in domain_dict.items():
63
- for lang, dataset_list in lang_dict.items():
64
- if task == "qa":
65
- benchmark_name = f"{domain}_{lang}"
66
- benchmark_name = get_safe_name(benchmark_name)
 
67
  col_name = benchmark_name
68
  for metric in dataset_list:
69
- qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
70
- elif task == "long-doc":
 
 
 
 
 
 
 
 
 
 
 
71
  for dataset in dataset_list:
72
  benchmark_name = f"{domain}_{lang}_{dataset}"
73
  benchmark_name = get_safe_name(benchmark_name)
74
  col_name = benchmark_name
 
 
75
  for metric in METRIC_LIST:
76
- long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain,
77
- lang, task)
 
 
78
 
79
- BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
80
- BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
81
 
82
- BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()]
83
- BENCHMARK_COLS_LONG_DOC = [c.col_name for c in long_doc_benchmark_dict.values()]
 
 
84
 
85
- DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
86
- LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))
 
 
 
 
87
 
88
- DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
89
- LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))
90
 
91
- DEFAULT_METRIC_QA = "ndcg_at_10"
92
- DEFAULT_METRIC_LONG_DOC = "recall_at_10"
 
1
  from dataclasses import dataclass
2
  from enum import Enum
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from air_benchmark.tasks.tasks import BenchmarkTable
5
 
6
+ from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
7
+ from src.models import TaskType, get_safe_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  @dataclass
11
  class Benchmark:
12
  name: str # [domain]_[language]_[metric], task_key in the json file,
13
+ metric: str # metric_key in the json file
14
  col_name: str # [domain]_[language], name to display in the leaderboard
15
  domain: str
16
  lang: str
17
  task: str
18
 
19
 
20
+ # create a function return an enum class containing all the benchmarks
21
+ def get_qa_benchmarks_dict(version: str):
22
+ benchmark_dict = {}
23
+ for task, domain_dict in BenchmarkTable[version].items():
24
+ if task != TaskType.qa.value:
25
+ continue
26
+ for domain, lang_dict in domain_dict.items():
27
+ for lang, dataset_list in lang_dict.items():
28
+ benchmark_name = get_safe_name(f"{domain}_{lang}")
29
  col_name = benchmark_name
30
  for metric in dataset_list:
31
+ if "test" not in dataset_list[metric]["splits"]:
32
+ continue
33
+ benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
34
+ return benchmark_dict
35
+
36
+
37
+ def get_doc_benchmarks_dict(version: str):
38
+ benchmark_dict = {}
39
+ for task, domain_dict in BenchmarkTable[version].items():
40
+ if task != TaskType.long_doc.value:
41
+ continue
42
+ for domain, lang_dict in domain_dict.items():
43
+ for lang, dataset_list in lang_dict.items():
44
  for dataset in dataset_list:
45
  benchmark_name = f"{domain}_{lang}_{dataset}"
46
  benchmark_name = get_safe_name(benchmark_name)
47
  col_name = benchmark_name
48
+ if "test" not in dataset_list[dataset]["splits"]:
49
+ continue
50
  for metric in METRIC_LIST:
51
+ benchmark_dict[benchmark_name] = Benchmark(
52
+ benchmark_name, metric, col_name, domain, lang, task
53
+ )
54
+ return benchmark_dict
55
 
 
 
56
 
57
+ _qa_benchmark_dict = {}
58
+ for version in BENCHMARK_VERSION_LIST:
59
+ safe_version_name = get_safe_name(version)
60
+ _qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_qa_benchmarks_dict(version))
61
 
62
+ _doc_benchmark_dict = {}
63
+ for version in BENCHMARK_VERSION_LIST:
64
+ safe_version_name = get_safe_name(version)
65
+ _doc_benchmark_dict[safe_version_name] = Enum(
66
+ f"LongDocBenchmarks_{safe_version_name}", get_doc_benchmarks_dict(version)
67
+ )
68
 
 
 
69
 
70
+ QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
71
+ LongDocBenchmarks = Enum("LongDocBenchmarks", _doc_benchmark_dict)
src/{display/utils.py → columns.py} RENAMED
@@ -1,9 +1,7 @@
1
  from dataclasses import dataclass, make_dataclass
2
 
3
- from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
4
 
5
-
6
- def fields(raw_class):
7
  return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
8
 
9
 
@@ -19,28 +17,22 @@ class ColumnContent:
19
  never_hidden: bool = False
20
 
21
 
22
- COL_NAME_AVG = "Average ⬆️"
23
- COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
24
- COL_NAME_RERANKING_MODEL = "Reranking Model"
25
- COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
26
- COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
27
- COL_NAME_RANK = "Rank 🏆"
28
- COL_NAME_REVISION = "Revision"
29
- COL_NAME_TIMESTAMP = "Submission Date"
30
- COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
31
-
32
-
33
  def get_default_auto_eval_column_dict():
34
  auto_eval_column_dict = []
35
- # Init
36
  auto_eval_column_dict.append(
37
- ["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)]
 
 
 
 
38
  )
39
  auto_eval_column_dict.append(
40
- ["retrieval_model", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, hidden=False, never_hidden=True)]
41
- )
42
- auto_eval_column_dict.append(
43
- ["reranking_model", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, hidden=False, never_hidden=True)]
 
44
  )
45
  auto_eval_column_dict.append(
46
  ["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
@@ -48,14 +40,30 @@ def get_default_auto_eval_column_dict():
48
  auto_eval_column_dict.append(
49
  ["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
50
  )
 
51
  auto_eval_column_dict.append(
52
- ["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)]
53
- )
54
- auto_eval_column_dict.append(
55
- ["retrieval_model_link", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False)]
 
 
 
 
 
 
56
  )
57
  auto_eval_column_dict.append(
58
- ["reranking_model_link", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False)]
 
 
 
 
 
 
 
 
 
59
  )
60
  auto_eval_column_dict.append(
61
  ["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
@@ -63,10 +71,10 @@ def get_default_auto_eval_column_dict():
63
  return auto_eval_column_dict
64
 
65
 
66
- def make_autoevalcolumn(cls_name="BenchmarksQA", benchmarks=BenchmarksQA):
67
  auto_eval_column_dict = get_default_auto_eval_column_dict()
68
- ## Leaderboard columns
69
- for benchmark in benchmarks:
70
  auto_eval_column_dict.append(
71
  [benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
72
  )
@@ -75,19 +83,24 @@ def make_autoevalcolumn(cls_name="BenchmarksQA", benchmarks=BenchmarksQA):
75
  return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
76
 
77
 
78
- AutoEvalColumnQA = make_autoevalcolumn(
79
- "AutoEvalColumnQA", BenchmarksQA)
80
- AutoEvalColumnLongDoc = make_autoevalcolumn(
81
- "AutoEvalColumnLongDoc", BenchmarksLongDoc)
 
82
 
83
 
84
- # Column selection
85
- COLS_QA = [c.name for c in fields(AutoEvalColumnQA) if not c.hidden]
86
- COLS_LONG_DOC = [c.name for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
87
- TYPES_QA = [c.type for c in fields(AutoEvalColumnQA) if not c.hidden]
88
- TYPES_LONG_DOC = [c.type for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
89
- COLS_LITE = [c.name for c in fields(AutoEvalColumnQA) if c.displayed_by_default and not c.hidden]
90
 
91
- QA_BENCHMARK_COLS = [t.value.col_name for t in BenchmarksQA]
92
 
93
- LONG_DOC_BENCHMARK_COLS = [t.value.col_name for t in BenchmarksLongDoc]
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass, make_dataclass
2
 
 
3
 
4
+ def _fields(raw_class):
 
5
  return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
6
 
7
 
 
17
  never_hidden: bool = False
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  def get_default_auto_eval_column_dict():
21
  auto_eval_column_dict = []
22
+ auto_eval_column_dict.append(["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)])
23
  auto_eval_column_dict.append(
24
+ [
25
+ "retrieval_model",
26
+ ColumnContent,
27
+ ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, never_hidden=True),
28
+ ]
29
  )
30
  auto_eval_column_dict.append(
31
+ [
32
+ "reranking_model",
33
+ ColumnContent,
34
+ ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, never_hidden=True),
35
+ ]
36
  )
37
  auto_eval_column_dict.append(
38
  ["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
 
40
  auto_eval_column_dict.append(
41
  ["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
42
  )
43
+ auto_eval_column_dict.append(["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)])
44
  auto_eval_column_dict.append(
45
+ [
46
+ "retrieval_model_link",
47
+ ColumnContent,
48
+ ColumnContent(
49
+ COL_NAME_RETRIEVAL_MODEL_LINK,
50
+ "markdown",
51
+ False,
52
+ hidden=True,
53
+ ),
54
+ ]
55
  )
56
  auto_eval_column_dict.append(
57
+ [
58
+ "reranking_model_link",
59
+ ColumnContent,
60
+ ColumnContent(
61
+ COL_NAME_RERANKING_MODEL_LINK,
62
+ "markdown",
63
+ False,
64
+ hidden=True,
65
+ ),
66
+ ]
67
  )
68
  auto_eval_column_dict.append(
69
  ["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
 
71
  return auto_eval_column_dict
72
 
73
 
74
+ def make_autoevalcolumn(cls_name, benchmarks):
75
  auto_eval_column_dict = get_default_auto_eval_column_dict()
76
+ # Leaderboard columns
77
+ for benchmark in list(benchmarks.value):
78
  auto_eval_column_dict.append(
79
  [benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
80
  )
 
83
  return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
84
 
85
 
86
+ def get_default_col_names_and_types(benchmarks):
87
+ AutoEvalColumn = make_autoevalcolumn("AutoEvalColumn", benchmarks)
88
+ col_names = [c.name for c in _fields(AutoEvalColumn) if not c.hidden]
89
+ col_types = [c.type for c in _fields(AutoEvalColumn) if not c.hidden]
90
+ return col_names, col_types
91
 
92
 
93
+ def get_fixed_col_names_and_types():
94
+ fixed_cols = get_default_auto_eval_column_dict()[:-3]
95
+ return [c.name for _, _, c in fixed_cols], [c.type for _, _, c in fixed_cols]
 
 
 
96
 
 
97
 
98
+ COL_NAME_AVG = "Average ⬆️"
99
+ COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
100
+ COL_NAME_RERANKING_MODEL = "Reranking Model"
101
+ COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
102
+ COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
103
+ COL_NAME_RANK = "Rank 🏆"
104
+ COL_NAME_REVISION = "Revision"
105
+ COL_NAME_TIMESTAMP = "Submission Date"
106
+ COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
src/{display/gradio_formatting.py → components.py} RENAMED
@@ -1,12 +1,14 @@
1
  import gradio as gr
 
2
  from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
3
 
 
4
  def get_version_dropdown():
5
  return gr.Dropdown(
6
  choices=BENCHMARK_VERSION_LIST,
7
  value=LATEST_BENCHMARK_VERSION,
8
  label="Select the version of AIR-Bench",
9
- interactive=True
10
  )
11
 
12
 
@@ -14,26 +16,25 @@ def get_search_bar():
14
  return gr.Textbox(
15
  placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
16
  show_label=False,
17
- info="Search the retrieval methods"
18
  )
19
 
20
 
21
  def get_reranking_dropdown(model_list):
22
- return gr.Dropdown(
23
- choices=model_list,
24
- label="Select the reranking models",
25
- interactive=True,
26
- multiselect=True
27
- )
28
 
29
 
30
  def get_noreranking_dropdown():
31
  return gr.Dropdown(
32
- choices=["NoReranker", ],
33
- value=["NoReranker", ],
 
 
 
 
34
  interactive=False,
35
  multiselect=True,
36
- visible=False
37
  )
38
 
39
 
@@ -52,7 +53,10 @@ def get_metric_dropdown(metric_list, default_metrics):
52
  )
53
 
54
 
55
- def get_domain_dropdown(domain_list, default_domains):
 
 
 
56
  return gr.CheckboxGroup(
57
  choices=domain_list,
58
  value=default_domains,
@@ -61,13 +65,16 @@ def get_domain_dropdown(domain_list, default_domains):
61
  )
62
 
63
 
64
- def get_language_dropdown(language_list, default_languages):
 
 
 
65
  return gr.Dropdown(
66
  choices=language_list,
67
- value=language_list,
68
  label="Select the languages",
69
  multiselect=True,
70
- interactive=True
71
  )
72
 
73
 
@@ -75,15 +82,13 @@ def get_anonymous_checkbox():
75
  return gr.Checkbox(
76
  label="Show anonymous submissions",
77
  value=False,
78
- info="The anonymous submissions might have invalid model information."
79
  )
80
 
81
 
82
  def get_revision_and_ts_checkbox():
83
  return gr.Checkbox(
84
- label="Show submission details",
85
- value=False,
86
- info="Show the revision and timestamp information of submissions"
87
  )
88
 
89
 
 
1
  import gradio as gr
2
+
3
  from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
4
 
5
+
6
  def get_version_dropdown():
7
  return gr.Dropdown(
8
  choices=BENCHMARK_VERSION_LIST,
9
  value=LATEST_BENCHMARK_VERSION,
10
  label="Select the version of AIR-Bench",
11
+ interactive=True,
12
  )
13
 
14
 
 
16
  return gr.Textbox(
17
  placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
18
  show_label=False,
19
+ info="Search the retrieval methods",
20
  )
21
 
22
 
23
  def get_reranking_dropdown(model_list):
24
+ return gr.Dropdown(choices=model_list, label="Select the reranking models", interactive=True, multiselect=True)
 
 
 
 
 
25
 
26
 
27
  def get_noreranking_dropdown():
28
  return gr.Dropdown(
29
+ choices=[
30
+ "NoReranker",
31
+ ],
32
+ value=[
33
+ "NoReranker",
34
+ ],
35
  interactive=False,
36
  multiselect=True,
37
+ visible=False,
38
  )
39
 
40
 
 
53
  )
54
 
55
 
56
+ def get_domain_dropdown(benchmarks, default_domains=None):
57
+ domain_list = list(frozenset([c.value.domain for c in list(benchmarks.value)]))
58
+ if default_domains is None:
59
+ default_domains = domain_list
60
  return gr.CheckboxGroup(
61
  choices=domain_list,
62
  value=default_domains,
 
65
  )
66
 
67
 
68
+ def get_language_dropdown(benchmarks, default_languages=None):
69
+ language_list = list(frozenset([c.value.lang for c in list(benchmarks.value)]))
70
+ if default_languages is None:
71
+ default_languages = language_list
72
  return gr.Dropdown(
73
  choices=language_list,
74
+ value=default_languages,
75
  label="Select the languages",
76
  multiselect=True,
77
+ interactive=True,
78
  )
79
 
80
 
 
82
  return gr.Checkbox(
83
  label="Show anonymous submissions",
84
  value=False,
85
+ info="The anonymous submissions might have invalid model information.",
86
  )
87
 
88
 
89
  def get_revision_and_ts_checkbox():
90
  return gr.Checkbox(
91
+ label="Show submission details", value=False, info="Show the revision and timestamp information of submissions"
 
 
92
  )
93
 
94
 
src/{display/css_html_js.py → css_html_js.py} RENAMED
File without changes
src/display/formatting.py DELETED
@@ -1,29 +0,0 @@
1
- def model_hyperlink(link, model_name):
2
- return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
3
-
4
-
5
- def make_clickable_model(model_name: str, model_link: str):
6
- # link = f"https://huggingface.co/{model_name}"
7
- if not model_link or not model_link.startswith("https://"):
8
- return model_name
9
- return model_hyperlink(model_link, model_name)
10
-
11
-
12
- def styled_error(error):
13
- return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
14
-
15
-
16
- def styled_warning(warn):
17
- return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
18
-
19
-
20
- def styled_message(message):
21
- return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
22
-
23
-
24
- def has_no_nan_values(df, columns):
25
- return df[columns].notna().all(axis=1)
26
-
27
-
28
- def has_nan_values(df, columns):
29
- return df[columns].isna().any(axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/display/gradio_listener.py DELETED
@@ -1,53 +0,0 @@
1
- from src.utils import update_table, update_table_long_doc
2
-
3
-
4
- def set_listeners(
5
- task,
6
- displayed_leaderboard,
7
- hidden_leaderboard,
8
- search_bar,
9
- selected_domains,
10
- selected_langs,
11
- selected_rerankings,
12
- show_anonymous,
13
- show_revision_and_timestamp,
14
-
15
- ):
16
- if task == "qa":
17
- update_table_func = update_table
18
- elif task == "long-doc":
19
- update_table_func = update_table_long_doc
20
- else:
21
- raise NotImplementedError
22
- # Set search_bar listener
23
- search_bar.submit(
24
- update_table_func,
25
- [
26
- hidden_leaderboard, # hidden_leaderboard_table_for_search,
27
- selected_domains,
28
- selected_langs,
29
- selected_rerankings,
30
- search_bar,
31
- show_anonymous,
32
- ],
33
- displayed_leaderboard
34
- )
35
-
36
- # Set column-wise listener
37
- for selector in [
38
- selected_domains, selected_langs, show_anonymous, show_revision_and_timestamp, selected_rerankings
39
- ]:
40
- selector.change(
41
- update_table_func,
42
- [
43
- hidden_leaderboard,
44
- selected_domains,
45
- selected_langs,
46
- selected_rerankings,
47
- search_bar,
48
- show_anonymous,
49
- show_revision_and_timestamp
50
- ],
51
- displayed_leaderboard,
52
- queue=True,
53
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/envs.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
- from src.display.formatting import model_hyperlink
3
  from huggingface_hub import HfApi
4
 
5
  # Info to change for your repository
6
  # ----------------------------------
7
  TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
8
 
9
- OWNER = "AIR-Bench" # "nan" # Change to your org - don't forget to create a results and request dataset, with the correct format!
 
 
10
  # ----------------------------------
11
 
12
  REPO_ID = f"{OWNER}/leaderboard"
@@ -15,7 +17,7 @@ RESULTS_REPO = f"{OWNER}/eval_results"
15
  # repo for submitting the evaluation
16
  SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
17
 
18
- # If you setup a cache later, just change HF_HOME
19
  CACHE_PATH = os.getenv("HF_HOME", ".")
20
 
21
  # Local caches
@@ -23,11 +25,43 @@ EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval_results")
23
 
24
  API = HfApi(token=TOKEN)
25
 
26
- BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
27
-
28
  BENCHMARK_VERSION_LIST = [
29
  "AIR-Bench_24.04",
30
  "AIR-Bench_24.05",
31
  ]
32
 
33
- LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+
3
  from huggingface_hub import HfApi
4
 
5
  # Info to change for your repository
6
  # ----------------------------------
7
  TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
8
 
9
+ OWNER = (
10
+ "AIR-Bench" # Change to your org - don't forget to create a results and request dataset, with the correct format!
11
+ )
12
  # ----------------------------------
13
 
14
  REPO_ID = f"{OWNER}/leaderboard"
 
17
  # repo for submitting the evaluation
18
  SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
19
 
20
+ # If you set up a cache later, just change HF_HOME
21
  CACHE_PATH = os.getenv("HF_HOME", ".")
22
 
23
  # Local caches
 
25
 
26
  API = HfApi(token=TOKEN)
27
 
 
 
28
  BENCHMARK_VERSION_LIST = [
29
  "AIR-Bench_24.04",
30
  "AIR-Bench_24.05",
31
  ]
32
 
33
+ LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[0]
34
+ DEFAULT_METRIC_QA = "ndcg_at_10"
35
+ DEFAULT_METRIC_LONG_DOC = "recall_at_10"
36
+ METRIC_LIST = [
37
+ "ndcg_at_1",
38
+ "ndcg_at_3",
39
+ "ndcg_at_5",
40
+ "ndcg_at_10",
41
+ "ndcg_at_100",
42
+ "ndcg_at_1000",
43
+ "map_at_1",
44
+ "map_at_3",
45
+ "map_at_5",
46
+ "map_at_10",
47
+ "map_at_100",
48
+ "map_at_1000",
49
+ "recall_at_1",
50
+ "recall_at_3",
51
+ "recall_at_5",
52
+ "recall_at_10",
53
+ "recall_at_100",
54
+ "recall_at_1000",
55
+ "precision_at_1",
56
+ "precision_at_3",
57
+ "precision_at_5",
58
+ "precision_at_10",
59
+ "precision_at_100",
60
+ "precision_at_1000",
61
+ "mrr_at_1",
62
+ "mrr_at_3",
63
+ "mrr_at_5",
64
+ "mrr_at_10",
65
+ "mrr_at_100",
66
+ "mrr_at_1000",
67
+ ]
src/loaders.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from pathlib import Path
3
+ from typing import Dict, List, Union
4
+
5
+ import pandas as pd
6
+
7
+ from src.columns import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP
8
+ from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA
9
+ from src.models import FullEvalResult, LeaderboardDataStore, TaskType, get_safe_name
10
+ from src.utils import get_default_cols, get_leaderboard_df, reset_rank
11
+
12
+ pd.options.mode.copy_on_write = True
13
+
14
+
15
+ def load_raw_eval_results(results_path: Union[Path, str]) -> List[FullEvalResult]:
16
+ """
17
+ Load the evaluation results from a json file
18
+ """
19
+ model_result_filepaths = []
20
+ for root, dirs, files in os.walk(results_path):
21
+ if len(files) == 0:
22
+ continue
23
+
24
+ # select the latest results
25
+ for file in files:
26
+ if not (file.startswith("results") and file.endswith(".json")):
27
+ print(f"skip {file}")
28
+ continue
29
+ model_result_filepaths.append(os.path.join(root, file))
30
+
31
+ eval_results = {}
32
+ for model_result_filepath in model_result_filepaths:
33
+ # create evaluation results
34
+ try:
35
+ eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
36
+ except UnicodeDecodeError:
37
+ print(f"loading file failed. {model_result_filepath}")
38
+ continue
39
+ print(f"file loaded: {model_result_filepath}")
40
+ timestamp = eval_result.timestamp
41
+ eval_results[timestamp] = eval_result
42
+
43
+ results = []
44
+ for k, v in eval_results.items():
45
+ try:
46
+ v.to_dict()
47
+ results.append(v)
48
+ except KeyError:
49
+ print(f"loading failed: {k}")
50
+ continue
51
+ return results
52
+
53
+
54
+ def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
55
+ ds = LeaderboardDataStore(version, get_safe_name(version))
56
+ ds.raw_data = load_raw_eval_results(file_path)
57
+ print(f"raw data: {len(ds.raw_data)}")
58
+
59
+ ds.qa_raw_df = get_leaderboard_df(ds, TaskType.qa, DEFAULT_METRIC_QA)
60
+ print(f"QA data loaded: {ds.qa_raw_df.shape}")
61
+ ds.qa_fmt_df = ds.qa_raw_df.copy()
62
+ qa_cols, ds.qa_types = get_default_cols(TaskType.qa, ds.slug, add_fix_cols=True)
63
+ # by default, drop the anonymous submissions
64
+ ds.qa_fmt_df = ds.qa_fmt_df[~ds.qa_fmt_df[COL_NAME_IS_ANONYMOUS]][qa_cols]
65
+ # reset the rank after dropping the anonymous submissions
66
+ ds.qa_fmt_df = reset_rank(ds.qa_fmt_df)
67
+ ds.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
68
+
69
+ ds.doc_raw_df = get_leaderboard_df(ds, TaskType.long_doc, DEFAULT_METRIC_LONG_DOC)
70
+ print(f"Long-Doc data loaded: {len(ds.doc_raw_df)}")
71
+ ds.doc_fmt_df = ds.doc_raw_df.copy()
72
+ doc_cols, ds.doc_types = get_default_cols(TaskType.long_doc, ds.slug, add_fix_cols=True)
73
+ # by default, drop the anonymous submissions
74
+ ds.doc_fmt_df = ds.doc_fmt_df[~ds.doc_fmt_df[COL_NAME_IS_ANONYMOUS]][doc_cols]
75
+ # reset the rank after dropping the anonymous submissions
76
+ ds.doc_fmt_df = reset_rank(ds.doc_fmt_df)
77
+ ds.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
78
+
79
+ ds.reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in ds.raw_data])))
80
+ return ds
81
+
82
+
83
+ def load_eval_results(file_path: Union[str, Path]) -> Dict[str, LeaderboardDataStore]:
84
+ output = {}
85
+ for version in BENCHMARK_VERSION_LIST:
86
+ fn = f"{file_path}/{version}"
87
+ output[version] = load_leaderboard_datastore(fn, version)
88
+ return output
src/{read_evals.py → models.py} RENAMED
@@ -1,38 +1,21 @@
1
  import json
2
- import os.path
3
  from collections import defaultdict
4
  from dataclasses import dataclass
 
5
  from typing import List
6
 
7
  import pandas as pd
8
 
9
- from src.benchmarks import get_safe_name
10
- from src.display.utils import (
11
  COL_NAME_RERANKING_MODEL,
12
- COL_NAME_RETRIEVAL_MODEL,
13
  COL_NAME_RERANKING_MODEL_LINK,
 
14
  COL_NAME_RETRIEVAL_MODEL_LINK,
15
  COL_NAME_REVISION,
16
  COL_NAME_TIMESTAMP,
17
- COL_NAME_IS_ANONYMOUS,
18
- COLS_QA,
19
- QA_BENCHMARK_COLS,
20
- COLS_LONG_DOC,
21
- LONG_DOC_BENCHMARK_COLS,
22
- COL_NAME_AVG,
23
- COL_NAME_RANK
24
  )
25
 
26
- from src.display.formatting import make_clickable_model
27
-
28
- pd.options.mode.copy_on_write = True
29
-
30
- def calculate_mean(row):
31
- if pd.isna(row).any():
32
- return -1
33
- else:
34
- return row.mean()
35
-
36
 
37
  @dataclass
38
  class EvalResult:
@@ -40,6 +23,7 @@ class EvalResult:
40
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
41
  domains, languages, and datasets
42
  """
 
43
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
44
  retrieval_model: str
45
  reranking_model: str
@@ -56,6 +40,7 @@ class FullEvalResult:
56
  """
57
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
58
  """
 
59
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
60
  retrieval_model: str
61
  reranking_model: str
@@ -79,7 +64,6 @@ class FullEvalResult:
79
  result_list = []
80
  retrieval_model_link = ""
81
  reranking_model_link = ""
82
- revision = ""
83
  for item in model_data:
84
  config = item.get("config", {})
85
  # eval results for different metrics
@@ -98,24 +82,26 @@ class FullEvalResult:
98
  metric=config["metric"],
99
  timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
100
  revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
101
- is_anonymous=config.get("is_anonymous", False)
102
  )
103
  result_list.append(eval_result)
 
104
  return cls(
105
- eval_name=f"{result_list[0].retrieval_model}_{result_list[0].reranking_model}",
106
- retrieval_model=result_list[0].retrieval_model,
107
- reranking_model=result_list[0].reranking_model,
108
  retrieval_model_link=retrieval_model_link,
109
  reranking_model_link=reranking_model_link,
110
  results=result_list,
111
- timestamp=result_list[0].timestamp,
112
- revision=result_list[0].revision,
113
- is_anonymous=result_list[0].is_anonymous
114
  )
115
 
116
- def to_dict(self, task='qa', metric='ndcg_at_3') -> List:
117
  """
118
- Convert the results in all the EvalResults over different tasks and metrics. The output is a list of dict compatible with the dataframe UI
 
119
  """
120
  results = defaultdict(dict)
121
  for eval_result in self.results:
@@ -123,106 +109,66 @@ class FullEvalResult:
123
  continue
124
  if eval_result.task != task:
125
  continue
126
- results[eval_result.eval_name]["eval_name"] = eval_result.eval_name
127
- results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL] = (
128
- make_clickable_model(self.retrieval_model, self.retrieval_model_link))
129
- results[eval_result.eval_name][COL_NAME_RERANKING_MODEL] = (
130
- make_clickable_model(self.reranking_model, self.reranking_model_link))
131
- results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
132
- results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
133
- results[eval_result.eval_name][COL_NAME_REVISION] = self.revision
134
- results[eval_result.eval_name][COL_NAME_TIMESTAMP] = self.timestamp
135
- results[eval_result.eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
136
-
137
- # print(f'result loaded: {eval_result.eval_name}')
 
 
138
  for result in eval_result.results:
139
  # add result for each domain, language, and dataset
140
  domain = result["domain"]
141
  lang = result["lang"]
142
  dataset = result["dataset"]
143
  value = result["value"] * 100
144
- if dataset == 'default':
145
  benchmark_name = f"{domain}_{lang}"
146
  else:
147
  benchmark_name = f"{domain}_{lang}_{dataset}"
148
- results[eval_result.eval_name][get_safe_name(benchmark_name)] = value
149
  return [v for v in results.values()]
150
 
151
 
152
- def get_raw_eval_results(results_path: str) -> List[FullEvalResult]:
153
- """
154
- Load the evaluation results from a json file
155
- """
156
- model_result_filepaths = []
157
- for root, dirs, files in os.walk(results_path):
158
- if len(files) == 0:
159
- continue
160
-
161
- # select the latest results
162
- for file in files:
163
- if not (file.startswith("results") and file.endswith(".json")):
164
- print(f'skip {file}')
165
- continue
166
- model_result_filepaths.append(os.path.join(root, file))
167
-
168
- eval_results = {}
169
- for model_result_filepath in model_result_filepaths:
170
- # create evaluation results
171
- try:
172
- eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
173
- except UnicodeDecodeError as e:
174
- print(f"loading file failed. {model_result_filepath}")
175
- continue
176
- print(f'file loaded: {model_result_filepath}')
177
- timestamp = eval_result.timestamp
178
- eval_results[timestamp] = eval_result
179
-
180
- results = []
181
- for k, v in eval_results.items():
182
- try:
183
- v.to_dict()
184
- results.append(v)
185
- except KeyError:
186
- print(f"loading failed: {k}")
187
- continue
188
- return results
189
-
190
-
191
- def get_leaderboard_df(raw_data: List[FullEvalResult], task: str, metric: str) -> pd.DataFrame:
192
- """
193
- Creates a dataframe from all the individual experiment results
194
- """
195
- cols = [COL_NAME_IS_ANONYMOUS, ]
196
- if task == "qa":
197
- cols += COLS_QA
198
- benchmark_cols = QA_BENCHMARK_COLS
199
- elif task == "long-doc":
200
- cols += COLS_LONG_DOC
201
- benchmark_cols = LONG_DOC_BENCHMARK_COLS
202
- else:
203
- raise NotImplemented
204
- all_data_json = []
205
- for v in raw_data:
206
- all_data_json += v.to_dict(task=task, metric=metric)
207
- df = pd.DataFrame.from_records(all_data_json)
208
- # print(f'dataframe created: {df.shape}')
209
-
210
- _benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
211
-
212
- # calculate the average score for selected benchmarks
213
- df[COL_NAME_AVG] = df[list(_benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
214
- df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
215
- df.reset_index(inplace=True, drop=True)
216
-
217
- _cols = frozenset(cols).intersection(frozenset(df.columns.to_list()))
218
- df = df[_cols].round(decimals=2)
219
-
220
- # filter out if any of the benchmarks have not been produced
221
- df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
222
-
223
- # shorten the revision
224
- df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
225
-
226
- # # replace "0" with "-" for average score
227
- # df[COL_NAME_AVG] = df[COL_NAME_AVG].replace(0, "-")
228
- return df
 
1
  import json
 
2
  from collections import defaultdict
3
  from dataclasses import dataclass
4
+ from enum import Enum
5
  from typing import List
6
 
7
  import pandas as pd
8
 
9
+ from src.columns import (
10
+ COL_NAME_IS_ANONYMOUS,
11
  COL_NAME_RERANKING_MODEL,
 
12
  COL_NAME_RERANKING_MODEL_LINK,
13
+ COL_NAME_RETRIEVAL_MODEL,
14
  COL_NAME_RETRIEVAL_MODEL_LINK,
15
  COL_NAME_REVISION,
16
  COL_NAME_TIMESTAMP,
 
 
 
 
 
 
 
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @dataclass
21
  class EvalResult:
 
23
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
24
  domains, languages, and datasets
25
  """
26
+
27
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
28
  retrieval_model: str
29
  reranking_model: str
 
40
  """
41
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
42
  """
43
+
44
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
45
  retrieval_model: str
46
  reranking_model: str
 
64
  result_list = []
65
  retrieval_model_link = ""
66
  reranking_model_link = ""
 
67
  for item in model_data:
68
  config = item.get("config", {})
69
  # eval results for different metrics
 
82
  metric=config["metric"],
83
  timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
84
  revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
85
+ is_anonymous=config.get("is_anonymous", False),
86
  )
87
  result_list.append(eval_result)
88
+ eval_result = result_list[0]
89
  return cls(
90
+ eval_name=f"{eval_result.retrieval_model}_{eval_result.reranking_model}",
91
+ retrieval_model=eval_result.retrieval_model,
92
+ reranking_model=eval_result.reranking_model,
93
  retrieval_model_link=retrieval_model_link,
94
  reranking_model_link=reranking_model_link,
95
  results=result_list,
96
+ timestamp=eval_result.timestamp,
97
+ revision=eval_result.revision,
98
+ is_anonymous=eval_result.is_anonymous,
99
  )
100
 
101
+ def to_dict(self, task="qa", metric="ndcg_at_3") -> List:
102
  """
103
+ Convert the results in all the EvalResults over different tasks and metrics.
104
+ The output is a list of dict compatible with the dataframe UI
105
  """
106
  results = defaultdict(dict)
107
  for eval_result in self.results:
 
109
  continue
110
  if eval_result.task != task:
111
  continue
112
+ eval_name = eval_result.eval_name
113
+ results[eval_name]["eval_name"] = eval_name
114
+ results[eval_name][COL_NAME_RETRIEVAL_MODEL] = make_clickable_model(
115
+ self.retrieval_model, self.retrieval_model_link
116
+ )
117
+ results[eval_name][COL_NAME_RERANKING_MODEL] = make_clickable_model(
118
+ self.reranking_model, self.reranking_model_link
119
+ )
120
+ results[eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
121
+ results[eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
122
+ results[eval_name][COL_NAME_REVISION] = self.revision
123
+ results[eval_name][COL_NAME_TIMESTAMP] = self.timestamp
124
+ results[eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
125
+
126
  for result in eval_result.results:
127
  # add result for each domain, language, and dataset
128
  domain = result["domain"]
129
  lang = result["lang"]
130
  dataset = result["dataset"]
131
  value = result["value"] * 100
132
+ if dataset == "default":
133
  benchmark_name = f"{domain}_{lang}"
134
  else:
135
  benchmark_name = f"{domain}_{lang}_{dataset}"
136
+ results[eval_name][get_safe_name(benchmark_name)] = value
137
  return [v for v in results.values()]
138
 
139
 
140
+ @dataclass
141
+ class LeaderboardDataStore:
142
+ version: str
143
+ slug: str
144
+ raw_data: list = None
145
+ qa_raw_df: pd.DataFrame = pd.DataFrame()
146
+ doc_raw_df: pd.DataFrame = pd.DataFrame()
147
+ qa_fmt_df: pd.DataFrame = pd.DataFrame()
148
+ doc_fmt_df: pd.DataFrame = pd.DataFrame()
149
+ reranking_models: list = None
150
+ qa_types: list = None
151
+ doc_types: list = None
152
+
153
+
154
+ # Define an enum class with the name `TaskType`. There are two types of tasks, `qa` and `long-doc`.
155
+ class TaskType(Enum):
156
+ qa = "qa"
157
+ long_doc = "long-doc"
158
+
159
+
160
+ def make_clickable_model(model_name: str, model_link: str):
161
+ # link = f"https://huggingface.co/{model_name}"
162
+ if not model_link or not model_link.startswith("https://"):
163
+ return model_name
164
+ return model_hyperlink(model_link, model_name)
165
+
166
+
167
+ def model_hyperlink(link, model_name):
168
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
169
+
170
+
171
+ def get_safe_name(name: str):
172
+ """Get RFC 1123 compatible safe name"""
173
+ name = name.replace("-", "_")
174
+ return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py CHANGED
@@ -1,24 +1,37 @@
1
- import json
2
  import hashlib
 
 
3
  from datetime import datetime, timezone
4
  from pathlib import Path
5
- from typing import List
6
 
7
  import pandas as pd
8
 
9
- from src.benchmarks import BENCHMARK_COLS_QA, BENCHMARK_COLS_LONG_DOC, BenchmarksQA, BenchmarksLongDoc
10
- from src.display.formatting import styled_message, styled_error
11
- from src.display.utils import COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, COL_NAME_RANK, COL_NAME_AVG, \
12
- COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_IS_ANONYMOUS, COL_NAME_TIMESTAMP, COL_NAME_REVISION, get_default_auto_eval_column_dict
13
- from src.envs import API, SEARCH_RESULTS_REPO, LATEST_BENCHMARK_VERSION
14
- from src.read_evals import FullEvalResult, get_leaderboard_df, calculate_mean
15
-
16
- import re
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def remove_html(input_str):
20
  # Regular expression for finding HTML tags
21
- clean = re.sub(r'<.*?>', '', input_str)
22
  return clean
23
 
24
 
@@ -55,160 +68,152 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
55
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
56
 
57
 
58
- def get_default_cols(task: str, columns: list=[], add_fix_cols: bool=True) -> list:
59
  cols = []
60
  types = []
61
- if task == "qa":
62
- cols_list = COLS_QA
63
- types_list = TYPES_QA
64
- benchmark_list = BENCHMARK_COLS_QA
65
- elif task == "long-doc":
66
- cols_list = COLS_LONG_DOC
67
- types_list = TYPES_LONG_DOC
68
- benchmark_list = BENCHMARK_COLS_LONG_DOC
69
  else:
70
- raise NotImplemented
 
 
71
  for col_name, col_type in zip(cols_list, types_list):
72
  if col_name not in benchmark_list:
73
  continue
74
- if len(columns) > 0 and col_name not in columns:
75
- continue
76
  cols.append(col_name)
77
  types.append(col_type)
78
-
79
  if add_fix_cols:
80
  _cols = []
81
  _types = []
 
82
  for col_name, col_type in zip(cols, types):
83
- if col_name in FIXED_COLS:
84
  continue
85
  _cols.append(col_name)
86
  _types.append(col_type)
87
- cols = FIXED_COLS + _cols
88
- types = FIXED_COLS_TYPES + _types
89
  return cols, types
90
 
91
 
92
- fixed_cols = get_default_auto_eval_column_dict()[:-3]
93
-
94
- FIXED_COLS = [c.name for _, _, c in fixed_cols]
95
- FIXED_COLS_TYPES = [c.type for _, _, c in fixed_cols]
96
-
97
-
98
- def select_columns(
99
- df: pd.DataFrame,
100
- domain_query: list,
101
- language_query: list,
102
- task: str = "qa",
103
- reset_ranking: bool = True
104
- ) -> pd.DataFrame:
105
- cols, _ = get_default_cols(task=task, columns=df.columns, add_fix_cols=False)
106
  selected_cols = []
107
  for c in cols:
108
- if task == "qa":
109
- eval_col = BenchmarksQA[c].value
110
- elif task == "long-doc":
111
- eval_col = BenchmarksLongDoc[c].value
112
- if eval_col.domain not in domain_query:
 
 
113
  continue
114
- if eval_col.lang not in language_query:
115
  continue
116
  selected_cols.append(c)
117
  # We use COLS to maintain sorting
118
- filtered_df = df[FIXED_COLS + selected_cols]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if reset_ranking:
120
  filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
121
  filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
122
  filtered_df.reset_index(inplace=True, drop=True)
123
  filtered_df = reset_rank(filtered_df)
124
-
125
  return filtered_df
126
 
127
 
128
- def _update_table(
129
- task: str,
130
- hidden_df: pd.DataFrame,
131
- domains: list,
132
- langs: list,
133
- reranking_query: list,
134
- query: str,
135
- show_anonymous: bool,
136
- reset_ranking: bool = True,
137
- show_revision_and_timestamp: bool = False
 
138
  ):
139
- filtered_df = hidden_df.copy()
140
  if not show_anonymous:
141
  filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
142
  filtered_df = filter_models(filtered_df, reranking_query)
143
  filtered_df = filter_queries(query, filtered_df)
144
- filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking)
145
  if not show_revision_and_timestamp:
146
  filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
147
  return filtered_df
148
 
149
 
150
- def update_table(
151
- hidden_df: pd.DataFrame,
152
- domains: list,
153
- langs: list,
154
- reranking_query: list,
155
- query: str,
156
- show_anonymous: bool,
157
- show_revision_and_timestamp: bool = False,
158
- reset_ranking: bool = True
159
- ):
160
- return _update_table(
161
- "qa", hidden_df, domains, langs, reranking_query, query, show_anonymous, reset_ranking, show_revision_and_timestamp)
162
-
163
-
164
- def update_table_long_doc(
165
- hidden_df: pd.DataFrame,
166
- domains: list,
167
- langs: list,
168
- reranking_query: list,
169
- query: str,
170
- show_anonymous: bool,
171
- show_revision_and_timestamp: bool = False,
172
- reset_ranking: bool = True
173
-
174
  ):
175
- return _update_table(
176
- "long-doc", hidden_df, domains, langs, reranking_query, query, show_anonymous, reset_ranking, show_revision_and_timestamp)
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  def update_metric(
180
- raw_data: List[FullEvalResult],
181
- task: str,
182
- metric: str,
183
- domains: list,
184
- langs: list,
185
- reranking_model: list,
186
- query: str,
187
- show_anonymous: bool = False,
188
- show_revision_and_timestamp: bool = False,
189
  ) -> pd.DataFrame:
190
- if task == 'qa':
191
- leaderboard_df = get_leaderboard_df(raw_data, task=task, metric=metric)
192
- return update_table(
193
- leaderboard_df,
194
- domains,
195
- langs,
196
- reranking_model,
197
- query,
198
- show_anonymous,
199
- show_revision_and_timestamp
200
- )
201
- elif task == "long-doc":
202
- leaderboard_df = get_leaderboard_df(raw_data, task=task, metric=metric)
203
- return update_table_long_doc(
204
- leaderboard_df,
205
- domains,
206
- langs,
207
- reranking_model,
208
- query,
209
- show_anonymous,
210
- show_revision_and_timestamp
211
- )
212
 
213
 
214
  def upload_file(filepath: str):
@@ -218,7 +223,6 @@ def upload_file(filepath: str):
218
  return filepath
219
 
220
 
221
-
222
  def get_iso_format_timestamp():
223
  # Get the current timestamp with UTC as the timezone
224
  current_timestamp = datetime.now(timezone.utc)
@@ -227,15 +231,15 @@ def get_iso_format_timestamp():
227
  current_timestamp = current_timestamp.replace(microsecond=0)
228
 
229
  # Convert to ISO 8601 format and replace the offset with 'Z'
230
- iso_format_timestamp = current_timestamp.isoformat().replace('+00:00', 'Z')
231
- filename_friendly_timestamp = current_timestamp.strftime('%Y%m%d%H%M%S')
232
  return iso_format_timestamp, filename_friendly_timestamp
233
 
234
 
235
  def calculate_file_md5(file_path):
236
  md5 = hashlib.md5()
237
 
238
- with open(file_path, 'rb') as f:
239
  while True:
240
  data = f.read(4096)
241
  if not data:
@@ -246,13 +250,14 @@ def calculate_file_md5(file_path):
246
 
247
 
248
  def submit_results(
249
- filepath: str,
250
- model: str,
251
- model_url: str,
252
- reranking_model: str="",
253
- reranking_model_url: str="",
254
- version: str=LATEST_BENCHMARK_VERSION,
255
- is_anonymous=False):
 
256
  if not filepath.endswith(".zip"):
257
  return styled_error(f"file uploading aborted. wrong file type: {filepath}")
258
 
@@ -265,11 +270,13 @@ def submit_results(
265
  if not model_url.startswith("https://") and not model_url.startswith("http://"):
266
  # TODO: retrieve the model page and find the model name on the page
267
  return styled_error(
268
- f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}")
 
269
  if reranking_model != "NoReranker":
270
  if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
271
  return styled_error(
272
- f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}")
 
273
 
274
  # rename the uploaded file
275
  input_fp = Path(filepath)
@@ -279,14 +286,15 @@ def submit_results(
279
  input_folder_path = input_fp.parent
280
 
281
  if not reranking_model:
282
- reranking_model = 'NoReranker'
283
-
284
  API.upload_file(
285
  path_or_fileobj=filepath,
286
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
287
  repo_id=SEARCH_RESULTS_REPO,
288
  repo_type="dataset",
289
- commit_message=f"feat: submit {model} to evaluate")
 
290
 
291
  output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
292
  output_config = {
@@ -297,7 +305,7 @@ def submit_results(
297
  "version": f"{version}",
298
  "is_anonymous": is_anonymous,
299
  "revision": f"{revision}",
300
- "timestamp": f"{timestamp_config}"
301
  }
302
  with open(input_folder_path / output_config_fn, "w") as f:
303
  json.dump(output_config, f, indent=4, ensure_ascii=False)
@@ -306,7 +314,8 @@ def submit_results(
306
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
307
  repo_id=SEARCH_RESULTS_REPO,
308
  repo_type="dataset",
309
- commit_message=f"feat: submit {model} + {reranking_model} config")
 
310
  return styled_message(
311
  f"Thanks for submission!\n"
312
  f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
@@ -316,3 +325,125 @@ def submit_results(
316
  def reset_rank(df):
317
  df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
318
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import hashlib
2
+ import json
3
+ import re
4
  from datetime import datetime, timezone
5
  from pathlib import Path
 
6
 
7
  import pandas as pd
8
 
9
+ from src.benchmarks import LongDocBenchmarks, QABenchmarks
10
+ from src.columns import (
11
+ COL_NAME_AVG,
12
+ COL_NAME_IS_ANONYMOUS,
13
+ COL_NAME_RANK,
14
+ COL_NAME_RERANKING_MODEL,
15
+ COL_NAME_RETRIEVAL_MODEL,
16
+ COL_NAME_REVISION,
17
+ COL_NAME_TIMESTAMP,
18
+ get_default_col_names_and_types,
19
+ get_fixed_col_names_and_types,
20
+ )
21
+ from src.envs import API, LATEST_BENCHMARK_VERSION, SEARCH_RESULTS_REPO
22
+ from src.models import TaskType, get_safe_name
23
+
24
+
25
+ def calculate_mean(row):
26
+ if pd.isna(row).any():
27
+ return -1
28
+ else:
29
+ return row.mean()
30
 
31
 
32
  def remove_html(input_str):
33
  # Regular expression for finding HTML tags
34
+ clean = re.sub(r"<.*?>", "", input_str)
35
  return clean
36
 
37
 
 
68
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
69
 
70
 
71
+ def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) -> tuple:
72
  cols = []
73
  types = []
74
+ if task == TaskType.qa:
75
+ benchmarks = QABenchmarks[version_slug]
76
+ elif task == TaskType.long_doc:
77
+ benchmarks = LongDocBenchmarks[version_slug]
 
 
 
 
78
  else:
79
+ raise NotImplementedError
80
+ cols_list, types_list = get_default_col_names_and_types(benchmarks)
81
+ benchmark_list = [c.value.col_name for c in list(benchmarks.value)]
82
  for col_name, col_type in zip(cols_list, types_list):
83
  if col_name not in benchmark_list:
84
  continue
 
 
85
  cols.append(col_name)
86
  types.append(col_type)
 
87
  if add_fix_cols:
88
  _cols = []
89
  _types = []
90
+ fixed_cols, fixed_cols_types = get_fixed_col_names_and_types()
91
  for col_name, col_type in zip(cols, types):
92
+ if col_name in fixed_cols:
93
  continue
94
  _cols.append(col_name)
95
  _types.append(col_type)
96
+ cols = fixed_cols + _cols
97
+ types = fixed_cols_types + _types
98
  return cols, types
99
 
100
 
101
+ def get_selected_cols(task, version_slug, domains, languages):
102
+ cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
 
 
 
 
 
 
 
 
 
 
 
 
103
  selected_cols = []
104
  for c in cols:
105
+ if task == TaskType.qa:
106
+ eval_col = QABenchmarks[version_slug].value[c].value
107
+ elif task == TaskType.long_doc:
108
+ eval_col = LongDocBenchmarks[version_slug].value[c].value
109
+ else:
110
+ raise NotImplementedError
111
+ if eval_col.domain not in domains:
112
  continue
113
+ if eval_col.lang not in languages:
114
  continue
115
  selected_cols.append(c)
116
  # We use COLS to maintain sorting
117
+ return selected_cols
118
+
119
+
120
+ def select_columns(
121
+ df: pd.DataFrame,
122
+ domains: list,
123
+ languages: list,
124
+ task: TaskType = TaskType.qa,
125
+ reset_ranking: bool = True,
126
+ version_slug: str = None,
127
+ ) -> pd.DataFrame:
128
+ selected_cols = get_selected_cols(task, version_slug, domains, languages)
129
+ fixed_cols, _ = get_fixed_col_names_and_types()
130
+ filtered_df = df[fixed_cols + selected_cols]
131
+ filtered_df.replace({"": pd.NA}, inplace=True)
132
  if reset_ranking:
133
  filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
134
  filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
135
  filtered_df.reset_index(inplace=True, drop=True)
136
  filtered_df = reset_rank(filtered_df)
 
137
  return filtered_df
138
 
139
 
140
+ def _update_df_elem(
141
+ task: TaskType,
142
+ version: str,
143
+ source_df: pd.DataFrame,
144
+ domains: list,
145
+ langs: list,
146
+ reranking_query: list,
147
+ query: str,
148
+ show_anonymous: bool,
149
+ reset_ranking: bool = True,
150
+ show_revision_and_timestamp: bool = False,
151
  ):
152
+ filtered_df = source_df.copy()
153
  if not show_anonymous:
154
  filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
155
  filtered_df = filter_models(filtered_df, reranking_query)
156
  filtered_df = filter_queries(query, filtered_df)
157
+ filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking, get_safe_name(version))
158
  if not show_revision_and_timestamp:
159
  filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
160
  return filtered_df
161
 
162
 
163
+ def update_doc_df_elem(
164
+ version: str,
165
+ hidden_df: pd.DataFrame,
166
+ domains: list,
167
+ langs: list,
168
+ reranking_query: list,
169
+ query: str,
170
+ show_anonymous: bool,
171
+ show_revision_and_timestamp: bool = False,
172
+ reset_ranking: bool = True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  ):
174
+ return _update_df_elem(
175
+ TaskType.long_doc,
176
+ version,
177
+ hidden_df,
178
+ domains,
179
+ langs,
180
+ reranking_query,
181
+ query,
182
+ show_anonymous,
183
+ reset_ranking,
184
+ show_revision_and_timestamp,
185
+ )
186
 
187
 
188
  def update_metric(
189
+ datastore,
190
+ task: TaskType,
191
+ metric: str,
192
+ domains: list,
193
+ langs: list,
194
+ reranking_model: list,
195
+ query: str,
196
+ show_anonymous: bool = False,
197
+ show_revision_and_timestamp: bool = False,
198
  ) -> pd.DataFrame:
199
+ if task == TaskType.qa:
200
+ update_func = update_qa_df_elem
201
+ elif task == TaskType.long_doc:
202
+ update_func = update_doc_df_elem
203
+ else:
204
+ raise NotImplementedError
205
+ df_elem = get_leaderboard_df(datastore, task=task, metric=metric)
206
+ version = datastore.version
207
+ return update_func(
208
+ version,
209
+ df_elem,
210
+ domains,
211
+ langs,
212
+ reranking_model,
213
+ query,
214
+ show_anonymous,
215
+ show_revision_and_timestamp,
216
+ )
 
 
 
 
217
 
218
 
219
  def upload_file(filepath: str):
 
223
  return filepath
224
 
225
 
 
226
  def get_iso_format_timestamp():
227
  # Get the current timestamp with UTC as the timezone
228
  current_timestamp = datetime.now(timezone.utc)
 
231
  current_timestamp = current_timestamp.replace(microsecond=0)
232
 
233
  # Convert to ISO 8601 format and replace the offset with 'Z'
234
+ iso_format_timestamp = current_timestamp.isoformat().replace("+00:00", "Z")
235
+ filename_friendly_timestamp = current_timestamp.strftime("%Y%m%d%H%M%S")
236
  return iso_format_timestamp, filename_friendly_timestamp
237
 
238
 
239
  def calculate_file_md5(file_path):
240
  md5 = hashlib.md5()
241
 
242
+ with open(file_path, "rb") as f:
243
  while True:
244
  data = f.read(4096)
245
  if not data:
 
250
 
251
 
252
  def submit_results(
253
+ filepath: str,
254
+ model: str,
255
+ model_url: str,
256
+ reranking_model: str = "",
257
+ reranking_model_url: str = "",
258
+ version: str = LATEST_BENCHMARK_VERSION,
259
+ is_anonymous=False,
260
+ ):
261
  if not filepath.endswith(".zip"):
262
  return styled_error(f"file uploading aborted. wrong file type: {filepath}")
263
 
 
270
  if not model_url.startswith("https://") and not model_url.startswith("http://"):
271
  # TODO: retrieve the model page and find the model name on the page
272
  return styled_error(
273
+ f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
274
+ )
275
  if reranking_model != "NoReranker":
276
  if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
277
  return styled_error(
278
+ f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
279
+ )
280
 
281
  # rename the uploaded file
282
  input_fp = Path(filepath)
 
286
  input_folder_path = input_fp.parent
287
 
288
  if not reranking_model:
289
+ reranking_model = "NoReranker"
290
+
291
  API.upload_file(
292
  path_or_fileobj=filepath,
293
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
294
  repo_id=SEARCH_RESULTS_REPO,
295
  repo_type="dataset",
296
+ commit_message=f"feat: submit {model} to evaluate",
297
+ )
298
 
299
  output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
300
  output_config = {
 
305
  "version": f"{version}",
306
  "is_anonymous": is_anonymous,
307
  "revision": f"{revision}",
308
+ "timestamp": f"{timestamp_config}",
309
  }
310
  with open(input_folder_path / output_config_fn, "w") as f:
311
  json.dump(output_config, f, indent=4, ensure_ascii=False)
 
314
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
315
  repo_id=SEARCH_RESULTS_REPO,
316
  repo_type="dataset",
317
+ commit_message=f"feat: submit {model} + {reranking_model} config",
318
+ )
319
  return styled_message(
320
  f"Thanks for submission!\n"
321
  f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
 
325
  def reset_rank(df):
326
  df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
327
  return df
328
+
329
+
330
+ def get_leaderboard_df(datastore, task: TaskType, metric: str) -> pd.DataFrame:
331
+ """
332
+ Creates a dataframe from all the individual experiment results
333
+ """
334
+ # load the selected metrics into a DataFrame from the raw json
335
+ all_data_json = []
336
+ for v in datastore.raw_data:
337
+ all_data_json += v.to_dict(task=task.value, metric=metric)
338
+ df = pd.DataFrame.from_records(all_data_json)
339
+
340
+ # calculate the average scores for selected task
341
+ if task == TaskType.qa:
342
+ benchmarks = QABenchmarks[datastore.slug]
343
+ elif task == TaskType.long_doc:
344
+ benchmarks = LongDocBenchmarks[datastore.slug]
345
+ else:
346
+ raise NotImplementedError
347
+ valid_cols = frozenset(df.columns.to_list())
348
+ benchmark_cols = []
349
+ for t in list(benchmarks.value):
350
+ if t.value.col_name not in valid_cols:
351
+ continue
352
+ benchmark_cols.append(t.value.col_name)
353
+
354
+ # filter out the columns that are not in the data
355
+ df[COL_NAME_AVG] = df[list(benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
356
+ df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
357
+ df.reset_index(inplace=True, drop=True)
358
+
359
+ # filter out columns that are not in the data
360
+ display_cols = [COL_NAME_IS_ANONYMOUS, COL_NAME_AVG]
361
+ default_cols, _ = get_default_col_names_and_types(benchmarks)
362
+ for col in default_cols:
363
+ if col in valid_cols:
364
+ display_cols.append(col)
365
+ df = df[display_cols].round(decimals=2)
366
+
367
+ # rank the scores
368
+ df = reset_rank(df)
369
+
370
+ # shorten the revision
371
+ df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
372
+
373
+ return df
374
+
375
+
376
+ def set_listeners(
377
+ task: TaskType,
378
+ target_df,
379
+ source_df,
380
+ search_bar,
381
+ version,
382
+ selected_domains,
383
+ selected_langs,
384
+ selected_rerankings,
385
+ show_anonymous,
386
+ show_revision_and_timestamp,
387
+ ):
388
+ if task == TaskType.qa:
389
+ update_table_func = update_qa_df_elem
390
+ elif task == TaskType.long_doc:
391
+ update_table_func = update_doc_df_elem
392
+ else:
393
+ raise NotImplementedError
394
+ selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
395
+ search_bar_args = [
396
+ source_df,
397
+ version,
398
+ ] + selector_list
399
+ selector_args = (
400
+ [version, source_df]
401
+ + selector_list
402
+ + [
403
+ show_revision_and_timestamp,
404
+ ]
405
+ )
406
+ # Set search_bar listener
407
+ search_bar.submit(update_table_func, search_bar_args, target_df)
408
+
409
+ # Set column-wise listener
410
+ for selector in selector_list:
411
+ selector.change(
412
+ update_table_func,
413
+ selector_args,
414
+ target_df,
415
+ queue=True,
416
+ )
417
+
418
+
419
+ def update_qa_df_elem(
420
+ version: str,
421
+ hidden_df: pd.DataFrame,
422
+ domains: list,
423
+ langs: list,
424
+ reranking_query: list,
425
+ query: str,
426
+ show_anonymous: bool,
427
+ show_revision_and_timestamp: bool = False,
428
+ reset_ranking: bool = True,
429
+ ):
430
+ return _update_df_elem(
431
+ TaskType.qa,
432
+ version,
433
+ hidden_df,
434
+ domains,
435
+ langs,
436
+ reranking_query,
437
+ query,
438
+ show_anonymous,
439
+ reset_ranking,
440
+ show_revision_and_timestamp,
441
+ )
442
+
443
+
444
+ def styled_error(error):
445
+ return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
446
+
447
+
448
+ def styled_message(message):
449
+ return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
tests/src/display/test_utils.py DELETED
@@ -1,23 +0,0 @@
1
- import pytest
2
- from src.display.utils import fields, AutoEvalColumnQA, COLS_QA, COLS_LONG_DOC, COLS_LITE, TYPES_QA, TYPES_LONG_DOC, QA_BENCHMARK_COLS, LONG_DOC_BENCHMARK_COLS, get_default_auto_eval_column_dict
3
-
4
-
5
- def test_fields():
6
- for c in fields(AutoEvalColumnQA):
7
- print(c)
8
-
9
-
10
- def test_macro_variables():
11
- print(f'COLS_QA: {COLS_QA}')
12
- print(f'COLS_LONG_DOC: {COLS_LONG_DOC}')
13
- print(f'COLS_LITE: {COLS_LITE}')
14
- print(f'TYPES_QA: {TYPES_QA}')
15
- print(f'TYPES_LONG_DOC: {TYPES_LONG_DOC}')
16
- print(f'QA_BENCHMARK_COLS: {QA_BENCHMARK_COLS}')
17
- print(f'LONG_DOC_BENCHMARK_COLS: {LONG_DOC_BENCHMARK_COLS}')
18
-
19
-
20
- def test_get_default_auto_eval_column_dict():
21
- auto_eval_column_dict_list = get_default_auto_eval_column_dict()
22
- assert len(auto_eval_column_dict_list) == 9
23
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/test_benchmarks.py CHANGED
@@ -1,9 +1,33 @@
1
- from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
2
 
 
 
3
 
4
- def test_qabenchmarks():
5
- print(list(BenchmarksQA))
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
- def test_longdocbenchmarks():
9
- print(list(BenchmarksLongDoc))
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
 
3
+ from src.benchmarks import LongDocBenchmarks, QABenchmarks
4
+ from src.envs import BENCHMARK_VERSION_LIST
5
 
6
+ # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
7
+ # 24.05
8
+ # | Task | dev | test |
9
+ # | ---- | --- | ---- |
10
+ # | Long-Doc | 4 | 11 |
11
+ # | QA | 54 | 53 |
12
+ #
13
+ # 24.04
14
+ # | Task | test |
15
+ # | ---- | ---- |
16
+ # | Long-Doc | 15 |
17
+ # | QA | 13 |
18
 
19
 
20
+ @pytest.mark.parametrize("num_datasets_dict", [{"air_bench_2404": 13, "air_bench_2405": 53}])
21
+ def test_qa_benchmarks(num_datasets_dict):
22
+ assert len(QABenchmarks) == len(BENCHMARK_VERSION_LIST)
23
+ for benchmark_list in list(QABenchmarks):
24
+ version_slug = benchmark_list.name
25
+ assert num_datasets_dict[version_slug] == len(benchmark_list.value)
26
+
27
+
28
+ @pytest.mark.parametrize("num_datasets_dict", [{"air_bench_2404": 15, "air_bench_2405": 11}])
29
+ def test_doc_benchmarks(num_datasets_dict):
30
+ assert len(LongDocBenchmarks) == len(BENCHMARK_VERSION_LIST)
31
+ for benchmark_list in list(LongDocBenchmarks):
32
+ version_slug = benchmark_list.name
33
+ assert num_datasets_dict[version_slug] == len(benchmark_list.value)
tests/src/test_columns.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from src.benchmarks import LongDocBenchmarks, QABenchmarks
4
+ from src.columns import (
5
+ COL_NAME_AVG,
6
+ COL_NAME_RANK,
7
+ COL_NAME_RERANKING_MODEL,
8
+ COL_NAME_RETRIEVAL_MODEL,
9
+ COL_NAME_REVISION,
10
+ COL_NAME_TIMESTAMP,
11
+ get_default_auto_eval_column_dict,
12
+ get_default_col_names_and_types,
13
+ get_fixed_col_names_and_types,
14
+ make_autoevalcolumn,
15
+ )
16
+
17
+ # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
18
+ # 24.05
19
+ # | Task | dev | test |
20
+ # | ---- | --- | ---- |
21
+ # | Long-Doc | 4 | 11 |
22
+ # | QA | 54 | 53 |
23
+ #
24
+ # 24.04
25
+ # | Task | test |
26
+ # | ---- | ---- |
27
+ # | Long-Doc | 15 |
28
+ # | QA | 13 |
29
+
30
+
31
+ @pytest.fixture()
32
+ def expected_col_names():
33
+ return [
34
+ "rank",
35
+ "retrieval_model",
36
+ "reranking_model",
37
+ "revision",
38
+ "timestamp",
39
+ "average",
40
+ "retrieval_model_link",
41
+ "reranking_model_link",
42
+ "is_anonymous",
43
+ ]
44
+
45
+
46
+ @pytest.fixture()
47
+ def expected_hidden_col_names():
48
+ return [
49
+ "retrieval_model_link",
50
+ "reranking_model_link",
51
+ "is_anonymous",
52
+ ]
53
+
54
+
55
+ def test_get_default_auto_eval_column_dict(expected_col_names, expected_hidden_col_names):
56
+ col_list = get_default_auto_eval_column_dict()
57
+ assert len(col_list) == 9
58
+ hidden_cols = []
59
+ for col_tuple, expected_col in zip(col_list, expected_col_names):
60
+ col, _, col_content = col_tuple
61
+ assert col == expected_col
62
+ if col_content.hidden:
63
+ hidden_cols.append(col)
64
+ assert hidden_cols == expected_hidden_col_names
65
+
66
+
67
+ def test_get_fixed_col_names_and_types():
68
+ col_names, col_types = get_fixed_col_names_and_types()
69
+ assert len(col_names) == 6
70
+ assert len(col_types) == 6
71
+ expected_col_and_type = [
72
+ (COL_NAME_RANK, "number"),
73
+ (COL_NAME_RETRIEVAL_MODEL, "markdown"),
74
+ (COL_NAME_RERANKING_MODEL, "markdown"),
75
+ (COL_NAME_REVISION, "markdown"),
76
+ (COL_NAME_TIMESTAMP, "date"),
77
+ (COL_NAME_AVG, "number"),
78
+ ]
79
+ for col_name, col_type, (c_name, c_type) in zip(col_names, col_types, expected_col_and_type):
80
+ assert col_name == c_name
81
+ assert col_type == c_type
82
+
83
+
84
+ @pytest.mark.parametrize(
85
+ "benchmarks, expected_benchmark_len",
86
+ [
87
+ (QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
88
+ (LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
89
+ ],
90
+ )
91
+ def test_make_autoevalcolumn(benchmarks, expected_benchmark_len, expected_col_names):
92
+ expected_default_attrs = frozenset(expected_col_names)
93
+ for benchmark in benchmarks:
94
+ TestEvalColumn = make_autoevalcolumn("TestEvalColumn", benchmark)
95
+ attrs = []
96
+ for k, v in TestEvalColumn.__dict__.items():
97
+ if not k.startswith("__"):
98
+ attrs.append(k)
99
+ attrs = frozenset(attrs)
100
+ assert expected_default_attrs.issubset(attrs)
101
+ benchmark_attrs = attrs.difference(expected_default_attrs)
102
+ assert len(benchmark_attrs) == expected_benchmark_len[benchmark.name]
103
+
104
+
105
+ @pytest.mark.parametrize(
106
+ "benchmarks, expected_benchmark_len",
107
+ [
108
+ (QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
109
+ (LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
110
+ ],
111
+ )
112
+ def test_get_default_col_names_and_types(
113
+ benchmarks, expected_benchmark_len, expected_col_names, expected_hidden_col_names
114
+ ):
115
+ default_col_len = len(expected_col_names)
116
+ hidden_col_len = len(expected_hidden_col_names)
117
+ for benchmark in benchmarks:
118
+ col_names, col_types = get_default_col_names_and_types(benchmark)
119
+ assert len(col_names) == expected_benchmark_len[benchmark.name] + default_col_len - hidden_col_len
tests/src/test_envs.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from air_benchmark.tasks import BenchmarkTable
2
+
3
+ from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA, METRIC_LIST
4
+
5
+
6
+ def test_benchmark_version_list():
7
+ leaderboard_versions = frozenset(BENCHMARK_VERSION_LIST)
8
+ available_versions = frozenset([k for k in BenchmarkTable.keys()])
9
+ assert leaderboard_versions.issubset(available_versions)
10
+
11
+
12
+ def test_default_metrics():
13
+ assert DEFAULT_METRIC_QA in METRIC_LIST
14
+ assert DEFAULT_METRIC_LONG_DOC in METRIC_LIST
tests/src/test_loaders.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+ import pytest
5
+
6
+ from src.loaders import load_eval_results, load_leaderboard_datastore, load_raw_eval_results
7
+
8
+ cur_fp = Path(__file__)
9
+
10
+
11
+ @pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
12
+ def test_load_raw_eval_results(version):
13
+ raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
14
+ assert len(raw_data) == 1
15
+ full_eval_result = raw_data[0]
16
+ expected_attr = [
17
+ "eval_name",
18
+ "retrieval_model",
19
+ "reranking_model",
20
+ "retrieval_model_link",
21
+ "reranking_model_link",
22
+ "results",
23
+ "timestamp",
24
+ "revision",
25
+ "is_anonymous",
26
+ ]
27
+ result_attr = [k for k in full_eval_result.__dict__.keys() if k[:2] != "__" and k[-2:] != "__"]
28
+ assert sorted(expected_attr) == sorted(result_attr)
29
+
30
+
31
+ @pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
32
+ def test_load_leaderboard_datastore(version):
33
+ file_path = cur_fp.parents[1] / f"toydata/eval_results/{version}"
34
+ datastore = load_leaderboard_datastore(file_path, version)
35
+ for k, v in datastore.__dict__.items():
36
+ if k[:2] != "__" and k[-2:] != "__":
37
+ if isinstance(v, list):
38
+ assert v
39
+ elif isinstance(v, pd.DataFrame):
40
+ assert not v.empty
41
+
42
+
43
+ def test_load_eval_results():
44
+ file_path = cur_fp.parents[1] / "toydata/eval_results/"
45
+ datastore_dict = load_eval_results(file_path)
46
+ assert len(datastore_dict) == 2
tests/src/test_models.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ from src.models import EvalResult, FullEvalResult
6
+
7
+ cur_fp = Path(__file__)
8
+
9
+
10
+ # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
11
+ # 24.05
12
+ # | Task | dev | test |
13
+ # | ---- | --- | ---- |
14
+ # | Long-Doc | 4 | 11 |
15
+ # | QA | 54 | 53 |
16
+ #
17
+ # 24.04
18
+ # | Task | test |
19
+ # | ---- | ---- |
20
+ # | Long-Doc | 15 |
21
+ # | QA | 13 |
22
+ NUM_QA_BENCHMARKS_24_05 = 53
23
+ NUM_DOC_BENCHMARKS_24_05 = 11
24
+ NUM_QA_BENCHMARKS_24_04 = 13
25
+ NUM_DOC_BENCHMARKS_24_04 = 15
26
+
27
+
28
+ def test_eval_result():
29
+ EvalResult(
30
+ eval_name="eval_name",
31
+ retrieval_model="bge-m3",
32
+ reranking_model="NoReranking",
33
+ results=[{"domain": "law", "lang": "en", "dataset": "lex_files_500K-600K", "value": 0.45723}],
34
+ task="qa",
35
+ metric="ndcg_at_3",
36
+ timestamp="2024-05-14T03:09:08Z",
37
+ revision="1e243f14bd295ccdea7a118fe847399d",
38
+ is_anonymous=True,
39
+ )
40
+
41
+
42
+ @pytest.mark.parametrize(
43
+ "file_path",
44
+ [
45
+ "AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
46
+ "AIR-Bench_24.05/bge-m3/NoReranker/results.json",
47
+ ],
48
+ )
49
+ def test_full_eval_result_init_from_json_file(file_path):
50
+ json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
51
+ full_eval_result = FullEvalResult.init_from_json_file(json_fp)
52
+ assert json_fp.parents[0].stem == full_eval_result.reranking_model
53
+ assert json_fp.parents[1].stem == full_eval_result.retrieval_model
54
+ assert len(full_eval_result.results) == 70
55
+
56
+
57
+ @pytest.mark.parametrize(
58
+ "file_path, task, expected_num_results",
59
+ [
60
+ ("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "qa", NUM_QA_BENCHMARKS_24_04),
61
+ (
62
+ "AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
63
+ "long-doc",
64
+ NUM_DOC_BENCHMARKS_24_04,
65
+ ),
66
+ ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "qa", NUM_QA_BENCHMARKS_24_05),
67
+ ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_05),
68
+ ],
69
+ )
70
+ def test_full_eval_result_to_dict(file_path, task, expected_num_results):
71
+ json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
72
+ full_eval_result = FullEvalResult.init_from_json_file(json_fp)
73
+ result_dict_list = full_eval_result.to_dict(task)
74
+ assert len(result_dict_list) == 1
75
+ result = result_dict_list[0]
76
+ attr_list = frozenset(
77
+ [
78
+ "eval_name",
79
+ "Retrieval Method",
80
+ "Reranking Model",
81
+ "Retrieval Model LINK",
82
+ "Reranking Model LINK",
83
+ "Revision",
84
+ "Submission Date",
85
+ "Anonymous Submission",
86
+ ]
87
+ )
88
+ result_cols = list(result.keys())
89
+ assert len(result_cols) == (expected_num_results + len(attr_list))
tests/src/test_read_evals.py DELETED
@@ -1,68 +0,0 @@
1
- from pathlib import Path
2
-
3
- from src.read_evals import FullEvalResult, get_raw_eval_results, get_leaderboard_df
4
-
5
- cur_fp = Path(__file__)
6
-
7
-
8
- def test_init_from_json_file():
9
- json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
10
- full_eval_result = FullEvalResult.init_from_json_file(json_fp)
11
- num_different_task_domain_lang_metric_dataset_combination = 6
12
- assert len(full_eval_result.results) == \
13
- num_different_task_domain_lang_metric_dataset_combination
14
- assert full_eval_result.retrieval_model == "bge-m3"
15
- assert full_eval_result.reranking_model == "bge-reranker-v2-m3"
16
-
17
-
18
- def test_to_dict():
19
- json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
20
- full_eval_result = FullEvalResult.init_from_json_file(json_fp)
21
- result_list = full_eval_result.to_dict(task='qa', metric='ndcg_at_1')
22
- assert len(result_list) == 1
23
- result_dict = result_list[0]
24
- assert result_dict["Retrieval Model"] == "bge-m3"
25
- assert result_dict["Reranking Model"] == "bge-reranker-v2-m3"
26
- assert result_dict["wiki_en"] is not None
27
- assert result_dict["wiki_zh"] is not None
28
-
29
-
30
- def test_get_raw_eval_results():
31
- results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
32
- results = get_raw_eval_results(results_path)
33
- # only load the latest results
34
- assert len(results) == 4
35
- assert results[0].eval_name == "bge-base-en-v1.5_NoReranker"
36
- assert len(results[0].results) == 70
37
- assert results[0].eval_name == "bge-base-en-v1.5_bge-reranker-v2-m3"
38
- assert len(results[1].results) == 70
39
-
40
-
41
- def test_get_leaderboard_df():
42
- results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
43
- raw_data = get_raw_eval_results(results_path)
44
- df = get_leaderboard_df(raw_data, 'qa', 'ndcg_at_10')
45
- assert df.shape[0] == 4
46
- # the results contain only one embedding model
47
- # for i in range(4):
48
- # assert df["Retrieval Model"][i] == "bge-m3"
49
- # # the results contain only two reranking model
50
- # assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
51
- # assert df["Reranking Model"][1] == "NoReranker"
52
- # assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
53
- # assert not df[['Average ⬆️', 'wiki_en', 'wiki_zh', ]].isnull().values.any()
54
-
55
-
56
- def test_get_leaderboard_df_long_doc():
57
- results_path = cur_fp.parents[2] / "toydata" / "test_results"
58
- raw_data = get_raw_eval_results(results_path)
59
- df = get_leaderboard_df(raw_data, 'long-doc', 'ndcg_at_1')
60
- assert df.shape[0] == 2
61
- # the results contain only one embedding model
62
- for i in range(2):
63
- assert df["Retrieval Model"][i] == "bge-m3"
64
- # the results contains only two reranking model
65
- assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
66
- assert df["Reranking Model"][1] == "NoReranker"
67
- assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
68
- assert not df[['Average ⬆️', 'law_en_lex_files_500k_600k', ]].isnull().values.any()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/test_utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+ import pytest
5
+
6
+ from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
7
+ from src.models import TaskType, model_hyperlink
8
+ from src.utils import (
9
+ _update_df_elem,
10
+ calculate_mean,
11
+ filter_models,
12
+ filter_queries,
13
+ get_default_cols,
14
+ get_leaderboard_df,
15
+ get_selected_cols,
16
+ remove_html,
17
+ select_columns,
18
+ )
19
+
20
+ cur_fp = Path(__file__)
21
+
22
+ NUM_QA_BENCHMARKS_24_05 = 53
23
+ NUM_DOC_BENCHMARKS_24_05 = 11
24
+ NUM_QA_BENCHMARKS_24_04 = 13
25
+ NUM_DOC_BENCHMARKS_24_04 = 15
26
+
27
+
28
+ @pytest.fixture
29
+ def toy_df():
30
+ return pd.DataFrame(
31
+ {
32
+ "Retrieval Method": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
33
+ "Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
34
+ "Rank 🏆": [1, 2, 3, 4],
35
+ "Revision": ["123", "234", "345", "456"],
36
+ "Submission Date": ["", "", "", ""],
37
+ "Average ⬆️": [0.6, 0.4, 0.3, 0.2],
38
+ "wiki_en": [0.8, 0.7, 0.2, 0.1],
39
+ "wiki_zh": [0.4, 0.1, 0.4, 0.3],
40
+ "news_en": [0.8, 0.7, 0.2, 0.1],
41
+ "news_zh": [0.4, 0.1, 0.2, 0.3],
42
+ "Anonymous Submission": [False, False, False, True],
43
+ }
44
+ )
45
+
46
+
47
+ def test_remove_html():
48
+ model_name = "jina-embeddings-v3"
49
+ html_str = model_hyperlink("https://jina.ai", model_name)
50
+ output_str = remove_html(html_str)
51
+ assert output_str == model_name
52
+
53
+
54
+ def test_calculate_mean():
55
+ valid_row = [1, 3]
56
+ invalid_row = [2, pd.NA]
57
+ df = pd.DataFrame([valid_row, invalid_row], columns=["a", "b"])
58
+ result = list(df.apply(calculate_mean, axis=1))
59
+ assert result[0] == sum(valid_row) / 2
60
+ assert result[1] == -1
61
+
62
+
63
+ @pytest.mark.parametrize(
64
+ "models, expected",
65
+ [
66
+ (["model1", "model3"], 2),
67
+ (["model1", "model_missing"], 1),
68
+ (["model1", "model2", "model3"], 3),
69
+ (
70
+ [
71
+ "model1",
72
+ ],
73
+ 1,
74
+ ),
75
+ ([], 3),
76
+ ],
77
+ )
78
+ def test_filter_models(models, expected):
79
+ df = pd.DataFrame(
80
+ {
81
+ COL_NAME_RERANKING_MODEL: [
82
+ "model1",
83
+ "model2",
84
+ "model3",
85
+ ],
86
+ "col2": [1, 2, 3],
87
+ }
88
+ )
89
+ output_df = filter_models(df, models)
90
+ assert len(output_df) == expected
91
+
92
+
93
+ @pytest.mark.parametrize(
94
+ "query, expected",
95
+ [
96
+ ("model1;model3", 2),
97
+ ("model1;model4", 1),
98
+ ("model1;model2;model3", 3),
99
+ ("model1", 1),
100
+ ("", 3),
101
+ ],
102
+ )
103
+ def test_filter_queries(query, expected):
104
+ df = pd.DataFrame(
105
+ {
106
+ COL_NAME_RETRIEVAL_MODEL: [
107
+ "model1",
108
+ "model2",
109
+ "model3",
110
+ ],
111
+ COL_NAME_RERANKING_MODEL: [
112
+ "model4",
113
+ "model5",
114
+ "model6",
115
+ ],
116
+ }
117
+ )
118
+ output_df = filter_queries(query, df)
119
+ assert len(output_df) == expected
120
+
121
+
122
+ @pytest.mark.parametrize(
123
+ "task_type, slug, add_fix_cols, expected",
124
+ [
125
+ (TaskType.qa, "air_bench_2404", True, NUM_QA_BENCHMARKS_24_04),
126
+ (TaskType.long_doc, "air_bench_2404", True, NUM_DOC_BENCHMARKS_24_04),
127
+ (TaskType.qa, "air_bench_2405", False, NUM_QA_BENCHMARKS_24_05),
128
+ (TaskType.long_doc, "air_bench_2405", False, NUM_DOC_BENCHMARKS_24_05),
129
+ ],
130
+ )
131
+ def test_get_default_cols(task_type, slug, add_fix_cols, expected):
132
+ attr_cols = ["Rank 🏆", "Retrieval Method", "Reranking Model", "Revision", "Submission Date", "Average ⬆️"]
133
+ cols, types = get_default_cols(task_type, slug)
134
+ cols_set = frozenset(cols)
135
+ attrs_set = frozenset(attr_cols)
136
+ if add_fix_cols:
137
+ assert attrs_set.issubset(cols_set)
138
+ benchmark_cols = list(cols_set.difference(attrs_set))
139
+ assert len(benchmark_cols) == expected
140
+
141
+
142
+ @pytest.mark.parametrize(
143
+ "task_type, domains, languages, expected",
144
+ [
145
+ (
146
+ TaskType.qa,
147
+ ["wiki", "news"],
148
+ [
149
+ "zh",
150
+ ],
151
+ ["wiki_zh", "news_zh"],
152
+ ),
153
+ (
154
+ TaskType.qa,
155
+ [
156
+ "law",
157
+ ],
158
+ ["zh", "en"],
159
+ ["law_en"],
160
+ ),
161
+ (
162
+ TaskType.long_doc,
163
+ ["healthcare"],
164
+ ["zh", "en"],
165
+ [
166
+ "healthcare_en_pubmed_100k_200k_1",
167
+ "healthcare_en_pubmed_100k_200k_2",
168
+ "healthcare_en_pubmed_100k_200k_3",
169
+ "healthcare_en_pubmed_40k_50k_5_merged",
170
+ "healthcare_en_pubmed_30k_40k_10_merged",
171
+ ],
172
+ ),
173
+ ],
174
+ )
175
+ def test_get_selected_cols(task_type, domains, languages, expected):
176
+ slug = "air_bench_2404"
177
+ cols = get_selected_cols(task_type, slug, domains, languages)
178
+ assert sorted(cols) == sorted(expected)
179
+
180
+
181
+ @pytest.mark.parametrize("reset_rank", [False])
182
+ def test_select_columns(toy_df, reset_rank):
183
+ expected = [
184
+ "Rank 🏆",
185
+ "Retrieval Method",
186
+ "Reranking Model",
187
+ "Revision",
188
+ "Submission Date",
189
+ "Average ⬆️",
190
+ "news_zh",
191
+ ]
192
+ df_result = select_columns(toy_df, ["news"], ["zh"], version_slug="air_bench_2404", reset_ranking=reset_rank)
193
+ assert len(df_result.columns) == len(expected)
194
+ if reset_rank:
195
+ assert df_result["Average ⬆️"].equals(df_result["news_zh"])
196
+ else:
197
+ assert df_result["Average ⬆️"].equals(toy_df["Average ⬆️"])
198
+
199
+
200
+ @pytest.mark.parametrize(
201
+ "reset_rank, show_anony",
202
+ [
203
+ (False, True),
204
+ (True, True),
205
+ (True, False),
206
+ ],
207
+ )
208
+ def test__update_df_elem(toy_df, reset_rank, show_anony):
209
+ df = _update_df_elem(TaskType.qa, "AIR-Bench_24.04", toy_df, ["news"], ["zh"], [], "", show_anony, reset_rank)
210
+ if show_anony:
211
+ assert df.shape[0] == 4
212
+ else:
213
+ assert df.shape[0] == 3
214
+ if show_anony:
215
+ if reset_rank:
216
+ assert df["Average ⬆️"].equals(df["news_zh"])
217
+ else:
218
+ assert df["Average ⬆️"].equals(toy_df["Average ⬆️"])
219
+
220
+
221
+ @pytest.mark.parametrize(
222
+ "version, task_type",
223
+ [
224
+ ("AIR-Bench_24.04", TaskType.qa),
225
+ ("AIR-Bench_24.04", TaskType.long_doc),
226
+ ("AIR-Bench_24.05", TaskType.qa),
227
+ ("AIR-Bench_24.05", TaskType.long_doc),
228
+ ],
229
+ )
230
+ def test_get_leaderboard_df(version, task_type):
231
+ from src.loaders import load_raw_eval_results
232
+ from src.models import LeaderboardDataStore, get_safe_name
233
+
234
+ raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
235
+ ds = LeaderboardDataStore(version, get_safe_name(version), raw_data=raw_data)
236
+ df = get_leaderboard_df(ds, task_type, "ndcg_at_10")
237
+ assert df.shape[0] == 1
tests/test_utils.py DELETED
@@ -1,115 +0,0 @@
1
- import pandas as pd
2
- import pytest
3
-
4
- from src.utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc, get_iso_format_timestamp, get_default_cols, update_table
5
- from src.display.utils import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RANK, COL_NAME_AVG
6
-
7
-
8
- @pytest.fixture
9
- def toy_df():
10
- return pd.DataFrame(
11
- {
12
- "Retrieval Model": [
13
- "bge-m3",
14
- "bge-m3",
15
- "jina-embeddings-v2-base",
16
- "jina-embeddings-v2-base"
17
- ],
18
- "Reranking Model": [
19
- "bge-reranker-v2-m3",
20
- "NoReranker",
21
- "bge-reranker-v2-m3",
22
- "NoReranker"
23
- ],
24
- "Average ⬆️": [0.6, 0.4, 0.3, 0.2],
25
- "wiki_en": [0.8, 0.7, 0.2, 0.1],
26
- "wiki_zh": [0.4, 0.1, 0.4, 0.3],
27
- "news_en": [0.8, 0.7, 0.2, 0.1],
28
- "news_zh": [0.4, 0.1, 0.4, 0.3],
29
- }
30
- )
31
-
32
-
33
- @pytest.fixture
34
- def toy_df_long_doc():
35
- return pd.DataFrame(
36
- {
37
- "Retrieval Model": [
38
- "bge-m3",
39
- "bge-m3",
40
- "jina-embeddings-v2-base",
41
- "jina-embeddings-v2-base"
42
- ],
43
- "Reranking Model": [
44
- "bge-reranker-v2-m3",
45
- "NoReranker",
46
- "bge-reranker-v2-m3",
47
- "NoReranker"
48
- ],
49
- "Average ⬆️": [0.6, 0.4, 0.3, 0.2],
50
- "law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
51
- "law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
52
- "law_en_lex_files_500k_600k": [0.8, 0.7, 0.2, 0.1],
53
- "law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
54
- }
55
- )
56
- def test_filter_models(toy_df):
57
- df_result = filter_models(toy_df, ["bge-reranker-v2-m3", ])
58
- assert len(df_result) == 2
59
- assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
60
-
61
-
62
- def test_search_table(toy_df):
63
- df_result = search_table(toy_df, "jina")
64
- assert len(df_result) == 2
65
- assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
66
-
67
-
68
- def test_filter_queries(toy_df):
69
- df_result = filter_queries("jina", toy_df)
70
- assert len(df_result) == 2
71
- assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
72
-
73
-
74
- def test_select_columns(toy_df):
75
- df_result = select_columns(toy_df, ['news',], ['zh',])
76
- assert len(df_result.columns) == 4
77
- assert df_result['Average ⬆️'].equals(df_result['news_zh'])
78
-
79
-
80
- def test_update_table_long_doc(toy_df_long_doc):
81
- df_result = update_table_long_doc(toy_df_long_doc, ['law',], ['en',], ["bge-reranker-v2-m3", ], "jina")
82
- print(df_result)
83
-
84
-
85
- def test_get_iso_format_timestamp():
86
- timestamp_config, timestamp_fn = get_iso_format_timestamp()
87
- assert len(timestamp_fn) == 14
88
- assert len(timestamp_config) == 20
89
- assert timestamp_config[-1] == "Z"
90
-
91
-
92
- def test_get_default_cols():
93
- cols, types = get_default_cols("qa")
94
- for c, t in zip(cols, types):
95
- print(f"type({c}): {t}")
96
- assert len(frozenset(cols)) == len(cols)
97
-
98
-
99
- def test_update_table():
100
- df = pd.DataFrame(
101
- {
102
- COL_NAME_IS_ANONYMOUS: [False, False, False],
103
- COL_NAME_REVISION: ["a1", "a2", "a3"],
104
- COL_NAME_TIMESTAMP: ["2024-05-12T12:24:02Z"] * 3,
105
- COL_NAME_RERANKING_MODEL: ["NoReranker"] * 3,
106
- COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
107
- COL_NAME_RANK: [1, 2, 3],
108
- COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
109
- "wiki_en": [0.1, 0.2, 0.3]
110
- }
111
- )
112
- results = update_table(df, "wiki", "en", ["NoReranker"], "", show_anonymous=False, reset_ranking=False, show_revision_and_timestamp=False)
113
- # keep the RANK as the same regardless of the unsorted averages
114
- assert results[COL_NAME_RANK].to_list() == [1, 2, 3]
115
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/toydata/eval_results/AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json ADDED
The diff for this file is too large to render. See raw diff
 
tests/toydata/eval_results/AIR-Bench_24.05/bge-m3/NoReranker/results.json ADDED
The diff for this file is too large to render. See raw diff
 
tests/toydata/test_data.json DELETED
@@ -1,98 +0,0 @@
1
- [
2
- {
3
- "config": {
4
- "retrieval_model": "bge-m3",
5
- "reranking_model": "bge-reranker-v2-m3",
6
- "task": "long_doc",
7
- "metric": "ndcg_at_1"
8
- },
9
- "results": [
10
- {
11
- "domain": "law",
12
- "lang": "en",
13
- "dataset": "lex_files_500K-600K",
14
- "value": 0.75723
15
- }
16
- ]
17
- },
18
- {
19
- "config": {
20
- "retrieval_model": "bge-m3",
21
- "reranking_model": "bge-reranker-v2-m3",
22
- "task": "long_doc",
23
- "metric": "ndcg_at_3"
24
- },
25
- "results": [
26
- {
27
- "domain": "law",
28
- "lang": "en",
29
- "dataset": "lex_files_500K-600K",
30
- "value": 0.69909
31
- }
32
- ]
33
- },
34
- {
35
- "config": {
36
- "retrieval_model": "bge-m3",
37
- "reranking_model": "bge-reranker-v2-m3",
38
- "task": "qa",
39
- "metric": "ndcg_at_1"
40
- },
41
- "results": [
42
- {
43
- "domain": "wiki",
44
- "lang": "en",
45
- "dataset": "unknown",
46
- "value": 0.69083
47
- }
48
- ]
49
- },
50
- {
51
- "config": {
52
- "retrieval_model": "bge-m3",
53
- "reranking_model": "bge-reranker-v2-m3",
54
- "task": "qa",
55
- "metric": "ndcg_at_3"
56
- },
57
- "results": [
58
- {
59
- "domain": "wiki",
60
- "lang": "en",
61
- "dataset": "unknown",
62
- "value": 0.73359
63
- }
64
- ]
65
- },
66
- {
67
- "config": {
68
- "retrieval_model": "bge-m3",
69
- "reranking_model": "bge-reranker-v2-m3",
70
- "task": "qa",
71
- "metric": "ndcg_at_1"
72
- },
73
- "results": [
74
- {
75
- "domain": "wiki",
76
- "lang": "zh",
77
- "dataset": "unknown",
78
- "value": 0.78358
79
- }
80
- ]
81
- },
82
- {
83
- "config": {
84
- "retrieval_model": "bge-m3",
85
- "reranking_model": "bge-reranker-v2-m3",
86
- "task": "qa",
87
- "metric": "ndcg_at_3"
88
- },
89
- "results": [
90
- {
91
- "domain": "wiki",
92
- "lang": "zh",
93
- "dataset": "unknown",
94
- "value": 0.78358
95
- }
96
- ]
97
- }
98
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/toydata/test_results/bge-m3/NoReranker/results_2023-11-21T18-10-08.json DELETED
@@ -1,98 +0,0 @@
1
- [
2
- {
3
- "config": {
4
- "retrieval_model": "bge-m3",
5
- "reranking_model": "NoReranker",
6
- "task": "long_doc",
7
- "metric": "ndcg_at_1"
8
- },
9
- "results": [
10
- {
11
- "domain": "law",
12
- "lang": "en",
13
- "dataset": "lex_files_500K-600K",
14
- "value": 0.45723
15
- }
16
- ]
17
- },
18
- {
19
- "config": {
20
- "retrieval_model": "bge-m3",
21
- "reranking_model": "NoReranker",
22
- "task": "long_doc",
23
- "metric": "ndcg_at_3"
24
- },
25
- "results": [
26
- {
27
- "domain": "law",
28
- "lang": "en",
29
- "dataset": "lex_files_500K-600K",
30
- "value": 0.49909
31
- }
32
- ]
33
- },
34
- {
35
- "config": {
36
- "retrieval_model": "bge-m3",
37
- "reranking_model": "NoReranker",
38
- "task": "qa",
39
- "metric": "ndcg_at_1"
40
- },
41
- "results": [
42
- {
43
- "domain": "wiki",
44
- "lang": "en",
45
- "dataset": "unknown",
46
- "value": 0.49083
47
- }
48
- ]
49
- },
50
- {
51
- "config": {
52
- "retrieval_model": "bge-m3",
53
- "reranking_model": "NoReranker",
54
- "task": "qa",
55
- "metric": "ndcg_at_3"
56
- },
57
- "results": [
58
- {
59
- "domain": "wiki",
60
- "lang": "en",
61
- "dataset": "unknown",
62
- "value": 0.43359
63
- }
64
- ]
65
- },
66
- {
67
- "config": {
68
- "retrieval_model": "bge-m3",
69
- "reranking_model": "NoReranker",
70
- "task": "qa",
71
- "metric": "ndcg_at_1"
72
- },
73
- "results": [
74
- {
75
- "domain": "wiki",
76
- "lang": "zh",
77
- "dataset": "unknown",
78
- "value": 0.78358
79
- }
80
- ]
81
- },
82
- {
83
- "config": {
84
- "retrieval_model": "bge-m3",
85
- "reranking_model": "NoReranker",
86
- "task": "qa",
87
- "metric": "ndcg_at_3"
88
- },
89
- "results": [
90
- {
91
- "domain": "wiki",
92
- "lang": "zh",
93
- "dataset": "unknown",
94
- "value": 0.78358
95
- }
96
- ]
97
- }
98
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/toydata/test_results/bge-m3/bge-reranker-v2-m3/results_2023-11-21T18-10-08.json DELETED
@@ -1,98 +0,0 @@
1
- [
2
- {
3
- "config": {
4
- "retrieval_model": "bge-m3",
5
- "reranking_model": "bge-reranker-v2-m3",
6
- "task": "long_doc",
7
- "metric": "ndcg_at_1"
8
- },
9
- "results": [
10
- {
11
- "domain": "law",
12
- "lang": "en",
13
- "dataset": "lex_files_500K-600K",
14
- "value": 0.75723
15
- }
16
- ]
17
- },
18
- {
19
- "config": {
20
- "retrieval_model": "bge-m3",
21
- "reranking_model": "bge-reranker-v2-m3",
22
- "task": "long_doc",
23
- "metric": "ndcg_at_3"
24
- },
25
- "results": [
26
- {
27
- "domain": "law",
28
- "lang": "en",
29
- "dataset": "lex_files_500K-600K",
30
- "value": 0.69909
31
- }
32
- ]
33
- },
34
- {
35
- "config": {
36
- "retrieval_model": "bge-m3",
37
- "reranking_model": "bge-reranker-v2-m3",
38
- "task": "qa",
39
- "metric": "ndcg_at_1"
40
- },
41
- "results": [
42
- {
43
- "domain": "wiki",
44
- "lang": "en",
45
- "dataset": "unknown",
46
- "value": 0.69083
47
- }
48
- ]
49
- },
50
- {
51
- "config": {
52
- "retrieval_model": "bge-m3",
53
- "reranking_model": "bge-reranker-v2-m3",
54
- "task": "qa",
55
- "metric": "ndcg_at_3"
56
- },
57
- "results": [
58
- {
59
- "domain": "wiki",
60
- "lang": "en",
61
- "dataset": "unknown",
62
- "value": 0.73359
63
- }
64
- ]
65
- },
66
- {
67
- "config": {
68
- "retrieval_model": "bge-m3",
69
- "reranking_model": "bge-reranker-v2-m3",
70
- "task": "qa",
71
- "metric": "ndcg_at_1"
72
- },
73
- "results": [
74
- {
75
- "domain": "wiki",
76
- "lang": "zh",
77
- "dataset": "unknown",
78
- "value": 0.78358
79
- }
80
- ]
81
- },
82
- {
83
- "config": {
84
- "retrieval_model": "bge-m3",
85
- "reranking_model": "bge-reranker-v2-m3",
86
- "task": "qa",
87
- "metric": "ndcg_at_3"
88
- },
89
- "results": [
90
- {
91
- "domain": "wiki",
92
- "lang": "zh",
93
- "dataset": "unknown",
94
- "value": 0.78358
95
- }
96
- ]
97
- }
98
- ]