|
from pathlib import Path |
|
from typing import Dict, List, Optional |
|
import matplotlib.pyplot as plt |
|
from datetime import datetime |
|
|
|
class Plotter: |
|
def __init__(self, save_dir: Optional[Path] = None): |
|
self.save_dir = save_dir |
|
if save_dir: |
|
self.save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"): |
|
"""Plot and save training metrics history |
|
Args: |
|
history: Dict with training metrics |
|
title: Plot title |
|
""" |
|
|
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) |
|
|
|
|
|
ax1.plot(history['train_loss'], label='Train Loss') |
|
ax1.plot(history['val_loss'], label='Validation Loss') |
|
ax1.set_xlabel('Epoch') |
|
ax1.set_ylabel('Loss') |
|
ax1.set_title('Training and Validation Loss') |
|
ax1.legend() |
|
ax1.grid(True) |
|
|
|
|
|
if 'learning_rate' in history: |
|
ax2.plot(history['learning_rate'], label='Learning Rate') |
|
ax2.set_xlabel('Step') |
|
ax2.set_ylabel('Learning Rate') |
|
ax2.set_title('Learning Rate Schedule') |
|
ax2.legend() |
|
ax2.grid(True) |
|
|
|
plt.suptitle(title) |
|
plt.tight_layout() |
|
|
|
|
|
if self.save_dir: |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
save_path = self.save_dir / f'training_history_{timestamp}.png' |
|
plt.savefig(save_path) |
|
|
|
plt.show() |
|
|
|
def plot_validation_metrics(self, metrics: Dict[str, float]): |
|
"""Plot validation metrics as a bar chart |
|
Args: |
|
metrics: Dictionary of validation metrics. Can handle nested dictionaries. |
|
""" |
|
|
|
|
|
flat_metrics = {} |
|
for key, value in metrics.items(): |
|
if key == 'num_queries_tested': |
|
continue |
|
|
|
|
|
if isinstance(value, dict): |
|
for subkey, subvalue in value.items(): |
|
if isinstance(subvalue, (int, float)): |
|
flat_metrics[f"{key}_{subkey}"] = subvalue |
|
elif isinstance(value, (int, float)): |
|
flat_metrics[key] = value |
|
|
|
if not flat_metrics: |
|
return |
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
|
|
|
metric_names = list(flat_metrics.keys()) |
|
values = list(flat_metrics.values()) |
|
|
|
|
|
bars = plt.bar(range(len(metric_names)), values) |
|
|
|
|
|
plt.title('Validation Metrics') |
|
plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right') |
|
plt.ylabel('Value') |
|
|
|
|
|
for bar in bars: |
|
height = bar.get_height() |
|
plt.text(bar.get_x() + bar.get_width()/2., height, |
|
f'{height:.3f}', |
|
ha='center', va='bottom') |
|
|
|
|
|
plt.ylim(0, 1.1) |
|
plt.tight_layout() |
|
|
|
|
|
if self.save_dir: |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
save_path = self.save_dir / f'validation_metrics_{timestamp}.png' |
|
plt.savefig(save_path) |
|
|
|
plt.show() |
|
|