Shiyu Zhao commited on
Commit
c8373c1
1 Parent(s): 1d08545

Update space

Browse files
Files changed (1) hide show
  1. app.py +498 -195
app.py CHANGED
@@ -1,204 +1,507 @@
1
  import gradio as gr
2
- from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns
3
  import pandas as pd
4
- from apscheduler.schedulers.background import BackgroundScheduler
5
- from huggingface_hub import snapshot_download
6
-
7
- from src.about import (
8
- CITATION_BUTTON_LABEL,
9
- CITATION_BUTTON_TEXT,
10
- EVALUATION_QUEUE_TEXT,
11
- INTRODUCTION_TEXT,
12
- LLM_BENCHMARKS_TEXT,
13
- TITLE,
14
- )
15
- from src.display.css_html_js import custom_css
16
- from src.display.utils import (
17
- BENCHMARK_COLS,
18
- COLS,
19
- EVAL_COLS,
20
- EVAL_TYPES,
21
- AutoEvalColumn,
22
- ModelType,
23
- fields,
24
- WeightType,
25
- Precision
26
- )
27
- from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN
28
- from src.populate import get_evaluation_queue_df, get_leaderboard_df
29
- from src.submission.submit import add_new_eval
30
-
31
-
32
- def restart_space():
33
- API.restart_space(repo_id=REPO_ID)
34
-
35
- ### Space initialisation
36
- try:
37
- print(EVAL_REQUESTS_PATH)
38
- snapshot_download(
39
- repo_id=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30, token=TOKEN
40
- )
41
- except Exception:
42
- restart_space()
43
- try:
44
- print(EVAL_RESULTS_PATH)
45
- snapshot_download(
46
- repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30, token=TOKEN
47
- )
48
- except Exception:
49
- restart_space()
50
-
51
-
52
- LEADERBOARD_DF = get_leaderboard_df(EVAL_RESULTS_PATH, EVAL_REQUESTS_PATH, COLS, BENCHMARK_COLS)
53
-
54
- (
55
- finished_eval_queue_df,
56
- running_eval_queue_df,
57
- pending_eval_queue_df,
58
- ) = get_evaluation_queue_df(EVAL_REQUESTS_PATH, EVAL_COLS)
59
-
60
- def init_leaderboard(dataframe):
61
- if dataframe is None or dataframe.empty:
62
- raise ValueError("Leaderboard DataFrame is empty or None.")
63
- return Leaderboard(
64
- value=dataframe,
65
- datatype=[c.type for c in fields(AutoEvalColumn)],
66
- select_columns=SelectColumns(
67
- default_selection=[c.name for c in fields(AutoEvalColumn) if c.displayed_by_default],
68
- cant_deselect=[c.name for c in fields(AutoEvalColumn) if c.never_hidden],
69
- label="Select Columns to Display:",
70
- ),
71
- search_columns=[AutoEvalColumn.model.name, AutoEvalColumn.license.name],
72
- hide_columns=[c.name for c in fields(AutoEvalColumn) if c.hidden],
73
- filter_columns=[
74
- ColumnFilter(AutoEvalColumn.model_type.name, type="checkboxgroup", label="Model types"),
75
- ColumnFilter(AutoEvalColumn.precision.name, type="checkboxgroup", label="Precision"),
76
- ColumnFilter(
77
- AutoEvalColumn.params.name,
78
- type="slider",
79
- min=0.01,
80
- max=150,
81
- label="Select the number of parameters (B)",
82
- ),
83
- ColumnFilter(
84
- AutoEvalColumn.still_on_hub.name, type="boolean", label="Deleted/incomplete", default=True
85
- ),
86
- ],
87
- bool_checkboxgroup_label="Hide models",
88
- interactive=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- demo = gr.Blocks(css=custom_css)
93
- with demo:
94
- gr.HTML(TITLE)
95
- gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
96
-
97
- with gr.Tabs(elem_classes="tab-buttons") as tabs:
98
- with gr.TabItem("🏅 LLM Benchmark", elem_id="llm-benchmark-tab-table", id=0):
99
- leaderboard = init_leaderboard(LEADERBOARD_DF)
100
-
101
- with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
102
- gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
103
-
104
- with gr.TabItem("🚀 Submit here! ", elem_id="llm-benchmark-tab-table", id=3):
105
- with gr.Column():
106
- with gr.Row():
107
- gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
108
-
109
- with gr.Column():
110
- with gr.Accordion(
111
- f"✅ Finished Evaluations ({len(finished_eval_queue_df)})",
112
- open=False,
113
- ):
114
- with gr.Row():
115
- finished_eval_table = gr.components.Dataframe(
116
- value=finished_eval_queue_df,
117
- headers=EVAL_COLS,
118
- datatype=EVAL_TYPES,
119
- row_count=5,
120
- )
121
- with gr.Accordion(
122
- f"🔄 Running Evaluation Queue ({len(running_eval_queue_df)})",
123
- open=False,
124
- ):
125
- with gr.Row():
126
- running_eval_table = gr.components.Dataframe(
127
- value=running_eval_queue_df,
128
- headers=EVAL_COLS,
129
- datatype=EVAL_TYPES,
130
- row_count=5,
131
- )
132
-
133
- with gr.Accordion(
134
- f"⏳ Pending Evaluation Queue ({len(pending_eval_queue_df)})",
135
- open=False,
136
- ):
137
- with gr.Row():
138
- pending_eval_table = gr.components.Dataframe(
139
- value=pending_eval_queue_df,
140
- headers=EVAL_COLS,
141
- datatype=EVAL_TYPES,
142
- row_count=5,
143
- )
144
- with gr.Row():
145
- gr.Markdown("# ✉️✨ Submit your model here!", elem_classes="markdown-text")
146
-
147
- with gr.Row():
148
- with gr.Column():
149
- model_name_textbox = gr.Textbox(label="Model name")
150
- revision_name_textbox = gr.Textbox(label="Revision commit", placeholder="main")
151
- model_type = gr.Dropdown(
152
- choices=[t.to_str(" : ") for t in ModelType if t != ModelType.Unknown],
153
- label="Model type",
154
- multiselect=False,
155
- value=None,
156
- interactive=True,
157
- )
158
-
159
- with gr.Column():
160
- precision = gr.Dropdown(
161
- choices=[i.value.name for i in Precision if i != Precision.Unknown],
162
- label="Precision",
163
- multiselect=False,
164
- value="float16",
165
- interactive=True,
166
- )
167
- weight_type = gr.Dropdown(
168
- choices=[i.value.name for i in WeightType],
169
- label="Weights type",
170
- multiselect=False,
171
- value="Original",
172
- interactive=True,
173
- )
174
- base_model_name_textbox = gr.Textbox(label="Base model (for delta or adapter weights)")
175
-
176
- submit_button = gr.Button("Submit Eval")
177
- submission_result = gr.Markdown()
178
- submit_button.click(
179
- add_new_eval,
180
- [
181
- model_name_textbox,
182
- base_model_name_textbox,
183
- revision_name_textbox,
184
- precision,
185
- weight_type,
186
- model_type,
187
- ],
188
- submission_result,
189
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  with gr.Row():
192
- with gr.Accordion("📙 Citation", open=False):
193
- citation_button = gr.Textbox(
194
- value=CITATION_BUTTON_TEXT,
195
- label=CITATION_BUTTON_LABEL,
196
- lines=20,
197
- elem_id="citation-button",
198
- show_copy_button=True,
199
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- scheduler = BackgroundScheduler()
202
- scheduler.add_job(restart_space, "interval", seconds=1800)
203
- scheduler.start()
204
- demo.queue(default_concurrency_limit=40).launch()
 
1
  import gradio as gr
 
2
  import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import re
6
+ from datetime import datetime
7
+ import json
8
+ import torch
9
+ from tqdm import tqdm
10
+ from concurrent.futures import ProcessPoolExecutor, as_completed
11
+
12
+ from stark_qa import load_qa
13
+ from stark_qa.evaluator import Evaluator
14
+
15
+
16
+ def process_single_instance(args):
17
+ idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
18
+ query, query_id, answer_ids, meta_info = qa_dataset[idx]
19
+
20
+ try:
21
+ pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
22
+ except IndexError:
23
+ raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
24
+ except Exception as e:
25
+ raise RuntimeError(f'Unexpected error occurred while fetching prediction rank for query_id={query_id}: {e}')
26
+
27
+ if isinstance(pred_rank, str):
28
+ try:
29
+ pred_rank = eval(pred_rank)
30
+ except SyntaxError as e:
31
+ raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
32
+
33
+ if not isinstance(pred_rank, list):
34
+ raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
35
+
36
+ pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
37
+ answer_ids = torch.LongTensor(answer_ids)
38
+ result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
39
+
40
+ result["idx"], result["query_id"] = idx, query_id
41
+ return result
42
+
43
+
44
+ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
45
+ candidate_ids_dict = {
46
+ 'amazon': [i for i in range(957192)],
47
+ 'mag': [i for i in range(1172724, 1872968)],
48
+ 'prime': [i for i in range(129375)]
49
+ }
50
+ try:
51
+ eval_csv = pd.read_csv(csv_path)
52
+ if 'query_id' not in eval_csv.columns:
53
+ raise ValueError('No `query_id` column found in the submitted csv.')
54
+ if 'pred_rank' not in eval_csv.columns:
55
+ raise ValueError('No `pred_rank` column found in the submitted csv.')
56
+
57
+ eval_csv = eval_csv[['query_id', 'pred_rank']]
58
+
59
+ if dataset not in candidate_ids_dict:
60
+ raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
61
+ if split not in ['test', 'test-0.1', 'human_generated_eval']:
62
+ raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
63
+
64
+ evaluator = Evaluator(candidate_ids_dict[dataset])
65
+ eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
66
+ qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
67
+ split_idx = qa_dataset.get_idx_split()
68
+ all_indices = split_idx[split].tolist()
69
+
70
+ results_list = []
71
+ query_ids = []
72
+
73
+ # Prepare args for each worker
74
+ args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
75
+
76
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
77
+ futures = [executor.submit(process_single_instance, arg) for arg in args]
78
+ for future in tqdm(as_completed(futures), total=len(futures)):
79
+ result = future.result() # This will raise an error if the worker encountered one
80
+ results_list.append(result)
81
+ query_ids.append(result['query_id'])
82
+
83
+ # Concatenate results and compute final metrics
84
+ eval_csv = pd.concat([eval_csv, pd.DataFrame(results_list)], ignore_index=True)
85
+ final_results = {
86
+ metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
87
+ }
88
+ return final_results
89
+
90
+ except pd.errors.EmptyDataError:
91
+ return "Error: The CSV file is empty or could not be read. Please check the file and try again."
92
+ except FileNotFoundError:
93
+ return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
94
+ except Exception as error:
95
+ return f"{error}"
96
+
97
+
98
+ # Data dictionaries for leaderboard
99
+ data_synthesized_full = {
100
+ 'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2'],
101
+ 'STARK-AMAZON_Hit@1': [44.94, 15.29, 30.96, 26.56, 39.16, 40.93, 21.74, 42.08, 40.07, 46.10],
102
+ 'STARK-AMAZON_Hit@5': [67.42, 47.93, 51.06, 50.01, 62.73, 64.37, 41.65, 66.87, 64.98, 66.02],
103
+ 'STARK-AMAZON_R@20': [53.77, 44.49, 41.95, 52.05, 53.29, 54.28, 33.22, 56.52, 55.12, 53.44],
104
+ 'STARK-AMAZON_MRR': [55.30, 30.20, 40.66, 37.75, 50.35, 51.60, 31.47, 53.46, 51.55, 55.51],
105
+ 'STARK-MAG_Hit@1': [25.85, 10.51, 21.96, 12.88, 29.08, 30.06, 18.01, 37.90, 25.92, 31.18],
106
+ 'STARK-MAG_Hit@5': [45.25, 35.23, 36.50, 39.01, 49.61, 50.58, 34.85, 56.74, 50.43, 46.42],
107
+ 'STARK-MAG_R@20': [45.69, 42.11, 35.32, 46.97, 48.36, 50.49, 35.46, 46.40, 50.80, 43.94],
108
+ 'STARK-MAG_MRR': [34.91, 21.34, 29.14, 29.12, 38.62, 39.66, 26.10, 47.25, 36.94, 38.39],
109
+ 'STARK-PRIME_Hit@1': [12.75, 4.46, 6.53, 8.85, 12.63, 10.85, 10.10, 15.57, 15.10, 11.75],
110
+ 'STARK-PRIME_Hit@5': [27.92, 21.85, 15.67, 21.35, 31.49, 30.23, 22.49, 33.42, 33.56, 23.85],
111
+ 'STARK-PRIME_R@20': [31.25, 30.13, 16.52, 29.63, 36.00, 37.83, 26.34, 39.09, 38.05, 25.04],
112
+ 'STARK-PRIME_MRR': [19.84, 12.38, 11.05, 14.73, 21.41, 19.99, 16.12, 24.11, 23.49, 17.39]
113
+ }
114
+
115
+ data_synthesized_10 = {
116
+ 'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2', 'Claude3 Reranker', 'GPT4 Reranker'],
117
+ 'STARK-AMAZON_Hit@1': [42.68, 16.46, 30.09, 25.00, 39.02, 43.29, 18.90, 43.29, 40.85, 44.31, 45.49, 44.79],
118
+ 'STARK-AMAZON_Hit@5': [67.07, 50.00, 49.27, 48.17, 64.02, 67.68, 37.80, 71.34, 62.80, 65.24, 71.13, 71.17],
119
+ 'STARK-AMAZON_R@20': [54.48, 42.15, 41.91, 51.65, 49.30, 56.04, 34.73, 56.14, 52.47, 51.00, 53.77, 55.35],
120
+ 'STARK-AMAZON_MRR': [54.02, 30.20, 39.30, 36.87, 50.32, 54.20, 28.76, 55.07, 51.54, 55.07, 55.91, 55.69],
121
+ 'STARK-MAG_Hit@1': [27.81, 11.65, 22.89, 12.03, 28.20, 34.59, 19.17, 38.35, 25.56, 31.58, 36.54, 40.90],
122
+ 'STARK-MAG_Hit@5': [45.48, 36.84, 37.26, 37.97, 52.63, 50.75, 33.46, 58.64, 50.37, 47.36, 53.17, 58.18],
123
+ 'STARK-MAG_R@20': [44.59, 42.30, 44.16, 47.98, 49.25, 50.75, 29.85, 46.38, 53.03, 45.72, 48.36, 48.60],
124
+ 'STARK-MAG_MRR': [35.97, 21.82, 30.00, 28.70, 38.55, 42.90, 26.06, 48.25, 36.82, 38.98, 44.15, 49.00],
125
+ 'STARK-PRIME_Hit@1': [13.93, 5.00, 6.78, 7.14, 15.36, 12.14, 9.29, 16.79, 15.36, 15.00, 17.79, 18.28],
126
+ 'STARK-PRIME_Hit@5': [31.07, 23.57, 16.15, 17.14, 31.07, 31.42, 20.7, 34.29, 32.86, 26.07, 36.90, 37.28],
127
+ 'STARK-PRIME_R@20': [32.84, 30.50, 17.07, 32.95, 37.88, 37.34, 25.54, 41.11, 40.99, 27.78, 35.57, 34.05],
128
+ 'STARK-PRIME_MRR': [21.68, 13.50, 11.42, 16.27, 23.50, 21.23, 15.00, 24.99, 23.70, 19.98, 26.27, 26.55]
129
+ }
130
+
131
+ data_human_generated = {
132
+ 'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2', 'Claude3 Reranker', 'GPT4 Reranker'],
133
+ 'STARK-AMAZON_Hit@1': [27.16, 16.05, 25.93, 22.22, 39.50, 35.80, 29.63, 40.74, 46.91, 33.33, 53.09, 50.62],
134
+ 'STARK-AMAZON_Hit@5': [51.85, 39.51, 54.32, 49.38, 64.19, 62.96, 46.91, 71.60, 72.84, 55.56, 74.07, 75.31],
135
+ 'STARK-AMAZON_R@20': [29.23, 15.23, 23.69, 21.54, 35.46, 33.01, 21.21, 36.30, 40.22, 29.03, 35.46, 35.46],
136
+ 'STARK-AMAZON_MRR': [18.79, 27.21, 37.12, 31.33, 52.65, 47.84, 38.61, 53.21, 58.74, 43.77, 62.11, 61.06],
137
+ 'STARK-MAG_Hit@1': [32.14, 4.72, 25.00, 20.24, 28.57, 22.62, 16.67, 34.52, 23.81, 33.33, 38.10, 36.90],
138
+ 'STARK-MAG_Hit@5': [41.67, 9.52, 30.95, 26.19, 41.67, 36.90, 28.57, 44.04, 41.67, 36.90, 45.24, 46.43],
139
+ 'STARK-MAG_R@20': [32.46, 25.00, 27.24, 28.76, 35.95, 32.44, 21.74, 34.57, 39.85, 30.50, 35.95, 35.95],
140
+ 'STARK-MAG_MRR': [37.42, 7.90, 27.98, 25.53, 35.81, 29.68, 21.59, 38.72, 31.43, 35.97, 42.00, 40.65],
141
+ 'STARK-PRIME_Hit@1': [22.45, 2.04, 7.14, 6.12, 17.35, 16.33, 9.18, 25.51, 24.49, 15.31, 28.57, 28.57],
142
+ 'STARK-PRIME_Hit@5': [41.84, 9.18, 13.27, 13.27, 34.69, 32.65, 21.43, 41.84, 39.80, 26.53, 46.94, 44.90],
143
+ 'STARK-PRIME_R@20': [42.32, 10.69, 11.72, 17.62, 41.09, 39.01, 26.77, 48.10, 47.21, 25.56, 41.61, 41.61],
144
+ 'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
145
+ }
146
+
147
+ # Initialize DataFrames
148
+ df_synthesized_full = pd.DataFrame(data_synthesized_full)
149
+ df_synthesized_10 = pd.DataFrame(data_synthesized_10)
150
+ df_human_generated = pd.DataFrame(data_human_generated)
151
+
152
+ # Model type definitions
153
+ model_types = {
154
+ 'Sparse Retriever': ['BM25'],
155
+ 'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
156
+ 'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
157
+ 'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
158
+ 'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
159
+ }
160
+
161
+ # Submission form validation functions
162
+ def validate_email(email_str):
163
+ """Validate email format(s)"""
164
+ emails = [e.strip() for e in email_str.split(';')]
165
+ email_pattern = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
166
+ return all(email_pattern.match(email) for email in emails)
167
+
168
+ def validate_github_url(url):
169
+ """Validate GitHub URL format"""
170
+ github_pattern = re.compile(
171
+ r'^https?:\/\/(?:www\.)?github\.com\/[\w-]+\/[\w.-]+\/?$'
172
  )
