kyleledbetter commited on
Commit
a5ba058
1 Parent(s): 6784da7

feat(app): gpt, dashboard, and dark mode

Browse files
Files changed (1) hide show
  1. app.py +168 -57
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import requests
3
  import json
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering
5
 
6
  from datasets import load_dataset
7
  import datasets
@@ -70,6 +70,7 @@ def generate_label_map(dataset):
70
  label_map = {i: label for i, label in enumerate(set(dataset["label"]))}
71
  return label_map
72
 
 
73
  def calculate_fairness_score(results, label_map):
74
  true_labels = [r[1] for r in results]
75
  pred_labels = [r[2] for r in results]
@@ -88,7 +89,7 @@ def calculate_fairness_score(results, label_map):
88
  cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names))
89
  group_cms[group] = cm
90
 
91
- # Calculate fairness score
92
  score = 0
93
  for i, group1 in enumerate(group_names):
94
  for j, group2 in enumerate(group_names):
@@ -100,6 +101,7 @@ def calculate_fairness_score(results, label_map):
100
 
101
  return accuracy, score
102
 
 
103
  def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'):
104
  unique_labels = sorted(label_map.values())
105
  metrics = []
@@ -119,12 +121,31 @@ def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='acc
119
 
120
  return metrics
121
 
122
- def generate_visualization(visualization_type, results, label_map):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  true_labels = [r[1] for r in results]
124
  pred_labels = [r[2] for r in results]
 
 
 
125
 
126
  if visualization_type == "confusion_matrix":
127
- return generate_report_card(results, label_map)["fig"]
128
  elif visualization_type == "per_class_accuracy":
129
  per_class_accuracy = calculate_per_class_metrics(
130
  true_labels, pred_labels, label_map, metric='accuracy')
@@ -139,8 +160,17 @@ def generate_visualization(visualization_type, results, label_map):
139
  marker_color=colors[i % len(colors)]
140
  ))
141
 
142
- fig.update_layout(title='Per-Class Accuracy',
143
- xaxis_title='Class', yaxis_title='Accuracy')
 
 
 
 
 
 
 
 
 
144
  return fig
145
  elif visualization_type == "per_class_f1":
146
  per_class_f1 = calculate_per_class_metrics(
@@ -156,35 +186,107 @@ def generate_visualization(visualization_type, results, label_map):
156
  marker_color=colors[i % len(colors)]
157
  ))
158
 
159
- fig.update_layout(title='Per-Class F1-Score',
160
- xaxis_title='Class', yaxis_title='F1-Score')
 
 
 
 
 
 
 
 
161
  return fig
 
 
162
  else:
163
  raise ValueError(f"Invalid visualization type: {visualization_type}")
164
 
165
-
166
-
167
- def generate_report_card(results, label_map):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  true_labels = [r[1] for r in results]
169
  pred_labels = [r[2] for r in results]
170
-
171
- cm = confusion_matrix(true_labels, pred_labels,
172
- labels=list(label_map.values()))
173
-
174
- # Create the plotly figure
175
- fig = make_subplots(rows=1, cols=1)
176
- fig.add_trace(go.Heatmap(
177
- z=cm,
178
- x=list(label_map.values()),
179
- y=list(label_map.values()),
180
- colorscale='RdYlGn',
181
- colorbar=dict(title='# of Samples')
182
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  fig.update_layout(
 
 
 
184
  height=500, width=600,
185
  title='Confusion Matrix',
186
  xaxis=dict(title='Predicted Labels'),
187
- yaxis=dict(title='True Labels', autorange='reversed')
188
  )
189
 
190
  # Create the text output
@@ -197,31 +299,6 @@ def generate_report_card(results, label_map):
197
  per_class_f1 = calculate_per_class_metrics(
198
  true_labels, pred_labels, label_map, metric='f1')
199
 
200
-
201
- text_output = html.Div(children=[
202
- html.H2('Performance Metrics'),
203
- html.Div(children=[
204
- html.Div(children=[
205
- html.H3('Accuracy'),
206
- html.H4(f'{accuracy}')
207
- ], className='metric'),
208
- html.Div(children=[
209
- html.H3('Fairness Score'),
210
- # html.H4(f'{fairness_score}')
211
- html.H4(
212
- f'Accuracy: {fairness_score[0]:.2f}, Score: {fairness_score[1]:.2f}')
213
- ], className='metric'),
214
- ], className='metric-container'),
215
- ], className='text-output')
216
-
217
- # Combine the plot and text output into a Dash container
218
- # report_card = html.Div([
219
- # dcc.Graph(figure=fig),
220
- # text_output,
221
- # ])
222
-
223
- # return report_card
224
-
225
  report_card = {
226
  "fig": fig,
227
  "accuracy": accuracy,
@@ -232,9 +309,26 @@ def generate_report_card(results, label_map):
232
  return report_card
233
 
234
  # return fig, text_output
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
 
237
- def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str):
 
 
 
 
 
238
  tokenizer, model = load_model(
239
  model_type, model_name_or_path, dataset_name, config_name)
240
 
@@ -277,17 +371,33 @@ def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name
277
 
278
  # return fig, text_output
279
 
280
- report_card = generate_report_card(results, label_map)
281
- visualization = generate_visualization(visualization_type, results, label_map)
282
 
283
  per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip(
284
  label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])])
