|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import json |
|
import torch |
|
import matplotlib.gridspec as gridspec |
|
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset |
|
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar |
|
|
|
|
|
def get_perplexity(filename): |
|
with open(filename, "r") as file: |
|
lines = file.readlines() |
|
|
|
|
|
for line in reversed(lines): |
|
if "evaluation perplexity:" in line: |
|
_, perplexity = line.split("evaluation perplexity:") |
|
return float(perplexity) |
|
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
kernel_size = 40 |
|
min_loss = 14 |
|
max_scaler = 1 |
|
log_level = 1 |
|
|
|
fig = plt.figure(figsize=(6 * 3, 5 * 3)) |
|
gs = gridspec.GridSpec(3, 3) |
|
|
|
exp_dir = "/fsx/home-mitchellw/experimetns/lm/" |
|
|
|
ax = fig.add_subplot(gs[0, 0]) |
|
|
|
for j, base in enumerate( |
|
[ |
|
|
|
"/fsx/home-mitchellw/experimetns/lmtune/instruction-tune-3b-2e-5-6", |
|
] |
|
): |
|
xs, ys, colors = [], [], [] |
|
for alpha in np.arange(0, 1.01, 0.05): |
|
chat_eval = f"{base}/checkpoints/chat-eval-interpolate-{alpha:.2f}-epoch_6.pt" |
|
base_eval = f"{base}/checkpoints/base-eval-interpolate-{alpha:.2f}-epoch_6.pt" |
|
if os.path.exists(chat_eval) and os.path.exists(base_eval): |
|
chat_y = get_perplexity(chat_eval) |
|
base_y = get_perplexity(base_eval) |
|
if chat_y is None or base_y is None: |
|
continue |
|
print(alpha) |
|
xs.append(base_y) |
|
ys.append(chat_y) |
|
colors.append(1 - alpha) |
|
|
|
scatter = ax.scatter( |
|
xs, |
|
ys, |
|
c=colors, |
|
cmap="cool", |
|
marker="d" if "3B" in base else "o", |
|
label="OpenLM-1B" if "3B" in base else "OpenLM-3B", |
|
) |
|
|
|
ax.set_xlabel("Base evaluation set (perplexity)", fontsize=12) |
|
ax.set_ylabel("Chat evaluation set (perplexity)", fontsize=12) |
|
|
|
ax.tick_params(axis="x", labelsize=11) |
|
ax.tick_params(axis="y", labelsize=11) |
|
ax.grid() |
|
|
|
ax.legend(fontsize=12) |
|
|
|
|
|
cbar = plt.colorbar(scatter) |
|
cbar.set_label( |
|
"Interpolation coefficient when interpolating\nbetween base and chat models", |
|
labelpad=10, |
|
) |
|
|
|
plt.savefig("plots/interpolation.png", bbox_inches="tight") |
|
|