| """
|
| Plotting utilities for training metrics visualization
|
| """
|
|
|
| import matplotlib.pyplot as plt
|
| import seaborn as sns
|
| import numpy as np
|
| from typing import Dict, List, Optional
|
| from pathlib import Path
|
| import json
|
|
|
|
|
| def set_style():
|
| """Set matplotlib style"""
|
| plt.style.use('seaborn-v0_8-whitegrid')
|
| sns.set_palette("husl")
|
|
|
|
|
| def plot_training_curves(history: Dict,
|
| save_path: str,
|
| title: str = "Training Progress"):
|
| """
|
| Plot training and validation curves
|
|
|
| Args:
|
| history: Training history dictionary
|
| save_path: Path to save plot
|
| title: Plot title
|
| """
|
| set_style()
|
|
|
| fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
| fig.suptitle(title, fontsize=14, fontweight='bold')
|
|
|
| epochs = range(1, len(history.get('train_loss', [])) + 1)
|
|
|
|
|
| ax = axes[0, 0]
|
| if 'train_loss' in history and history['train_loss']:
|
| ax.plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
|
| if 'val_loss' in history and history['val_loss']:
|
| ax.plot(epochs, history['val_loss'], 'r-', label='Val', linewidth=2)
|
| ax.set_xlabel('Epoch')
|
| ax.set_ylabel('Loss')
|
| ax.set_title('Loss')
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
|
|
| ax = axes[0, 1]
|
| if 'train_iou' in history and history['train_iou']:
|
| ax.plot(epochs, history['train_iou'], 'b-', label='Train', linewidth=2)
|
| if 'val_iou' in history and history['val_iou']:
|
| ax.plot(epochs, history['val_iou'], 'r-', label='Val', linewidth=2)
|
| ax.set_xlabel('Epoch')
|
| ax.set_ylabel('IoU')
|
| ax.set_title('Intersection over Union')
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
|
|
| ax = axes[0, 2]
|
| if 'train_dice' in history and history['train_dice']:
|
| ax.plot(epochs, history['train_dice'], 'b-', label='Train', linewidth=2)
|
| if 'val_dice' in history and history['val_dice']:
|
| ax.plot(epochs, history['val_dice'], 'r-', label='Val', linewidth=2)
|
| ax.set_xlabel('Epoch')
|
| ax.set_ylabel('Dice')
|
| ax.set_title('Dice Score (F1)')
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
|
|
| ax = axes[1, 0]
|
| if 'train_precision' in history and history['train_precision']:
|
| ax.plot(epochs, history['train_precision'], 'b-', label='Train', linewidth=2)
|
| if 'val_precision' in history and history['val_precision']:
|
| ax.plot(epochs, history['val_precision'], 'r-', label='Val', linewidth=2)
|
| ax.set_xlabel('Epoch')
|
| ax.set_ylabel('Precision')
|
| ax.set_title('Precision')
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
|
|
| ax = axes[1, 1]
|
| if 'train_recall' in history and history['train_recall']:
|
| ax.plot(epochs, history['train_recall'], 'b-', label='Train', linewidth=2)
|
| if 'val_recall' in history and history['val_recall']:
|
| ax.plot(epochs, history['val_recall'], 'r-', label='Val', linewidth=2)
|
| ax.set_xlabel('Epoch')
|
| ax.set_ylabel('Recall')
|
| ax.set_title('Recall')
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
|
|
| ax = axes[1, 2]
|
| if history.get('val_iou') and history.get('val_dice'):
|
| metrics = ['IoU', 'Dice', 'Precision', 'Recall']
|
| final_values = [
|
| history['val_iou'][-1] if history['val_iou'] else 0,
|
| history['val_dice'][-1] if history['val_dice'] else 0,
|
| history['val_precision'][-1] if history.get('val_precision') else 0,
|
| history['val_recall'][-1] if history.get('val_recall') else 0
|
| ]
|
| colors = sns.color_palette("husl", 4)
|
| bars = ax.bar(metrics, final_values, color=colors)
|
| ax.set_ylabel('Score')
|
| ax.set_title('Final Validation Metrics')
|
| ax.set_ylim(0, 1)
|
|
|
|
|
| for bar, val in zip(bars, final_values):
|
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
|
| f'{val:.3f}', ha='center', fontsize=10)
|
|
|
| plt.tight_layout()
|
| plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| plt.close()
|
|
|
| print(f"Training curves saved to {save_path}")
|
|
|
|
|
| def plot_confusion_matrix(cm: np.ndarray,
|
| class_names: List[str],
|
| save_path: str,
|
| title: str = "Confusion Matrix"):
|
| """
|
| Plot confusion matrix
|
|
|
| Args:
|
| cm: Confusion matrix
|
| class_names: Class names
|
| save_path: Path to save plot
|
| title: Plot title
|
| """
|
| set_style()
|
|
|
| fig, ax = plt.subplots(figsize=(8, 6))
|
|
|
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| xticklabels=class_names,
|
| yticklabels=class_names,
|
| ax=ax)
|
|
|
| ax.set_xlabel('Predicted')
|
| ax.set_ylabel('True')
|
| ax.set_title(title)
|
|
|
| plt.tight_layout()
|
| plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| plt.close()
|
|
|
| print(f"Confusion matrix saved to {save_path}")
|
|
|
|
|
| def plot_feature_importance(importance: List[tuple],
|
| save_path: str,
|
| title: str = "Feature Importance"):
|
| """
|
| Plot feature importance
|
|
|
| Args:
|
| importance: List of (feature_name, importance) tuples
|
| save_path: Path to save plot
|
| title: Plot title
|
| """
|
| set_style()
|
|
|
| fig, ax = plt.subplots(figsize=(10, 8))
|
|
|
| names = [item[0] for item in importance]
|
| values = [item[1] for item in importance]
|
|
|
| colors = sns.color_palette("viridis", len(importance))
|
|
|
| y_pos = np.arange(len(names))
|
| ax.barh(y_pos, values, color=colors)
|
| ax.set_yticks(y_pos)
|
| ax.set_yticklabels(names)
|
| ax.invert_yaxis()
|
| ax.set_xlabel('Importance (Gain)')
|
| ax.set_title(title)
|
|
|
| plt.tight_layout()
|
| plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| plt.close()
|
|
|
| print(f"Feature importance saved to {save_path}")
|
|
|
|
|
| def plot_dataset_comparison(all_histories: Dict[str, Dict],
|
| save_path: str):
|
| """
|
| Plot comparison across datasets
|
|
|
| Args:
|
| all_histories: Dictionary of {dataset_name: history}
|
| save_path: Path to save plot
|
| """
|
| set_style()
|
|
|
| fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
|
|
| metrics = ['val_dice', 'val_iou']
|
| titles = ['Validation Dice Score', 'Validation IoU']
|
|
|
| for ax, metric, title in zip(axes, metrics, titles):
|
| for dataset_name, history in all_histories.items():
|
| if metric in history and history[metric]:
|
| epochs = range(1, len(history[metric]) + 1)
|
| ax.plot(epochs, history[metric], label=dataset_name, linewidth=2)
|
|
|
| ax.set_xlabel('Epoch')
|
| ax.set_ylabel(metric.replace('val_', '').replace('_', ' ').title())
|
| ax.set_title(title)
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
| plt.tight_layout()
|
| plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| plt.close()
|
|
|
| print(f"Dataset comparison saved to {save_path}")
|
|
|
|
|
| def plot_chunked_training_progress(chunk_histories: List[Dict],
|
| save_path: str,
|
| title: str = "Chunked Training Progress"):
|
| """
|
| Plot progress across training chunks
|
|
|
| Args:
|
| chunk_histories: List of history dictionaries per chunk
|
| save_path: Path to save plot
|
| title: Plot title
|
| """
|
| set_style()
|
|
|
| fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| fig.suptitle(title, fontsize=14, fontweight='bold')
|
|
|
| colors = sns.color_palette("husl", len(chunk_histories))
|
|
|
| metrics = [
|
| ('train_loss', 'val_loss', 'Loss'),
|
| ('train_dice', 'val_dice', 'Dice Score'),
|
| ('train_iou', 'val_iou', 'IoU'),
|
| ('train_precision', 'val_precision', 'Precision')
|
| ]
|
|
|
| for ax, (train_key, val_key, ylabel) in zip(axes.flat, metrics):
|
| total_epochs = 0
|
|
|
| for i, history in enumerate(chunk_histories):
|
| if train_key in history and history[train_key]:
|
| epochs = range(total_epochs + 1, total_epochs + len(history[train_key]) + 1)
|
| ax.plot(epochs, history[train_key], '--', color=colors[i], alpha=0.5)
|
| total_epochs += len(history[train_key])
|
|
|
| total_epochs = 0
|
| for i, history in enumerate(chunk_histories):
|
| if val_key in history and history[val_key]:
|
| epochs = range(total_epochs + 1, total_epochs + len(history[val_key]) + 1)
|
| ax.plot(epochs, history[val_key], '-', color=colors[i],
|
| label=f'Chunk {i+1}', linewidth=2)
|
|
|
|
|
| if i < len(chunk_histories) - 1:
|
| ax.axvline(x=total_epochs + len(history[val_key]),
|
| color='gray', linestyle=':', alpha=0.5)
|
|
|
| total_epochs += len(history[val_key])
|
|
|
| ax.set_xlabel('Epoch')
|
| ax.set_ylabel(ylabel)
|
| ax.set_title(f'Validation {ylabel}')
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
| plt.tight_layout()
|
| plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| plt.close()
|
|
|
| print(f"Chunked training progress saved to {save_path}")
|
|
|
|
|
| def generate_training_report(history: Dict,
|
| save_path: str,
|
| dataset_name: str = "unknown"):
|
| """
|
| Generate training report as text file
|
|
|
| Args:
|
| history: Training history
|
| save_path: Path to save report
|
| dataset_name: Dataset name
|
| """
|
| with open(save_path, 'w') as f:
|
| f.write("="*60 + "\n")
|
| f.write(f"Training Report - {dataset_name}\n")
|
| f.write("="*60 + "\n\n")
|
|
|
| num_epochs = len(history.get('train_loss', []))
|
| f.write(f"Total Epochs: {num_epochs}\n\n")
|
|
|
| f.write("Final Metrics:\n")
|
| f.write("-"*40 + "\n")
|
|
|
| for key, values in history.items():
|
| if values and isinstance(values, list):
|
| final_value = values[-1]
|
| if isinstance(final_value, (int, float)):
|
| f.write(f" {key}: {final_value:.4f}\n")
|
|
|
| f.write("\n")
|
| f.write("Best Metrics:\n")
|
| f.write("-"*40 + "\n")
|
|
|
| for key, values in history.items():
|
| if values and isinstance(values, list):
|
| if 'loss' in key:
|
| best_value = min(values)
|
| best_epoch = values.index(best_value) + 1
|
| else:
|
| best_value = max(values)
|
| best_epoch = values.index(best_value) + 1
|
|
|
| if isinstance(best_value, (int, float)):
|
| f.write(f" {key}: {best_value:.4f} (epoch {best_epoch})\n")
|
|
|
| print(f"Training report saved to {save_path}")
|
|
|