Spaces:
Runtime error
Runtime error
| """Visualization tools for training monitoring.""" | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from typing import Dict, List, Optional, Any | |
| from pathlib import Path | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class Visualizer: | |
| """ | |
| Creates visualizations for training metrics. | |
| Supports TensorBoard integration and static plots. | |
| """ | |
| def __init__(self, output_dir: str = "visualizations"): | |
| """ | |
| Initialize visualizer. | |
| Args: | |
| output_dir: Directory to save visualizations | |
| """ | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Try to import tensorboard | |
| self.tensorboard_available = False | |
| try: | |
| from torch.utils.tensorboard import SummaryWriter | |
| self.SummaryWriter = SummaryWriter | |
| self.tensorboard_available = True | |
| logger.info("TensorBoard available") | |
| except ImportError: | |
| logger.warning("TensorBoard not available") | |
| self.writer = None | |
| logger.info(f"Visualizer initialized: output_dir={output_dir}") | |
| def initialize_tensorboard(self, log_dir: Optional[str] = None) -> None: | |
| """ | |
| Initialize TensorBoard writer. | |
| Args: | |
| log_dir: Optional TensorBoard log directory | |
| """ | |
| if not self.tensorboard_available: | |
| logger.warning("TensorBoard not available, skipping initialization") | |
| return | |
| if log_dir is None: | |
| log_dir = str(self.output_dir / "tensorboard") | |
| self.writer = self.SummaryWriter(log_dir) | |
| logger.info(f"TensorBoard initialized: {log_dir}") | |
| def log_scalar_to_tensorboard( | |
| self, | |
| tag: str, | |
| value: float, | |
| step: int | |
| ) -> None: | |
| """ | |
| Log scalar value to TensorBoard. | |
| Args: | |
| tag: Metric name | |
| value: Metric value | |
| step: Step number | |
| """ | |
| if self.writer is not None: | |
| self.writer.add_scalar(tag, value, step) | |
| def plot_training_curve( | |
| self, | |
| metrics: Dict[str, List[Dict[str, Any]]], | |
| metric_name: str, | |
| title: Optional[str] = None, | |
| filename: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Plot training curve for a metric. | |
| Args: | |
| metrics: Dictionary of metrics | |
| metric_name: Name of metric to plot | |
| title: Optional plot title | |
| filename: Optional output filename | |
| Returns: | |
| Path to saved plot | |
| """ | |
| if metric_name not in metrics: | |
| raise ValueError(f"Metric '{metric_name}' not found") | |
| data = metrics[metric_name] | |
| steps = [entry['step'] for entry in data] | |
| values = [entry['value'] for entry in data] | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(steps, values, linewidth=2) | |
| plt.xlabel('Step') | |
| plt.ylabel(metric_name.replace('_', ' ').title()) | |
| plt.title(title or f'{metric_name.replace("_", " ").title()} Over Time') | |
| plt.grid(True, alpha=0.3) | |
| if filename is None: | |
| filename = f"{metric_name}_curve.png" | |
| output_path = self.output_dir / filename | |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| logger.info(f"Training curve saved: {output_path}") | |
| return str(output_path) | |
| def plot_multiple_metrics( | |
| self, | |
| metrics: Dict[str, List[Dict[str, Any]]], | |
| metric_names: List[str], | |
| title: Optional[str] = None, | |
| filename: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Plot multiple metrics on the same figure. | |
| Args: | |
| metrics: Dictionary of metrics | |
| metric_names: List of metric names to plot | |
| title: Optional plot title | |
| filename: Optional output filename | |
| Returns: | |
| Path to saved plot | |
| """ | |
| plt.figure(figsize=(12, 6)) | |
| for metric_name in metric_names: | |
| if metric_name in metrics: | |
| data = metrics[metric_name] | |
| steps = [entry['step'] for entry in data] | |
| values = [entry['value'] for entry in data] | |
| plt.plot(steps, values, label=metric_name, linewidth=2) | |
| plt.xlabel('Step') | |
| plt.ylabel('Value') | |
| plt.title(title or 'Training Metrics') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| if filename is None: | |
| filename = "multiple_metrics.png" | |
| output_path = self.output_dir / filename | |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| logger.info(f"Multi-metric plot saved: {output_path}") | |
| return str(output_path) | |
| def plot_training_curves( | |
| self, | |
| metrics: Dict[str, List[Dict[str, Any]]], | |
| title: str = "Training Progress", | |
| filename: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Plot comprehensive training curves with subplots. | |
| Args: | |
| metrics: Dictionary of all metrics | |
| title: Main title for the figure | |
| filename: Optional output filename | |
| Returns: | |
| Path to saved plot | |
| """ | |
| if not metrics: | |
| logger.warning("No metrics to plot") | |
| return "" | |
| # Determine which metrics to plot | |
| metric_names = list(metrics.keys()) | |
| num_metrics = len(metric_names) | |
| if num_metrics == 0: | |
| return "" | |
| # Create subplots | |
| fig, axes = plt.subplots(2, 2, figsize=(15, 10)) | |
| fig.suptitle(title, fontsize=16, fontweight='bold') | |
| axes = axes.flatten() | |
| # Plot up to 4 key metrics | |
| key_metrics = ['reward', 'loss', 'total_reward', 'episode_time'] | |
| plot_idx = 0 | |
| for metric_name in key_metrics: | |
| if metric_name in metrics and plot_idx < 4: | |
| data = metrics[metric_name] | |
| steps = [entry['step'] for entry in data] | |
| values = [entry['value'] for entry in data] | |
| ax = axes[plot_idx] | |
| ax.plot(steps, values, linewidth=2, marker='o', markersize=4) | |
| ax.set_xlabel('Episode') | |
| ax.set_ylabel(metric_name.replace('_', ' ').title()) | |
| ax.set_title(f'{metric_name.replace("_", " ").title()}') | |
| ax.grid(True, alpha=0.3) | |
| # Add trend line | |
| if len(steps) > 1: | |
| z = np.polyfit(steps, values, 1) | |
| p = np.poly1d(z) | |
| ax.plot(steps, p(steps), "--", alpha=0.5, color='red', label='Trend') | |
| ax.legend() | |
| plot_idx += 1 | |
| # Hide unused subplots | |
| for idx in range(plot_idx, 4): | |
| axes[idx].axis('off') | |
| plt.tight_layout() | |
| if filename is None: | |
| filename = f"training_curves_{len(steps)}_episodes.png" | |
| output_path = self.output_dir / filename | |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| logger.info(f"Training curves saved: {output_path}") | |
| return str(output_path) | |
| def plot_reward_distribution( | |
| self, | |
| rewards: List[float], | |
| title: Optional[str] = None, | |
| filename: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Plot reward distribution histogram. | |
| Args: | |
| rewards: List of reward values | |
| title: Optional plot title | |
| filename: Optional output filename | |
| Returns: | |
| Path to saved plot | |
| """ | |
| plt.figure(figsize=(10, 6)) | |
| plt.hist(rewards, bins=30, alpha=0.7, edgecolor='black') | |
| plt.xlabel('Reward') | |
| plt.ylabel('Frequency') | |
| plt.title(title or 'Reward Distribution') | |
| plt.grid(True, alpha=0.3, axis='y') | |
| # Add statistics | |
| mean_reward = np.mean(rewards) | |
| std_reward = np.std(rewards) | |
| plt.axvline(mean_reward, color='red', linestyle='--', | |
| label=f'Mean: {mean_reward:.3f}') | |
| plt.axvline(mean_reward + std_reward, color='orange', | |
| linestyle=':', alpha=0.7, label=f'±1 Std') | |
| plt.axvline(mean_reward - std_reward, color='orange', | |
| linestyle=':', alpha=0.7) | |
| plt.legend() | |
| if filename is None: | |
| filename = "reward_distribution.png" | |
| output_path = self.output_dir / filename | |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close() | |
| logger.info(f"Reward distribution saved: {output_path}") | |
| return str(output_path) | |
| def generate_summary_report( | |
| self, | |
| metrics: Dict[str, List[Dict[str, Any]]], | |
| statistics: Dict[str, Dict[str, float]], | |
| output_filename: str = "training_summary.txt" | |
| ) -> str: | |
| """ | |
| Generate text summary report. | |
| Args: | |
| metrics: Dictionary of metrics | |
| statistics: Dictionary of metric statistics | |
| output_filename: Output filename | |
| Returns: | |
| Path to saved report | |
| """ | |
| lines = [] | |
| lines.append("=" * 60) | |
| lines.append("TRAINING SUMMARY REPORT") | |
| lines.append("=" * 60) | |
| lines.append("") | |
| # Overall statistics | |
| lines.append("METRIC STATISTICS:") | |
| lines.append("-" * 60) | |
| for metric_name, stats in statistics.items(): | |
| lines.append(f"\n{metric_name}:") | |
| lines.append(f" Count: {stats['count']}") | |
| lines.append(f" Mean: {stats['mean']:.6f}") | |
| lines.append(f" Std: {stats['std']:.6f}") | |
| lines.append(f" Min: {stats['min']:.6f}") | |
| lines.append(f" Max: {stats['max']:.6f}") | |
| lines.append("") | |
| lines.append("=" * 60) | |
| report_text = "\n".join(lines) | |
| output_path = self.output_dir / output_filename | |
| with open(output_path, 'w') as f: | |
| f.write(report_text) | |
| logger.info(f"Summary report saved: {output_path}") | |
| return str(output_path) | |
| def close(self) -> None: | |
| """Close TensorBoard writer if open.""" | |
| if self.writer is not None: | |
| self.writer.close() | |
| logger.info("TensorBoard writer closed") | |