zz / scripts /analysis /evaluate.py
lzwjava's picture
refactor: reorganize project structure
4f685ca
#!/usr/bin/env python3
import re
import matplotlib.pyplot as plt
import numpy as np
import argparse
from pathlib import Path
def parse_training_log(log_file_path):
"""Parse training log file and extract metrics"""
step_data = []
iter_data = []
with open(log_file_path, "r") as f:
for line in f:
line = line.strip()
# Parse step lines (training and validation loss)
step_match = re.match(
r"step (\d+): train loss ([\d.]+), val loss ([\d.]+)", line
)
if step_match:
step_num = int(step_match.group(1))
train_loss = float(step_match.group(2))
val_loss = float(step_match.group(3))
step_data.append(
{"step": step_num, "train_loss": train_loss, "val_loss": val_loss}
)
# Parse iteration lines (loss, time, mfu)
iter_match = re.match(
r"iter (\d+): loss ([\d.]+), time ([\d.]+)ms, mfu ([\d.-]+)%", line
)
if iter_match:
iter_num = int(iter_match.group(1))
loss = float(iter_match.group(2))
time_ms = float(iter_match.group(3))
mfu = float(iter_match.group(4))
iter_data.append(
{"iter": iter_num, "loss": loss, "time_ms": time_ms, "mfu": mfu}
)
return step_data, iter_data
def create_visualizations(step_data, iter_data):
"""Create matplotlib visualizations of training metrics"""
# Set up the plotting style
plt.style.use("default")
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["font.size"] = 10
# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle("nanoGPT Training Metrics - RTX 4070", fontsize=16, fontweight="bold")
# Plot 1: Training and Validation Loss over Steps
if step_data:
steps = [d["step"] for d in step_data]
train_losses = [d["train_loss"] for d in step_data]
val_losses = [d["val_loss"] for d in step_data]
axes[0, 0].plot(steps, train_losses, "b-", label="Training Loss", linewidth=2)
axes[0, 0].plot(steps, val_losses, "r-", label="Validation Loss", linewidth=2)
axes[0, 0].set_xlabel("Training Step")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training vs Validation Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Plot 2: Iteration Loss over Iterations
if iter_data:
iters = [d["iter"] for d in iter_data]
iter_losses = [d["loss"] for d in iter_data]
axes[0, 1].plot(iters, iter_losses, "g-", linewidth=1.5)
axes[0, 1].set_xlabel("Iteration")
axes[0, 1].set_ylabel("Loss")
axes[0, 1].set_title("Loss per Iteration")
axes[0, 1].grid(True, alpha=0.3)
# Plot 3: Training Time per Iteration
if iter_data:
iters = [d["iter"] for d in iter_data]
times = [d["time_ms"] for d in iter_data]
axes[1, 0].plot(iters, times, "orange", linewidth=1.5)
axes[1, 0].set_xlabel("Iteration")
axes[1, 0].set_ylabel("Time (ms)")
axes[1, 0].set_title("Training Time per Iteration")
axes[1, 0].grid(True, alpha=0.3)
# Plot 4: Model FLOP Utilization (MFU)
if iter_data:
iters = [d["iter"] for d in iter_data]
mfus = [d["mfu"] for d in iter_data]
axes[1, 1].plot(iters, mfus, "purple", linewidth=1.5)
axes[1, 1].set_xlabel("Iteration")
axes[1, 1].set_ylabel("MFU (%)")
axes[1, 1].set_title("Model FLOP Utilization (MFU)")
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
# Only show plot if not in headless environment
import os
if os.environ.get("DISPLAY") is not None and os.environ.get("DISPLAY") != "":
plt.show()
return fig
def print_statistics(step_data, iter_data):
"""Print basic statistics about the training"""
print("=" * 50)
print("TRAINING STATISTICS")
print("=" * 50)
if step_data:
final_step = step_data[-1]
print(f"Total Steps Completed: {final_step['step']}")
print(f"Final Training Loss: {final_step['train_loss']:.4f}")
print(f"Final Validation Loss: {final_step['val_loss']:.4f}")
# Calculate loss improvement
initial_loss = step_data[0]["train_loss"]
final_loss = final_step["train_loss"]
improvement = ((initial_loss - final_loss) / initial_loss) * 100
print(f"Training Loss Improvement: {improvement:.1f}%")
if iter_data:
final_iter = iter_data[-1]
print(f"\nTotal Iterations: {final_iter['iter']}")
# Calculate average time
avg_time = np.mean([d["time_ms"] for d in iter_data])
print(f"Average Time per Iteration: {avg_time:.1f}ms")
# Calculate average MFU
valid_mfus = [
d["mfu"] for d in iter_data if d["mfu"] > -50
] # Filter out invalid values
if valid_mfus:
avg_mfu = np.mean(valid_mfus)
print(f"Average MFU: {avg_mfu:.1f}%")
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(
description="Parse and visualize nanoGPT training logs"
)
parser.add_argument("--file", "-f", type=str, help="Path to training log file")
args = parser.parse_args()
# File paths
if args.file:
log_file = Path(args.file)
else:
log_file = Path(__file__).parent / "train_log_openweb.txt"
if not log_file.exists():
print(f"Error: Log file {log_file} not found!")
return
print(f"Parsing training log: {log_file}")
# Parse the log data
step_data, iter_data = parse_training_log(log_file)
print(f"Found {len(step_data)} step records and {len(iter_data)} iteration records")
# Print statistics
print_statistics(step_data, iter_data)
# Create visualizations
fig = create_visualizations(step_data, iter_data)
# Save the plot
output_path = Path(__file__).parent / "training_metrics.png"
fig.savefig(output_path, dpi=300, bbox_inches="tight")
print(f"\nVisualization saved to: {output_path}")
# Also save as PDF for high quality
pdf_path = Path(__file__).parent / "training_metrics.pdf"
fig.savefig(pdf_path, bbox_inches="tight")
print(f"High-quality PDF saved to: {pdf_path}")
plt.show()
if __name__ == "__main__":
main()