GCycleGAN / utils /callbacks.py
Egrt's picture
init
95e767b
import os
import torch
import matplotlib
matplotlib.use('Agg')
import scipy.signal
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
class LossHistory():
def __init__(self, log_dir, model, input_shape):
self.log_dir = log_dir
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
for m in model:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(m, dummy_input)
except:
pass
def append_loss(self, epoch, **kwargs):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
for key, value in kwargs.items():
if not hasattr(self, key):
setattr(self, key, [])
#---------------------------------#
# 为列表添加数值
#---------------------------------#
getattr(self, key).append(value)
#---------------------------------#
# 写入txt
#---------------------------------#
with open(os.path.join(self.log_dir, key + ".txt"), 'a') as f:
f.write(str(value))
f.write("\n")
#---------------------------------#
# 写入tensorboard
#---------------------------------#
self.writer.add_scalar(key, value, epoch)
self.loss_plot(**kwargs)
def loss_plot(self, **kwargs):
plt.figure()
for key, value in kwargs.items():
losses = getattr(self, key)
plt.plot(range(len(losses)), losses, linewidth = 2, label = key)
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")