Create benchmark.py
Browse files- benchmark.py +263 -0
benchmark.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# PentachoraViT CIFAR-100 Evaluation
|
| 3 |
+
# ============================================
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
def evaluate_pentachora_vit(model, test_loader, device='cuda'):
|
| 13 |
+
"""Properly evaluate PentachoraViT model."""
|
| 14 |
+
model.eval()
|
| 15 |
+
|
| 16 |
+
# Get class names
|
| 17 |
+
class_names = get_cifar100_class_names()
|
| 18 |
+
|
| 19 |
+
# Check model configuration
|
| 20 |
+
print(f"Model Configuration:")
|
| 21 |
+
print(f" Internal dim: {model.dim}")
|
| 22 |
+
print(f" Vocab dim: {model.vocab_dim}")
|
| 23 |
+
print(f" Num classes: {model.num_classes}")
|
| 24 |
+
|
| 25 |
+
# Get the class crystals
|
| 26 |
+
if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'):
|
| 27 |
+
crystals = model.cls_tokens.class_pentachora # [100, 5, vocab_dim]
|
| 28 |
+
print(f" Crystal shape: {crystals.shape}")
|
| 29 |
+
else:
|
| 30 |
+
print(" No crystals found!")
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
# Track metrics
|
| 34 |
+
all_predictions = []
|
| 35 |
+
all_targets = []
|
| 36 |
+
all_confidences = []
|
| 37 |
+
geometric_alignments_by_class = defaultdict(list)
|
| 38 |
+
aux_predictions = []
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
for images, targets in tqdm(test_loader, desc="Evaluating"):
|
| 42 |
+
images = images.to(device)
|
| 43 |
+
targets = targets.to(device)
|
| 44 |
+
|
| 45 |
+
# Get model outputs dictionary
|
| 46 |
+
outputs = model(images)
|
| 47 |
+
|
| 48 |
+
# Main predictions from primary head
|
| 49 |
+
logits = outputs['logits'] # [batch, 100]
|
| 50 |
+
probs = F.softmax(logits, dim=1)
|
| 51 |
+
confidence, predicted = torch.max(probs, 1)
|
| 52 |
+
|
| 53 |
+
# Store predictions
|
| 54 |
+
all_predictions.extend(predicted.cpu().numpy())
|
| 55 |
+
all_targets.extend(targets.cpu().numpy())
|
| 56 |
+
all_confidences.extend(confidence.cpu().numpy())
|
| 57 |
+
|
| 58 |
+
# Auxiliary predictions
|
| 59 |
+
if 'aux_logits' in outputs:
|
| 60 |
+
aux_probs = F.softmax(outputs['aux_logits'], dim=1)
|
| 61 |
+
_, aux_pred = torch.max(aux_probs, 1)
|
| 62 |
+
aux_predictions.extend(aux_pred.cpu().numpy())
|
| 63 |
+
|
| 64 |
+
# Geometric alignments - these show how patches align with class crystals
|
| 65 |
+
if 'geometric_alignments' in outputs:
|
| 66 |
+
# Shape: [batch, num_patches, num_classes]
|
| 67 |
+
geo_align = outputs['geometric_alignments']
|
| 68 |
+
# Average over patches to get per-sample class alignments
|
| 69 |
+
geo_align_mean = geo_align.mean(dim=1) # [batch, num_classes]
|
| 70 |
+
|
| 71 |
+
for i, target_class in enumerate(targets):
|
| 72 |
+
class_idx = target_class.item()
|
| 73 |
+
# Store alignment score for the true class
|
| 74 |
+
geometric_alignments_by_class[class_idx].append(
|
| 75 |
+
geo_align_mean[i, class_idx].item()
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Convert to numpy arrays
|
| 79 |
+
all_predictions = np.array(all_predictions)
|
| 80 |
+
all_targets = np.array(all_targets)
|
| 81 |
+
all_confidences = np.array(all_confidences)
|
| 82 |
+
|
| 83 |
+
# Calculate per-class metrics
|
| 84 |
+
class_results = []
|
| 85 |
+
for class_idx in range(len(class_names)):
|
| 86 |
+
mask = all_targets == class_idx
|
| 87 |
+
if mask.sum() == 0:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
class_preds = all_predictions[mask]
|
| 91 |
+
correct = (class_preds == class_idx).sum()
|
| 92 |
+
total = mask.sum()
|
| 93 |
+
accuracy = 100.0 * correct / total
|
| 94 |
+
|
| 95 |
+
# Average confidence for this class
|
| 96 |
+
class_conf = all_confidences[mask].mean()
|
| 97 |
+
|
| 98 |
+
# Geometric alignment for this class
|
| 99 |
+
geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0
|
| 100 |
+
|
| 101 |
+
# Crystal statistics
|
| 102 |
+
class_crystal = crystals[class_idx].detach().cpu() # [5, vocab_dim]
|
| 103 |
+
vertex_variance = class_crystal.var(dim=0).mean().item()
|
| 104 |
+
|
| 105 |
+
# Crystal norm (average magnitude)
|
| 106 |
+
crystal_norm = class_crystal.norm(dim=-1).mean().item()
|
| 107 |
+
|
| 108 |
+
class_results.append({
|
| 109 |
+
'class_idx': class_idx,
|
| 110 |
+
'class_name': class_names[class_idx],
|
| 111 |
+
'accuracy': accuracy,
|
| 112 |
+
'correct': int(correct),
|
| 113 |
+
'total': int(total),
|
| 114 |
+
'avg_confidence': class_conf,
|
| 115 |
+
'geometric_alignment': geo_align,
|
| 116 |
+
'vertex_variance': vertex_variance,
|
| 117 |
+
'crystal_norm': crystal_norm
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
# Sort by accuracy
|
| 121 |
+
class_results.sort(key=lambda x: x['accuracy'], reverse=True)
|
| 122 |
+
|
| 123 |
+
# Overall metrics
|
| 124 |
+
overall_acc = 100.0 * (all_predictions == all_targets).mean()
|
| 125 |
+
|
| 126 |
+
# Auxiliary head accuracy if available
|
| 127 |
+
aux_acc = None
|
| 128 |
+
if aux_predictions:
|
| 129 |
+
aux_predictions = np.array(aux_predictions)
|
| 130 |
+
aux_acc = 100.0 * (aux_predictions == all_targets).mean()
|
| 131 |
+
|
| 132 |
+
# Print results
|
| 133 |
+
print(f"\n" + "="*80)
|
| 134 |
+
print(f"EVALUATION RESULTS")
|
| 135 |
+
print(f"="*80)
|
| 136 |
+
print(f"\nOverall Accuracy: {overall_acc:.2f}%")
|
| 137 |
+
if aux_acc:
|
| 138 |
+
print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%")
|
| 139 |
+
|
| 140 |
+
# Top 10 classes
|
| 141 |
+
print(f"\nTop 10 Classes:")
|
| 142 |
+
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
|
| 143 |
+
print("-"*70)
|
| 144 |
+
for r in class_results[:10]:
|
| 145 |
+
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
|
| 146 |
+
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
|
| 147 |
+
|
| 148 |
+
# Bottom 10 classes
|
| 149 |
+
print(f"\nBottom 10 Classes:")
|
| 150 |
+
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
|
| 151 |
+
print("-"*70)
|
| 152 |
+
for r in class_results[-10:]:
|
| 153 |
+
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
|
| 154 |
+
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
|
| 155 |
+
|
| 156 |
+
# Analyze correlations
|
| 157 |
+
accuracies = [r['accuracy'] for r in class_results]
|
| 158 |
+
geo_aligns = [r['geometric_alignment'] for r in class_results]
|
| 159 |
+
crystal_norms = [r['crystal_norm'] for r in class_results]
|
| 160 |
+
vertex_vars = [r['vertex_variance'] for r in class_results]
|
| 161 |
+
|
| 162 |
+
print(f"\nCorrelations with Accuracy:")
|
| 163 |
+
print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}")
|
| 164 |
+
print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}")
|
| 165 |
+
print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}")
|
| 166 |
+
|
| 167 |
+
# Visualizations
|
| 168 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 169 |
+
|
| 170 |
+
# 1. Accuracy distribution
|
| 171 |
+
ax = axes[0, 0]
|
| 172 |
+
ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7)
|
| 173 |
+
ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%')
|
| 174 |
+
ax.set_xlabel('Accuracy (%)')
|
| 175 |
+
ax.set_ylabel('Count')
|
| 176 |
+
ax.set_title('Per-Class Accuracy Distribution')
|
| 177 |
+
ax.legend()
|
| 178 |
+
ax.grid(True, alpha=0.3)
|
| 179 |
+
|
| 180 |
+
# 2. Accuracy vs Geometric Alignment
|
| 181 |
+
ax = axes[0, 1]
|
| 182 |
+
scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6)
|
| 183 |
+
ax.set_xlabel('Geometric Alignment Score')
|
| 184 |
+
ax.set_ylabel('Accuracy (%)')
|
| 185 |
+
ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)')
|
| 186 |
+
plt.colorbar(scatter, ax=ax)
|
| 187 |
+
ax.grid(True, alpha=0.3)
|
| 188 |
+
|
| 189 |
+
# 3. Crystal Analysis
|
| 190 |
+
ax = axes[1, 0]
|
| 191 |
+
ax.scatter(crystal_norms, accuracies, alpha=0.6)
|
| 192 |
+
ax.set_xlabel('Crystal Norm (avg magnitude)')
|
| 193 |
+
ax.set_ylabel('Accuracy (%)')
|
| 194 |
+
ax.set_title('Accuracy vs Crystal Norm')
|
| 195 |
+
ax.grid(True, alpha=0.3)
|
| 196 |
+
|
| 197 |
+
# 4. Top/Bottom comparison
|
| 198 |
+
ax = axes[1, 1]
|
| 199 |
+
top10_acc = [r['accuracy'] for r in class_results[:10]]
|
| 200 |
+
bottom10_acc = [r['accuracy'] for r in class_results[-10:]]
|
| 201 |
+
top10_geo = [r['geometric_alignment'] for r in class_results[:10]]
|
| 202 |
+
bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]]
|
| 203 |
+
|
| 204 |
+
x = np.arange(10)
|
| 205 |
+
width = 0.35
|
| 206 |
+
ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7)
|
| 207 |
+
ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7)
|
| 208 |
+
ax.set_xlabel('Rank within group')
|
| 209 |
+
ax.set_ylabel('Accuracy (%)')
|
| 210 |
+
ax.set_title('Top 10 vs Bottom 10 Classes')
|
| 211 |
+
ax.legend()
|
| 212 |
+
ax.grid(True, alpha=0.3)
|
| 213 |
+
|
| 214 |
+
plt.tight_layout()
|
| 215 |
+
plt.show()
|
| 216 |
+
# ===================================================================================
|
| 217 |
+
# FULL 100-CLASS DIAGNOSTIC SPECTRUM (SORTED BY CLASS IDX FOR CONSISTENCY)
|
| 218 |
+
# ===================================================================================
|
| 219 |
+
print(f"\n{'='*90}")
|
| 220 |
+
print("Sparky — Full Class Spectrum")
|
| 221 |
+
print(f"{'='*90}")
|
| 222 |
+
print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}")
|
| 223 |
+
print("-" * 90)
|
| 224 |
+
|
| 225 |
+
for r in sorted(class_results, key=lambda x: x['class_idx']):
|
| 226 |
+
print(f"{r['class_idx']:<5} {r['class_name']:<20} "
|
| 227 |
+
f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
|
| 228 |
+
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} "
|
| 229 |
+
f"{r['vertex_variance']:>9.8f}")
|
| 230 |
+
|
| 231 |
+
return class_results, overall_acc
|
| 232 |
+
|
| 233 |
+
# Run evaluation
|
| 234 |
+
if 'model' in globals():
|
| 235 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 236 |
+
_, test_loader = get_cifar100_dataloaders(batch_size=100)
|
| 237 |
+
|
| 238 |
+
results, overall = evaluate_pentachora_vit(model, test_loader, device)
|
| 239 |
+
|
| 240 |
+
# Additional crystal analysis
|
| 241 |
+
print("\nCrystal Geometry Analysis:")
|
| 242 |
+
print("-"*50)
|
| 243 |
+
|
| 244 |
+
# Get crystals
|
| 245 |
+
crystals = model.cls_tokens.class_pentachora.detach().cpu()
|
| 246 |
+
|
| 247 |
+
# Compute pairwise similarities between class crystals
|
| 248 |
+
crystals_flat = crystals.mean(dim=1) # Average over 5 vertices
|
| 249 |
+
crystals_norm = F.normalize(crystals_flat, dim=1)
|
| 250 |
+
similarities = torch.matmul(crystals_norm, crystals_norm.T)
|
| 251 |
+
|
| 252 |
+
# Find confused pairs (high similarity, both low accuracy)
|
| 253 |
+
print("\nMost similar classes with poor performance:")
|
| 254 |
+
for i in range(100):
|
| 255 |
+
for j in range(i+1, 100):
|
| 256 |
+
if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20:
|
| 257 |
+
sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item()
|
| 258 |
+
if sim > 0.5:
|
| 259 |
+
print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ "
|
| 260 |
+
f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}")
|
| 261 |
+
|
| 262 |
+
else:
|
| 263 |
+
print("No model found in memory!")
|