import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from torch.nn.functional import conv2d class VGG16(nn.Module): def __init__(self): super(VGG16, self).__init__() vgg16 = models.vgg16(pretrained=True) self.enc_1 = nn.Sequential(*vgg16.features[:5]) self.enc_2 = nn.Sequential(*vgg16.features[5:10]) self.enc_3 = nn.Sequential(*vgg16.features[10:17]) # fix the encoder for i in range(3): for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): param.requires_grad = False def forward(self, image): results = [image] for i in range(3): func = getattr(self, 'enc_{:d}'.format(i + 1)).to(image.device) results.append(func(results[-1])) return results[1:] class VGG19(nn.Module): def __init__(self, resize_input=False): super(VGG19, self).__init__() features = models.vgg19(pretrained=True).features self.resize_input = resize_input self.mean = torch.Tensor([0.485, 0.456, 0.406]).cuda() self.std = torch.Tensor([0.229, 0.224, 0.225]).cuda() prefix = [1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5] posfix = [1, 2, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4] names = list(zip(prefix, posfix)) self.relus = [] for pre, pos in names: self.relus.append('relu{}_{}'.format(pre, pos)) self.__setattr__('relu{}_{}'.format( pre, pos), torch.nn.Sequential()) nums = [[0, 1], [2, 3], [4, 5, 6], [7, 8], [9, 10, 11], [12, 13], [14, 15], [16, 17], [18, 19, 20], [21, 22], [23, 24], [25, 26], [27, 28, 29], [30, 31], [32, 33], [34, 35]] for i, layer in enumerate(self.relus): for num in nums[i]: self.__getattr__(layer).add_module(str(num), features[num]) # don't need the gradients, just want the features for param in self.parameters(): param.requires_grad = False def forward(self, x): # resize and normalize input for pretrained vgg19 x = (x + 1.0) / 2.0 x = (x - self.mean.view(1, 3, 1, 1).to(x.device)) / (self.std.view(1, 3, 1, 1).to(x.device)) if self.resize_input: x = F.interpolate( x, size=(256, 256), mode='bilinear', align_corners=True) features = [] for layer in self.relus: x = self.__getattr__(layer).to(x.device)(x) features.append(x) out = {key: value for (key, value) in list(zip(self.relus, features))} return out