kyleledbetter commited on
Commit
438c90e
1 Parent(s): 992bd55

feat(): Initial app commit

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import plotly.graph_objects as go
5
+ import plotly.express as px
6
+ import pandas as pd
7
+ from sklearn.metrics import confusion_matrix
8
+ from datasets import load_dataset
9
+
10
+
11
+ def load_model(endpoint: str):
12
+ tokenizer = AutoTokenizer.from_pretrained(endpoint)
13
+ model = AutoModelForSequenceClassification.from_pretrained(endpoint)
14
+ return tokenizer, model
15
+
16
+
17
+ def test_model(tokenizer, model, test_data: list, label_map: dict):
18
+ results = []
19
+ for text, true_label in test_data:
20
+ inputs = tokenizer(text, return_tensors="pt",
21
+ truncation=True, padding=True)
22
+ outputs = model(**inputs)
23
+ pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
24
+ results.append((text, true_label, pred_label))
25
+ return results
26
+
27
+
28
+ def generate_report_card(results, label_map):
29
+ true_labels = [r[1] for r in results]
30
+ pred_labels = [r[2] for r in results]
31
+
32
+ cm = confusion_matrix(true_labels, pred_labels,
33
+ labels=list(label_map.values()))
34
+
35
+ fig = go.Figure(
36
+ data=go.Heatmap(
37
+ z=cm,
38
+ x=list(label_map.values()),
39
+ y=list(label_map.values()),
40
+ colorscale='Viridis',
41
+ colorbar=dict(title='Number of Samples')
42
+ ),
43
+ layout=go.Layout(
44
+ title='Confusion Matrix',
45
+ xaxis=dict(title='Predicted Labels'),
46
+ yaxis=dict(title='True Labels', autorange='reversed')
47
+ )
48
+ )
49
+
50
+ fig.show()
51
+
52
+
53
+ def load_sst2_data(split="test"):
54
+ dataset = load_dataset("glue", "sst2", split=split)
55
+ data = [(item["sentence"], "positive" if item["label"] == 1 else "negative")
56
+ for item in dataset]
57
+ return data
58
+
59
+
60
+ # Define your model endpoint and label map
61
+ # model_endpoint = "your-model-endpoint"
62
+
63
+ # Modify this according to your model's labels
64
+ # label_map = {0: "label0", 1: "label1"}
65
+
66
+ model_endpoint = "distilbert-base-uncased-finetuned-sst-2-english"
67
+ label_map = {0: "negative", 1: "positive"}
68
+
69
+ # Load the model and tokenizer
70
+ tokenizer, model = load_model(model_endpoint)
71
+
72
+ # Prepare your test data (list of tuples containing text and true label)
73
+ #test_data = [
74
+ # ("Sample text 1", "label0"),
75
+ # ("Sample text 2", "label1"),
76
+ # # Add more test samples here
77
+ #]
78
+
79
+ # Load the test data from the SST-2 dataset
80
+ test_data = load_sst2_data()
81
+ # Use a smaller subset of test_data for a quicker demonstration (optional)
82
+ test_data = test_data[:100]
83
+
84
+ # Test the model and generate results
85
+ results = test_model(tokenizer, model, test_data, label_map)
86
+
87
+ # Generate the visual report card
88
+ generate_report_card(results, label_map)