|
import os |
|
import glob |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import re |
|
|
|
def extract_run_name(filename): |
|
"""Extract the run name from the filename.""" |
|
basename = os.path.basename(filename) |
|
|
|
match = re.search(r'_([^_]+)(?:-loss)?_tensorboard\.csv$', basename) |
|
if match: |
|
return match.group(1) |
|
return basename.split('_')[1].split('-')[0] |
|
|
|
def setup_plot_style(): |
|
"""Apply publication-quality styling to plots.""" |
|
plt.rcParams.update({ |
|
'font.family': 'serif', |
|
'font.size': 12, |
|
'axes.labelsize': 14, |
|
'axes.titlesize': 16, |
|
'legend.fontsize': 10, |
|
'figure.dpi': 300, |
|
'figure.figsize': (10, 6), |
|
'lines.linewidth': 2.5, |
|
'axes.grid': True, |
|
'grid.linestyle': '--', |
|
'grid.alpha': 0.6, |
|
'axes.spines.top': False, |
|
'axes.spines.right': False, |
|
}) |
|
|
|
def get_metric_label(metric_name): |
|
"""Return a human-readable label for the metric.""" |
|
labels = { |
|
'loss_epoch': 'Loss', |
|
'perplexityval_epoch': 'Validation Perplexity', |
|
'topkacc_epoch': 'Top-K Accuracy', |
|
'acc_trainstep': 'Training Accuracy' |
|
} |
|
return labels.get(metric_name, metric_name.replace('_', ' ').title()) |
|
|
|
def get_color_mapping(run_names): |
|
"""Create a consistent color mapping for all runs.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colors = [ |
|
"#e6194b", |
|
"#f58231", |
|
"#ffe119", |
|
"#bfef45", |
|
"#3cb44b", |
|
"#42d4f4", |
|
"#4363d8", |
|
"#911eb4", |
|
"#f032e6", |
|
"#a9a9a9" |
|
] |
|
|
|
|
|
return {name: colors[i % len(colors)] for i, name in enumerate(sorted(run_names))} |
|
|
|
def plot_metric(metric_dir, color_mapping, output_dir): |
|
"""Plot all runs for a specific metric.""" |
|
metric_name = os.path.basename(metric_dir) |
|
csv_files = glob.glob(os.path.join(metric_dir, '*.csv')) |
|
|
|
if not csv_files: |
|
print(f"No CSV files found in {metric_dir}") |
|
return |
|
|
|
plt.figure(figsize=(12, 7)) |
|
|
|
for csv_file in sorted(csv_files): |
|
try: |
|
|
|
df = pd.read_csv(csv_file) |
|
|
|
|
|
run_name = extract_run_name(csv_file) |
|
|
|
|
|
color = color_mapping.get(run_name, 'gray') |
|
plt.plot(df['Step'], df['Value'], label=run_name, color=color, alpha=0.9) |
|
|
|
|
|
except Exception as e: |
|
print(f"Error processing {csv_file}: {e}") |
|
|
|
|
|
plt.xlabel('Step') |
|
plt.ylabel(get_metric_label(metric_name)) |
|
|
|
comparison = "Epoch" if "epoch" in metric_name else "Step" |
|
plt.title(f'{get_metric_label(metric_name)} vs. {comparison}', fontweight='bold') |
|
|
|
|
|
plt.legend(loc='best', frameon=True, fancybox=True, framealpha=0.9, |
|
shadow=True, borderpad=1, ncol=2 if len(csv_files) > 5 else 1) |
|
|
|
|
|
plt.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
output_path = os.path.join(output_dir, f'{metric_name}_plot.png') |
|
plt.savefig(output_path, bbox_inches='tight') |
|
print(f"Saved plot to {output_path}") |
|
|
|
|
|
plt.close() |
|
|
|
def main(): |
|
|
|
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'runs_jsons') |
|
|
|
|
|
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plots') |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
setup_plot_style() |
|
|
|
|
|
metric_dirs = [d for d in glob.glob(os.path.join(base_dir, '*')) if os.path.isdir(d)] |
|
|
|
|
|
all_run_names = set() |
|
for metric_dir in metric_dirs: |
|
csv_files = glob.glob(os.path.join(metric_dir, '*.csv')) |
|
for csv_file in csv_files: |
|
run_name = extract_run_name(csv_file) |
|
all_run_names.add(run_name) |
|
|
|
|
|
color_mapping = get_color_mapping(all_run_names) |
|
|
|
|
|
for metric_dir in metric_dirs: |
|
plot_metric(metric_dir, color_mapping, output_dir) |
|
|
|
print(f"All plots have been generated in {output_dir}") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|