|
import os
|
|
from importlib import import_module
|
|
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
def sequence_loss(sr, hr, loss_func, gamma=0.8, max_val=None):
|
|
""" Loss function defined over sequence of flow predictions """
|
|
|
|
n_recurrence = len(sr)
|
|
total_loss = 0.0
|
|
buffer=[0.0]*n_recurrence
|
|
|
|
for i in range(n_recurrence):
|
|
i_weight = gamma**(n_recurrence - i - 1)
|
|
i_loss = loss_func(sr[i],hr)
|
|
buffer[i]=i_loss.item()
|
|
|
|
total_loss += i_weight * (i_loss)
|
|
return total_loss,buffer
|
|
|
|
class Loss(nn.modules.loss._Loss):
|
|
def __init__(self, args, ckp):
|
|
super(Loss, self).__init__()
|
|
print('Preparing loss function:')
|
|
self.buffer=[0.0]*args.recurrence
|
|
self.n_GPUs = args.n_GPUs
|
|
self.loss = []
|
|
self.loss_module = nn.ModuleList()
|
|
for loss in args.loss.split('+'):
|
|
weight, loss_type = loss.split('*')
|
|
if loss_type == 'MSE':
|
|
loss_function = nn.MSELoss()
|
|
elif loss_type == 'L1':
|
|
loss_function = nn.L1Loss()
|
|
elif loss_type.find('VGG') >= 0:
|
|
module = import_module('loss.vgg')
|
|
loss_function = getattr(module, 'VGG')(
|
|
loss_type[3:],
|
|
rgb_range=args.rgb_range
|
|
)
|
|
elif loss_type.find('GAN') >= 0:
|
|
module = import_module('loss.adversarial')
|
|
loss_function = getattr(module, 'Adversarial')(
|
|
args,
|
|
loss_type
|
|
)
|
|
|
|
self.loss.append({
|
|
'type': loss_type,
|
|
'weight': float(weight),
|
|
'function': loss_function}
|
|
)
|
|
if loss_type.find('GAN') >= 0:
|
|
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
|
|
|
|
if len(self.loss) > 1:
|
|
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
|
|
|
|
for l in self.loss:
|
|
if l['function'] is not None:
|
|
print('{:.3f} * {}'.format(l['weight'], l['type']))
|
|
|
|
|
|
self.log = torch.Tensor()
|
|
|
|
device = torch.device('cpu' if args.cpu else 'cuda')
|
|
self.loss_module.to(device)
|
|
if args.precision == 'half': self.loss_module.half()
|
|
if not args.cpu and args.n_GPUs > 1:
|
|
self.loss_module = nn.DataParallel(
|
|
self.loss_module, range(args.n_GPUs)
|
|
)
|
|
|
|
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
|
|
|
|
def forward(self, sr, hr):
|
|
losses = []
|
|
for i, l in enumerate(self.loss):
|
|
if l['function'] is not None:
|
|
if isinstance(sr,list):
|
|
|
|
|
|
|
|
|
|
effective_loss,buffer_lst=sequence_loss(sr,hr,l['function'])
|
|
|
|
|
|
|
|
|
|
losses.append(effective_loss)
|
|
self.buffer=buffer_lst
|
|
self.log[-1, i] += effective_loss.item()
|
|
else:
|
|
loss = l['function'](sr, hr)
|
|
effective_loss = l['weight'] * loss
|
|
losses.append(effective_loss)
|
|
self.buffer[0]=effective_loss.item()
|
|
self.log[-1, i] += effective_loss.item()
|
|
elif l['type'] == 'DIS':
|
|
self.log[-1, i] += self.loss[i - 1]['function'].loss
|
|
|
|
loss_sum = sum(losses)
|
|
if len(self.loss) > 1:
|
|
self.log[-1, -1] += loss_sum.item()
|
|
|
|
return loss_sum
|
|
|
|
def step(self):
|
|
for l in self.get_loss_module():
|
|
if hasattr(l, 'scheduler'):
|
|
l.scheduler.step()
|
|
|
|
def start_log(self):
|
|
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
|
|
|
|
def end_log(self, n_batches):
|
|
self.log[-1].div_(n_batches)
|
|
|
|
def display_loss(self, batch):
|
|
n_samples = batch + 1
|
|
log = []
|
|
for l, c in zip(self.loss, self.log[-1]):
|
|
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
|
|
|
|
return ''.join(log)
|
|
|
|
def plot_loss(self, apath, epoch):
|
|
axis = np.linspace(1, epoch, epoch)
|
|
for i, l in enumerate(self.loss):
|
|
label = '{} Loss'.format(l['type'])
|
|
fig = plt.figure()
|
|
plt.title(label)
|
|
plt.plot(axis, self.log[:, i].numpy(), label=label)
|
|
plt.legend()
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Loss')
|
|
plt.grid(True)
|
|
plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
|
|
plt.close(fig)
|
|
|
|
def get_loss_module(self):
|
|
if self.n_GPUs == 1:
|
|
return self.loss_module
|
|
else:
|
|
return self.loss_module.module
|
|
|
|
def save(self, apath):
|
|
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
|
|
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
|
|
|
|
def load(self, apath, cpu=False):
|
|
if cpu:
|
|
kwargs = {'map_location': lambda storage, loc: storage}
|
|
else:
|
|
kwargs = {}
|
|
|
|
self.load_state_dict(torch.load(
|
|
os.path.join(apath, 'loss.pt'),
|
|
**kwargs
|
|
))
|
|
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
|
|
for l in self.get_loss_module():
|
|
if hasattr(l, 'scheduler'):
|
|
for _ in range(len(self.log)): l.scheduler.step()
|
|
|
|
|