import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from .nn import mean_flat # input image range [-1,1] class VGG(nn.Module): def __init__(self, conv_index='22', rgb_range=1): super(VGG, self).__init__() vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] if conv_index.find('22') >= 0: self.vgg = nn.Sequential(*modules[:8]) elif conv_index.find('54') >= 0: self.vgg = nn.Sequential(*modules[:35]) vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std) for p in self.parameters(): p.requires_grad = False def forward(self, sr, hr): def _forward(x): x = self.sub_mean(x) x = self.vgg(x) return x sr = (sr + 1.)/2. hr = (hr + 1.)/2. vgg_sr = _forward(sr) with torch.no_grad(): vgg_hr = _forward(hr.detach()) loss = mean_flat((vgg_sr - vgg_hr) ** 2) return loss class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std for p in self.parameters(): p.requires_grad = False