285
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  # return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}"
288
  # return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"]
289
- return (f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}\n\n"
290
- f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization
 
291
 
292
  interface = gr.Interface(
293
  fn=app,
@@ -304,8 +414,9 @@ interface = gr.Interface(
304
  choices=["train", "validation", "test"], label="Dataset Split", default="validation"),
305
  gr.inputs.Number(default=100, label="Number of Samples"),
306
  gr.inputs.Dropdown(
307
- choices=["confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="confusion_matrix"
308
  ),
 
309
  ],
310
  # outputs=gr.Plot(),
311
  # outputs=gr.outputs.HTML(),
 
1
  import gradio as gr
2
  import requests
3
  import json
4
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering
5
 
6
  from datasets import load_dataset
7
  import datasets
 
70
  label_map = {i: label for i, label in enumerate(set(dataset["label"]))}
71
  return label_map
72
 
73
+ # Explain fairness score: https://arxiv.org/pdf/1908.09635.pdf
74
  def calculate_fairness_score(results, label_map):
75
  true_labels = [r[1] for r in results]
76
  pred_labels = [r[2] for r in results]
 
89
  cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names))
90
  group_cms[group] = cm
91
 
92
+ # Calculate fairness score which means the average difference between confusion matrices
93
  score = 0
94
  for i, group1 in enumerate(group_names):
95
  for j, group2 in enumerate(group_names):
 
101
 
102
  return accuracy, score
103
 
104
+ # Per-class metrics means the metrics for each class, and the class is defined by the label_map
105
  def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'):
106
  unique_labels = sorted(label_map.values())
107
  metrics = []
 
121
 
122
  return metrics
123
 
124
+ def generate_fairness_statement(accuracy, fairness_score):
125
+ accuracy_level = "high" if accuracy >= 0.85 else "moderate" if accuracy >= 0.7 else "low"
126
+ fairness_level = "low" if fairness_score <= 0.15 else "moderate" if fairness_score <= 0.3 else "high"
127
+
128
+ # statement = f"The model has a {accuracy_level} overall accuracy of {accuracy * 100:.2f}% and a {fairness_level} fairness score of {fairness_score:.2f}. "
129
+ statement = f"Assessment: "
130
+
131
+ if fairness_level == "low":
132
+ statement += f"The low fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) indicate that the model is relatively fair and does not exhibit significant bias across different groups."
133
+ elif fairness_level == "moderate":
134
+ statement += f"The moderate fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) suggest that the model may have some bias across different groups, and further investigation is needed to ensure it does not disproportionately affect certain groups."
135
+ else:
136
+ statement += f"The high fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) indicate that the model exhibits significant bias across different groups, and it's recommended to address this issue to ensure fair predictions for all groups."
137
+
138
+ return statement
139
+
140
+ def generate_visualization(visualization_type, results, label_map, chart_mode):
141
  true_labels = [r[1] for r in results]
142
  pred_labels = [r[2] for r in results]
143
+
144
+ background_color = "white" if chart_mode == "Light" else "black"
145
+ text_color = "black" if chart_mode == "Light" else "white"
146
 
147
  if visualization_type == "confusion_matrix":
148
+ return generate_report_card(results, label_map, chart_mode)["fig"]
149
  elif visualization_type == "per_class_accuracy":
150
  per_class_accuracy = calculate_per_class_metrics(
151
  true_labels, pred_labels, label_map, metric='accuracy')
 
160
  marker_color=colors[i % len(colors)]
161
  ))
162
 
