| import re |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| import argparse |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Parse training log and plot losses.') |
| parser.add_argument('--log_file', type=str, default='logs/train_222379.out', help='Path to the log file') |
| args = parser.parse_args() |
|
|
| log_path = args.log_file |
| data = [] |
|
|
| |
| |
| pattern = re.compile(r'Epoch (\d+), Step (\d+), Total: ([\d\.]+), Flow: ([\d\.]+), Proj: ([\d\.]+)') |
|
|
| with open(log_path, 'r') as f: |
| for line in f: |
| match = pattern.search(line) |
| if match: |
| epoch = int(match.group(1)) |
| step = int(match.group(2)) |
| total = float(match.group(3)) |
| flow = float(match.group(4)) |
| proj = float(match.group(5)) |
| data.append({ |
| 'epoch': epoch, |
| 'step': step, |
| 'total': total, |
| 'flow': flow, |
| 'proj': proj |
| }) |
|
|
| if not data: |
| print("No valid log lines found.") |
| return |
|
|
| df = pd.DataFrame(data) |
|
|
| |
| stats = df[['total', 'flow', 'proj']].agg(['mean', 'std', 'min', 'max']) |
| print("--- Loss Statistics ---") |
| print(stats) |
|
|
| |
| plt.figure(figsize=(12, 6)) |
| plt.plot(df.index, df['total'], label='Total Loss', alpha=0.8, linewidth=1) |
| plt.plot(df.index, df['flow'], label='Flow Loss', alpha=0.8, linewidth=1) |
| plt.plot(df.index, df['proj'], label='Proj Loss', alpha=0.8, linewidth=1) |
| |
| plt.xlabel('Logging Steps') |
| plt.ylabel('Loss Value') |
| plt.title('Training Losses over Time') |
| plt.legend() |
| plt.grid(True, linestyle='--', alpha=0.7) |
| plt.tight_layout() |
| |
| plot_path = 'loss_plot.png' |
| plt.savefig(plot_path, dpi=300) |
| print(f"\nPlot successfully generated and saved to {plot_path}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|