YoloGesture / utils /callbacks.py
Kedreamix's picture
YoloGesture推理主要代码
4a3ab35
raw
history blame
No virus
2.32 kB
import datetime
import os
import torch
import matplotlib
matplotlib.use('Agg')
import scipy.signal
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
class LossHistory():
def __init__(self, log_dir, model, input_shape):
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
self.losses = []
self.val_loss = []
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(model, dummy_input)
except:
pass
def append_loss(self, epoch, loss, val_loss):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
self.loss_plot()
def loss_plot(self):
iters = range(len(self.losses))
plt.figure()
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")