code-completion / archive-misc /plot_metrics.py
jblitzar's picture
Upload folder using huggingface_hub
a8639ac verified
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import re
def extract_run_name(filename):
"""Extract the run name from the filename."""
basename = os.path.basename(filename)
# Extract the part between '_' and '_tensorboard.csv'
match = re.search(r'_([^_]+)(?:-loss)?_tensorboard\.csv$', basename)
if match:
return match.group(1)
return basename.split('_')[1].split('-')[0] # Fallback extraction
def setup_plot_style():
"""Apply publication-quality styling to plots."""
plt.rcParams.update({
'font.family': 'serif',
'font.size': 12,
'axes.labelsize': 14,
'axes.titlesize': 16,
'legend.fontsize': 10,
'figure.dpi': 300,
'figure.figsize': (10, 6),
'lines.linewidth': 2.5,
'axes.grid': True,
'grid.linestyle': '--',
'grid.alpha': 0.6,
'axes.spines.top': False,
'axes.spines.right': False,
})
def get_metric_label(metric_name):
"""Return a human-readable label for the metric."""
labels = {
'loss_epoch': 'Loss',
'perplexityval_epoch': 'Validation Perplexity',
'topkacc_epoch': 'Top-K Accuracy',
'acc_trainstep': 'Training Accuracy'
}
return labels.get(metric_name, metric_name.replace('_', ' ').title())
def get_color_mapping(run_names):
"""Create a consistent color mapping for all runs."""
# Define a color palette with distinct colors
# colors = [
# '#1f77b4', # Blue
# '#ff7f0e', # Orange
# '#2ca02c', # Green
# '#d62728', # Red
# '#9467bd', # Purple
# '#8c564b', # Brown
# '#e377c2', # Pink
# '#7f7f7f', # Gray
# '#bcbd22', # Yellow-green
# '#17becf', # Cyan
# ]
# colors = """#091717
# #13B3B9
# #265E5A
# #20808D
# #25E5A5
# #20808D
# #FBFAF4
# #E4E3D4
# #FFD2A6
# #A84B2F
# #944454""".lower().split("\n\n")
colors = [
"#e6194b", # Red
"#f58231", # Orange
"#ffe119", # Yellow
"#bfef45", # Lime
"#3cb44b", # Green
"#42d4f4", # Cyan
"#4363d8", # Blue
"#911eb4", # Purple
"#f032e6", # Magenta
"#a9a9a9" # Grey
]
# Create a mapping of run names to colors
return {name: colors[i % len(colors)] for i, name in enumerate(sorted(run_names))}
def plot_metric(metric_dir, color_mapping, output_dir):
"""Plot all runs for a specific metric."""
metric_name = os.path.basename(metric_dir)
csv_files = glob.glob(os.path.join(metric_dir, '*.csv'))
if not csv_files:
print(f"No CSV files found in {metric_dir}")
return
plt.figure(figsize=(12, 7))
for csv_file in sorted(csv_files):
try:
# Read the CSV file
df = pd.read_csv(csv_file)
# Extract run name from filename
run_name = extract_run_name(csv_file)
# Plot the data using step as x-axis
color = color_mapping.get(run_name, 'gray')
plt.plot(df['Step'], df['Value'], label=run_name, color=color, alpha=0.9)
#plt.plot(df['Step'], df['Value'], label=run_name, color=color, marker='o', markersize=6, alpha=0.8)
except Exception as e:
print(f"Error processing {csv_file}: {e}")
# Set labels and title
plt.xlabel('Step')
plt.ylabel(get_metric_label(metric_name))
comparison = "Epoch" if "epoch" in metric_name else "Step"
plt.title(f'{get_metric_label(metric_name)} vs. {comparison}', fontweight='bold')
# Add legend with good positioning
plt.legend(loc='best', frameon=True, fancybox=True, framealpha=0.9,
shadow=True, borderpad=1, ncol=2 if len(csv_files) > 5 else 1)
# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7)
# Tight layout for clean margins
plt.tight_layout()
# Save the plot
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f'{metric_name}_plot.png')
plt.savefig(output_path, bbox_inches='tight')
print(f"Saved plot to {output_path}")
# Close the figure to free memory
plt.close()
def main():
# Base directory containing the metric directories
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'runs_jsons')
# Output directory for plots
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plots')
os.makedirs(output_dir, exist_ok=True)
# Setup plot style
setup_plot_style()
# Get all metric directories
metric_dirs = [d for d in glob.glob(os.path.join(base_dir, '*')) if os.path.isdir(d)]
# Collect all run names across all metrics for consistent coloring
all_run_names = set()
for metric_dir in metric_dirs:
csv_files = glob.glob(os.path.join(metric_dir, '*.csv'))
for csv_file in csv_files:
run_name = extract_run_name(csv_file)
all_run_names.add(run_name)
# Create color mapping
color_mapping = get_color_mapping(all_run_names)
# Plot each metric
for metric_dir in metric_dirs:
plot_metric(metric_dir, color_mapping, output_dir)
print(f"All plots have been generated in {output_dir}")
if __name__ == '__main__':
main()