File size: 3,605 Bytes
f7b283c fc5f33b f7b283c 71ca212 f7b283c 71ca212 fc5f33b f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c 71ca212 f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from pathlib import Path
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
from datetime import datetime
class Plotter:
def __init__(self, save_dir: Optional[Path] = None):
self.save_dir = save_dir
if save_dir:
self.save_dir.mkdir(parents=True, exist_ok=True)
def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
"""Plot and save training metrics history
Args:
history: Dict with training metrics
title: Plot title
"""
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
# Plot losses
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)
# Plot learning rate
if 'learning_rate' in history:
ax2.plot(history['learning_rate'], label='Learning Rate')
ax2.set_xlabel('Step')
ax2.set_ylabel('Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.legend()
ax2.grid(True)
plt.suptitle(title)
plt.tight_layout()
# Save
if self.save_dir:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = self.save_dir / f'training_history_{timestamp}.png'
plt.savefig(save_path)
plt.show()
def plot_validation_metrics(self, metrics: Dict[str, float]):
"""Plot validation metrics as a bar chart
Args:
metrics: Dictionary of validation metrics. Can handle nested dictionaries.
"""
# Flatten nested metrics dict
flat_metrics = {}
for key, value in metrics.items():
if key == 'num_queries_tested':
continue
# Flatten dict values, use numerical values only
if isinstance(value, dict):
for subkey, subvalue in value.items():
if isinstance(subvalue, (int, float)):
flat_metrics[f"{key}_{subkey}"] = subvalue
elif isinstance(value, (int, float)):
flat_metrics[key] = value
if not flat_metrics:
return
plt.figure(figsize=(12, 6))
# Extract metrics and values
metric_names = list(flat_metrics.keys())
values = list(flat_metrics.values())
# Create bar chart
bars = plt.bar(range(len(metric_names)), values)
# Customize the plot
plt.title('Validation Metrics')
plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
plt.ylabel('Value')
# Add value labels on bars
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom')
# Set y-axis limits and adjust layout
plt.ylim(0, 1.1)
plt.tight_layout()
# Save
if self.save_dir:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
plt.savefig(save_path)
plt.show()
|