import os import torch import matplotlib matplotlib.use('Agg') import scipy.signal from matplotlib import pyplot as plt from torch.utils.tensorboard import SummaryWriter class LossHistory(): def __init__(self, log_dir, model, input_shape): self.log_dir = log_dir os.makedirs(self.log_dir) self.writer = SummaryWriter(self.log_dir) try: for m in model: dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) self.writer.add_graph(m, dummy_input) except: pass def append_loss(self, epoch, **kwargs): if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) for key, value in kwargs.items(): if not hasattr(self, key): setattr(self, key, []) #---------------------------------# # 为列表添加数值 #---------------------------------# getattr(self, key).append(value) #---------------------------------# # 写入txt #---------------------------------# with open(os.path.join(self.log_dir, key + ".txt"), 'a') as f: f.write(str(value)) f.write("\n") #---------------------------------# # 写入tensorboard #---------------------------------# self.writer.add_scalar(key, value, epoch) self.loss_plot(**kwargs) def loss_plot(self, **kwargs): plt.figure() for key, value in kwargs.items(): losses = getattr(self, key) plt.plot(range(len(losses)), losses, linewidth = 2, label = key) plt.grid(True) plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend(loc="upper right") plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) plt.cla() plt.close("all")