| """ |
| Visualization utilities for model analysis |
| """ |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import seaborn as sns |
| from pathlib import Path |
| from sklearn.metrics import confusion_matrix |
|
|
| def setup_plotting(): |
| """Setup plotting style""" |
| plt.style.use('seaborn-v0_8-darkgrid') |
| sns.set_palette("husl") |
| |
| |
| plt.rcParams['figure.figsize'] = (10, 6) |
| plt.rcParams['font.size'] = 12 |
| plt.rcParams['axes.titlesize'] = 14 |
| plt.rcParams['axes.labelsize'] = 12 |
|
|
| def plot_training_history(metrics_file: str, save_path: str = None): |
| """Plot training and validation metrics""" |
| |
| import json |
| |
| with open(metrics_file, 'r') as f: |
| metrics = json.load(f) |
| |
| epochs = [m['epoch'] for m in metrics] |
| train_loss = [m['train']['loss'] for m in metrics] |
| val_loss = [m['validation']['loss'] for m in metrics] |
| train_acc = [m['train']['accuracy'] for m in metrics] |
| val_acc = [m['validation']['accuracy'] for m in metrics] |
| |
| fig, axes = plt.subplots(1, 2, figsize=(15, 5)) |
| |
| |
| axes[0].plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2) |
| axes[0].plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2) |
| axes[0].set_xlabel('Epoch') |
| axes[0].set_ylabel('Loss') |
| axes[0].set_title('Training and Validation Loss') |
| axes[0].legend() |
| axes[0].grid(True, alpha=0.3) |
| |
| |
| axes[1].plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2) |
| axes[1].plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2) |
| axes[1].set_xlabel('Epoch') |
| axes[1].set_ylabel('Accuracy (%)') |
| axes[1].set_title('Training and Validation Accuracy') |
| axes[1].legend() |
| axes[1].grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| |
| return fig |
|
|
| def plot_confusion_matrix(model, dataloader, device='cpu', save_path: str = None): |
| """Plot confusion matrix""" |
| |
| model.eval() |
| all_preds = [] |
| all_targets = [] |
| |
| with torch.no_grad(): |
| for data, target in dataloader: |
| data, target = data.to(device), target.to(device) |
| output = model(data) |
| pred = output.argmax(dim=1) |
| |
| all_preds.extend(pred.cpu().numpy()) |
| all_targets.extend(target.cpu().numpy()) |
| |
| |
| cm = confusion_matrix(all_targets, all_preds) |
| |
| |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) |
| ax.figure.colorbar(im, ax=ax) |
| |
| |
| ax.set_xlabel('Predicted Label') |
| ax.set_ylabel('True Label') |
| ax.set_title('Confusion Matrix') |
| |
| |
| thresh = cm.max() / 2. |
| for i in range(cm.shape[0]): |
| for j in range(cm.shape[1]): |
| ax.text(j, i, format(cm[i, j], 'd'), |
| ha="center", va="center", |
| color="white" if cm[i, j] > thresh else "black") |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| |
| return fig |
|
|
| def visualize_attacks(original, adversarial, predictions, save_path: str = None): |
| """Visualize original vs adversarial examples""" |
| |
| n_samples = min(10, len(original)) |
| |
| fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 2, 4)) |
| |
| for i in range(n_samples): |
| |
| ax = axes[0, i] |
| ax.imshow(original[i].squeeze(), cmap='gray') |
| ax.set_title(f"Orig: {predictions['original'][i]}") |
| ax.axis('off') |
| |
| |
| ax = axes[1, i] |
| ax.imshow(adversarial[i].squeeze(), cmap='gray') |
| ax.set_title(f"Adv: {predictions['adversarial'][i]}") |
| ax.axis('off') |
| |
| plt.suptitle('Original vs Adversarial Examples') |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| |
| return fig |
|
|