mgyigit commited on
Commit
6ba85f3
·
verified ·
1 Parent(s): edb9d91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -131
app.py CHANGED
@@ -21,7 +21,7 @@ from src.vis_utils import *
21
  from src.bin.PROBE import run_probe
22
 
23
  # ------------------------------------------------------------------
24
- # Helper functions moved / added here so that UI callbacks can see them
25
  # ------------------------------------------------------------------
26
 
27
  def add_new_eval(
@@ -40,7 +40,6 @@ def add_new_eval(
40
  if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
41
  gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
42
  return -1
43
-
44
  if 'affinity' in benchmark_types and skempi_file is None:
45
  gr.Warning("SKEMPI representations are required for affinity benchmark!")
46
  return -1
@@ -77,20 +76,17 @@ def refresh_data():
77
  """Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
78
  api.restart_space(repo_id=repo_id)
79
  benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
80
-
81
  for benchmark_type in benchmark_types:
82
  path = f"/tmp/{benchmark_type}_results.csv"
83
  if os.path.exists(path):
84
  os.remove(path)
85
-
86
  benchmark_types.remove("leaderboard")
87
  download_from_hub(benchmark_types)
88
 
89
 
90
- # ------- Leaderboard helpers -------------------------------------------------
91
 
92
  def update_metrics(selected_benchmarks):
93
- """Populate metric selector according to chosen benchmark types."""
94
  updated_metrics = set()
95
  for benchmark in selected_benchmarks:
96
  updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
@@ -98,50 +94,33 @@ def update_metrics(selected_benchmarks):
98
 
99
 
100
  def update_leaderboard(selected_methods, selected_metrics):
101
- updated_df = get_baseline_df(selected_methods, selected_metrics)
102
- return updated_df
103
 
104
- # ------- Visualisation helpers ----------------------------------------------
105
 
106
  def get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric):
107
- """Return a short natural‑language explanation for the produced plot."""
108
  if benchmark_type == "similarity":
109
  return (
110
- f"The scatter plot compares models on **{x_metric}** (x‑axis) and "
111
- f"**{y_metric}** (yaxis). Points further to the upper‑right indicate better "
112
- "performance on both metrics."
113
  )
114
- elif benchmark_type == "function":
115
  return (
116
- f"The heat‑map shows performance of each model (columns) across GO terms "
117
- f"for the **{aspect.upper()}** aspect using the **{single_metric}** metric. "
118
- "Darker squares correspond to stronger performance; hierarchical clustering "
119
- "groups similar models and tasks together."
120
  )
121
- elif benchmark_type == "family":
122
  return (
123
- f"The horizontal box‑plots summarise cross‑validation performance on the "
124
- f"**{dataset}** dataset. Higher median MCC values indicate better family‑"
125
- "classification accuracy."
126
  )
127
- elif benchmark_type == "affinity":
128
  return (
129
- f"Each boxplot shows the distribution of **{single_metric}** scores for every "
130
- "model when predicting binding affinity changes. Higher values are better."
131
  )
132
  return ""
133
 
134
 
135
- def generate_plot_and_explanation(
136
- benchmark_type,
137
- methods_selected,
138
- x_metric,
139
- y_metric,
140
- aspect,
141
- dataset,
142
- single_metric,
143
- ):
144
- """Callback wrapper that returns both the image path and a textual explanation."""
145
  plot_path = benchmark_plot(
146
  benchmark_type,
147
  methods_selected,
@@ -154,10 +133,34 @@ def generate_plot_and_explanation(
154
  explanation = get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric)
155
  return plot_path, explanation
156
 
157
- # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # UI definition
159
- # ------------------------------------------------------------------
160
- block = gr.Blocks()
161
 
162
  with block:
163
  gr.Markdown(LEADERBOARD_INTRODUCTION)
@@ -167,23 +170,28 @@ with block:
167
  # 1️⃣ Leaderboard tab
168
  # ------------------------------------------------------------------
169
  with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
170
- leaderboard = get_baseline_df(None, None) # baseline leaderboard without filtering
 
 
 
 
 
 
171
 
 
172
  method_names = leaderboard['Method'].unique().tolist()
173
- metric_names = leaderboard.columns.tolist()
174
- metric_names.remove('Method') # remove non‑metric column
175
 
176
  benchmark_metric_mapping = {
177
  "similarity": [m for m in metric_names if m.startswith('sim_')],
178
- "function": [m for m in metric_names if m.startswith('func')],
179
- "family": [m for m in metric_names if m.startswith('fam_')],
180
- "affinity": [m for m in metric_names if m.startswith('aff_')],
181
  }
182
 
183
- # selectors -----------------------------------------------------
184
  leaderboard_method_selector = gr.CheckboxGroup(
185
  choices=method_names,
186
- label="Select Methods for the Leaderboard",
187
  value=method_names,
188
  interactive=True,
189
  )
@@ -197,15 +205,14 @@ with block:
197
 
198
  leaderboard_metric_selector = gr.CheckboxGroup(
199
  choices=metric_names,
200
- label="Select Metrics for the Leaderboard",
201
  value=None,
202
  interactive=True,
203
  )
204
 
205
- # leaderboard table --------------------------------------------
206
  baseline_value = get_baseline_df(method_names, metric_names)
207
  baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
208
- baseline_header = ["Method"] + metric_names
209
  baseline_datatype = ['markdown'] + ['number'] * len(metric_names)
210
 
211
  with gr.Row(show_progress=True, variant='panel'):
@@ -215,22 +222,21 @@ with block:
215
  type="pandas",
216
  datatype=baseline_datatype,
217
  interactive=False,
218
- visible=True,
 
219
  )
220
 
221
- # callbacks -----------------------------------------------------
222
  leaderboard_method_selector.change(
223
  get_baseline_df,
224
  inputs=[leaderboard_method_selector, leaderboard_metric_selector],
225
  outputs=data_component,
226
  )
227
-
228
  benchmark_type_selector_lb.change(
229
  lambda selected: update_metrics(selected),
230
  inputs=[benchmark_type_selector_lb],
231
  outputs=leaderboard_metric_selector,
232
  )
233
-
234
  leaderboard_metric_selector.change(
235
  get_baseline_df,
236
  inputs=[leaderboard_method_selector, leaderboard_metric_selector],
@@ -238,58 +244,36 @@ with block:
238
  )
239
 
240
  # ------------------------------------------------------------------
241
- # 2️⃣ Visualisation tab
242
  # ------------------------------------------------------------------
243
- with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2):
244
- # Intro / instructions
245
  gr.Markdown(
246
- """
247
- ## **Interactive Visualizations**
248
- Select a benchmark type first; context‑specific options will appear automatically.
249
- Once your parameters are set, click **Plot** to generate the figure.
250
-
251
- **How to read the plots**
252
- * **Similarity (scatter)** – Each point is a model. Points nearer the top‑right perform well on both chosen similarity metrics.
253
- * **Function prediction (heat‑map)** – Darker squares denote better scores. Rows/columns are clustered to reveal shared structure.
254
- * **Family / Affinity (boxplots)** – Boxes summarise distribution across folds/targets. Higher medians indicate stronger performance.
255
- """,
256
  elem_classes="markdown-text",
257
  )
258
-
259
- # ------------------------------------------------------------------
260
- # selectors specific to visualisation
261
- # ------------------------------------------------------------------
262
  vis_benchmark_type_selector = gr.Dropdown(
263
  choices=list(benchmark_specific_metrics.keys()),
264
- label="Select Benchmark Type",
265
  value=None,
266
  )
267
-
268
  with gr.Row():
269
- vis_x_metric_selector = gr.Dropdown(choices=[], label="Select X‑axis Metric", visible=False)
270
- vis_y_metric_selector = gr.Dropdown(choices=[], label="Select Y‑axis Metric", visible=False)
271
- vis_aspect_type_selector = gr.Dropdown(choices=[], label="Select Aspect Type", visible=False)
272
- vis_dataset_selector = gr.Dropdown(choices=[], label="Select Dataset", visible=False)
273
- vis_single_metric_selector = gr.Dropdown(choices=[], label="Select Metric", visible=False)
274
-
275
  vis_method_selector = gr.CheckboxGroup(
276
  choices=method_names,
277
- label="Select methods to visualize",
278
- interactive=True,
279
  value=method_names,
 
280
  )
281
-
282
  plot_button = gr.Button("Plot")
283
-
284
  with gr.Row(show_progress=True, variant='panel'):
285
  plot_output = gr.Image(label="Plot")
286
-
287
- # textual explanation below the image
288
  plot_explanation = gr.Markdown(visible=False)
289
-
290
- # ------------------------------------------------------------------
291
- # callbacks for visualisation tab
292
- # ------------------------------------------------------------------
293
  vis_benchmark_type_selector.change(
294
  update_metric_choices,
295
  inputs=[vis_benchmark_type_selector],
@@ -301,7 +285,6 @@ with block:
301
  vis_single_metric_selector,
302
  ],
303
  )
304
-
305
  plot_button.click(
306
  generate_plot_and_explanation,
307
  inputs=[
@@ -335,53 +318,21 @@ with block:
335
  with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
336
  with gr.Row():
337
  gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
338
-
339
  with gr.Row():
340
  gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
341
-
342
  with gr.Row():
343
  with gr.Column():
344
  model_name_textbox = gr.Textbox(label="Method name")
345
  revision_name_textbox = gr.Textbox(label="Revision Method Name")
346
-
347
- benchmark_types = gr.CheckboxGroup(
348
- choices=TASK_INFO,
349
- label="Benchmark Types",
350
- interactive=True,
351
- )
352
- similarity_tasks = gr.CheckboxGroup(
353
- choices=similarity_tasks_options,
354
- label="Similarity Tasks",
355
- interactive=True,
356
- )
357
-
358
- function_prediction_aspect = gr.Radio(
359
- choices=function_prediction_aspect_options,
360
- label="Function Prediction Aspects",
361
- interactive=True,
362
- )
363
-
364
- family_prediction_dataset = gr.CheckboxGroup(
365
- choices=family_prediction_dataset_options,
366
- label="Family Prediction Datasets",
367
- interactive=True,
368
- )
369
-
370
- function_dataset = gr.Textbox(
371
- label="Function Prediction Datasets",
372
- visible=False,
373
- value="All_Data_Sets",
374
- )
375
-
376
- save_checkbox = gr.Checkbox(
377
- label="Save results for leaderboard and visualization",
378
- value=True,
379
- )
380
-
381
  with gr.Row():
382
  human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
383
  skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')
384
-
385
  submit_button = gr.Button("Submit Eval")
386
  submission_result = gr.Markdown()
387
  submit_button.click(
@@ -400,9 +351,7 @@ with block:
400
  ],
401
  )
402
 
403
- # ----------------------------------------------------------------------
404
- # global refresh button & citation accordion
405
- # ----------------------------------------------------------------------
406
  with gr.Row():
407
  data_run = gr.Button("Refresh")
408
  data_run.click(refresh_data, outputs=[data_component])
@@ -415,5 +364,5 @@ with block:
415
  show_copy_button=True,
416
  )
417
 
418
- # -----------------------------------------------------------------------------
419
  block.launch()
 
21
  from src.bin.PROBE import run_probe
22
 
23
  # ------------------------------------------------------------------
24
+ # Helper functions --------------------------------------------------
25
  # ------------------------------------------------------------------
26
 
27
  def add_new_eval(
 
40
  if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
41
  gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
42
  return -1
 
43
  if 'affinity' in benchmark_types and skempi_file is None:
44
  gr.Warning("SKEMPI representations are required for affinity benchmark!")
45
  return -1
 
76
  """Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
77
  api.restart_space(repo_id=repo_id)
78
  benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
 
79
  for benchmark_type in benchmark_types:
80
  path = f"/tmp/{benchmark_type}_results.csv"
81
  if os.path.exists(path):
82
  os.remove(path)
 
83
  benchmark_types.remove("leaderboard")
84
  download_from_hub(benchmark_types)
85
 
86
 
87
+ # ------- Leaderboard helpers -----------------------------------------------
88
 
89
  def update_metrics(selected_benchmarks):
 
90
  updated_metrics = set()
91
  for benchmark in selected_benchmarks:
92
  updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
 
94
 
95
 
96
  def update_leaderboard(selected_methods, selected_metrics):
97
+ return get_baseline_df(selected_methods, selected_metrics)
 
98
 
99
+ # ------- Visualisation helpers ---------------------------------------------
100
 
101
  def get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric):
 
102
  if benchmark_type == "similarity":
103
  return (
104
+ f"Scatter plot compares models on **{x_metric}** (x‑axis) and **{y_metric}** (y‑axis). "
105
+ "Upperright points indicate jointly strong performance."
 
106
  )
107
+ if benchmark_type == "function":
108
  return (
109
+ f"Heat‑map shows model scores for **{aspect.upper()}** terms with **{single_metric}**. "
110
+ "Darker squares → better predictions."
 
 
111
  )
112
+ if benchmark_type == "family":
113
  return (
114
+ f"Box‑plots summarise cross‑fold MCC on **{dataset}**; higher medians are better."
 
 
115
  )
116
+ if benchmark_type == "affinity":
117
  return (
118
+ f"Boxplots display distribution of **{single_metric}** scores for affinity prediction; higher values are better."
 
119
  )
120
  return ""
121
 
122
 
123
+ def generate_plot_and_explanation(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric):
 
 
 
 
 
 
 
 
 
124
  plot_path = benchmark_plot(
125
  benchmark_type,
126
  methods_selected,
 
133
  explanation = get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric)
134
  return plot_path, explanation
135
 
136
+ # ---------------------------------------------------------------------------
137
+ # Custom CSS for frozen first column and clearer table styles
138
+ # ---------------------------------------------------------------------------
139
+ CUSTOM_CSS = """
140
+ /* freeze first column */
141
+ #leaderboard-table thead th:first-child,
142
+ #leaderboard-table tbody td:first-child {
143
+ position: sticky;
144
+ left: 0;
145
+ background: white;
146
+ z-index: 2;
147
+ }
148
+
149
+ /* striped rows for readability */
150
+ #leaderboard-table tbody tr:nth-child(odd) {
151
+ background: #fafafa;
152
+ }
153
+
154
+ /* centre numeric cells */
155
+ #leaderboard-table td:not(:first-child) {
156
+ text-align: center;
157
+ }
158
+ """
159
+
160
+ # ---------------------------------------------------------------------------
161
  # UI definition
162
+ # ---------------------------------------------------------------------------
163
+ block = gr.Blocks(css=CUSTOM_CSS)
164
 
165
  with block:
166
  gr.Markdown(LEADERBOARD_INTRODUCTION)
 
170
  # 1️⃣ Leaderboard tab
171
  # ------------------------------------------------------------------
172
  with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
173
+ # small workflow figure at top
174
+ gr.Image(
175
+ value="./src/data/PROBE_workflow_figure.jpg",
176
+ show_label=False,
177
+ height=150,
178
+ container=False,
179
+ )
180
 
181
+ leaderboard = get_baseline_df(None, None)
182
  method_names = leaderboard['Method'].unique().tolist()
183
+ metric_names = leaderboard.columns.tolist(); metric_names.remove('Method')
 
184
 
185
  benchmark_metric_mapping = {
186
  "similarity": [m for m in metric_names if m.startswith('sim_')],
187
+ "function": [m for m in metric_names if m.startswith('func')],
188
+ "family": [m for m in metric_names if m.startswith('fam_')],
189
+ "affinity": [m for m in metric_names if m.startswith('aff_')],
190
  }
191
 
 
192
  leaderboard_method_selector = gr.CheckboxGroup(
193
  choices=method_names,
194
+ label="Select Methods",
195
  value=method_names,
196
  interactive=True,
197
  )
 
205
 
206
  leaderboard_metric_selector = gr.CheckboxGroup(
207
  choices=metric_names,
208
+ label="Select Metrics",
209
  value=None,
210
  interactive=True,
211
  )
212
 
 
213
  baseline_value = get_baseline_df(method_names, metric_names)
214
  baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
215
+ baseline_header = ["Method"] + metric_names
216
  baseline_datatype = ['markdown'] + ['number'] * len(metric_names)
217
 
218
  with gr.Row(show_progress=True, variant='panel'):
 
222
  type="pandas",
223
  datatype=baseline_datatype,
224
  interactive=False,
225
+ elem_id="leaderboard-table",
226
+ height=600, # make table longer
227
  )
