kyleledbetter commited on
Commit
0ec25a0
·
1 Parent(s): 7bd4255

feat(app): support more models and datasets

Browse files
Files changed (1) hide show
  1. app.py +278 -61
app.py CHANGED
@@ -1,104 +1,321 @@
1
  import gradio as gr
2
  import requests
3
  import json
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
5
  from datasets import load_dataset
 
6
  import plotly.io as pio
7
  import plotly.graph_objects as go
8
  import plotly.express as px
 
9
  import pandas as pd
10
  from sklearn.metrics import confusion_matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def load_model(endpoint: str):
13
- tokenizer = AutoTokenizer.from_pretrained(endpoint)
14
- model = AutoModelForSequenceClassification.from_pretrained(endpoint)
15
- return tokenizer, model
16
 
17
 
18
  def test_model(tokenizer, model, test_data: list, label_map: dict):
19
- results = []
20
- for text, true_label in test_data:
21
- inputs = tokenizer(text, return_tensors="pt",
22
- truncation=True, padding=True)
23
- outputs = model(**inputs)
24
- pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
25
- results.append((text, true_label, pred_label))
26
- return results
 
27
 
28
  def generate_label_map(dataset):
29
- num_labels = len(dataset.features["label"].names)
30
- label_map = {i: label for i, label in enumerate(dataset.features["label"].names)}
 
 
 
 
 
 
 
31
  return label_map
32
 
 
 
 
33
 
34
- def generate_report_card(results, label_map):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  true_labels = [r[1] for r in results]
36
  pred_labels = [r[2] for r in results]
37
 
