magilogi commited on
Commit
5542fa4
β€’
1 Parent(s): cbf54c8

layout and adjusted score

Browse files
Files changed (1) hide show
  1. app.py +120 -54
app.py CHANGED
@@ -3,34 +3,84 @@ import gradio as gr
3
  import plotly.express as px
4
  import plotly.graph_objects as go
5
 
 
6
  explanation_data = {
7
- "Accuracy Scores (rename for clarity)": [
8
- "b4bqa",
9
- "b4b",
10
- "medmcqa_g2b",
11
- "medmcqa_orig_filtered",
12
- "medmcqa_diff",
13
- "medqa_4options_g2b",
14
- "medqa_4options_orig_filtered",
15
- "medqa_diff"
16
  ],
17
  "Description": [
18
- "Model accuracy on the [Come up with a fitting name] task.",
19
- "[How do we best explain this?]",
20
  "G2B Refers to the 'Generic' to 'Brand' name swap. This is model accuracy on MedMCQA task where generic drug names are substituted with brand names.",
21
  "Model accuracy on MedMCQA task with original data. (Only includes questions that overlap with the g2b dataset)",
22
  "Difference in MedMCQA accuracy for swapped and non-swapped datasets, highlighting the impact of G2B drug name substitution on performance.",
23
  "Model accuracy on MedQA (4 options) task where generic drug names are substituted with brand names.",
24
  "Model accuracy on MedQA (4 options) task with original data. (Only includes questions that overlap with the g2b dataset)",
25
- "Difference in MedMCQA accuracy for swapped and non-swapped datasets, highlighting the impact of G2B drug name substitution on performance."
 
26
  ]
27
  }
28
  explanation_df = pd.DataFrame(explanation_data)
29
 
 
 
 
 
 
30
  df = pd.read_csv("data/csv/models_data.csv")
31
  df['average_g2b'] = df[['medmcqa_g2b', 'medqa_4options_g2b']].mean(axis=1).round(2)
