Corey Morris commited on
Commit
c671de9
1 Parent(s): ed019c6

added MMLU overall average column. added a few charts comparing more moral reasoning and comparing MMLU overall to other data

Browse files
Files changed (1) hide show
  1. app.py +43 -7
app.py CHANGED
@@ -25,15 +25,16 @@ class MultiURLData:
25
  data = json.load(f)
26
  df = pd.DataFrame(data['results']).T
27
 
28
- df = df.rename(columns={'acc': model_name})
29
-
30
- df.index = df.index.str.replace('hendrycksTest-', '', regex=True)
31
 
 
 
 
 
32
  df.index = df.index.str.replace('harness\|', '', regex=True)
33
-
34
  # remove |5 from the index
35
  df.index = df.index.str.replace('\|5', '', regex=True)
36
 
 
37
  dataframes.append(df[[model_name]])
38
 
39
  data = pd.concat(dataframes, axis=1)
@@ -44,7 +45,18 @@ class MultiURLData:
44
  cols = cols[-1:] + cols[:-1]
45
  data = data[cols]
46
 
 
 
 
 
 
 
 
 
 
47
  return data
 
 
48
 
49
  def get_data(self, selected_models):
50
  filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
@@ -75,6 +87,7 @@ selected_models = st.multiselect(
75
 
76
 
77
  # Get the filtered data and display it in a table
 
78
  filtered_data = data_provider.get_data(selected_models)
79
  st.dataframe(filtered_data)
80
 
@@ -111,11 +124,34 @@ def create_plot(df, model_column, arc_column, moral_column, models=None):
111
  # models_to_plot = ['Model1', 'Model2', 'Model3']
112
  # fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
113
 
114
- fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios')
115
- st.plotly_chart(fig)
116
 
117
  fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
118
  st.plotly_chart(fig)
119
 
120
- fig = create_plot(filtered_data, 'Model Name', 'moral_disputes', 'moral_scenarios')
121
  st.plotly_chart(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  data = json.load(f)
26
  df = pd.DataFrame(data['results']).T
27
 
 
 
 
28
 
29
+ # data cleanup
30
+ df = df.rename(columns={'acc': model_name})
31
+ # Replace 'hendrycksTest-' with a more descriptive column name
32
+ df.index = df.index.str.replace('hendrycksTest-', 'MMLU_', regex=True)
33
  df.index = df.index.str.replace('harness\|', '', regex=True)
 
34
  # remove |5 from the index
35
  df.index = df.index.str.replace('\|5', '', regex=True)
36
 
37
+
38
  dataframes.append(df[[model_name]])
39
 
40
  data = pd.concat(dataframes, axis=1)
 
45
  cols = cols[-1:] + cols[:-1]
46
  data = data[cols]
47
 
48
+ # create a new column that averages the results from each of the columns with a name that start with MMLU
49
+ data['MMLU_average'] = data.filter(regex='MMLU').mean(axis=1)
50
+
51
+ # move the MMLU_average column to the the second column in the dataframe
52
+ cols = data.columns.tolist()
53
+ cols = cols[:1] + cols[-1:] + cols[1:-1]
54
+ data = data[cols]
55
+ data
56
+
57
  return data
58
+
59
+
60
 
61
  def get_data(self, selected_models):
62
  filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
 
87
 
88
 
89
  # Get the filtered data and display it in a table
90
+ st.header('Sortable table')
91
  filtered_data = data_provider.get_data(selected_models)
92
  st.dataframe(filtered_data)
93
 
 
124
  # models_to_plot = ['Model1', 'Model2', 'Model3']
125
  # fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
126
 
127
+ st.header('Overall benchmark comparison')
 
128
 
129
  fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
130
  st.plotly_chart(fig)
131
 
132
+ fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_average')
133
  st.plotly_chart(fig)
134
+
135
+ fig = create_plot(filtered_data, 'Model Name', 'hellaswag|10', 'MMLU_average')
136
+ st.plotly_chart(fig)
137
+
138
+ # Add heading to page to say Moral Scenarios
139
+ st.header('Moral Scenarios')
140
+
141
+ fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_moral_scenarios')
142
+ st.plotly_chart(fig)
143
+
144
+
145
+ fig = create_plot(filtered_data, 'Model Name', 'MMLU_moral_disputes', 'MMLU_moral_scenarios')
146
+ st.plotly_chart(fig)
147
+
148
+ fig = create_plot(filtered_data, 'Model Name', 'MMLU_average', 'MMLU_moral_scenarios')
149
+ st.plotly_chart(fig)
150
+
151
+ # create a histogram of moral scenarios
152
+ fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)
153
+ st.plotly_chart(fig)
154
+
155
+ # create a histogram of moral disputes
156
+ fig = px.histogram(filtered_data, x="MMLU_moral_disputes", marginal="rug", hover_data=filtered_data.columns)
157
+ st.plotly_chart(fig)