173
+ return bool(github_pattern.match(url))
174
 
175
+ def validate_csv(file_obj):
176
+ """Validate CSV file format and content"""
177
+ try:
178
+ df = pd.read_csv(file_obj.name)
179
+ required_cols = ['query_id', 'pred_rank']
180
+
181
+ if not all(col in df.columns for col in required_cols):
182
+ return False, "CSV must contain 'query_id' and 'pred_rank' columns"
183
+
184
+ try:
185
+ first_rank = eval(df['pred_rank'].iloc[0]) if isinstance(df['pred_rank'].iloc[0], str) else df['pred_rank'].iloc[0]
186
+ if not isinstance(first_rank, list) or len(first_rank) < 20:
187
+ return False, "pred_rank must be a list with at least 20 candidates"
188
+ except:
189
+ return False, "Invalid pred_rank format"
190
+
191
+ return True, "Valid CSV file"
192
+ except Exception as e:
193
+ return False, f"Error processing CSV: {str(e)}"
194
 
195
+ def sanitize_name(name):
196
+ """Sanitize name for file system use"""
197
+ return re.sub(r'[^a-zA-Z0-9]', '_', name)
198
+
199
+ def save_submission(submission_data, csv_file):
200
+ """
201
+ Save submission data and CSV file using model_name_team_name format
202
+
203
+ Args:
204
+ submission_data (dict): Metadata and results for the submission
205
+ csv_file: The uploaded CSV file object
206
+ """
207
+ # Create folder name from model name and team name
208
+ model_name_clean = sanitize_name(submission_data['method_name'])
209
+ team_name_clean = sanitize_name(submission_data['team_name'])
210
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
211
+
212
+ # Create folder name: model_name_team_name
213
+ folder_name = f"{model_name_clean}_{team_name_clean}"
214
+ submission_id = f"{folder_name}_{timestamp}"
215
+
216
+ # Create submission directory structure
217
+ base_dir = "submissions"
218
+ submission_dir = os.path.join(base_dir, folder_name)
219
+ os.makedirs(submission_dir, exist_ok=True)
220
+
221
+ # Save CSV file with timestamp to allow multiple submissions
222
+ csv_filename = f"predictions_{timestamp}.csv"
223
+ csv_path = os.path.join(submission_dir, csv_filename)
224
+ if hasattr(csv_file, 'name'):
225
+ with open(csv_file.name, 'rb') as source, open(csv_path, 'wb') as target:
226
+ target.write(source.read())
227
+
228
+ # Add file paths to submission data
229
+ submission_data.update({
230
+ "csv_path": csv_path,
231
+ "submission_id": submission_id,
232
+ "folder_name": folder_name
233
+ })
234
+
235
+ # Save metadata as JSON with timestamp
236
+ metadata_path = os.path.join(submission_dir, f"metadata_{timestamp}.json")
237
+ with open(metadata_path, 'w') as f:
238
+ json.dump(submission_data, f, indent=4)
239
+
240
+ # Update latest.json to track most recent submission
241
+ latest_path = os.path.join(submission_dir, "latest.json")
242
+ with open(latest_path, 'w') as f:
243
+ json.dump({
244
+ "latest_submission": timestamp,
245
+ "status": "pending_review",
246
+ "method_name": submission_data['method_name']
247
+ }, f, indent=4)
248
+
249
+ return submission_id
250
+
251
+ def update_leaderboard_data(submission_data):
252
+ """
253
+ Update leaderboard data with new submission results
254
+ Only uses model name in the displayed table
255
+ """
256
+ global df_synthesized_full, df_synthesized_10, df_human_generated
257
+
258
+ # Determine which DataFrame to update based on split
259
+ split_to_df = {
260
+ 'test': df_synthesized_full,
261
+ 'test-0.1': df_synthesized_10,
262
+ 'human_generated_eval': df_human_generated
263
+ }
264
+
265
+ df_to_update = split_to_df[submission_data['split']]
266
+
267
+ # Prepare new row data
268
+ new_row = {
269
+ 'Method': submission_data['method_name'], # Only use method name in table
270
+ f'STARK-{submission_data["dataset"].upper()}_Hit@1': submission_data['results']['hit@1'],
271
+ f'STARK-{submission_data["dataset"].upper()}_Hit@5': submission_data['results']['hit@5'],
272
+ f'STARK-{submission_data["dataset"].upper()}_R@20': submission_data['results']['recall@20'],
273
+ f'STARK-{submission_data["dataset"].upper()}_MRR': submission_data['results']['mrr']
274
+ }
275
+
276
+ # Check if method already exists
277
+ method_mask = df_to_update['Method'] == submission_data['method_name']
278
+ if method_mask.any():
279
+ # Update existing row
280
+ for col in new_row:
281
+ df_to_update.loc[method_mask, col] = new_row[col]
282
+ else:
283
+ # Add new row
284
+ df_to_update.loc[len(df_to_update)] = new_row
285
+
286
+ def process_submission(
287
+ method_name, team_name, dataset, split, contact_email,
288
+ code_repo, csv_file, model_description, hardware, paper_link
289
+ ):
290
+ """Process and validate submission"""
291
+ try:
292
+ # [Previous validation code remains the same]
293
+
294
+ # Process CSV file through evaluation pipeline
295
+ results = compute_metrics(
296
+ csv_file.name,
297
+ dataset=dataset.lower(),
298
+ split=split,
299
+ num_workers=4
300
+ )
301
+
302
+ if isinstance(results, str) and results.startswith("Error"):
303
+ return f"Evaluation error: {results}"
304
+
305
+ # Prepare submission data
306
+ submission_data = {
307
+ "method_name": method_name,
308
+ "team_name": team_name,
309
+ "dataset": dataset,
310
+ "split": split,
311
+ "contact_email": contact_email,
312
+ "code_repo": code_repo,
313
+ "model_description": model_description,
314
+ "hardware": hardware,
315
+ "paper_link": paper_link,
316
+ "results": results,
317
+ "status": "pending_review",
318
+ "submission_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
319
+ }
320
+
321
+ # Save submission and get ID
322
+ submission_id = save_submission(submission_data, csv_file)
323
+
324
+ # Update leaderboard data if submission is valid
325
+ update_leaderboard_data(submission_data)
326
+
327
+ return f"""
328
+ Submission successful! Your submission ID is: {submission_id}
329
+
330
+ Evaluation Results:
331
+ Hit@1: {results['hit@1']:.2f}
332
+ Hit@5: {results['hit@5']:.2f}
333
+ Recall@20: {results['recall@20']:.2f}
334
+ MRR: {results['mrr']:.2f}
335
+
336
+ Your submission has been saved and is pending review.
337
+ Once approved, your results will appear in the leaderboard under the method name: {method_name}
338
+ """
339
+
340
+ except Exception as e:
341
+ return f"Error processing submission: {str(e)}"
342
+
343
+ def filter_by_model_type(df, selected_types):
344
+ if not selected_types:
345
+ return df.head(0)
346
+ selected_models = [model for type in selected_types for model in model_types[type]]
347
+ return df[df['Method'].isin(selected_models)]
348
+
349
+ def format_dataframe(df, dataset):
350
+ columns = ['Method'] + [col for col in df.columns if dataset in col]
351
+ filtered_df = df[columns].copy()
352
+ filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
353
+ filtered_df = filtered_df.sort_values('MRR', ascending=False)
354
+ return filtered_df
355
 
