max
reinit
b6dd358
raw history blame
No virus
4.59 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from losses.vggNet import VGGFeatureExtractor
import numpy as np
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculting losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
Default: False.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
use_pcp_loss=True,
use_style_loss=False,
norm_img=True,
criterion='l1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
self.use_pcp_loss = use_pcp_loss
self.use_style_loss = use_style_loss
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(
layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = torch.nn.L1Loss()
elif self.criterion_type == 'l2':
self.criterion = torch.nn.L2loss()
elif self.criterion_type == 'fro':
self.criterion = None
else:
raise NotImplementedError('%s criterion has not been supported.' % self.criterion_type)
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate perceptual loss
if self.use_pcp_loss:
percep_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
percep_loss += torch.norm(
x_features[k] - gt_features[k],
p='fro') * self.layer_weights[k]
else:
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
# calculate style loss
if self.use_style_loss:
style_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
style_loss += torch.norm(
self._gram_mat(x_features[k]) -
self._gram_mat(gt_features[k]),
p='fro') * self.layer_weights[k]
else:
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) \
* self.layer_weights[k]
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram