PandA / networks /genforce /models /perceptual_model.py
james-oldfield's picture
Upload 194 files
2a76164
raw
history blame
No virus
6.53 kB
# python 3.7
"""Contains the VGG16 model for perceptual feature extraction.
This file is particularly used for computing perceptual loss and hence is highly
recommended to use with pre-trained weights.
The PyTorch weights can be downloaded from
https://drive.google.com/file/d/1qQ-r7MYZ8ZcjQQFe17eQfJbOAuE3eS0y/view?usp=sharing
which is converted from the Keras model
https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
The variable mapping is shown below
pth_to_tf_var_mapping = {
'layer0.weight': 'block1_conv1_W_1:0', # [64, 3, 3, 3]
'layer0.bias': 'block1_conv1_b_1:0', # [64]
'layer2.weight': 'block1_conv2_W_1:0', # [64, 64, 3, 3]
'layer2.bias': 'block1_conv2_b_1:0', # [64]
'layer5.weight': 'block2_conv1_W_1:0', # [128, 64, 3, 3]
'layer5.bias': 'block2_conv1_b_1:0', # [128]
'layer7.weight': 'block2_conv2_W_1:0', # [128, 128, 3, 3]
'layer7.bias': 'block2_conv2_b_1:0', # [128]
'layer10.weight': 'block3_conv1_W_1:0', # [256, 128, 3, 3]
'layer10.bias': 'block3_conv1_b_1:0', # [256]
'layer12.weight': 'block3_conv2_W_1:0', # [256, 256, 3, 3]
'layer12.bias': 'block3_conv2_b_1:0', # [256]
'layer14.weight': 'block3_conv3_W_1:0', # [256, 256, 3, 3]
'layer14.bias': 'block3_conv3_b_1:0', # [256]
'layer17.weight': 'block4_conv1_W_1:0', # [512, 256, 3, 3]
'layer17.bias': 'block4_conv1_b_1:0', # [512]
'layer19.weight': 'block4_conv2_W_1:0', # [512, 512, 3, 3]
'layer19.bias': 'block4_conv2_b_1:0', # [512]
'layer21.weight': 'block4_conv3_W_1:0', # [512, 512, 3, 3]
'layer21.bias': 'block4_conv3_b_1:0', # [512]
'layer24.weight': 'block5_conv1_W_1:0', # [512, 512, 3, 3]
'layer24.bias': 'block5_conv1_b_1:0', # [512]
'layer26.weight': 'block5_conv2_W_1:0', # [512, 512, 3, 3]
'layer26.bias': 'block5_conv2_b_1:0', # [512]
'layer28.weight': 'block5_conv3_W_1:0', # [512, 512, 3, 3]
'layer28.bias': 'block5_conv3_b_1:0', # [512]
}
"""
import os
import warnings
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
__all__ = ['PerceptualModel']
_MEAN_STATS = (103.939, 116.779, 123.68)
class PerceptualModel(nn.Module):
"""Defines the VGG16 structure as the perceptual network.
This model takes `RGB` images with data format `NCHW` as the raw inputs, and
outputs the perceptual feature. This following operations will be performed
to preprocess the inputs to match the preprocessing during the model
training:
(1) Shift pixel range to [0, 255].
(2) Change channel order to `BGR`.
(3) Subtract the statistical mean.
NOTE: The three fully connected layers on top of the model are dropped.
"""
def __init__(self,
output_layer_idx=23,
min_val=-1.0,
max_val=1.0,
pretrained_weight_path=None):
"""Defines the network structure.
Args:
output_layer_idx: Index of layer whose output will be used as the
perceptual feature. (default: 23, which is the `block4_conv3`
layer activated by `ReLU` function)
min_val: Minimum value of the raw input. (default: -1.0)
max_val: Maximum value of the raw input. (default: 1.0)
pretrained_weight_path: Path to the pretrained weights.
(default: None)
"""
super().__init__()
self.vgg16 = nn.Sequential(OrderedDict({
'layer0': nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
'layer1': nn.ReLU(inplace=True),
'layer2': nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
'layer3': nn.ReLU(inplace=True),
'layer4': nn.MaxPool2d(kernel_size=2, stride=2),
'layer5': nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
'layer6': nn.ReLU(inplace=True),
'layer7': nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
'layer8': nn.ReLU(inplace=True),
'layer9': nn.MaxPool2d(kernel_size=2, stride=2),
'layer10': nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
'layer11': nn.ReLU(inplace=True),
'layer12': nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
'layer13': nn.ReLU(inplace=True),
'layer14': nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
'layer15': nn.ReLU(inplace=True),
'layer16': nn.MaxPool2d(kernel_size=2, stride=2),
'layer17': nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
'layer18': nn.ReLU(inplace=True),
'layer19': nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
'layer20': nn.ReLU(inplace=True),
'layer21': nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
'layer22': nn.ReLU(inplace=True),
'layer23': nn.MaxPool2d(kernel_size=2, stride=2),
'layer24': nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
'layer25': nn.ReLU(inplace=True),
'layer26': nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
'layer27': nn.ReLU(inplace=True),
'layer28': nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
'layer29': nn.ReLU(inplace=True),
'layer30': nn.MaxPool2d(kernel_size=2, stride=2),
}))
self.output_layer_idx = output_layer_idx
self.min_val = min_val
self.max_val = max_val
self.mean = torch.from_numpy(np.array(_MEAN_STATS)).view(1, 3, 1, 1)
self.mean = self.mean.type(torch.FloatTensor)
self.pretrained_weight_path = pretrained_weight_path
if os.path.isfile(self.pretrained_weight_path):
self.vgg16.load_state_dict(
torch.load(self.pretrained_weight_path, map_location='cpu'))
else:
warnings.warn('No pre-trained weights found for perceptual model!')
def forward(self, x):
x = (x - self.min_val) * 255.0 / (self.max_val - self.min_val)
x = x.flip(1) # RGB to BGR
x = x - self.mean.to(x)
# TODO: Resize image?
for idx, layer in enumerate(self.vgg16.children()):
if idx == self.output_layer_idx:
break
x = layer(x)
# x = x.permute(0, 2, 3, 1)
x = x.flatten(start_dim=1)
return x