|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import matplotlib as mpl |
|
import argparse |
|
|
|
|
|
mpl.rcParams['font.family'] = 'serif' |
|
mpl.rcParams['font.serif'] = ['Georgia'] |
|
mpl.rcParams['font.size'] = 20 |
|
mpl.rcParams['axes.titlesize']= 20 |
|
mpl.rcParams['axes.labelsize']= 18 |
|
mpl.rcParams['xtick.labelsize']=16 |
|
mpl.rcParams['ytick.labelsize']=16 |
|
|
|
|
|
def plot_two_loss_curves( |
|
csv_file1, |
|
csv_file2, |
|
title="Loss Comparison on Qwen3-8B", |
|
dataset1_name="Dataset1", |
|
dataset2_name="Dataset2" |
|
): |
|
|
|
df1 = pd.read_csv(csv_file1) |
|
df2 = pd.read_csv(csv_file2) |
|
|
|
|
|
for df, path in ((df1, csv_file1), (df2, csv_file2)): |
|
if 'Step' not in df.columns or 'Loss' not in df.columns: |
|
raise ValueError(f"Missing 'Step' or 'Loss' columns in {path}") |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
|
|
plt.plot(df1['Step'], df1['Loss'], |
|
color='#1f77b4', linewidth=2.5) |
|
plt.plot(df2['Step'], df2['Loss'], |
|
color='#2ca02c', linewidth=2.5) |
|
|
|
|
|
plt.title(title, fontweight='bold') |
|
plt.xlabel('Steps', fontweight='bold') |
|
plt.ylabel('Loss', fontweight='bold') |
|
|
|
|
|
plt.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
plt.tight_layout(pad=3.0) |
|
|
|
|
|
plt.savefig('loss_comparison_qwen38b.svg', format='svg') |
|
plt.savefig('loss_comparison.png', dpi=300) |
|
|
|
|
|
plt.show() |
|
|
|
print("Saved: loss_comparison.svg, loss_comparison.png") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Plot comparison of two training loss curves') |
|
parser.add_argument('csv_file1', help='Path to the first CSV file') |
|
parser.add_argument('csv_file2', help='Path to the second CSV file') |
|
parser.add_argument('--title', default='Training Loss Comparison', help='Title for the plot') |
|
parser.add_argument('--dataset1-name', default='Original Dataset', help='Name for the first dataset') |
|
parser.add_argument('--dataset2-name', default='Revised Dataset', help='Name for the second dataset') |
|
|
|
args = parser.parse_args() |
|
|
|
plot_two_loss_curves( |
|
args.csv_file1, |
|
args.csv_file2, |
|
title=args.title, |
|
dataset1_name=args.dataset1_name, |
|
dataset2_name=args.dataset2_name |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |