GCycleGAN / utils /callbacks.py
Egrt's picture
init
95e767b
raw history blame
No virus
1.99 kB
import os
import torch
import matplotlib
matplotlib.use('Agg')
import scipy.signal
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
class LossHistory():
def __init__(self, log_dir, model, input_shape):
self.log_dir = log_dir
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
for m in model:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(m, dummy_input)
except:
pass
def append_loss(self, epoch, **kwargs):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
for key, value in kwargs.items():
if not hasattr(self, key):
setattr(self, key, [])
#---------------------------------#
# 为列表添加数值
#---------------------------------#
getattr(self, key).append(value)
#---------------------------------#
# 写入txt
#---------------------------------#
with open(os.path.join(self.log_dir, key + ".txt"), 'a') as f:
f.write(str(value))
f.write("\n")
#---------------------------------#
# 写入tensorboard
#---------------------------------#
self.writer.add_scalar(key, value, epoch)
self.loss_plot(**kwargs)
def loss_plot(self, **kwargs):
plt.figure()
for key, value in kwargs.items():
losses = getattr(self, key)
plt.plot(range(len(losses)), losses, linewidth = 2, label = key)
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")