Spaces:
Running
on
Zero
Running
on
Zero
"""This file is for Inception model borrowed from torch metrics / fidelity. | |
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). | |
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. | |
Reference: | |
https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py | |
""" | |
# Copyright The Lightning team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn.functional as F | |
from torch_fidelity.feature_extractor_base import FeatureExtractorBase | |
from torch_fidelity.helpers import vassert | |
from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2 | |
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x | |
try: | |
from torchvision.models.utils import load_state_dict_from_url | |
except ImportError: | |
from torch.utils.model_zoo import load_url as load_state_dict_from_url | |
# Note: Compared shasum and models should be the same. | |
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' | |
class FeatureExtractorInceptionV3(FeatureExtractorBase): | |
INPUT_IMAGE_SIZE = 299 | |
def __init__( | |
self, | |
name, | |
features_list, | |
**kwargs, | |
): | |
""" | |
InceptionV3 feature extractor for 2D RGB 24bit images. | |
Args: | |
name (str): Unique name of the feature extractor, must be the same as used in | |
:func:`register_feature_extractor`. | |
features_list (list): A list of the requested feature names, which will be produced for each input. This | |
feature extractor provides the following features: | |
- '64' | |
- '192' | |
- '768' | |
- '2048' | |
- 'logits_unbiased' | |
- 'logits' | |
""" | |
super(FeatureExtractorInceptionV3, self).__init__(name, features_list) | |
self.feature_extractor_internal_dtype = torch.float64 | |
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) | |
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) | |
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) | |
self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2) | |
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) | |
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) | |
self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2) | |
self.Mixed_5b = InceptionA(192, pool_features=32) | |
self.Mixed_5c = InceptionA(256, pool_features=64) | |
self.Mixed_5d = InceptionA(288, pool_features=64) | |
self.Mixed_6a = InceptionB(288) | |
self.Mixed_6b = InceptionC(768, channels_7x7=128) | |
self.Mixed_6c = InceptionC(768, channels_7x7=160) | |
self.Mixed_6d = InceptionC(768, channels_7x7=160) | |
self.Mixed_6e = InceptionC(768, channels_7x7=192) | |
self.Mixed_7a = InceptionD(768) | |
self.Mixed_7b = InceptionE_1(1280) | |
self.Mixed_7c = InceptionE_2(2048) | |
self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1)) | |
self.fc = torch.nn.Linear(2048, 1008) | |
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) | |
#state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu') | |
self.load_state_dict(state_dict) | |
self.to(self.feature_extractor_internal_dtype) | |
self.requires_grad_(False) | |
self.eval() | |
def forward(self, x): | |
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8') | |
vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}') | |
features = {} | |
remaining_features = self.features_list.copy() | |
x = x.to(self.feature_extractor_internal_dtype) | |
# N x 3 x ? x ? | |
x = interpolate_bilinear_2d_like_tensorflow1x( | |
x, | |
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), | |
align_corners=False, | |
) | |
# N x 3 x 299 x 299 | |
# x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph | |
x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too | |
# N x 3 x 299 x 299 | |
x = self.Conv2d_1a_3x3(x) | |
# N x 32 x 149 x 149 | |
x = self.Conv2d_2a_3x3(x) | |
# N x 32 x 147 x 147 | |
x = self.Conv2d_2b_3x3(x) | |
# N x 64 x 147 x 147 | |
x = self.MaxPool_1(x) | |
# N x 64 x 73 x 73 | |
if '64' in remaining_features: | |
features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) | |
remaining_features.remove('64') | |
if len(remaining_features) == 0: | |
return features | |
x = self.Conv2d_3b_1x1(x) | |
# N x 80 x 73 x 73 | |
x = self.Conv2d_4a_3x3(x) | |
# N x 192 x 71 x 71 | |
x = self.MaxPool_2(x) | |
# N x 192 x 35 x 35 | |
if '192' in remaining_features: | |
features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) | |
remaining_features.remove('192') | |
if len(remaining_features) == 0: | |
return features | |
x = self.Mixed_5b(x) | |
# N x 256 x 35 x 35 | |
x = self.Mixed_5c(x) | |
# N x 288 x 35 x 35 | |
x = self.Mixed_5d(x) | |
# N x 288 x 35 x 35 | |
x = self.Mixed_6a(x) | |
# N x 768 x 17 x 17 | |
x = self.Mixed_6b(x) | |
# N x 768 x 17 x 17 | |
x = self.Mixed_6c(x) | |
# N x 768 x 17 x 17 | |
x = self.Mixed_6d(x) | |
# N x 768 x 17 x 17 | |
x = self.Mixed_6e(x) | |
# N x 768 x 17 x 17 | |
if '768' in remaining_features: | |
features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32) | |
remaining_features.remove('768') | |
if len(remaining_features) == 0: | |
return features | |
x = self.Mixed_7a(x) | |
# N x 1280 x 8 x 8 | |
x = self.Mixed_7b(x) | |
# N x 2048 x 8 x 8 | |
x = self.Mixed_7c(x) | |
# N x 2048 x 8 x 8 | |
x = self.AvgPool(x) | |
# N x 2048 x 1 x 1 | |
x = torch.flatten(x, 1) | |
# N x 2048 | |
if '2048' in remaining_features: | |
features['2048'] = x | |
remaining_features.remove('2048') | |
if len(remaining_features) == 0: | |
return features | |
if 'logits_unbiased' in remaining_features: | |
x = x.mm(self.fc.weight.T) | |
# N x 1008 (num_classes) | |
features['logits_unbiased'] = x | |
remaining_features.remove('logits_unbiased') | |
if len(remaining_features) == 0: | |
return features | |
x = x + self.fc.bias.unsqueeze(0) | |
else: | |
x = self.fc(x) | |
# N x 1008 (num_classes) | |
features['logits'] = x | |
return features | |
def get_provided_features_list(): | |
return '64', '192', '768', '2048', 'logits_unbiased', 'logits' | |
def get_default_feature_layer_for_metric(metric): | |
return { | |
'isc': 'logits_unbiased', | |
'fid': '2048', | |
'kid': '2048', | |
'prc': '2048', | |
}[metric] | |
def can_be_compiled(): | |
return True | |
def get_dummy_input_for_compile(): | |
return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8) | |
def get_inception_model(): | |
model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"]) | |
return model |