356
+ def update_tables(selected_types):
357
+ filtered_df_full = filter_by_model_type(df_synthesized_full, selected_types)
358
+ filtered_df_10 = filter_by_model_type(df_synthesized_10, selected_types)
359
+ filtered_df_human = filter_by_model_type(df_human_generated, selected_types)
360
+
361
+ outputs = []
362
+ for df in [filtered_df_full, filtered_df_10, filtered_df_human]:
363
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
364
+ outputs.append(format_dataframe(df, f"STARK-{dataset}"))
365
+
366
+ return outputs
367
+
368
+
369
+ css = """
370
+ table > thead {
371
+ white-space: normal
372
+ }
373
+
374
+ table {
375
+ --cell-width-1: 250px
376
+ }
377
+
378
+ table > tbody > tr > td:nth-child(2) > div {
379
+ overflow-x: auto
380
+ }
381
+
382
+ .tab-nav {
383
+ border-bottom: 1px solid rgba(255, 255, 255, 0.1);
384
+ margin-bottom: 1rem;
385
+ }
386
+ """
387
+
388
+ # Main application
389
+ with gr.Blocks(css=css) as demo:
390
+ gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
391
+ gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
392
+
393
+ # Model type filter
394
+ model_type_filter = gr.CheckboxGroup(
395
+ choices=list(model_types.keys()),
396
+ value=list(model_types.keys()),
397
+ label="Model types",
398
+ interactive=True
399
+ )
400
+
401
+ # Initialize dataframes list
402
+ all_dfs = []
403
+
404
+ # Create nested tabs structure
405
+ with gr.Tabs() as outer_tabs:
406
+ with gr.TabItem("Synthesized (full)"):
407
+ with gr.Tabs() as inner_tabs1:
408
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
409
+ with gr.TabItem(dataset):
410
+ all_dfs.append(gr.DataFrame(interactive=False))
411
+
412
+ with gr.TabItem("Synthesized (10%)"):
413
+ with gr.Tabs() as inner_tabs2:
414
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
415
+ with gr.TabItem(dataset):
416
+ all_dfs.append(gr.DataFrame(interactive=False))
417
+
418
+ with gr.TabItem("Human-Generated"):
419
+ with gr.Tabs() as inner_tabs3:
420
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
421
+ with gr.TabItem(dataset):
422
+ all_dfs.append(gr.DataFrame(interactive=False))
423
+
424
+ # Submission section
425
+ gr.Markdown("---")
426
+ gr.Markdown("## Submit Your Results")
427
+ gr.Markdown("""
428
+ Submit your results to be included in the leaderboard. Please ensure your submission meets all requirements.
429
+ For questions, contact stark-qa@cs.stanford.edu
430
+ """)
431
+
432
  with gr.Row():
