Spaces:
Sleeping
Sleeping
adding filtering by dtype and updating dataset
Browse files- src/display.py +8 -2
src/display.py
CHANGED
@@ -24,7 +24,9 @@ def display_app():
|
|
24 |
categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"]
|
25 |
|
26 |
st.markdown("## Leaderboard")
|
27 |
-
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index =
|
|
|
|
|
28 |
number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
|
29 |
ascending = True
|
30 |
|
@@ -74,6 +76,8 @@ def display_app():
|
|
74 |
|
75 |
valid_categories = validate_categories(ordering_metrics)
|
76 |
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
|
|
|
|
|
77 |
dataframe_display = dataframe.copy()
|
78 |
dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
|
79 |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
|
@@ -173,6 +177,7 @@ def display_app():
|
|
173 |
round(model_list[i]["HellaSwag"],2),
|
174 |
round(model_list[i]["MMLU"],2)
|
175 |
))
|
|
|
176 |
elif len(model_list) == 1:
|
177 |
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"]))
|
178 |
st.markdown("**Results:**")
|
@@ -192,7 +197,8 @@ def display_app():
|
|
192 |
round(model_list[0]["HellaSwag"],2),
|
193 |
round(model_list[0]["MMLU"],2)
|
194 |
))
|
|
|
195 |
st.markdown("For more details, hover over the radar chart.")
|
196 |
else:
|
197 |
st.markdown("**Model name:** %s" % model_name)
|
198 |
-
st.markdown("For more details, select the model.")
|
|
|
24 |
categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"]
|
25 |
|
26 |
st.markdown("## Leaderboard")
|
27 |
+
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns.difference(["model_dtype"])), index = 1)
|
28 |
+
d_type_options = ["all", "torch.bfloat16", "torch.float16", "4bit", "8bit"]
|
29 |
+
d_type = st.radio(label = "Filter by dtype", options = d_type_options, index = 0, horizontal = True)
|
30 |
number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
|
31 |
ascending = True
|
32 |
|
|
|
76 |
|
77 |
valid_categories = validate_categories(ordering_metrics)
|
78 |
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
|
79 |
+
if d_type != "all":
|
80 |
+
dataframe = dataframe[dataframe["model_dtype"] == d_type]
|
81 |
dataframe_display = dataframe.copy()
|
82 |
dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
|
83 |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
|
|
|
177 |
round(model_list[i]["HellaSwag"],2),
|
178 |
round(model_list[i]["MMLU"],2)
|
179 |
))
|
180 |
+
st.markdown("**dtype:** %s" % model_list[i]["model_dtype"])
|
181 |
elif len(model_list) == 1:
|
182 |
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"]))
|
183 |
st.markdown("**Results:**")
|
|
|
197 |
round(model_list[0]["HellaSwag"],2),
|
198 |
round(model_list[0]["MMLU"],2)
|
199 |
))
|
200 |
+
st.markdown("**dtype:** %s" % model_list[0]["model_dtype"])
|
201 |
st.markdown("For more details, hover over the radar chart.")
|
202 |
else:
|
203 |
st.markdown("**Model name:** %s" % model_name)
|
204 |
+
st.markdown("For more details, select the first model in the list/leaderboard.")
|