32
- df['average_orginal_acc'] = df[['medmcqa_orig_filtered', 'medqa_4options_orig_filtered']].mean(axis=1).round(2)
33
  df['average_diff'] = df[['medmcqa_diff', 'medqa_diff']].mean(axis=1).round(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  filter_mapping = {
36
  "all": "all",
@@ -69,7 +119,7 @@ def create_scatter_plot(df, x_col, y_col, title, x_title, y_title):
69
  return fig
70
 
71
  def create_lm_plot(df, x_col, y_col, title, x_title, y_title):
72
- fig = px.scatter(df, x=x_col, y=y_col, color='Model', title=title, color_discrete_sequence=px.colors.sequential.solar, trendline='ols')
73
 
74
  fig.update_layout(
75
  xaxis_title=x_title,
@@ -80,6 +130,7 @@ def create_lm_plot(df, x_col, y_col, title, x_title, y_title):
80
  return fig
81
 
82
  def create_bar_plot(df, col, title):
 
83
  sorted_df = df.sort_values(by=col, ascending=True)
84
  fig = px.bar(sorted_df,
85
  x=col,
@@ -87,15 +138,32 @@ def create_bar_plot(df, col, title):
87
  orientation='h',
88
  title=title,
89
  color=col,
90
- color_continuous_scale='solar')
91
  fig.update_layout(xaxis_title=col, yaxis_title='Model', height=600, coloraxis_showscale=False)
 
92
  return fig
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  with gr.Blocks(css="custom.css") as demo:
95
  with gr.Column():
96
  gr.Markdown(
97
- """<div style="text-align: center;"><h1> <span style='color: #E6B800;'>🐰 RABBITS</span>: <span style='color: #E6B800;'>R</span>obust <span style='color: #E6B800;'>A</span>ssessment of <span style='color: #E6B800;'>B</span>iomedical <span style='color: #E6B800;'>B</span>enchmarks <span style='color: #E6B800;'>I</span>nvolving drug
98
- <span style='color: #E6B800;'>T</span>erm <span style='color: #E6B800;'>S</span>ubstitutions for Language Models <span style='color: #E6B800;'></span></h1></div>"""
99
  )
100
  with gr.Row():
101
  gr.Markdown(""" """)
@@ -107,20 +175,17 @@ with gr.Blocks(css="custom.css") as demo:
107
  )
108
  with gr.Row():
109
  gr.Markdown(""" """)
110
- with gr.Row():
111
- gr.Image(value="workflow-1-2.svg", width=200, height=450)
112
- gr.Image(value="workflow-3-4.svg", width=200, height=450)
113
 
114
  with gr.Row():
115
  gr.Markdown(""" """)
116
 
117
  with gr.Row():
118
  bar1 = gr.Plot(
119
- value=create_bar_plot(df, "medmcqa_diff", "Impact of Generic2Brand swap on MedMCQA Accuracy"),
120
  elem_id="bar1"
121
  )
122
  bar2 = gr.Plot(
123
- value=create_bar_plot(df, "medqa_diff", "Impact of Generic2Brand swap on MedQA Accuracy"),
124
  elem_id="bar2"
125
  )
126
 
@@ -131,7 +196,7 @@ with gr.Blocks(css="custom.css") as demo:
131
  with gr.Row():
132
  gr.Markdown(""" """)
133
 
134
- default_visible_columns = ['T', 'Model', 'average_original_acc', 'average_g2b','average_diff']
135
 
136
  with gr.Tabs(elem_classes="tab-buttons"):
137
  with gr.TabItem("πŸ” Evaluation table"):
@@ -199,31 +264,37 @@ with gr.Blocks(css="custom.css") as demo:
199
  with gr.Column():
200
  with gr.Row():
201
  scatter1 = gr.Plot(
202
- value=create_scatter_plot(df, "medmcqa_orig_filtered", "medmcqa_g2b",
203
- "MedMCQA: Orig vs G2B", "medmcqa_orig_filtered", "medmcqa_g2b"),
204
  elem_id="scatter1"
205
  )
206
  scatter2 = gr.Plot(
207
- value=create_scatter_plot(df, "medqa_4options_orig_filtered", "medqa_4options_g2b",
208
- "MedQA: Orig vs G2B", "medqa_4options_orig_filtered", "medqa_4options_g2b"),
209
  elem_id="scatter2"
210
  )
211
- with gr.Row():
212
- scatter3 = gr.Plot(
213
- value=create_scatter_plot(df, "b4bqa", "b4b",
214
- "b4bqa vs b4b", "b4bqa", "b4b"),
215
- elem_id="scatter3"
216
- )
217
 
218
  with gr.TabItem("πŸ“ About"):
219
- gr.Markdown(
220
- """<div style="text-align: center;">
221
- <h2>About RABBITS LLM Leaderboard</h2>
222
- <p>This leaderboard ...</p>
223
- <p>It is designed to ...</p>
224
- </div>""",
225
- elem_classes="markdown-text"
226
- )
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  with gr.TabItem("πŸš€ Submit Here!"):
229
  gr.Markdown(
@@ -238,25 +309,20 @@ with gr.Blocks(css="custom.css") as demo:
238
  elem_classes="markdown-text"
239
  )
240
 
241
- with gr.Row():
242
- gr.Dataframe(
243
- value=explanation_df,
244
- headers="keys",
245
- datatype=["str", "str"],
246
- interactive=False,
247
- label="Explanation of Scores"
248
- )
249
 
250
  with gr.Row():
251
  bar3 = gr.Plot(
252
- value=create_bar_plot(df, "b4bqa", "Which LLMs are best at matching brand names to generic drug names? (Results from custom task)"),
253
  elem_id="bar3"
254
  )
255
-
256
- with gr.Row():
257
- scatter_g2b = gr.Plot(
258
- value=create_lm_plot(df, "b4bqa", "average_g2b", "Does that matching accuracy correlate with biomedical task robustness?", "b4bqa", "average_diff"),
259
  )
 
 
 
260
 
261
 
262
 
 
3
  import plotly.express as px
4
  import plotly.graph_objects as go
5
 
6
+ # Creating data for explanation df in about section
7
  explanation_data = {
8
+ "Accuracy Scores": [
9
+ "DrugMatchQA",
10
+ "MedMCQA: G2B",
11
+ "MedMCQA: Original",
12
+ "MedMCQA: Difference",
13
+ "MedQA: G2B",
14
+ "MedQA: Original",
15
+ "MedQA: Difference",
16
+ "Adjusted Robustness Score"
17
  ],
18
  "Description": [
19
+ "A custom MC task where the model is asked to match a brand name to its generic counterpart and vice versa. This task is designed to test the model's ability to understand drug name synonyms.",
 
20
  "G2B Refers to the 'Generic' to 'Brand' name swap. This is model accuracy on MedMCQA task where generic drug names are substituted with brand names.",
21
  "Model accuracy on MedMCQA task with original data. (Only includes questions that overlap with the g2b dataset)",
22
  "Difference in MedMCQA accuracy for swapped and non-swapped datasets, highlighting the impact of G2B drug name substitution on performance.",
23
  "Model accuracy on MedQA (4 options) task where generic drug names are substituted with brand names.",
24
  "Model accuracy on MedQA (4 options) task with original data. (Only includes questions that overlap with the g2b dataset)",
25
+ "Difference in MedMCQA accuracy for swapped and non-swapped datasets, highlighting the impact of G2B drug name substitution on performance.",
26
+ "A score given by Avg Difference / Avg G2B Accuracy. A higher score indicates a model that is more robust to drug name synonym substitution."
27
  ]
28
  }
29
  explanation_df = pd.DataFrame(explanation_data)
30
 
31
+
32
+
33
+
34
+ #Loading and cleaning eval data processed by json2df.py
35
+
36
  df = pd.read_csv("data/csv/models_data.csv")
37
  df['average_g2b'] = df[['medmcqa_g2b', 'medqa_4options_g2b']].mean(axis=1).round(2)
38
+ df['average_original_acc'] = df[['medmcqa_orig_filtered', 'medqa_4options_orig_filtered']].mean(axis=1).round(2)
39
  df['average_diff'] = df[['medmcqa_diff', 'medqa_diff']].mean(axis=1).round(2)
40
+ df.drop(columns=['b4b'], inplace=True)
41
+ #Rename columns for clarity
42
+
43
+ df.rename(columns={
44
+ 'medmcqa_g2b': 'MedMCQA: G2B',
45
+ 'medmcqa_orig_filtered': 'MedMCQA: Original',
46
+ 'medmcqa_diff': 'MedMCQA: Difference',
47
+ 'medqa_4options_g2b': 'MedQA: G2B',
48
+ 'medqa_4options_orig_filtered': 'MedQA: Original',
49
+ 'medqa_diff': 'MedQA: Difference',
50
+ 'b4bqa': 'DrugMatchQA',
51
+ 'average_g2b': 'Average G2B Accuracy',
52
+ 'average_original_acc': 'Average Original Accuracy',
53
+ 'average_diff': 'Average Difference'
54
+ }, inplace=True)
55
+
56
+ #Create adjusted robustness score that accounts for g2b accuracy and difference in accuracy
57
+ # (models with low difference like phi will seem robust, but its simply because they are bad / random at both tasks)
58
+ df['Average Accuracy (Original and G2B)'] = (df['Average G2B Accuracy'] + df['Average Original Accuracy']) / 2
59
+
60
+
61
+
62
+ # Introduce a penalty factor for low average accuracy
63
+ penalty_factor = 1 / (df['Average Accuracy (Original and G2B)'] ** 2)
64
+
65
+ # Calculate the adjusted robustness score with penalty
66
+ df['Adjusted Robustness Score'] = df['Average Difference'] * penalty_factor
67
+ df['Adjusted Robustness Score'] = df['Adjusted Robustness Score'].round(2)
68
+
69
+
70
+
71
+
72
+
73
+
74
+ #if acc is 0 in DrugMatchQA column, set it to none
75
+ df['DrugMatchQA'] = df['DrugMatchQA'].apply(lambda x: None if x == 0 else x)
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+ #Defining functions for filtering and plotting
84
 
85
  filter_mapping = {
86
  "all": "all",
 
119
  return fig
120
 
121
  def create_lm_plot(df, x_col, y_col, title, x_title, y_title):
122
+ fig = px.scatter(df, x=x_col, y=y_col, color='Model', title=title, trendline='ols')
123
 
124
  fig.update_layout(
125
  xaxis_title=x_title,
 
130
  return fig
131
 
132
  def create_bar_plot(df, col, title):
133
+
134
  sorted_df = df.sort_values(by=col, ascending=True)
135
  fig = px.bar(sorted_df,
136
  x=col,
 
138
  orientation='h',
139
  title=title,
140
  color=col,
141
+ color_continuous_scale='Aggrnyl')
142
  fig.update_layout(xaxis_title=col, yaxis_title='Model', height=600, coloraxis_showscale=False)
143
+ fig.update_xaxes(range=[-20, 20])
144
  return fig
145
 
146
+
147
+ def create_bar_plot_drugmatchqa(df, col, title):
148
+ clean_df = df.dropna(subset=['DrugMatchQA'])
149
+ sorted_df = clean_df.sort_values(by=col, ascending=True)
150
+ fig = px.bar(sorted_df,
151
+ x=col,
152
+ y='Model',
153
+ orientation='h',
154
+ title=title,
155
+ color=col,
156
+ color_continuous_scale='Aggrnyl')
157
+ fig.update_layout(xaxis_title=col, yaxis_title='Model', height=600, coloraxis_showscale=False)
158
+ return fig
159
+
160
+ #Create UI/Layout
161
+
162
  with gr.Blocks(css="custom.css") as demo:
163
  with gr.Column():
164
  gr.Markdown(
165
+ """<div style="text-align: center;"><h1> <span style='color: #00BF63;'>🐰 RABBITS</span>: <span style='color: #00BF63;'>R</span>obust <span style='color: #00BF63;'>A</span>ssessment of <span style='color: #00BF63;'>B</span>iomedical <span style='color: #00BF63;'>B</span>enchmarks <span style='color: #00BF63;'>I</span>nvolving drug
166
+ <span style='color: #00BF63;'>T</span>erm <span style='color: #00BF63;'>S</span>ubstitutions<span style='color: #00BF63;'></span></h1></div>"""
167
  )
168
  with gr.Row():
169
  gr.Markdown(""" """)
 
175
  )
176
  with gr.Row():
177
  gr.Markdown(""" """)
 
 
 
178
 
179
  with gr.Row():
180
  gr.Markdown(""" """)
181
 
182
  with gr.Row():
183
  bar1 = gr.Plot(
184
+ value=create_bar_plot(df, "MedMCQA: Difference", "Impact of Generic2Brand swap on MedMCQA Accuracy"),
185
  elem_id="bar1"
186
  )
187
  bar2 = gr.Plot(
188
+ value=create_bar_plot(df, "MedQA: Difference", "Impact of Generic2Brand swap on MedQA Accuracy"),
189
  elem_id="bar2"
190
  )
191
 
 
196
  with gr.Row():
197
  gr.Markdown(""" """)
198
 
199
+ #default_visible_columns = []
200
 
201
  with gr.Tabs(elem_classes="tab-buttons"):
202
  with gr.TabItem("πŸ” Evaluation table"):
 
264
  with gr.Column():
265
  with gr.Row():
266
  scatter1 = gr.Plot(
267
+ value=create_scatter_plot(df, "MedMCQA: Original", "MedMCQA: G2B",
268
+ "MedMCQA: Orig vs G2B", "MedMCQA: Original", "MedMCQA: G2B"),
269
  elem_id="scatter1"
270
  )
271
  scatter2 = gr.Plot(
272
+ value=create_scatter_plot(df, "MedQA: Original", "MedQA: G2B",
273
+ "MedQA: Orig vs G2B", "MedQA: Original", "MedQA: G2B"),
274
  elem_id="scatter2"
275
  )
 
 
 
 
 
 
276
 
277
  with gr.TabItem("πŸ“ About"):
278
+ with gr.Column():
279
+ gr.Markdown(
280
+ """<div style="text-align: center;">
281
+ <h2>About the RABBITS LLM Leaderboard</h2>
282
+ <p>The following is an overview of the framework, along with an explanation of scores in the evaluation table.</p>
283
+ </div>""",
284
+ elem_classes="markdown-text"
285
+ )
286
+ with gr.Row():
287
+ gr.Image(value="workflow-1-2.svg", width=200, height=450)
288
+ gr.Image(value="workflow-3-4.svg", width=200, height=450)
289
+ with gr.Row():
290
+ gr.Dataframe(
291
+ value=explanation_df,
292
+ headers="keys",
293
+ datatype=["str", "str"],
294
+ interactive=False,
295
+ label="Explanation of Scores"
296
+ )
297
+
298
 
299
  with gr.TabItem("πŸš€ Submit Here!"):
300
  gr.Markdown(
 
309
  elem_classes="markdown-text"
310
  )
311
 
312
+
 
 
 
 
 
 
 
313
 
314
  with gr.Row():
315
  bar3 = gr.Plot(
316
+ value=create_bar_plot_drugmatchqa(df, "DrugMatchQA", "Which LLMs are best at matching brand names to generic drug names?"),
317
  elem_id="bar3"
318
  )
319
+ bar4 = gr.Plot(
320
+ value=create_bar_plot_drugmatchqa(df, "Adjusted Robustness Score", "Which LLMs are most robust to drug name synonym substitution?"),
321
+ elem_id="bar4"
 
322
  )
323
+
324
+
325
+
326
 
327
 
328