163
+ fig.update_xaxes(showgrid=True, gridwidth=1,
164
+ gridcolor='LightGray', linecolor='black', linewidth=1)
165
+ fig.update_yaxes(showgrid=True, gridwidth=1,
166
+ gridcolor='LightGray', linecolor='black', linewidth=1)
167
+ fig.update_layout(plot_bgcolor=background_color,
168
+ paper_bgcolor=background_color,
169
+ font=dict(color=text_color),
170
+ title='Per-Class Accuracy',
171
+ xaxis_title='Class', yaxis_title='Accuracy'
172
+
173
+ )
174
  return fig
175
  elif visualization_type == "per_class_f1":
176
  per_class_f1 = calculate_per_class_metrics(
 
186
  marker_color=colors[i % len(colors)]
187
  ))
188
 
189
+ fig.update_xaxes(showgrid=True, gridwidth=1,
190
+ gridcolor='LightGray', linecolor='black', linewidth=1)
191
+ fig.update_yaxes(showgrid=True, gridwidth=1,
192
+ gridcolor='LightGray', linecolor='black', linewidth=1)
193
+ fig.update_layout(plot_bgcolor=background_color,
194
+ paper_bgcolor=background_color,
195
+ font=dict(color=text_color),
196
+ title='Per-Class F1-Score',
197
+ xaxis_title='Class', yaxis_title='F1-Score'
198
+ )
199
  return fig
200
+ elif visualization_type == "interactive_dashboard":
201
+ return generate_interactive_dashboard(results, label_map, chart_mode)
202
  else:
203
  raise ValueError(f"Invalid visualization type: {visualization_type}")
204
 
205
+ def generate_interactive_dashboard(results, label_map, chart_mode):
206
+ true_labels = [r[1] for r in results]
207
+ pred_labels = [r[2] for r in results]
208
+
209
+ colors = ['#EF553B', '#00CC96', '#636EFA', '#AB63FA', '#FFA15A',
210
+ '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
211
+
212
+ background_color = "white" if chart_mode == "Light" else "black"
213
+ text_color = "black" if chart_mode == "Light" else "white"
214
+
215
+ # Create confusion matrix
216
+ cm_fig = generate_report_card(results, label_map, chart_mode)["fig"]
217
+
218
+ # Create per-class accuracy bar chart
219
+ pca_data = calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy')
220
+ pca_fig = go.Bar(x=list(label_map.values()), y=pca_data, marker=dict(color=colors))
221
+
222
+ # Create per-class F1-score bar chart
223
+ pcf_data = calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='f1')
224
+ pcf_fig = go.Bar(x=list(label_map.values()), y=pcf_data, marker=dict(color=colors))
225
+
226
+ # Combine all charts into a mixed subplot
227
+ fig = make_subplots(rows=2, cols=2, shared_xaxes=True, specs=[[{"colspan": 2}, None],
228
+ [{}, {}]],
229
+ print_grid=True,subplot_titles=(
230
+ "Confusion Matrix", "Per-Class Accuracy", "Per-Class F1-Score"))
231
+ fig.add_trace(cm_fig['data'][0], row=1, col=1)
232
+ fig.add_trace(pca_fig, row=2, col=1)
233
+ fig.add_trace(pcf_fig, row=2, col=2)
234
+
235
+ fig.update_xaxes(showgrid=True, gridwidth=1,
236
+ gridcolor='LightGray', linecolor='black', linewidth=1)
237
+ fig.update_yaxes(showgrid=True, gridwidth=1,
238
+ gridcolor='LightGray', linecolor='black', linewidth=1)
239
+ # Update layout
240
+ fig.update_layout(height=700, width=650,
241
+ plot_bgcolor=background_color,
242
+ paper_bgcolor=background_color,
243
+ font=dict(color=text_color),
244
+ title="Fairness Report", showlegend=False
245
+ )
246
+
247
+ return fig
248
+
249
+ def generate_report_card(results, label_map, chart_mode):
250
  true_labels = [r[1] for r in results]
251
  pred_labels = [r[2] for r in results]