228
 
229
+ # callbacks
230
  leaderboard_method_selector.change(
231
  get_baseline_df,
232
  inputs=[leaderboard_method_selector, leaderboard_metric_selector],
233
  outputs=data_component,
234
  )
 
235
  benchmark_type_selector_lb.change(
236
  lambda selected: update_metrics(selected),
237
  inputs=[benchmark_type_selector_lb],
238
  outputs=leaderboard_metric_selector,
239
  )
 
240
  leaderboard_metric_selector.change(
241
  get_baseline_df,
242
  inputs=[leaderboard_method_selector, leaderboard_metric_selector],
 
244
  )
245
 
246
  # ------------------------------------------------------------------
247
+ # 2️⃣ Visualisation tab
248
  # ------------------------------------------------------------------
249
+ with gr.TabItem("📊 Visualizations", elem_id="probe-benchmark-tab-visualization", id=2):
 
250
  gr.Markdown(
251
+ """## **Interactive Visualizations**
252
+ Choose a benchmark type; context‑specific options will appear. Click **Plot** and an explanation will follow the figure.""",
 
 
 
 
 
 
 
 
253
  elem_classes="markdown-text",
254
  )
 
 
 
 
255
  vis_benchmark_type_selector = gr.Dropdown(
256
  choices=list(benchmark_specific_metrics.keys()),
257
+ label="Benchmark Type",
258
  value=None,
259
  )
 
