|
import os |
|
import math |
|
import matplotlib.pyplot as plt |
|
|
|
files = [f"models/{x}" for x in os.listdir("models") if x.endswith(".csv")] |
|
train_loss = {} |
|
eval_loss = {} |
|
|
|
def process_lines(lines): |
|
global train_loss |
|
global eval_loss |
|
name = fp.split("/")[1] |
|
vals = [x.split(",") for x in lines] |
|
train_loss[name] = ( |
|
[int(x[0]) for x in vals], |
|
[math.log(float(x[1])) for x in vals], |
|
) |
|
if len(vals[0]) >= 3: |
|
eval_loss[name] = ( |
|
[int(x[0]) for x in vals], |
|
[math.log(float(x[2])) for x in vals], |
|
) |
|
|
|
|
|
def smooth(scalars, weight): |
|
last = scalars[0] |
|
smoothed = list() |
|
for point in scalars: |
|
smoothed_val = last * weight + (1 - weight) * point |
|
smoothed.append(smoothed_val) |
|
last = smoothed_val |
|
return smoothed |
|
|
|
def plot(data, fname): |
|
fig, ax = plt.subplots() |
|
ax.grid() |
|
for name, val in data.items(): |
|
ax.plot(val[0], smooth(val[1], 0.9), label=name) |
|
plt.legend(loc="upper right") |
|
plt.savefig(fname, dpi=300, bbox_inches='tight') |
|
|
|
for fp in files: |
|
with open(fp) as f: |
|
lines = f.readlines() |
|
process_lines(lines) |
|
|
|
plot(train_loss, "loss.png") |
|
plot(eval_loss, "loss-eval.png") |
|
|