dimbyTa commited on
Commit
864cb6d
1 Parent(s): a5b8391

adding filtering by dtype and updating dataset

Browse files
Files changed (1) hide show
  1. src/display.py +8 -2
src/display.py CHANGED
@@ -24,7 +24,9 @@ def display_app():
24
  categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"]
25
 
26
  st.markdown("## Leaderboard")
27
- sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7)
 
 
28
  number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
29
  ascending = True
30
 
@@ -74,6 +76,8 @@ def display_app():
74
 
75
  valid_categories = validate_categories(ordering_metrics)
76
  dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
 
 
77
  dataframe_display = dataframe.copy()
78
  dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
79
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
@@ -173,6 +177,7 @@ def display_app():
173
  round(model_list[i]["HellaSwag"],2),
174
  round(model_list[i]["MMLU"],2)
175
  ))
 
176
  elif len(model_list) == 1:
177
  st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"]))
178
  st.markdown("**Results:**")
@@ -192,7 +197,8 @@ def display_app():
192
  round(model_list[0]["HellaSwag"],2),
193
  round(model_list[0]["MMLU"],2)
194
  ))
 
195
  st.markdown("For more details, hover over the radar chart.")
196
  else:
197
  st.markdown("**Model name:** %s" % model_name)
198
- st.markdown("For more details, select the model.")
 
24
  categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"]
25
 
26
  st.markdown("## Leaderboard")
27
+ sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns.difference(["model_dtype"])), index = 1)
28
+ d_type_options = ["all", "torch.bfloat16", "torch.float16", "4bit", "8bit"]
29
+ d_type = st.radio(label = "Filter by dtype", options = d_type_options, index = 0, horizontal = True)
30
  number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
31
  ascending = True
32
 
 
76
 
77
  valid_categories = validate_categories(ordering_metrics)
78
  dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
79
+ if d_type != "all":
80
+ dataframe = dataframe[dataframe["model_dtype"] == d_type]
81
  dataframe_display = dataframe.copy()
82
  dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
83
  dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
 
177
  round(model_list[i]["HellaSwag"],2),
178
  round(model_list[i]["MMLU"],2)
179
  ))
180
+ st.markdown("**dtype:** %s" % model_list[i]["model_dtype"])
181
  elif len(model_list) == 1:
182
  st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"]))
183
  st.markdown("**Results:**")
 
197
  round(model_list[0]["HellaSwag"],2),
198
  round(model_list[0]["MMLU"],2)
199
  ))
200
+ st.markdown("**dtype:** %s" % model_list[0]["model_dtype"])
201
  st.markdown("For more details, hover over the radar chart.")
202
  else:
203
  st.markdown("**Model name:** %s" % model_name)
204
+ st.markdown("For more details, select the first model in the list/leaderboard.")