dimbyTa commited on
Commit
ff0c7de
1 Parent(s): 174296d

adding feature of changing the order of the metrics on the circle of the chart

Browse files
Files changed (2) hide show
  1. src/display.py +41 -8
  2. src/load_data.py +31 -1
src/display.py CHANGED
@@ -1,10 +1,7 @@
1
- #import streamlit as st
2
- #from src.load_data import load_dataframe, sort_by
3
- #from src.plot import plot_radar_chart_index, plot_radar_chart_name
4
- #from st_aggrid import GridOptionsBuilder, AgGrid
5
  from st_aggrid import GridOptionsBuilder, AgGrid
6
  import streamlit as st
7
- from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name
8
  from .plot import plot_radar_chart_name, plot_radar_chart_rows
9
 
10
 
@@ -32,8 +29,34 @@ def display_app():
32
 
33
 
34
  name = st.text_input(label = ":mag: Search by name")
 
 
35
  selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
36
  st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  len_name_input = len(name)
38
  if len_name_input > 0:
39
  dataframe_by_search = search_by_name(name)
@@ -79,12 +102,22 @@ def display_app():
79
 
80
  with column2:
81
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
82
- figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
83
- #figure = plot_radar_chart_name(dataframe= dataframe, model_name=grid_response['selected_rows'][0]["model_name"])
 
 
 
 
84
  st.plotly_chart(figure, use_container_width=False)
 
85
  else:
86
  if len(subdata)>0:
87
- figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
 
 
 
 
 
88
  st.plotly_chart(figure, use_container_width=True)
89
 
90
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
 
1
+
 
 
 
2
  from st_aggrid import GridOptionsBuilder, AgGrid
3
  import streamlit as st
4
+ from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
5
  from .plot import plot_radar_chart_name, plot_radar_chart_rows
6
 
7
 
 
29
 
30
 
31
  name = st.text_input(label = ":mag: Search by name")
32
+
33
+ #Sidebar configurations
34
  selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
35
  st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
36
+ ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
37
+ placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
38
+
39
+ ordering_metrics = ordering_metrics.replace(" ", "")
40
+ ordering_metrics = ordering_metrics.split(",")
41
+
42
+ st.sidebar.markdown("""
43
+ As a reminder, here are the different metrics:
44
+ * ARC
45
+ * GSM8K
46
+ * TruthfulQA
47
+ * Winogrande
48
+ * HellaSwag
49
+ * MMLU
50
+ """)
51
+ st.sidebar.markdown("""
52
+ If there are **typos** in the name of the metrics, or the number of metrics
53
+ is **different of six**, there will be no effect on the chart and the
54
+ default ordering will be used.
55
+ """)
56
+
57
+ valid_categories = validate_categories(ordering_metrics)
58
+
59
+ # Search bar
60
  len_name_input = len(name)
61
  if len_name_input > 0:
62
  dataframe_by_search = search_by_name(name)
 
102
 
103
  with column2:
104
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
105
+ figure = None
106
+ if valid_categories:
107
+
108
+ figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3], categories = ordering_metrics)
109
+ else:
110
+ figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
111
  st.plotly_chart(figure, use_container_width=False)
112
+
113
  else:
114
  if len(subdata)>0:
115
+ figure = None
116
+ if valid_categories:
117
+ figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name)
118
+ else:
119
+ figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
120
+
121
  st.plotly_chart(figure, use_container_width=True)
122
 
123
  if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
src/load_data.py CHANGED
@@ -54,4 +54,34 @@ def search_by_name(name: str) -> pd.DataFrame:
54
  """
55
  dataframe = load_dataframe()
56
  indexes = dataframe["model_name"].str.contains(name)
57
- return dataframe[indexes]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  """
55
  dataframe = load_dataframe()
56
  indexes = dataframe["model_name"].str.contains(name)
57
+ return dataframe[indexes]
58
+
59
+ def validate_categories(categories: list) -> bool:
60
+ """
61
+ validate a list of categories to the columns in the dataframe
62
+ Arguments:
63
+ - categories: a list of categories for the ordering of the columns in the dataframe
64
+
65
+ This expects a list with six elements that should be (not necessary in order):
66
+ - ARC
67
+ - GSM8K
68
+ - TruthfulQA
69
+ - Winogrande
70
+ - HellaSwag
71
+ - MMLU
72
+
73
+ Returns
74
+ - True if the list has the right number of element and right elements
75
+ - False otherwise
76
+ """
77
+ valid_categories = False
78
+ if len(categories) == 6:
79
+ if ("ARC" in categories and "GSM8K" in categories and "TruthfulQA" in categories
80
+ and "Winogrande" in categories and "HellaSwag" in categories and "MMLU" in categories):
81
+ valid_categories = True
82
+ else:
83
+ valid_categories = False
84
+ else:
85
+ valid_categories = False
86
+
87
+ return valid_categories