import torch import torch.nn as nn import torchvision class FeatureExtractor(nn.Module): def __init__(self, cnn, feature_layer=11): super(FeatureExtractor, self).__init__() self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)]) def normalize(self, tensors, mean, std): if not torch.is_tensor(tensors): raise TypeError('tensor is not a torch image.') for tensor in tensors: for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) return tensors def forward(self, x): # it image is gray scale then make it to 3 channel if x.size()[1] == 1: x = x.expand(-1, 3, -1, -1) # [-1: 1] image to [0:1] image---------------------------------------------------(1) x = (x + 1) * 0.5 # https://pytorch.org/docs/stable/torchvision/models.html x.data = self.normalize(x.data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return self.features(x) # Feature extracting using vgg19 vgg19 = torchvision.models.vgg19(pretrained=True) feature_extractor = FeatureExtractor(vgg19, feature_layer=35) feature_extractor.eval() class VGG19Loss(object): def __init__(self): global feature_extractor self.initialized = False self.feature_extractor = feature_extractor self.MSE = nn.MSELoss() def __call__(self, output, target, device): if self.initialized == False: self.feature_extractor = self.feature_extractor.to(device) self.MSE = self.MSE.to(device) self.initialized = True # [-1: 1] image to [0:1] image---------------------------------------------------(2) output = (output + 1) * 0.5 target = (target + 1) * 0.5 output = self.feature_extractor(output) target = self.feature_extractor(target).data return self.MSE(output, target)