Spaces:
Runtime error
Runtime error
import argparse | |
import numpy as np | |
from matplotlib import pyplot as plt | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Plot losses from log") | |
parser.add_argument("--log-file", help="path to log file", required=True) | |
parser.add_argument("--fake-weight", help="weight for fake loss", default=1.4, type=float) | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
with open(args.log_file, "r") as f: | |
lines = f.readlines() | |
real_losses = [] | |
fake_losses = [] | |
for line in lines: | |
line = line.strip() | |
if line.startswith("fake_loss"): | |
fake_losses.append(float(line.split(" ")[-1])) | |
elif line.startswith("real_loss"): | |
real_losses.append(float(line.split(" ")[-1])) | |
real_losses = np.array(real_losses) | |
fake_losses = np.array(fake_losses) | |
loss = (fake_losses * args.fake_weight + real_losses)/2 | |
plt.title("Weighted loss ({}*fake_loss + real_loss)/2)".format(args.fake_weight)) | |
best_loss_idx = np.argsort(loss)[:5] | |
# ignore early epochs loss is quite noisy and there could be spikes | |
best_loss_idx = best_loss_idx[best_loss_idx > 16] | |
plt.scatter(best_loss_idx, loss[best_loss_idx], c="red") | |
for idx in best_loss_idx: | |
plt.annotate(str(idx), (idx, loss[idx])) | |
plt.plot(loss) | |
plt.show() | |
if __name__ == '__main__': | |
main() | |