Rezky Mulia Kam commited on
Commit
6071c10
·
verified ·
1 Parent(s): 27f7905

added metrics

Browse files
_multiclass_confusion_matrix.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from sklearn.metrics import confusion_matrix
5
+ import seaborn as sns
6
+ import matplotlib
7
+ matplotlib.use('Qt5Agg')
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.model_selection import train_test_split
10
+ import numpy as np
11
+ import os
12
+ os.environ['QT_QPA_PLATFORM'] = 'xcb'
13
+
14
+ # Define label mappings
15
+ label_map = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'}
16
+ reverse_label_map = {v: k for k, v in label_map.items()} # Reverse mapping for converting labels to integers
17
+
18
+ # Load the dataset
19
+ df = pd.read_csv('./dataset/emotions.csv')
20
+
21
+ # Ensure the 'label' column exists
22
+ if 'label' not in df.columns:
23
+ print("Error: 'label' column is missing from the dataset.")
24
+ exit(1)
25
+
26
+ # Convert text labels to numeric if they're not already numeric
27
+ if df['label'].dtype == 'object':
28
+ df['label'] = df['label'].map(reverse_label_map)
29
+
30
+ # Verify label conversion
31
+ if df['label'].isnull().any():
32
+ print("Error: Some labels could not be mapped properly.")
33
+ exit(1)
34
+
35
+ # Sample a smaller subset for faster debugging
36
+ sample_size = 20000 # Adjust sample size as needed
37
+ df_sampled = df.sample(n=sample_size, random_state=42)
38
+
39
+
40
+ # Split the sampled dataset
41
+ train_texts, val_texts, train_labels, val_labels = train_test_split(
42
+ df_sampled['text'].tolist(),
43
+ df_sampled['label'].tolist(),
44
+ test_size=0.2,
45
+ random_state=42
46
+ )
47
+
48
+ model_6_path = "./models/stardust_6"
49
+ tokenizer = AutoTokenizer.from_pretrained(model_6_path)
50
+ model = AutoModelForSequenceClassification.from_pretrained(model_6_path, num_labels=6)
51
+ model.eval() # Set model to evaluation mode
52
+
53
+ # Define a function for tokenization and encoding
54
+ def tokenize_and_encode(texts, labels):
55
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
56
+ inputs['labels'] = torch.tensor(labels)
57
+ return inputs
58
+
59
+ # Create datasets with labels
60
+ train_dataset = tokenize_and_encode(train_texts, train_labels)
61
+ val_dataset = tokenize_and_encode(val_texts, val_labels)
62
+
63
+ # Move model to GPU if available
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ model.to(device)
66
+
67
+ # Move validation inputs to the device
68
+ val_inputs = {k: v.to(device) for k, v in val_dataset.items() if k != 'labels'}
69
+ val_labels = val_dataset['labels'].to(device)
70
+
71
+ def plot_classification_analysis(val_labels, val_inputs, model, label_map):
72
+ # Convert labels if they're one-hot encoded
73
+ true_labels = val_labels.argmax(dim=-1).cpu().numpy() if len(val_labels.shape) > 1 else val_labels.cpu().numpy()
74
+
75
+ with torch.no_grad():
76
+ # Get the raw logits from the model
77
+ outputs = model(**val_inputs)
78
+ logits = outputs.logits.cpu().numpy()
79
+
80
+ # Calculate softmax probabilities
81
+ probabilities = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
82
+ predictions_softmax = np.argmax(probabilities, axis=-1)
83
+
84
+ # Convert label_map to list for plotting
85
+ label_map_list = list(label_map.values())
86
+
87
+ # Create figure with two subplots
88
+ fig, axes = plt.subplots(1, 2, figsize=(20, 8))
89
+
90
+ # First subplot: Confusion Matrix
91
+ cm_softmax = confusion_matrix(true_labels, predictions_softmax)
92
+ sns.heatmap(
93
+ cm_softmax,
94
+ annot=True,
95
+ fmt="d",
96
+ cmap="Oranges",
97
+ xticklabels=label_map_list,
98
+ yticklabels=label_map_list,
99
+ ax=axes[0],
100
+ square=True
101
+ )
102
+ axes[0].set_xlabel("Prediction")
103
+ axes[0].set_ylabel("Truth")
104
+ axes[0].set_title(f"Softmax [{sample_size}]")
105
+
106
+ # Rotate x-axis labels for better readability
107
+ axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha='right')
108
+ axes[0].set_yticklabels(axes[0].get_yticklabels(), rotation=0)
109
+
110
+ # Second subplot: Raw Logits Heatmap
111
+ sample_size_r = min(sample_size, logits.shape[0]) # Show up to 50 samples
112
+ logits_subset = logits[:sample_size_r]
113
+
114
+ sns.heatmap(
115
+ logits_subset,
116
+ annot=False,
117
+ cmap="Oranges",
118
+ cbar=True,
119
+ xticklabels=label_map_list,
120
+ yticklabels=False,
121
+ ax=axes[1]
122
+ )
123
+ axes[1].set_xlabel("Classes")
124
+ axes[1].set_ylabel("Samples")
125
+ axes[1].set_title(f"Logits Distribution [{sample_size}]")
126
+
127
+ # Rotate x-axis labels for better readability
128
+ axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha='right')
129
+
130
+ # Add color bar labels
131
+ for im, title in zip(axes, ['Number of Samples', 'Logit Value']):
132
+ cbar = im.collections[0].colorbar
133
+ cbar.set_label(title)
134
+
135
+ plt.tight_layout()
136
+
137
+ # Calculate and return additional metrics
138
+ metrics = {
139
+ 'confusion_matrix': cm_softmax,
140
+ 'raw_logits_stats': {
141
+ 'mean': np.mean(logits, axis=0),
142
+ 'std': np.std(logits, axis=0),
143
+ 'min': np.min(logits, axis=0),
144
+ 'max': np.max(logits, axis=0)
145
+ }
146
+ }
147
+
148
+ return fig, metrics
149
+
150
+ fig, metrics = plot_classification_analysis(
151
+ val_labels=val_labels,
152
+ val_inputs=val_inputs,
153
+ model=model,
154
+ label_map=label_map
155
+ )
156
+
157
+ plt.show()
158
+
159
+
160
+
161
+
_multiclass_metrics.png ADDED
multiclass_f1's.png ADDED