dimbyTa commited on
Commit
a7f9f33
1 Parent(s): ff0c7de

adding search boxes with suggestions, and automatic sorting of models for easier plotting

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. src/display.py +98 -33
requirements.txt CHANGED
@@ -4,3 +4,4 @@ matplotlib
4
  plotly
5
  streamlit-nightly
6
  streamlit-aggrid
 
 
4
  plotly
5
  streamlit-nightly
6
  streamlit-aggrid
7
+ streamlit-searchbox
src/display.py CHANGED
@@ -1,5 +1,6 @@
1
 
2
  from st_aggrid import GridOptionsBuilder, AgGrid
 
3
  import streamlit as st
4
  from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
5
  from .plot import plot_radar_chart_name, plot_radar_chart_rows
@@ -7,14 +8,22 @@ from .plot import plot_radar_chart_name, plot_radar_chart_rows
7
 
8
  def display_app():
9
  st.markdown("# Open LLM Leaderboard Viz")
 
10
  st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
11
- st.markdown("To select a model, click on the checkbox beside its name.")
12
- st.markdown("This displays the top 100 models by default, but you can change that using the number input in the sidebar.")
 
 
 
 
 
13
  st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.")
14
  st.markdown("If your model doesn't show up, please search it by its name.")
15
 
16
  dataframe = load_dataframe()
 
17
 
 
18
  sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7)
19
  number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
20
  ascending = True
@@ -27,11 +36,20 @@ def display_app():
27
  else:
28
  ascending = False
29
 
30
-
31
- name = st.text_input(label = ":mag: Search by name")
 
 
 
 
 
 
 
 
 
32
 
33
  #Sidebar configurations
34
- selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
35
  st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
36
  ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
37
  placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
@@ -55,25 +73,9 @@ def display_app():
55
  """)
56
 
57
  valid_categories = validate_categories(ordering_metrics)
58
-
59
- # Search bar
60
- len_name_input = len(name)
61
- if len_name_input > 0:
62
- dataframe_by_search = search_by_name(name)
63
- if len(dataframe_by_search) > 0:
64
- #st.write("number of model name with name", len(dataframe_by_search))
65
- dataframe = dataframe_by_search
66
- else:
67
- dataframe = load_dataframe()
68
-
69
  dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
70
- dataframe_display = dataframe.copy()
71
-
72
- if len_name_input == 0:
73
- # Show every only top n row
74
- dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
75
-
76
-
77
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
78
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
79
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)
@@ -93,7 +95,22 @@ def display_app():
93
  height=300,
94
  width='40%'
95
  )
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  subdata = dataframe.head(1)
98
  if len(subdata) > 0:
99
  model_name = subdata["model_name"].values[0]
@@ -103,13 +120,25 @@ def display_app():
103
  with column2:
104
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
105
  figure = None
 
 
 
 
106
  if valid_categories:
107
-
108
- figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3], categories = ordering_metrics)
109
  else:
110
- figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
111
  st.plotly_chart(figure, use_container_width=False)
112
 
 
 
 
 
 
 
 
 
 
113
  else:
114
  if len(subdata)>0:
115
  figure = None
@@ -120,14 +149,50 @@ def display_app():
120
 
121
  st.plotly_chart(figure, use_container_width=True)
122
 
123
- if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
124
- n_col = len(grid_response['selected_rows']) if len(grid_response['selected_rows']) <=3 else 3
125
  st.markdown("## Models")
126
  columns = st.columns(n_col)
127
  for i in range(n_col):
128
  with columns[i]:
129
- st.markdown("**Model name:** %s" % grid_response['selected_rows'][i]["model_name"])
130
- elif grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) == 1:
131
- st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  else:
133
- st.markdown("**Model name:** %s" % model_name)
 
 
1
 
2
  from st_aggrid import GridOptionsBuilder, AgGrid
3
+ from streamlit_searchbox import st_searchbox
4
  import streamlit as st
5
  from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
6
  from .plot import plot_radar_chart_name, plot_radar_chart_rows
 
8
 
9
  def display_app():
10
  st.markdown("# Open LLM Leaderboard Viz")
11
+ st.markdown("## Some explanations")
12
  st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
13
+ st.markdown("To select a model, click on the checkbox beside its name, or search it by its name in the search boxes **Model 1, Model 2, or Model 3** bellow.")
14
+ st.markdown("You can select up to three models using the search boxes and/or the checkboxes.")
15
+ st.markdown("""In the case you use both the search boxes and the checkboxes, the search boxes will take precedence over the checkboxes,
16
+ i.e. the models searched using the search boxes will be prioritized over the ones selected using the checkboxes.
17
+ Please, search models using the search boxes first, and then use the checkboxes.
18
+ """)
19
+ st.markdown("This app displays the top 100 models by default, but you can change that using the number input in the sidebar.")
20
  st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.")
21
  st.markdown("If your model doesn't show up, please search it by its name.")
22
 
23
  dataframe = load_dataframe()
24
+ categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"]
25
 
26
+ st.markdown("## Leaderboard")
27
  sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7)
28
  number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
29
  ascending = True
 
36
  else:
37
  ascending = False
38
 
39
+ # Dynamic search boxes
40
+ def search_model(model_name: str):
41
+ model_list = None
42
+ if model_name is not None:
43
+ models = dataframe["model_name"].str.contains(model_name)
44
+ model_list = dataframe["model_name"][models]
45
+ else:
46
+ model_list = []
47
+ return model_list
48
+
49
+ model_list = []
50
 
51
  #Sidebar configurations
52
+ selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=1)
53
  st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
54
  ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
55
  placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
 
73
  """)