38
- cm = confusion_matrix(true_labels, pred_labels,
39
- labels=list(label_map.values()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- fig = go.Figure(
42
- data=go.Heatmap(
43
- z=cm,
44
- x=list(label_map.values()),
45
- y=list(label_map.values()),
46
- colorscale='Viridis',
47
- colorbar=dict(title='Number of Samples')
48
- ),
49
- layout=go.Layout(
50
- title='Confusion Matrix',
51
- xaxis=dict(title='Predicted Labels'),
52
- yaxis=dict(title='True Labels', autorange='reversed')
53
- )
54
- )
55
 
56
- fig.update_layout(height=600, width=800)
57
 
58
- # return fig in new window
59
- # fig.show() # uncomment this line to show the plot in a new window
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Convert the Plotly figure to an HTML string < i was trying this bc i couldn't get Plot() to work before
62
- # plot_html = pio.to_html(fig, full_html=True, include_plotlyjs=True, config={
63
- # "displayModeBar": False, "responsive": True})
64
- #return plot_html
65
- return fig
66
 
67
- def app(model_endpoint: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int):
68
- tokenizer, model = load_model(model_endpoint)
69
 
70
- # Load the dataset
71
- num_samples = int(num_samples) # Add this line to cast num_samples to an integer
72
- dataset = load_dataset(
73
- dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]")
74
- test_data = [(item["sentence"], dataset.features["label"].names[item["label"]])
75
- for item in dataset]
76
-
77
- label_map = generate_label_map(dataset)
 
 
 
 
 
 
 
78
 
79
- results = test_model(tokenizer, model, test_data, label_map)
80
- report_card = generate_report_card(results, label_map)
 
 
 
81
 
82
- return report_card
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  interface = gr.Interface(
85
  fn=app,
86
  inputs=[
87
- gr.inputs.Textbox(lines=1, label="Model Endpoint",
88
- placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english"),
 
 
89
  gr.inputs.Textbox(lines=1, label="Dataset Name",
90
- placeholder="ex: glue"),
91
  gr.inputs.Textbox(lines=1, label="Config Name",
92
- placeholder="ex: sst2"),
93
  gr.inputs.Dropdown(
94
- choices=["train", "validation", "test"], label="Dataset Split"),
95
  gr.inputs.Number(default=100, label="Number of Samples"),
 
 
 
96
  ],
97
- # outputs=gr.outputs.Plotly(),
98
  # outputs=gr.outputs.HTML(),
99
- outputs=gr.Plot(),
 
 
 
 
100
  title="Fairness and Bias Testing",
101
- description="Enter a model endpoint and dataset to test for fairness and bias.",
102
  )
103
 
104
  # Define the label map globally
 
1
  import gradio as gr
2
  import requests
3
  import json
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering
5
+
6
  from datasets import load_dataset
7
+ import datasets
8
  import plotly.io as pio
9
  import plotly.graph_objects as go
10
  import plotly.express as px
11
+ from plotly.subplots import make_subplots
12
  import pandas as pd
13
  from sklearn.metrics import confusion_matrix
14
+ import importlib
15
+ import torch
16
+ from dash import Dash, html, dcc
17
+ import numpy as np
18
+ from sklearn.metrics import accuracy_score
19
+ from sklearn.metrics import f1_score
20
+
21
+
22
+ def load_model(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str):
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
24
+
25
+ if model_type == "text_classification":
26
+ dataset = load_dataset(dataset_name, config_name)
27
+ num_labels = len(dataset["train"].features["label"].names)
28
+
29
+ if "roberta" in model_name_or_path.lower():
30
+ from transformers import RobertaForSequenceClassification
31
+ model = RobertaForSequenceClassification.from_pretrained(
32
+ model_name_or_path, num_labels=num_labels)
33
+ else:
34
+ model = AutoModelForSequenceClassification.from_pretrained(
35
+ model_name_or_path, num_labels=num_labels)
36
+ elif model_type == "token_classification":
37
+ dataset = load_dataset(dataset_name, config_name)
38
+ num_labels = len(
39
+ dataset["train"].features["ner_tags"].feature.names)
40
+ model = AutoModelForTokenClassification.from_pretrained(
41
+ model_name_or_path, num_labels=num_labels)
42
+ elif model_type == "question_answering":
43
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
44
+ else:
45
+ raise ValueError(f"Invalid model type: {model_type}")
46
 
47
+ return tokenizer, model
 
 
 
48
 
49
 
50
  def test_model(tokenizer, model, test_data: list, label_map: dict):
51
+ results = []
52
+ for text, _, true_label in test_data:
53
+ inputs = tokenizer(text, return_tensors="pt",
54
+ truncation=True, padding=True)
55
+ outputs = model(**inputs)
56
+ pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
57
+ results.append((text, true_label, pred_label))
58
+ return results
59
+
60
 
61
  def generate_label_map(dataset):
62
+ if "label" not in dataset.features or dataset.features["label"] is None:
63
+ return {}
64
+
65
+ if isinstance(dataset.features["label"], datasets.ClassLabel):
66
+ num_labels = dataset.features["label"].num_classes
67
+ label_map = {i: label for i, label in enumerate(dataset.features["label"].names)}
68
+ else:
69
+ num_labels = len(set(dataset["label"]))
70
+ label_map = {i: label for i, label in enumerate(set(dataset["label"]))}
71
  return label_map
72
 
73
+ def calculate_fairness_score(results, label_map):
74
+ true_labels = [r[1] for r in results]
75
+ pred_labels = [r[2] for r in results]
76
 
77
+ # Overall accuracy
78
+ # accuracy = (true_labels == pred_labels).mean()
79
+ accuracy = accuracy_score(true_labels, pred_labels)
80
+ # Calculate confusion matrix for each group
81
+ group_names = label_map.values()
82
+ group_cms = {}
83
+ for group in group_names:
84
+ true_group_indices = [i for i, label in enumerate(true_labels) if label == group]
85
+ pred_group_labels = [pred_labels[i] for i in true_group_indices]
86
+ true_group_labels = [true_labels[i] for i in true_group_indices]
87
+
88
+ cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names))
89
+ group_cms[group] = cm
90
+
91
+ # Calculate fairness score
92
+ score = 0
93
+ for i, group1 in enumerate(group_names):
94
+ for j, group2 in enumerate(group_names):
95
+ if i < j:
96
+ cm1 = group_cms[group1]
97
+ cm2 = group_cms[group2]
98
+ diff = np.abs(cm1 - cm2)
99
+ score += (diff.sum() / 2) / cm1.sum()
100
+
101
+ return accuracy, score
102
+
103
+ def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'):
104
+ unique_labels = sorted(label_map.values())
105
+ metrics = []
106
+
107
+ if metric == 'accuracy':
108
+ for label in unique_labels:
109
+ label_indices = [i for i, true_label in enumerate(true_labels) if true_label == label]
110
+ true_label_subset = [true_labels[i] for i in label_indices]
111
+ pred_label_subset = [pred_labels[i] for i in label_indices]
112
+ accuracy = accuracy_score(true_label_subset, pred_label_subset)
113
+ metrics.append(accuracy)
114
+ elif metric == 'f1':
115
+ f1_scores = f1_score(true_labels, pred_labels, labels=unique_labels, average=None)
116
+ metrics = f1_scores.tolist()
117
+ else:
118
+ raise ValueError(f"Invalid metric: {metric}")
119
+
120
+ return metrics
121
+
122
+ def generate_visualization(visualization_type, results, label_map):
123
  true_labels = [r[1] for r in results]
124
  pred_labels = [r[2] for r in results]
125
 
126
+ if visualization_type == "confusion_matrix":
127
+ return generate_report_card(results, label_map)["fig"]
128
+ elif visualization_type == "per_class_accuracy":
129
+ per_class_accuracy = calculate_per_class_metrics(
130
+ true_labels, pred_labels, label_map, metric='accuracy')
131
+
132
+ colors = px.colors.qualitative.Plotly
133
+ fig = go.Figure()
134
+ for i, label in enumerate(label_map.values()):
135
+ fig.add_trace(go.Bar(
136
+ x=[label],
137
+ y=[per_class_accuracy[i]],
138
+ name=label,
139
+ marker_color=colors[i % len(colors)]
140
+ ))
141
+
142
+ fig.update_layout(title='Per-Class Accuracy',
143
+ xaxis_title='Class', yaxis_title='Accuracy')
144
+ return fig
145
+ elif visualization_type == "per_class_f1":
146
+ per_class_f1 = calculate_per_class_metrics(
147
+ true_labels, pred_labels, label_map, metric='f1')
148
+
149
+ colors = px.colors.qualitative.Plotly
150
+ fig = go.Figure()
151
+ for i, label in enumerate(label_map.values()):
152
+ fig.add_trace(go.Bar(
153
+ x=[label],
154
+ y=[per_class_f1[i]],
155
+ name=label,
156
+ marker_color=colors[i % len(colors)]
157
+ ))
158
+
159
+ fig.update_layout(title='Per-Class F1-Score',
160
+ xaxis_title='Class', yaxis_title='F1-Score')
161
+ return fig
162
+ else:
163
+ raise ValueError(f"Invalid visualization type: {visualization_type}")
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
 
166
 
167
+ def generate_report_card(results, label_map):
168
+ true_labels = [r[1] for r in results]
169
+ pred_labels = [r[2] for r in results]
170
+
171
+ cm = confusion_matrix(true_labels, pred_labels,
172
+ labels=list(label_map.values()))
173
+
174
+ # Create the plotly figure
175
+ fig = make_subplots(rows=1, cols=1)
176
+ fig.add_trace(go.Heatmap(
177
+ z=cm,
178
+ x=list(label_map.values()),
179
+ y=list(label_map.values()),
180
+ colorscale='RdYlGn',
181
+ colorbar=dict(title='# of Samples')
182
+ ))
183
+ fig.update_layout(
184
+ height=500, width=600,
185
+ title='Confusion Matrix',
186
+ xaxis=dict(title='Predicted Labels'),
187
+ yaxis=dict(title='True Labels', autorange='reversed')
188
+ )
189
+
190
+ # Create the text output
191
+ # accuracy = pd.Series(true_labels) == pd.Series(pred_labels)
192
+ accuracy = accuracy_score(true_labels, pred_labels, normalize=False)
193
+ fairness_score = calculate_fairness_score(results, label_map)
194
 
195
+ per_class_accuracy = calculate_per_class_metrics(
196
+ true_labels, pred_labels, label_map, metric='accuracy')
197
+ per_class_f1 = calculate_per_class_metrics(
198
+ true_labels, pred_labels, label_map, metric='f1')
 
199
 
 
 
200
 
201
+ text_output = html.Div(children=[
202
+ html.H2('Performance Metrics'),
203
+ html.Div(children=[
204
+ html.Div(children=[
205
+ html.H3('Accuracy'),
206
+ html.H4(f'{accuracy}')
207
+ ], className='metric'),
208
+ html.Div(children=[
209
+ html.H3('Fairness Score'),
210
+ # html.H4(f'{fairness_score}')
211
+ html.H4(
212
+ f'Accuracy: {fairness_score[0]:.2f}, Score: {fairness_score[1]:.2f}')
213
+ ], className='metric'),
214
+ ], className='metric-container'),
215
+ ], className='text-output')
216
 
217
+ # Combine the plot and text output into a Dash container
218
+ # report_card = html.Div([
219
+ # dcc.Graph(figure=fig),
220
+ # text_output,
221
+ # ])
222
 
223
+ # return report_card
224
+
225
+ report_card = {
226
+ "fig": fig,
227
+ "accuracy": accuracy,
228
+ "fairness_score": fairness_score,
229
+ "per_class_accuracy": per_class_accuracy,
230
+ "per_class_f1": per_class_f1
231
+ }
232
+ return report_card
233
+
234
+ # return fig, text_output
235
+
236
+
237
+ def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str):
238
+ tokenizer, model = load_model(
239
+ model_type, model_name_or_path, dataset_name, config_name)
240
+
241
+ # Load the dataset
242
+ # Add this line to cast num_samples to an integer
243
+ num_samples = int(num_samples)
244
+ dataset = load_dataset(
245
+ dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]")
246
+ test_data = []
247
+
248
+ if dataset_name == "glue":
249
+ test_data = [(item["sentence"], None,
250
+ dataset.features["label"].names[item["label"]]) for item in dataset]
251
+ elif dataset_name == "tweet_eval":
252
+ test_data = [(item["text"], None, dataset.features["label"].names[item["label"]])
253
+ for item in dataset]
254
+ else:
255
+ test_data = [(item["sentence"], None,
256
+ dataset.features["label"].names[item["label"]]) for item in dataset]
257
+
258
+ # if model_type == "text_classification":
259
+ # for item in dataset:
260
+ # text = item["sentence"]
261
+ # context = None
262
+ # true_label = item["label"]
263
+ # test_data.append((text, context, true_label))
264
+ # elif model_type == "question_answering":
265
+ # for item in dataset:
266
+ # text = item["question"]
267
+ # context = item["context"]
268
+ # true_label = None
269
+ # test_data.append((text, context, true_label))
270
+ # else:
271
+ # raise ValueError(f"Invalid model type: {model_type}")
272
+
273
+ label_map = generate_label_map(dataset)
274
+
275
+ results = test_model(tokenizer, model, test_data, label_map)
276
+ # fig, text_output = generate_report_card(results, label_map)
277
+
278
+ # return fig, text_output
279
+
280
+ report_card = generate_report_card(results, label_map)
281
+ visualization = generate_visualization(visualization_type, results, label_map)
282
+
283
+ per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip(
284
+ label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])])
285
+
286
+
287
+ # return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}"
288
+ # return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"]
289
+ return (f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}\n\n"
290
+ f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization
291
 
