import torch import torch.nn as nn import torch.nn.functional as F from losses.vggNet import VGGFeatureExtractor import numpy as np class PerceptualLoss(nn.Module): """Perceptual loss with commonly used style loss. Args: layer_weights (dict): The weight for each layer of vgg feature. Here is an example: {'conv5_4': 1.}, which means the conv5_4 feature layer (before relu5_4) will be extracted with weight 1.0 in calculting losses. vgg_type (str): The type of vgg network used as feature extractor. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image in vgg. Default: True. perceptual_weight (float): If `perceptual_weight > 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. norm_img (bool): If True, the image will be normed to [0, 1]. Note that this is different from the `use_input_norm` which norm the input in in forward function of vgg according to the statistics of dataset. Importantly, the input image must be in range [-1, 1]. Default: False. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, use_pcp_loss=True, use_style_loss=False, norm_img=True, criterion='l1'): super(PerceptualLoss, self).__init__() self.norm_img = norm_img self.use_pcp_loss = use_pcp_loss self.use_style_loss = use_style_loss self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm) self.criterion_type = criterion if self.criterion_type == 'l1': self.criterion = torch.nn.L1Loss() elif self.criterion_type == 'l2': self.criterion = torch.nn.L2loss() elif self.criterion_type == 'fro': self.criterion = None else: raise NotImplementedError('%s criterion has not been supported.' % self.criterion_type) def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.norm_img: x = (x + 1.) * 0.5 gt = (gt + 1.) * 0.5 # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.use_pcp_loss: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == 'fro': percep_loss += torch.norm( x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] else: percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] else: percep_loss = None # calculate style loss if self.use_style_loss: style_loss = 0 for k in x_features.keys(): if self.criterion_type == 'fro': style_loss += torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] else: style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) \ * self.layer_weights[k] else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram