Spaces:
Sleeping
Sleeping
adding feature of changing the order of the metrics on the circle of the chart
Browse files- src/display.py +41 -8
- src/load_data.py +31 -1
src/display.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
-
|
2 |
-
#from src.load_data import load_dataframe, sort_by
|
3 |
-
#from src.plot import plot_radar_chart_index, plot_radar_chart_name
|
4 |
-
#from st_aggrid import GridOptionsBuilder, AgGrid
|
5 |
from st_aggrid import GridOptionsBuilder, AgGrid
|
6 |
import streamlit as st
|
7 |
-
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name
|
8 |
from .plot import plot_radar_chart_name, plot_radar_chart_rows
|
9 |
|
10 |
|
@@ -32,8 +29,34 @@ def display_app():
|
|
32 |
|
33 |
|
34 |
name = st.text_input(label = ":mag: Search by name")
|
|
|
|
|
35 |
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
|
36 |
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
len_name_input = len(name)
|
38 |
if len_name_input > 0:
|
39 |
dataframe_by_search = search_by_name(name)
|
@@ -79,12 +102,22 @@ def display_app():
|
|
79 |
|
80 |
with column2:
|
81 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
|
82 |
-
figure =
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
st.plotly_chart(figure, use_container_width=False)
|
|
|
85 |
else:
|
86 |
if len(subdata)>0:
|
87 |
-
figure =
|
|
|
|
|
|
|
|
|
|
|
88 |
st.plotly_chart(figure, use_container_width=True)
|
89 |
|
90 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
|
|
|
1 |
+
|
|
|
|
|
|
|
2 |
from st_aggrid import GridOptionsBuilder, AgGrid
|
3 |
import streamlit as st
|
4 |
+
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
|
5 |
from .plot import plot_radar_chart_name, plot_radar_chart_rows
|
6 |
|
7 |
|
|
|
29 |
|
30 |
|
31 |
name = st.text_input(label = ":mag: Search by name")
|
32 |
+
|
33 |
+
#Sidebar configurations
|
34 |
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
|
35 |
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
|
36 |
+
ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
|
37 |
+
placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
|
38 |
+
|
39 |
+
ordering_metrics = ordering_metrics.replace(" ", "")
|
40 |
+
ordering_metrics = ordering_metrics.split(",")
|
41 |
+
|
42 |
+
st.sidebar.markdown("""
|
43 |
+
As a reminder, here are the different metrics:
|
44 |
+
* ARC
|
45 |
+
* GSM8K
|
46 |
+
* TruthfulQA
|
47 |
+
* Winogrande
|
48 |
+
* HellaSwag
|
49 |
+
* MMLU
|
50 |
+
""")
|
51 |
+
st.sidebar.markdown("""
|
52 |
+
If there are **typos** in the name of the metrics, or the number of metrics
|
53 |
+
is **different of six**, there will be no effect on the chart and the
|
54 |
+
default ordering will be used.
|
55 |
+
""")
|
56 |
+
|
57 |
+
valid_categories = validate_categories(ordering_metrics)
|
58 |
+
|
59 |
+
# Search bar
|
60 |
len_name_input = len(name)
|
61 |
if len_name_input > 0:
|
62 |
dataframe_by_search = search_by_name(name)
|
|
|
102 |
|
103 |
with column2:
|
104 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
|
105 |
+
figure = None
|
106 |
+
if valid_categories:
|
107 |
+
|
108 |
+
figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3], categories = ordering_metrics)
|
109 |
+
else:
|
110 |
+
figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
|
111 |
st.plotly_chart(figure, use_container_width=False)
|
112 |
+
|
113 |
else:
|
114 |
if len(subdata)>0:
|
115 |
+
figure = None
|
116 |
+
if valid_categories:
|
117 |
+
figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name)
|
118 |
+
else:
|
119 |
+
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
|
120 |
+
|
121 |
st.plotly_chart(figure, use_container_width=True)
|
122 |
|
123 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
|
src/load_data.py
CHANGED
@@ -54,4 +54,34 @@ def search_by_name(name: str) -> pd.DataFrame:
|
|
54 |
"""
|
55 |
dataframe = load_dataframe()
|
56 |
indexes = dataframe["model_name"].str.contains(name)
|
57 |
-
return dataframe[indexes]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
"""
|
55 |
dataframe = load_dataframe()
|
56 |
indexes = dataframe["model_name"].str.contains(name)
|
57 |
+
return dataframe[indexes]
|
58 |
+
|
59 |
+
def validate_categories(categories: list) -> bool:
|
60 |
+
"""
|
61 |
+
validate a list of categories to the columns in the dataframe
|
62 |
+
Arguments:
|
63 |
+
- categories: a list of categories for the ordering of the columns in the dataframe
|
64 |
+
|
65 |
+
This expects a list with six elements that should be (not necessary in order):
|
66 |
+
- ARC
|
67 |
+
- GSM8K
|
68 |
+
- TruthfulQA
|
69 |
+
- Winogrande
|
70 |
+
- HellaSwag
|
71 |
+
- MMLU
|
72 |
+
|
73 |
+
Returns
|
74 |
+
- True if the list has the right number of element and right elements
|
75 |
+
- False otherwise
|
76 |
+
"""
|
77 |
+
valid_categories = False
|
78 |
+
if len(categories) == 6:
|
79 |
+
if ("ARC" in categories and "GSM8K" in categories and "TruthfulQA" in categories
|
80 |
+
and "Winogrande" in categories and "HellaSwag" in categories and "MMLU" in categories):
|
81 |
+
valid_categories = True
|
82 |
+
else:
|
83 |
+
valid_categories = False
|
84 |
+
else:
|
85 |
+
valid_categories = False
|
86 |
+
|
87 |
+
return valid_categories
|