292
  interface = gr.Interface(
293
  fn=app,
294
  inputs=[
295
+ gr.inputs.Radio(["text_classification", "token_classification",
296
+ "question_answering"], label="Model Type", default="text_classification"),
297
+ gr.inputs.Textbox(lines=1, label="Model Name or Path",
298
+ placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english", default="distilbert-base-uncased-finetuned-sst-2-english"),
299
  gr.inputs.Textbox(lines=1, label="Dataset Name",
300
+ placeholder="ex: glue", default="glue"),
301
  gr.inputs.Textbox(lines=1, label="Config Name",
302
+ placeholder="ex: sst2", default="cola"),
303
  gr.inputs.Dropdown(
304
+ choices=["train", "validation", "test"], label="Dataset Split", default="validation"),
305
  gr.inputs.Number(default=100, label="Number of Samples"),
306
+ gr.inputs.Dropdown(
307
+ choices=["confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="confusion_matrix"
308
+ ),
309
  ],
310
+ # outputs=gr.Plot(),
311
  # outputs=gr.outputs.HTML(),
312
+ # outputs=[gr.outputs.HTML(), gr.Plot()],
313
+ outputs=[
314
+ gr.outputs.Textbox(label="Fairness and Bias Metrics"),
315
+ gr.Plot(label="Graph")
316
+ ],
317
  title="Fairness and Bias Testing",
318
+ description="Enter a model and dataset to test for fairness and bias.",
319
  )
320
 
321
  # Define the label map globally