252
+
253
+ background_color = "white" if chart_mode == "Light" else "black"
254
+ text_color = "black" if chart_mode == "Light" else "white"
255
+
256
+ cm = confusion_matrix(true_labels, pred_labels)
257
+
258
+ # Normalize the confusion matrix
259
+ cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
260
+
261
+ # Create a custom color scale
262
+ custom_color_scale = np.zeros(cm_normalized.shape, dtype='str')
263
+ for i in range(cm_normalized.shape[0]):
264
+ for j in range(cm_normalized.shape[1]):
265
+ custom_color_scale[i, j] = '#EF553B' if i == j else '#00CC96'
266
+
267
+ fig = go.Figure(go.Heatmap(z=cm_normalized,
268
+ x=list(label_map.values()),
269
+ y=list(label_map.values()),
270
+ text=cm,
271
+ hovertemplate='%{text}',
272
+ colorscale=[[0, '#EF553B'], [
273
+ 1, '#00CC96']],
274
+ showscale=False,
275
+ zmin=0, zmax=1,
276
+ customdata=custom_color_scale))
277
+
278
+ fig.update_xaxes(showgrid=True, gridwidth=1,
279
+ gridcolor='LightGray', linecolor='black', linewidth=1)
280
+ fig.update_yaxes(showgrid=True, gridwidth=1,
281
+ gridcolor='LightGray', linecolor='black', linewidth=1)
282
  fig.update_layout(
283
+ plot_bgcolor=background_color,
284
+ paper_bgcolor=background_color,
285
+ font=dict(color=text_color),
286
  height=500, width=600,
287
  title='Confusion Matrix',
288
  xaxis=dict(title='Predicted Labels'),
289
+ yaxis=dict(title='True Labels')
290
  )
291
 
292
  # Create the text output
 
299
  per_class_f1 = calculate_per_class_metrics(
300
  true_labels, pred_labels, label_map, metric='f1')
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  report_card = {
303
  "fig": fig,
304
  "accuracy": accuracy,
 
309
  return report_card
310
 
311
  # return fig, text_output
312
+
313
+
314
+ def generate_insights(custom_text, model_name, dataset_name, accuracy, fairness_score, report_card, generator):
315
+ per_class_metrics = {
316
+ 'accuracy': report_card.get('per_class_accuracy', []),
317
+ 'f1': report_card.get('per_class_f1', [])
318
+ }
319
+
320
+ if not per_class_metrics['accuracy'] or not per_class_metrics['f1']:
321
+ input_text = f"{custom_text} The model {model_name} has been evaluated on the {dataset_name} dataset. It has an overall accuracy of {accuracy * 100:.2f}%. The fairness score is {fairness_score:.2f}. Per-class metrics could not be calculated. Please provide some interesting insights about the fairness and bias of the model."
322
+ else:
323
+ input_text = f"{custom_text} The model {model_name} has been evaluated on the {dataset_name} dataset. It has an overall accuracy of {accuracy * 100:.2f}%. The fairness score is {fairness_score:.2f}. The per-class metrics are: {per_class_metrics}. Please provide some interesting insights about the fairness, bias, and per-class performance."
324
 
325
 
326
+ insights = generator(input_text, max_length=600,
327
+ do_sample=True, temperature=0.7)
328
+ return insights[0]['generated_text']
329
+
330
+
331
+ def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str, chart_mode: str):
332
  tokenizer, model = load_model(
333
  model_type, model_name_or_path, dataset_name, config_name)
334
 
 
371
 
372
  # return fig, text_output
373
 
374
+ report_card = generate_report_card(results, label_map, chart_mode)
375
+ visualization = generate_visualization(visualization_type, results, label_map, chart_mode)
376
 
377
  per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip(
378
  label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])])
379
+
380
+ accuracy, fairness_score = calculate_fairness_score(results, label_map)
381
+ fairness_statement = generate_fairness_statement(accuracy, fairness_score)
382
+
383
+ # Use a GPU if available, otherwise use -1 for CPU.
384
+ generator = pipeline(
385
+ 'text-generation', model='gpt2', device=-1) # Use EleutherAI/gpt-neo-1.3B or EleutherAI/GPT-J-6B for GPT3 for distilgpt2 for GPT2
386
+ per_class_metrics = {
387
+ 'accuracy': report_card['per_class_accuracy'],
388
+ 'f1': report_card['per_class_f1']
389
+ }
390
+
391
+ custom_text = fairness_statement
392
+
393
+ insights = generate_insights(custom_text, model_name_or_path,
394
+ dataset_name, accuracy, fairness_score, report_card, generator)
395
 
396
  # return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}"
397
  # return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"]
398
+ return (f"{insights}\n\n"
399
+ f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]: .2f}\n\n"
400
+ f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization
401
 
402
  interface = gr.Interface(
403
  fn=app,
 
414
  choices=["train", "validation", "test"], label="Dataset Split", default="validation"),
415
  gr.inputs.Number(default=100, label="Number of Samples"),
416
  gr.inputs.Dropdown(
417
+ choices=["interactive_dashboard", "confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="interactive_dashboard"
418
  ),
419
+ gr.inputs.Radio(["Light", "Dark"], label="Chart Mode", default="Light"),
420
  ],
421
  # outputs=gr.Plot(),
422
  # outputs=gr.outputs.HTML(),