260
  with gr.Row():
261
+ vis_x_metric_selector = gr.Dropdown(choices=[], label="X‑axis Metric", visible=False)
262
+ vis_y_metric_selector = gr.Dropdown(choices=[], label="Y‑axis Metric", visible=False)
263
+ vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False)
264
+ vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False)
265
+ vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False)
 
266
  vis_method_selector = gr.CheckboxGroup(
267
  choices=method_names,
268
+ label="Methods",
 
269
  value=method_names,
270
+ interactive=True,
271
  )
 
272
  plot_button = gr.Button("Plot")
 
273
  with gr.Row(show_progress=True, variant='panel'):
274
  plot_output = gr.Image(label="Plot")
 
 
275
  plot_explanation = gr.Markdown(visible=False)
276
+ # callbacks
 
 
 
277
  vis_benchmark_type_selector.change(
278
  update_metric_choices,
279
  inputs=[vis_benchmark_type_selector],
 
285
  vis_single_metric_selector,
286
  ],
287
  )
 
288
  plot_button.click(
289
  generate_plot_and_explanation,
290
  inputs=[
 
318
  with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
319
  with gr.Row():
320
  gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
 
321
  with gr.Row():
322
  gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
 
323
  with gr.Row():
324
  with gr.Column():
325
  model_name_textbox = gr.Textbox(label="Method name")
326
  revision_name_textbox = gr.Textbox(label="Revision Method Name")
327
+ benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True)
328
+ similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Similarity Tasks", interactive=True)
329
+ function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Function Prediction Aspects", interactive=True)
330
+ family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Family Prediction Datasets", interactive=True)
331
+ function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets")
332
+ save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  with gr.Row():
334
  human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
335
  skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')
 
336
  submit_button = gr.Button("Submit Eval")
337
  submission_result = gr.Markdown()
338
  submit_button.click(
 
351
  ],
352
  )
353
 
354
+ # global refresh + citation ---------------------------------------------
 
 
355
  with gr.Row():
356
  data_run = gr.Button("Refresh")
357
  data_run.click(refresh_data, outputs=[data_component])
 
364
  show_copy_button=True,
365
  )
366
 
367
+ # ---------------------------------------------------------------------------
368
  block.launch()