mbellan's picture
Initial deployment
c3efd49
"""Visualization tools for training monitoring."""
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Optional, Any
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
class Visualizer:
"""
Creates visualizations for training metrics.
Supports TensorBoard integration and static plots.
"""
def __init__(self, output_dir: str = "visualizations"):
"""
Initialize visualizer.
Args:
output_dir: Directory to save visualizations
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# Try to import tensorboard
self.tensorboard_available = False
try:
from torch.utils.tensorboard import SummaryWriter
self.SummaryWriter = SummaryWriter
self.tensorboard_available = True
logger.info("TensorBoard available")
except ImportError:
logger.warning("TensorBoard not available")
self.writer = None
logger.info(f"Visualizer initialized: output_dir={output_dir}")
def initialize_tensorboard(self, log_dir: Optional[str] = None) -> None:
"""
Initialize TensorBoard writer.
Args:
log_dir: Optional TensorBoard log directory
"""
if not self.tensorboard_available:
logger.warning("TensorBoard not available, skipping initialization")
return
if log_dir is None:
log_dir = str(self.output_dir / "tensorboard")
self.writer = self.SummaryWriter(log_dir)
logger.info(f"TensorBoard initialized: {log_dir}")
def log_scalar_to_tensorboard(
self,
tag: str,
value: float,
step: int
) -> None:
"""
Log scalar value to TensorBoard.
Args:
tag: Metric name
value: Metric value
step: Step number
"""
if self.writer is not None:
self.writer.add_scalar(tag, value, step)
def plot_training_curve(
self,
metrics: Dict[str, List[Dict[str, Any]]],
metric_name: str,
title: Optional[str] = None,
filename: Optional[str] = None
) -> str:
"""
Plot training curve for a metric.
Args:
metrics: Dictionary of metrics
metric_name: Name of metric to plot
title: Optional plot title
filename: Optional output filename
Returns:
Path to saved plot
"""
if metric_name not in metrics:
raise ValueError(f"Metric '{metric_name}' not found")
data = metrics[metric_name]
steps = [entry['step'] for entry in data]
values = [entry['value'] for entry in data]
plt.figure(figsize=(10, 6))
plt.plot(steps, values, linewidth=2)
plt.xlabel('Step')
plt.ylabel(metric_name.replace('_', ' ').title())
plt.title(title or f'{metric_name.replace("_", " ").title()} Over Time')
plt.grid(True, alpha=0.3)
if filename is None:
filename = f"{metric_name}_curve.png"
output_path = self.output_dir / filename
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
logger.info(f"Training curve saved: {output_path}")
return str(output_path)
def plot_multiple_metrics(
self,
metrics: Dict[str, List[Dict[str, Any]]],
metric_names: List[str],
title: Optional[str] = None,
filename: Optional[str] = None
) -> str:
"""
Plot multiple metrics on the same figure.
Args:
metrics: Dictionary of metrics
metric_names: List of metric names to plot
title: Optional plot title
filename: Optional output filename
Returns:
Path to saved plot
"""
plt.figure(figsize=(12, 6))
for metric_name in metric_names:
if metric_name in metrics:
data = metrics[metric_name]
steps = [entry['step'] for entry in data]
values = [entry['value'] for entry in data]
plt.plot(steps, values, label=metric_name, linewidth=2)
plt.xlabel('Step')
plt.ylabel('Value')
plt.title(title or 'Training Metrics')
plt.legend()
plt.grid(True, alpha=0.3)
if filename is None:
filename = "multiple_metrics.png"
output_path = self.output_dir / filename
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
logger.info(f"Multi-metric plot saved: {output_path}")
return str(output_path)
def plot_training_curves(
self,
metrics: Dict[str, List[Dict[str, Any]]],
title: str = "Training Progress",
filename: Optional[str] = None
) -> str:
"""
Plot comprehensive training curves with subplots.
Args:
metrics: Dictionary of all metrics
title: Main title for the figure
filename: Optional output filename
Returns:
Path to saved plot
"""
if not metrics:
logger.warning("No metrics to plot")
return ""
# Determine which metrics to plot
metric_names = list(metrics.keys())
num_metrics = len(metric_names)
if num_metrics == 0:
return ""
# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle(title, fontsize=16, fontweight='bold')
axes = axes.flatten()
# Plot up to 4 key metrics
key_metrics = ['reward', 'loss', 'total_reward', 'episode_time']
plot_idx = 0
for metric_name in key_metrics:
if metric_name in metrics and plot_idx < 4:
data = metrics[metric_name]
steps = [entry['step'] for entry in data]
values = [entry['value'] for entry in data]
ax = axes[plot_idx]
ax.plot(steps, values, linewidth=2, marker='o', markersize=4)
ax.set_xlabel('Episode')
ax.set_ylabel(metric_name.replace('_', ' ').title())
ax.set_title(f'{metric_name.replace("_", " ").title()}')
ax.grid(True, alpha=0.3)
# Add trend line
if len(steps) > 1:
z = np.polyfit(steps, values, 1)
p = np.poly1d(z)
ax.plot(steps, p(steps), "--", alpha=0.5, color='red', label='Trend')
ax.legend()
plot_idx += 1
# Hide unused subplots
for idx in range(plot_idx, 4):
axes[idx].axis('off')
plt.tight_layout()
if filename is None:
filename = f"training_curves_{len(steps)}_episodes.png"
output_path = self.output_dir / filename
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
logger.info(f"Training curves saved: {output_path}")
return str(output_path)
def plot_reward_distribution(
self,
rewards: List[float],
title: Optional[str] = None,
filename: Optional[str] = None
) -> str:
"""
Plot reward distribution histogram.
Args:
rewards: List of reward values
title: Optional plot title
filename: Optional output filename
Returns:
Path to saved plot
"""
plt.figure(figsize=(10, 6))
plt.hist(rewards, bins=30, alpha=0.7, edgecolor='black')
plt.xlabel('Reward')
plt.ylabel('Frequency')
plt.title(title or 'Reward Distribution')
plt.grid(True, alpha=0.3, axis='y')
# Add statistics
mean_reward = np.mean(rewards)
std_reward = np.std(rewards)
plt.axvline(mean_reward, color='red', linestyle='--',
label=f'Mean: {mean_reward:.3f}')
plt.axvline(mean_reward + std_reward, color='orange',
linestyle=':', alpha=0.7, label=f'±1 Std')
plt.axvline(mean_reward - std_reward, color='orange',
linestyle=':', alpha=0.7)
plt.legend()
if filename is None:
filename = "reward_distribution.png"
output_path = self.output_dir / filename
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
logger.info(f"Reward distribution saved: {output_path}")
return str(output_path)
def generate_summary_report(
self,
metrics: Dict[str, List[Dict[str, Any]]],
statistics: Dict[str, Dict[str, float]],
output_filename: str = "training_summary.txt"
) -> str:
"""
Generate text summary report.
Args:
metrics: Dictionary of metrics
statistics: Dictionary of metric statistics
output_filename: Output filename
Returns:
Path to saved report
"""
lines = []
lines.append("=" * 60)
lines.append("TRAINING SUMMARY REPORT")
lines.append("=" * 60)
lines.append("")
# Overall statistics
lines.append("METRIC STATISTICS:")
lines.append("-" * 60)
for metric_name, stats in statistics.items():
lines.append(f"\n{metric_name}:")
lines.append(f" Count: {stats['count']}")
lines.append(f" Mean: {stats['mean']:.6f}")
lines.append(f" Std: {stats['std']:.6f}")
lines.append(f" Min: {stats['min']:.6f}")
lines.append(f" Max: {stats['max']:.6f}")
lines.append("")
lines.append("=" * 60)
report_text = "\n".join(lines)
output_path = self.output_dir / output_filename
with open(output_path, 'w') as f:
f.write(report_text)
logger.info(f"Summary report saved: {output_path}")
return str(output_path)
def close(self) -> None:
"""Close TensorBoard writer if open."""
if self.writer is not None:
self.writer.close()
logger.info("TensorBoard writer closed")