resefa / models /perceptual_model.py
akhaliq's picture
akhaliq HF staff
add files
8ca3a29
raw
history blame
20.8 kB
# python3.7
"""Contains the VGG16 model, which is used for inference ONLY.
VGG16 is commonly used for perceptual feature extraction. The model implemented
in this file can be used for evaluation (like computing LPIPS, perceptual path
length, etc.), OR be used in training for loss computation (like perceptual
loss, etc.).
The pre-trained model is officially shared by
https://www.robots.ox.ac.uk/~vgg/research/very_deep/
and ported by
https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
Compared to the official VGG16 model, this ported model also support evaluating
LPIPS, which is introduced in
https://github.com/richzhang/PerceptualSimilarity
"""
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from utils.misc import download_url
__all__ = ['PerceptualModel']
# pylint: disable=line-too-long
_MODEL_URL_SHA256 = {
# This model is provided by `torchvision`, which is ported from TensorFlow.
'torchvision_official': (
'https://download.pytorch.org/models/vgg16-397923af.pth',
'397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256
),
# This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch
'vgg_perceptual_lpips': (
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt',
'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256
)
}
# pylint: enable=line-too-long
class PerceptualModel(object):
"""Defines the perceptual model, which is based on VGG16 structure.
This is a static class, which is used to avoid this model to be built
repeatedly. Consequently, this model is particularly used for inference,
like computing LPIPS, or for loss computation, like perceptual loss. If
training is required, please use the model from `torchvision.models` or
implement by yourself.
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
order and pixel range [-1, 1], and will NOT resize the input automatically
if only perceptual feature is needed.
"""
models = dict()
@staticmethod
def build_model(use_torchvision=False, no_top=True, enable_lpips=True):
"""Builds the model and load pre-trained weights.
1. If `use_torchvision` is set as True, the model released by
`torchvision` will be loaded, otherwise, the model released by
https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used.
(default: False)
2. To save computing resources, these is an option to only load the
backbone (i.e., without the last three fully-connected layers). This
is commonly used for perceptual loss or LPIPS loss computation.
Please use argument `no_top` to control this. (default: True)
3. For LPIPS loss computation, some additional weights (which is used
for balancing the features from different resolutions) are employed
on top of the original VGG16 backbone. Details can be found at
https://github.com/richzhang/PerceptualSimilarity. Please use
`enable_lpips` to enable this feature. (default: True)
The built model supports following arguments when forwarding:
- resize_input: Whether to resize the input image to size [224, 224]
before forwarding. For feature-based computation (i.e., only
convolutional layers are used), image resizing is not essential.
(default: False)
- return_tensor: This field resolves the model behavior. Following
options are supported:
`feature1`: Before the first max pooling layer.
`pool1`: After the first max pooling layer.
`feature2`: Before the second max pooling layer.
`pool2`: After the second max pooling layer.
`feature3`: Before the third max pooling layer.
`pool3`: After the third max pooling layer.
`feature4`: Before the fourth max pooling layer.
`pool4`: After the fourth max pooling layer.
`feature5`: Before the fifth max pooling layer.
`pool5`: After the fifth max pooling layer.
`flatten`: The flattened feature, after `adaptive_avgpool`.
`feature`: The 4096d feature for logits computation. (default)
`logits`: The 1000d categorical logits.
`prediction`: The 1000d predicted probability.
`lpips`: The LPIPS score between two input images.
"""
if use_torchvision:
model_source = 'torchvision_official'
align_tf_resize = False
is_torch_script = False
else:
model_source = 'vgg_perceptual_lpips'
align_tf_resize = True
is_torch_script = True
if enable_lpips and model_source != 'vgg_perceptual_lpips':
warnings.warn('The pre-trained model officially released by '
'`torchvision` does not support LPIPS computation! '
'Equal weights will be used for each resolution.')
fingerprint = (model_source, no_top, enable_lpips)
if fingerprint not in PerceptualModel.models:
# Build model.
model = VGG16(align_tf_resize=align_tf_resize,
no_top=no_top,
enable_lpips=enable_lpips)
# Download pre-trained weights.
if dist.is_initialized() and dist.get_rank() != 0:
dist.barrier() # Download by chief.
url, sha256 = _MODEL_URL_SHA256[model_source]
filename = f'perceptual_model_{model_source}_{sha256}.pth'
model_path, hash_check = download_url(url,
filename=filename,
sha256=sha256)
if is_torch_script:
src_state_dict = torch.jit.load(model_path, map_location='cpu')
else:
src_state_dict = torch.load(model_path, map_location='cpu')
if hash_check is False:
warnings.warn(f'Hash check failed! The remote file from URL '
f'`{url}` may be changed, or the downloading is '
f'interrupted. The loaded perceptual model may '
f'have unexpected behavior.')
if dist.is_initialized() and dist.get_rank() == 0:
dist.barrier() # Wait for other replicas.
# Load weights.
dst_state_dict = _convert_weights(src_state_dict, model_source)
model.load_state_dict(dst_state_dict, strict=False)
del src_state_dict, dst_state_dict
# For inference only.
model.eval().requires_grad_(False).cuda()
PerceptualModel.models[fingerprint] = model
return PerceptualModel.models[fingerprint]
def _convert_weights(src_state_dict, model_source):
if model_source not in _MODEL_URL_SHA256:
raise ValueError(f'Invalid model source `{model_source}`!\n'
f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.')
if model_source == 'torchvision_official':
dst_to_src_var_mapping = {
'conv11.weight': 'features.0.weight',
'conv11.bias': 'features.0.bias',
'conv12.weight': 'features.2.weight',
'conv12.bias': 'features.2.bias',
'conv21.weight': 'features.5.weight',
'conv21.bias': 'features.5.bias',
'conv22.weight': 'features.7.weight',
'conv22.bias': 'features.7.bias',
'conv31.weight': 'features.10.weight',
'conv31.bias': 'features.10.bias',
'conv32.weight': 'features.12.weight',
'conv32.bias': 'features.12.bias',
'conv33.weight': 'features.14.weight',
'conv33.bias': 'features.14.bias',
'conv41.weight': 'features.17.weight',
'conv41.bias': 'features.17.bias',
'conv42.weight': 'features.19.weight',
'conv42.bias': 'features.19.bias',
'conv43.weight': 'features.21.weight',
'conv43.bias': 'features.21.bias',
'conv51.weight': 'features.24.weight',
'conv51.bias': 'features.24.bias',
'conv52.weight': 'features.26.weight',
'conv52.bias': 'features.26.bias',
'conv53.weight': 'features.28.weight',
'conv53.bias': 'features.28.bias',
'fc1.weight': 'classifier.0.weight',
'fc1.bias': 'classifier.0.bias',
'fc2.weight': 'classifier.3.weight',
'fc2.bias': 'classifier.3.bias',
'fc3.weight': 'classifier.6.weight',
'fc3.bias': 'classifier.6.bias',
}
elif model_source == 'vgg_perceptual_lpips':
src_state_dict = src_state_dict.state_dict()
dst_to_src_var_mapping = {
'conv11.weight': 'layers.conv1.weight',
'conv11.bias': 'layers.conv1.bias',
'conv12.weight': 'layers.conv2.weight',
'conv12.bias': 'layers.conv2.bias',
'conv21.weight': 'layers.conv3.weight',
'conv21.bias': 'layers.conv3.bias',
'conv22.weight': 'layers.conv4.weight',
'conv22.bias': 'layers.conv4.bias',
'conv31.weight': 'layers.conv5.weight',
'conv31.bias': 'layers.conv5.bias',
'conv32.weight': 'layers.conv6.weight',
'conv32.bias': 'layers.conv6.bias',
'conv33.weight': 'layers.conv7.weight',
'conv33.bias': 'layers.conv7.bias',
'conv41.weight': 'layers.conv8.weight',
'conv41.bias': 'layers.conv8.bias',
'conv42.weight': 'layers.conv9.weight',
'conv42.bias': 'layers.conv9.bias',
'conv43.weight': 'layers.conv10.weight',
'conv43.bias': 'layers.conv10.bias',
'conv51.weight': 'layers.conv11.weight',
'conv51.bias': 'layers.conv11.bias',
'conv52.weight': 'layers.conv12.weight',
'conv52.bias': 'layers.conv12.bias',
'conv53.weight': 'layers.conv13.weight',
'conv53.bias': 'layers.conv13.bias',
'fc1.weight': 'layers.fc1.weight',
'fc1.bias': 'layers.fc1.bias',
'fc2.weight': 'layers.fc2.weight',
'fc2.bias': 'layers.fc2.bias',
'fc3.weight': 'layers.fc3.weight',
'fc3.bias': 'layers.fc3.bias',
'lpips.0.weight': 'lpips0',
'lpips.1.weight': 'lpips1',
'lpips.2.weight': 'lpips2',
'lpips.3.weight': 'lpips3',
'lpips.4.weight': 'lpips4',
}
else:
raise NotImplementedError(f'Not implemented model source '
f'`{model_source}`!')
dst_state_dict = {}
for dst_name, src_name in dst_to_src_var_mapping.items():
if dst_name.startswith('lpips'):
dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0)
else:
dst_state_dict[dst_name] = src_state_dict[src_name].clone()
return dst_state_dict
_IMG_MEAN = (0.485, 0.456, 0.406)
_IMG_STD = (0.229, 0.224, 0.225)
_ALLOWED_RETURN = [
'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4',
'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction',
'lpips'
]
# pylint: disable=missing-function-docstring
class VGG16(nn.Module):
"""Defines the VGG16 structure.
This model takes `RGB` images with data format `NCHW` as the raw inputs. The
pixel range are assumed to be [-1, 1].
"""
def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True):
"""Defines the network structure."""
super().__init__()
self.align_tf_resize = align_tf_resize
self.no_top = no_top
self.enable_lpips = enable_lpips
self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu11 = nn.ReLU(inplace=True)
self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.relu12 = nn.ReLU(inplace=True)
# output `feature1`, with shape [N, 64, 224, 224]
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# output `pool1`, with shape [N, 64, 112, 112]
self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.relu21 = nn.ReLU(inplace=True)
self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.relu22 = nn.ReLU(inplace=True)
# output `feature2`, with shape [N, 128, 112, 112]
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# output `pool2`, with shape [N, 128, 56, 56]
self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.relu31 = nn.ReLU(inplace=True)
self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.relu32 = nn.ReLU(inplace=True)
self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.relu33 = nn.ReLU(inplace=True)
# output `feature3`, with shape [N, 256, 56, 56]
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# output `pool3`, with shape [N,256, 28, 28]
self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.relu41 = nn.ReLU(inplace=True)
self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.relu42 = nn.ReLU(inplace=True)
self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.relu43 = nn.ReLU(inplace=True)
# output `feature4`, with shape [N, 512, 28, 28]
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
# output `pool4`, with shape [N, 512, 14, 14]
self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.relu51 = nn.ReLU(inplace=True)
self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.relu52 = nn.ReLU(inplace=True)
self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.relu53 = nn.ReLU(inplace=True)
# output `feature5`, with shape [N, 512, 14, 14]
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
# output `pool5`, with shape [N, 512, 7, 7]
if self.enable_lpips:
self.lpips = nn.ModuleList()
for idx, ch in enumerate([64, 128, 256, 512, 512]):
self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False))
self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1))
if not self.no_top:
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
# output `flatten`, with shape [N, 25088]
self.fc1 = nn.Linear(512 * 7 * 7, 4096)
self.fc1_relu = nn.ReLU(inplace=True)
self.fc1_dropout = nn.Dropout(0.5, inplace=False)
self.fc2 = nn.Linear(4096, 4096)
self.fc2_relu = nn.ReLU(inplace=True)
self.fc2_dropout = nn.Dropout(0.5, inplace=False)
# output `feature`, with shape [N, 4096]
self.fc3 = nn.Linear(4096, 1000)
# output `logits`, with shape [N, 1000]
self.out = nn.Softmax(dim=1)
# output `softmax`, with shape [N, 1000]
img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32)
img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32)
self.register_buffer('img_mean', torch.from_numpy(img_mean))
self.register_buffer('img_std', torch.from_numpy(img_std))
def forward(self,
x,
y=None,
*,
resize_input=False,
return_tensor='feature'):
return_tensor = return_tensor.lower()
if return_tensor not in _ALLOWED_RETURN:
raise ValueError(f'Invalid output tensor name `{return_tensor}` '
f'for perceptual model (VGG16)!\n'
f'Names allowed: {_ALLOWED_RETURN}.')
if return_tensor == 'lpips' and y is None:
raise ValueError('Two images are required for LPIPS computation, '
'but only one is received!')
if return_tensor == 'lpips':
assert x.shape == y.shape
x = torch.cat([x, y], dim=0)
features = []
if resize_input:
if self.align_tf_resize:
theta = torch.eye(2, 3).to(x)
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
grid = F.affine_grid(theta,
size=(x.shape[0], x.shape[1], 224, 224),
align_corners=False)
x = F.grid_sample(x, grid,
mode='bilinear',
padding_mode='border',
align_corners=False)
else:
x = F.interpolate(x,
size=(224, 224),
mode='bilinear',
align_corners=False)
if x.shape[1] == 1:
x = x.repeat((1, 3, 1, 1))
x = (x + 1) / 2
x = (x - self.img_mean) / self.img_std
x = self.conv11(x)
x = self.relu11(x)
x = self.conv12(x)
x = self.relu12(x)
if return_tensor == 'feature1':
return x
if return_tensor == 'lpips':
features.append(x)
x = self.pool1(x)
if return_tensor == 'pool1':
return x
x = self.conv21(x)
x = self.relu21(x)
x = self.conv22(x)
x = self.relu22(x)
if return_tensor == 'feature2':
return x
if return_tensor == 'lpips':
features.append(x)
x = self.pool2(x)
if return_tensor == 'pool2':
return x
x = self.conv31(x)
x = self.relu31(x)
x = self.conv32(x)
x = self.relu32(x)
x = self.conv33(x)
x = self.relu33(x)
if return_tensor == 'feature3':
return x
if return_tensor == 'lpips':
features.append(x)
x = self.pool3(x)
if return_tensor == 'pool3':
return x
x = self.conv41(x)
x = self.relu41(x)
x = self.conv42(x)
x = self.relu42(x)
x = self.conv43(x)
x = self.relu43(x)
if return_tensor == 'feature4':
return x
if return_tensor == 'lpips':
features.append(x)
x = self.pool4(x)
if return_tensor == 'pool4':
return x
x = self.conv51(x)
x = self.relu51(x)
x = self.conv52(x)
x = self.relu52(x)
x = self.conv53(x)
x = self.relu53(x)
if return_tensor == 'feature5':
return x
if return_tensor == 'lpips':
features.append(x)
x = self.pool5(x)
if return_tensor == 'pool5':
return x
if return_tensor == 'lpips':
score = 0
assert len(features) == 5
for idx in range(5):
feature = features[idx]
norm = feature.norm(dim=1, keepdim=True)
feature = feature / (norm + 1e-10)
feature_x, feature_y = feature.chunk(2, dim=0)
diff = (feature_x - feature_y).square()
score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False)
return score.sum(dim=1, keepdim=False)
x = self.avgpool(x)
x = self.flatten(x)
if return_tensor == 'flatten':
return x
x = self.fc1(x)
x = self.fc1_relu(x)
x = self.fc1_dropout(x)
x = self.fc2(x)
x = self.fc2_relu(x)
x = self.fc2_dropout(x)
if return_tensor == 'feature':
return x
x = self.fc3(x)
if return_tensor == 'logits':
return x
x = self.out(x)
if return_tensor == 'prediction':
return x
raise NotImplementedError(f'Output tensor name `{return_tensor}` is '
f'not implemented!')
# pylint: enable=missing-function-docstring