74
 
75
  valid_categories = validate_categories(ordering_metrics)
 
 
 
 
 
 
 
 
 
 
 
76
  dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
77
+ dataframe_display = dataframe.copy()
78
+ dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
 
 
 
 
 
79
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
80
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
81
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)
 
95
  height=300,
96
  width='40%'
97
  )
98
+ model_one = st_searchbox(label = "Model 1", search_function = search_model, key = "model_1", default= None)
99
+ model_two = st_searchbox(label = "Model 2", search_function = search_model, key = "model_2", default= None)
100
+ model_three = st_searchbox(label = "Model 3", search_function = search_model, key = "model_3", default= None)
101
+ if model_one is not None:
102
+ row = dataframe[dataframe["model_name"] == model_one]
103
+ row[categories_display] = row[categories_display]*100
104
+ model_list.append(row.to_dict("records")[0])
105
+ if model_two is not None:
106
+ row = dataframe[dataframe["model_name"] == model_two]
107
+ row[categories_display] = row[categories_display]*100
108
+ model_list.append(row.to_dict("records")[0])
109
+ if model_three is not None:
110
+ row = dataframe[dataframe["model_name"] == model_three]
111
+ row[categories_display] = row[categories_display]*100
112
+ model_list.append(row.to_dict("records")[0])
113
+
114
  subdata = dataframe.head(1)
115
  if len(subdata) > 0:
116
  model_name = subdata["model_name"].values[0]
 
120
  with column2:
121
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
122
  figure = None
123
+ model_list += grid_response['selected_rows']
124
+ model_list = model_list[:3]
125
+ model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True)
126
+
127
  if valid_categories:
128
+ figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics)
 
129
  else:
130
+ figure = plot_radar_chart_rows(rows=model_list)
131
  st.plotly_chart(figure, use_container_width=False)
132
 
133
+ elif len(model_list) > 0:
134
+ figure = None
135
+ model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True)
136
+
137
+ if valid_categories:
138
+ figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics)
139
+ else:
140
+ figure = plot_radar_chart_rows(rows=model_list)
141
+ st.plotly_chart(figure, use_container_width=False)
142
  else:
143
  if len(subdata)>0:
144
  figure = None
 
149
 
150
  st.plotly_chart(figure, use_container_width=True)
151
 
152
+ if len(model_list) > 1:
153
+ n_col = len(model_list) if len(model_list) <=3 else 3
154
  st.markdown("## Models")
155
  columns = st.columns(n_col)
156
  for i in range(n_col):
157
  with columns[i]:
158
+ st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[i]["model_name"] , model_list[i]["model_name"]))
159
+ st.markdown("**Results:**")
160
+ st.markdown("""
161
+ * Average: %s
162
+ * ARC: %s
163
+ * GSM8K: %s
164
+ * TruthfulQA: %s
165
+ * Winogrande: %s
166
+ * HellaSwag: %s
167
+ * MMLU: %s
168
+ """ % (round(model_list[i]["Average"],2),
169
+ round(model_list[i]["ARC"],2),
170
+ round(model_list[i]["GSM8K"],2),
171
+ round(model_list[i]["TruthfulQA"],2),
172
+ round(model_list[i]["Winogrande"],2),
173
+ round(model_list[i]["HellaSwag"],2),
174
+ round(model_list[i]["MMLU"],2)
175
+ ))
176
+ elif len(model_list) == 1:
177
+ st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[i]["model_name"]))
178
+ st.markdown("**Results:**")
179
+ st.markdown("""
180
+ * Average: %s
181
+ * ARC: %s
182
+ * GSM8K: %s
183
+ * TruthfulQA: %s
184
+ * Winogrande: %s
185
+ * HellaSwag: %s
186
+ * MMLU: %s
187
+ """ % (round(model_list[0]["Average"],2),
188
+ round(model_list[0]["ARC"],2),
189
+ round(model_list[0]["GSM8K"],2),
190
+ round(model_list[0]["TruthfulQA"],2),
191
+ round(model_list[0]["Winogrande"],2),
192
+ round(model_list[0]["HellaSwag"],2),
193
+ round(model_list[0]["MMLU"],2)
194
+ ))
195
+ st.markdown("For more details, hover over the radar chart.")
196
  else:
197
+ st.markdown("**Model name:** %s" % model_name)
198
+ st.markdown("For more details, select the model.")