433
+ with gr.Column():
434
+ method_name = gr.Textbox(
435
+ label="Method Name (max 25 chars)*",
436
+ placeholder="e.g., MyRetrievalModel-v1"
 
 
 
437
  )
438
+ team_name = gr.Textbox(
439
+ label="Team Name (max 25 chars)*",
440
+ placeholder="e.g., Stanford NLP"
441
+ )
442
+ dataset = gr.Dropdown(
443
+ choices=["amazon", "mag", "prime"],
444
+ label="Dataset*",
445
+ value="amazon"
446
+ )
447
+ split = gr.Dropdown(
448
+ choices=["test", "test-0.1", "human_generated_eval"],
449
+ label="Split*",
450
+ value="test"
451
+ )
452
+ contact_email = gr.Textbox(
453
+ label="Contact Email(s)*",
454
+ placeholder="email@example.com; another@example.com"
455
+ )
456
+
457
+ with gr.Column():
458
+ code_repo = gr.Textbox(
459
+ label="Code Repository*",
460
+ placeholder="https://github.com/username/repository"
461
+ )
462
+ csv_file = gr.File(
463
+ label="Prediction CSV*",
464
+ file_types=[".csv"]
465
+ )
466
+ model_description = gr.Textbox(
467
+ label="Model Description*",
468
+ lines=3,
469
+ placeholder="Briefly describe how your retriever model works..."
470
+ )
471
+ hardware = gr.Textbox(
472
+ label="Hardware Specifications*",
473
+ placeholder="e.g., 4x NVIDIA A100 80GB"
474
+ )
475
+ paper_link = gr.Textbox(
476
+ label="Paper Link (Optional)",
477
+ placeholder="https://arxiv.org/abs/..."
478
+ )
479
+
480
+ submit_btn = gr.Button("Submit", variant="primary")
481
+ result = gr.Textbox(label="Submission Status", interactive=False)
482
+
483
+ # Set up event handlers
484
+ model_type_filter.change(
485
+ update_tables,
486
+ inputs=[model_type_filter],
487
+ outputs=all_dfs
488
+ )
489
+
490
+ submit_btn.click(
491
+ process_submission,
492
+ inputs=[
493
+ method_name, team_name, dataset, split, contact_email,
494
+ code_repo, csv_file, model_description, hardware, paper_link
495
+ ],
496
+ outputs=result
497
+ )
498
+
499
+ # Initial table update
500
+ demo.load(
501
+ update_tables,
502
+ inputs=[model_type_filter],
503
+ outputs=all_dfs
504
+ )
505
 
506
+ # Launch the application
507
+ demo.launch()