LastingBench / utils /draw.py
kixx's picture
Upload 34 files
b1e25b1 verified
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import argparse
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Georgia']
mpl.rcParams['font.size'] = 20
mpl.rcParams['axes.titlesize']= 20
mpl.rcParams['axes.labelsize']= 18
mpl.rcParams['xtick.labelsize']=16
mpl.rcParams['ytick.labelsize']=16
# no legend, so no need to set legend.fontsize
def plot_two_loss_curves(
csv_file1,
csv_file2,
title="Loss Comparison on Qwen3-8B",
dataset1_name="Dataset1",
dataset2_name="Dataset2"
):
# Read CSV files
df1 = pd.read_csv(csv_file1)
df2 = pd.read_csv(csv_file2)
# Check columns
for df, path in ((df1, csv_file1), (df2, csv_file2)):
if 'Step' not in df.columns or 'Loss' not in df.columns:
raise ValueError(f"Missing 'Step' or 'Loss' columns in {path}")
# Create figure
plt.figure(figsize=(12, 8))
# Plot two lines with softer colors
plt.plot(df1['Step'], df1['Loss'],
color='#1f77b4', linewidth=2.5) # steel blue
plt.plot(df2['Step'], df2['Loss'],
color='#2ca02c', linewidth=2.5) # medium sea green
# Title and labels
plt.title(title, fontweight='bold')
plt.xlabel('Steps', fontweight='bold')
plt.ylabel('Loss', fontweight='bold')
# Grid
plt.grid(True, linestyle='--', alpha=0.7)
# Layout
plt.tight_layout(pad=3.0)
# Save
plt.savefig('loss_comparison_qwen38b.svg', format='svg')
plt.savefig('loss_comparison.png', dpi=300)
# Display
plt.show()
print("Saved: loss_comparison.svg, loss_comparison.png")
def main():
parser = argparse.ArgumentParser(description='Plot comparison of two training loss curves')
parser.add_argument('csv_file1', help='Path to the first CSV file')
parser.add_argument('csv_file2', help='Path to the second CSV file')
parser.add_argument('--title', default='Training Loss Comparison', help='Title for the plot')
parser.add_argument('--dataset1-name', default='Original Dataset', help='Name for the first dataset')
parser.add_argument('--dataset2-name', default='Revised Dataset', help='Name for the second dataset')
args = parser.parse_args()
plot_two_loss_curves(
args.csv_file1,
args.csv_file2,
title=args.title,
dataset1_name=args.dataset1_name,
dataset2_name=args.dataset2_name
)
if __name__ == "__main__":
main()