from collections import namedtuple import torch from torchvision import models from src.utils import utils """ More detail about the VGG architecture (if you want to understand magic/hardcoded numbers) can be found here: https://github.com/pytorch/vision/blob/3c254fb7af5f8af252c24e89949c54a3461ff0be/torchvision/models/vgg.py """ class Vgg16(torch.nn.Module): """Only those layers are exposed which have already proven to work nicely.""" def __init__(self, requires_grad=False, show_progress=False): super().__init__() vgg_pretrained_features = models.vgg16(pretrained=True, progress=show_progress).features self.layer_names = {'relu1_2': 1, 'relu2_2': 2, 'relu3_3': 3, 'relu4_3': 4} self.content_feature_maps_index = self.layer_names[ utils.yamlGet('contentLayer')]-1 # relu2_2 self.style_feature_maps_indices = list(range(len( self.layer_names))) # all layers used for style representation self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): x = self.slice1(x) relu1_2 = x x = self.slice2(x) relu2_2 = x x = self.slice3(x) relu3_3 = x x = self.slice4(x) relu4_3 = x vgg_outputs = namedtuple("VggOutputs", self.layer_names.keys()) out = vgg_outputs(relu1_2, relu2_2, relu3_3, relu4_3) return out class Vgg16Experimental(torch.nn.Module): """Everything exposed so you can play with different combinations for style and content representation""" def __init__(self, requires_grad=False, show_progress=False): super().__init__() vgg_pretrained_features = models.vgg16(pretrained=True, progress=show_progress).features self.layer_names = [ 'relu1_1', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu4_1', 'relu4_3', 'relu5_1' ] self.content_feature_maps_index = 4 self.style_feature_maps_indices = list(range(len( self.layer_names))) # all layers used for style representation self.conv1_1 = vgg_pretrained_features[0] self.relu1_1 = vgg_pretrained_features[1] self.conv1_2 = vgg_pretrained_features[2] self.relu1_2 = vgg_pretrained_features[3] self.max_pooling1 = vgg_pretrained_features[4] self.conv2_1 = vgg_pretrained_features[5] self.relu2_1 = vgg_pretrained_features[6] self.conv2_2 = vgg_pretrained_features[7] self.relu2_2 = vgg_pretrained_features[8] self.max_pooling2 = vgg_pretrained_features[9] self.conv3_1 = vgg_pretrained_features[10] self.relu3_1 = vgg_pretrained_features[11] self.conv3_2 = vgg_pretrained_features[12] self.relu3_2 = vgg_pretrained_features[13] self.conv3_3 = vgg_pretrained_features[14] self.relu3_3 = vgg_pretrained_features[15] self.max_pooling3 = vgg_pretrained_features[16] self.conv4_1 = vgg_pretrained_features[17] self.relu4_1 = vgg_pretrained_features[18] self.conv4_2 = vgg_pretrained_features[19] self.relu4_2 = vgg_pretrained_features[20] self.conv4_3 = vgg_pretrained_features[21] self.relu4_3 = vgg_pretrained_features[22] self.max_pooling4 = vgg_pretrained_features[23] self.conv5_1 = vgg_pretrained_features[24] self.relu5_1 = vgg_pretrained_features[25] self.conv5_2 = vgg_pretrained_features[26] self.relu5_2 = vgg_pretrained_features[27] self.conv5_3 = vgg_pretrained_features[28] self.relu5_3 = vgg_pretrained_features[29] self.max_pooling5 = vgg_pretrained_features[30] if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): x = self.conv1_1(x) conv1_1 = x x = self.relu1_1(x) relu1_1 = x x = self.conv1_2(x) conv1_2 = x x = self.relu1_2(x) relu1_2 = x x = self.max_pooling1(x) x = self.conv2_1(x) conv2_1 = x x = self.relu2_1(x) relu2_1 = x x = self.conv2_2(x) conv2_2 = x x = self.relu2_2(x) relu2_2 = x x = self.max_pooling2(x) x = self.conv3_1(x) conv3_1 = x x = self.relu3_1(x) relu3_1 = x x = self.conv3_2(x) conv3_2 = x x = self.relu3_2(x) relu3_2 = x x = self.conv3_3(x) conv3_3 = x x = self.relu3_3(x) relu3_3 = x x = self.max_pooling3(x) x = self.conv4_1(x) conv4_1 = x x = self.relu4_1(x) relu4_1 = x x = self.conv4_2(x) conv4_2 = x x = self.relu4_2(x) relu4_2 = x x = self.conv4_3(x) conv4_3 = x x = self.relu4_3(x) relu4_3 = x x = self.max_pooling4(x) x = self.conv5_1(x) conv5_1 = x x = self.relu5_1(x) relu5_1 = x x = self.conv5_2(x) conv5_2 = x x = self.relu5_2(x) relu5_2 = x x = self.conv5_3(x) conv5_3 = x x = self.relu5_3(x) relu5_3 = x x = self.max_pooling5(x) # expose only the layers that you want to experiment with here vgg_outputs = namedtuple("VggOutputs", self.layer_names) out = vgg_outputs(relu1_1, relu2_1, relu2_2, relu3_1, relu3_2, relu4_1, relu4_3, relu5_1) return out class Vgg19(torch.nn.Module): """ Used in the original NST paper, only those layers are exposed which were used in the original paper 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1' were used for style representation 'conv4_2' was used for content representation (although they did some experiments with conv2_2 and conv5_2) """ def __init__(self, requires_grad=False, show_progress=False, use_relu=True): super().__init__() vgg_pretrained_features = models.vgg19(pretrained=True, progress=show_progress).features if use_relu: # use relu or as in original paper conv layers self.layer_names = [ 'relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1' ] self.offset = 1 else: self.layer_names = [ 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv4_2', 'conv5_1' ] self.offset = 0 self.content_feature_maps_index = 4 # conv4_2 # all layers used for style representation except conv4_2 self.style_feature_maps_indices = list(range(len(self.layer_names))) self.style_feature_maps_indices.remove(4) # conv4_2 self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.slice6 = torch.nn.Sequential() for x in range(1 + self.offset): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(1 + self.offset, 6 + self.offset): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(6 + self.offset, 11 + self.offset): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(11 + self.offset, 20 + self.offset): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(20 + self.offset, 22): self.slice5.add_module(str(x), vgg_pretrained_features[x]) for x in range(22, 29 + +self.offset): self.slice6.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): x = self.slice1(x) layer1_1 = x x = self.slice2(x) layer2_1 = x x = self.slice3(x) layer3_1 = x x = self.slice4(x) layer4_1 = x x = self.slice5(x) conv4_2 = x x = self.slice6(x) layer5_1 = x vgg_outputs = namedtuple("VggOutputs", self.layer_names) out = vgg_outputs(layer1_1, layer2_1, layer3_1, layer4_1, conv4_2, layer5_1) return out