AbstractPhil commited on
Commit
c1e8b54
·
verified ·
1 Parent(s): 3b9e20c

Create benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +263 -0
benchmark.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # PentachoraViT CIFAR-100 Evaluation
3
+ # ============================================
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from collections import defaultdict
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import matplotlib.pyplot as plt
11
+
12
+ def evaluate_pentachora_vit(model, test_loader, device='cuda'):
13
+ """Properly evaluate PentachoraViT model."""
14
+ model.eval()
15
+
16
+ # Get class names
17
+ class_names = get_cifar100_class_names()
18
+
19
+ # Check model configuration
20
+ print(f"Model Configuration:")
21
+ print(f" Internal dim: {model.dim}")
22
+ print(f" Vocab dim: {model.vocab_dim}")
23
+ print(f" Num classes: {model.num_classes}")
24
+
25
+ # Get the class crystals
26
+ if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'):
27
+ crystals = model.cls_tokens.class_pentachora # [100, 5, vocab_dim]
28
+ print(f" Crystal shape: {crystals.shape}")
29
+ else:
30
+ print(" No crystals found!")
31
+ return None
32
+
33
+ # Track metrics
34
+ all_predictions = []
35
+ all_targets = []
36
+ all_confidences = []
37
+ geometric_alignments_by_class = defaultdict(list)
38
+ aux_predictions = []
39
+
40
+ with torch.no_grad():
41
+ for images, targets in tqdm(test_loader, desc="Evaluating"):
42
+ images = images.to(device)
43
+ targets = targets.to(device)
44
+
45
+ # Get model outputs dictionary
46
+ outputs = model(images)
47
+
48
+ # Main predictions from primary head
49
+ logits = outputs['logits'] # [batch, 100]
50
+ probs = F.softmax(logits, dim=1)
51
+ confidence, predicted = torch.max(probs, 1)
52
+
53
+ # Store predictions
54
+ all_predictions.extend(predicted.cpu().numpy())
55
+ all_targets.extend(targets.cpu().numpy())
56
+ all_confidences.extend(confidence.cpu().numpy())
57
+
58
+ # Auxiliary predictions
59
+ if 'aux_logits' in outputs:
60
+ aux_probs = F.softmax(outputs['aux_logits'], dim=1)
61
+ _, aux_pred = torch.max(aux_probs, 1)
62
+ aux_predictions.extend(aux_pred.cpu().numpy())
63
+
64
+ # Geometric alignments - these show how patches align with class crystals
65
+ if 'geometric_alignments' in outputs:
66
+ # Shape: [batch, num_patches, num_classes]
67
+ geo_align = outputs['geometric_alignments']
68
+ # Average over patches to get per-sample class alignments
69
+ geo_align_mean = geo_align.mean(dim=1) # [batch, num_classes]
70
+
71
+ for i, target_class in enumerate(targets):
72
+ class_idx = target_class.item()
73
+ # Store alignment score for the true class
74
+ geometric_alignments_by_class[class_idx].append(
75
+ geo_align_mean[i, class_idx].item()
76
+ )
77
+
78
+ # Convert to numpy arrays
79
+ all_predictions = np.array(all_predictions)
80
+ all_targets = np.array(all_targets)
81
+ all_confidences = np.array(all_confidences)
82
+
83
+ # Calculate per-class metrics
84
+ class_results = []
85
+ for class_idx in range(len(class_names)):
86
+ mask = all_targets == class_idx
87
+ if mask.sum() == 0:
88
+ continue
89
+
90
+ class_preds = all_predictions[mask]
91
+ correct = (class_preds == class_idx).sum()
92
+ total = mask.sum()
93
+ accuracy = 100.0 * correct / total
94
+
95
+ # Average confidence for this class
96
+ class_conf = all_confidences[mask].mean()
97
+
98
+ # Geometric alignment for this class
99
+ geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0
100
+
101
+ # Crystal statistics
102
+ class_crystal = crystals[class_idx].detach().cpu() # [5, vocab_dim]
103
+ vertex_variance = class_crystal.var(dim=0).mean().item()
104
+
105
+ # Crystal norm (average magnitude)
106
+ crystal_norm = class_crystal.norm(dim=-1).mean().item()
107
+
108
+ class_results.append({
109
+ 'class_idx': class_idx,
110
+ 'class_name': class_names[class_idx],
111
+ 'accuracy': accuracy,
112
+ 'correct': int(correct),
113
+ 'total': int(total),
114
+ 'avg_confidence': class_conf,
115
+ 'geometric_alignment': geo_align,
116
+ 'vertex_variance': vertex_variance,
117
+ 'crystal_norm': crystal_norm
118
+ })
119
+
120
+ # Sort by accuracy
121
+ class_results.sort(key=lambda x: x['accuracy'], reverse=True)
122
+
123
+ # Overall metrics
124
+ overall_acc = 100.0 * (all_predictions == all_targets).mean()
125
+
126
+ # Auxiliary head accuracy if available
127
+ aux_acc = None
128
+ if aux_predictions:
129
+ aux_predictions = np.array(aux_predictions)
130
+ aux_acc = 100.0 * (aux_predictions == all_targets).mean()
131
+
132
+ # Print results
133
+ print(f"\n" + "="*80)
134
+ print(f"EVALUATION RESULTS")
135
+ print(f"="*80)
136
+ print(f"\nOverall Accuracy: {overall_acc:.2f}%")
137
+ if aux_acc:
138
+ print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%")
139
+
140
+ # Top 10 classes
141
+ print(f"\nTop 10 Classes:")
142
+ print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
143
+ print("-"*70)
144
+ for r in class_results[:10]:
145
+ print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
146
+ f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
147
+
148
+ # Bottom 10 classes
149
+ print(f"\nBottom 10 Classes:")
150
+ print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
151
+ print("-"*70)
152
+ for r in class_results[-10:]:
153
+ print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
154
+ f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
155
+
156
+ # Analyze correlations
157
+ accuracies = [r['accuracy'] for r in class_results]
158
+ geo_aligns = [r['geometric_alignment'] for r in class_results]
159
+ crystal_norms = [r['crystal_norm'] for r in class_results]
160
+ vertex_vars = [r['vertex_variance'] for r in class_results]
161
+
162
+ print(f"\nCorrelations with Accuracy:")
163
+ print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}")
164
+ print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}")
165
+ print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}")
166
+
167
+ # Visualizations
168
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
169
+
170
+ # 1. Accuracy distribution
171
+ ax = axes[0, 0]
172
+ ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7)
173
+ ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%')
174
+ ax.set_xlabel('Accuracy (%)')
175
+ ax.set_ylabel('Count')
176
+ ax.set_title('Per-Class Accuracy Distribution')
177
+ ax.legend()
178
+ ax.grid(True, alpha=0.3)
179
+
180
+ # 2. Accuracy vs Geometric Alignment
181
+ ax = axes[0, 1]
182
+ scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6)
183
+ ax.set_xlabel('Geometric Alignment Score')
184
+ ax.set_ylabel('Accuracy (%)')
185
+ ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)')
186
+ plt.colorbar(scatter, ax=ax)
187
+ ax.grid(True, alpha=0.3)
188
+
189
+ # 3. Crystal Analysis
190
+ ax = axes[1, 0]
191
+ ax.scatter(crystal_norms, accuracies, alpha=0.6)
192
+ ax.set_xlabel('Crystal Norm (avg magnitude)')
193
+ ax.set_ylabel('Accuracy (%)')
194
+ ax.set_title('Accuracy vs Crystal Norm')
195
+ ax.grid(True, alpha=0.3)
196
+
197
+ # 4. Top/Bottom comparison
198
+ ax = axes[1, 1]
199
+ top10_acc = [r['accuracy'] for r in class_results[:10]]
200
+ bottom10_acc = [r['accuracy'] for r in class_results[-10:]]
201
+ top10_geo = [r['geometric_alignment'] for r in class_results[:10]]
202
+ bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]]
203
+
204
+ x = np.arange(10)
205
+ width = 0.35
206
+ ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7)
207
+ ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7)
208
+ ax.set_xlabel('Rank within group')
209
+ ax.set_ylabel('Accuracy (%)')
210
+ ax.set_title('Top 10 vs Bottom 10 Classes')
211
+ ax.legend()
212
+ ax.grid(True, alpha=0.3)
213
+
214
+ plt.tight_layout()
215
+ plt.show()
216
+ # ===================================================================================
217
+ # FULL 100-CLASS DIAGNOSTIC SPECTRUM (SORTED BY CLASS IDX FOR CONSISTENCY)
218
+ # ===================================================================================
219
+ print(f"\n{'='*90}")
220
+ print("Sparky — Full Class Spectrum")
221
+ print(f"{'='*90}")
222
+ print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}")
223
+ print("-" * 90)
224
+
225
+ for r in sorted(class_results, key=lambda x: x['class_idx']):
226
+ print(f"{r['class_idx']:<5} {r['class_name']:<20} "
227
+ f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
228
+ f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} "
229
+ f"{r['vertex_variance']:>9.8f}")
230
+
231
+ return class_results, overall_acc
232
+
233
+ # Run evaluation
234
+ if 'model' in globals():
235
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
236
+ _, test_loader = get_cifar100_dataloaders(batch_size=100)
237
+
238
+ results, overall = evaluate_pentachora_vit(model, test_loader, device)
239
+
240
+ # Additional crystal analysis
241
+ print("\nCrystal Geometry Analysis:")
242
+ print("-"*50)
243
+
244
+ # Get crystals
245
+ crystals = model.cls_tokens.class_pentachora.detach().cpu()
246
+
247
+ # Compute pairwise similarities between class crystals
248
+ crystals_flat = crystals.mean(dim=1) # Average over 5 vertices
249
+ crystals_norm = F.normalize(crystals_flat, dim=1)
250
+ similarities = torch.matmul(crystals_norm, crystals_norm.T)
251
+
252
+ # Find confused pairs (high similarity, both low accuracy)
253
+ print("\nMost similar classes with poor performance:")
254
+ for i in range(100):
255
+ for j in range(i+1, 100):
256
+ if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20:
257
+ sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item()
258
+ if sim > 0.5:
259
+ print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ "
260
+ f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}")
261
+
262
+ else:
263
+ print("No model found in memory!")