resefa / models /inception_model.py
akhaliq's picture
akhaliq HF staff
add files
8ca3a29
# python3.7
"""Contains the Inception V3 model, which is used for inference ONLY.
This file is mostly borrowed from `torchvision/models/inception.py`.
Inception model is widely used to compute FID or IS metric for evaluating
generative models. However, the pre-trained models from torchvision is slightly
different from the TensorFlow version
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
which is used by the official FID implementation
https://github.com/bioinf-jku/TTUR
In particular:
(1) The number of classes in TensorFlow model is 1008 instead of 1000.
(2) The avg_pool() layers in TensorFlow model does not include the padded zero.
(3) The last Inception E Block in TensorFlow model use max_pool() instead of
avg_pool().
Hence, to align the evaluation results with those from TensorFlow
implementation, we modified the inception model to support both versions. Please
use `align_tf` argument to control the version.
"""
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from utils.misc import download_url
__all__ = ['InceptionModel']
# pylint: disable=line-too-long
_MODEL_URL_SHA256 = {
# This model is provided by `torchvision`, which is ported from TensorFlow.
'torchvision_official': (
'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
'1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256
),
# This model is provided by https://github.com/mseitzer/pytorch-fid
'tf_inception_v3': (
'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth',
'6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256
)
}
class InceptionModel(object):
"""Defines the Inception (V3) model.
This is a static class, which is used to avoid this model to be built
repeatedly. Consequently, this model is particularly used for inference,
like computing FID. If training is required, please use the model from
`torchvision.models` or implement by yourself.
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
order and pixel range [-1, 1], and will also resize the images to shape
[299, 299] automatically. If your input is normalized by subtracting
(0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use
`transform_input` in the `forward()` function to un-normalize it.
"""
models = dict()
@staticmethod
def build_model(align_tf=True):
"""Builds the model and load pre-trained weights.
If `align_tf` is set as True, the model will predict 1008 classes, and
the pre-trained weight from `https://github.com/mseitzer/pytorch-fid`
will be loaded. Otherwise, the model will predict 1000 classes, and will
load the model from `torchvision`.
The built model supports following arguments when forwarding:
- transform_input: Whether to transform the input back to pixel range
(-1, 1). Please disable this argument if your input is already with
pixel range (-1, 1). (default: False)
- output_logits: Whether to output the categorical logits instead of
features. (default: False)
- remove_logits_bias: Whether to remove the bias when computing the
logits. The official implementation removes the bias by default.
Please refer to
`https://github.com/openai/improved-gan/blob/master/inception_score/model.py`.
(default: False)
- output_predictions: Whether to output the final predictions, i.e.,
`softmax(logits)`. (default: False)
"""
if align_tf:
num_classes = 1008
model_source = 'tf_inception_v3'
else:
num_classes = 1000
model_source = 'torchvision_official'
fingerprint = model_source
if fingerprint not in InceptionModel.models:
# Build model.
model = Inception3(num_classes=num_classes,
aux_logits=False,
init_weights=False,
align_tf=align_tf)
# Download pre-trained weights.
if dist.is_initialized() and dist.get_rank() != 0:
dist.barrier() # Download by chief.
url, sha256 = _MODEL_URL_SHA256[model_source]
filename = f'inception_model_{model_source}_{sha256}.pth'
model_path, hash_check = download_url(url,
filename=filename,
sha256=sha256)
state_dict = torch.load(model_path, map_location='cpu')
if hash_check is False:
warnings.warn(f'Hash check failed! The remote file from URL '
f'`{url}` may be changed, or the downloading is '
f'interrupted. The loaded inception model may '
f'have unexpected behavior.')
if dist.is_initialized() and dist.get_rank() == 0:
dist.barrier() # Wait for other replicas.
# Load weights.
model.load_state_dict(state_dict, strict=False)
del state_dict
# For inference only.
model.eval().requires_grad_(False).cuda()
InceptionModel.models[fingerprint] = model
return InceptionModel.models[fingerprint]
# pylint: disable=missing-function-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=super-with-arguments
# pylint: disable=consider-merging-isinstance
# pylint: disable=import-outside-toplevel
# pylint: disable=no-else-return
class Inception3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None,
init_weights=True, align_tf=True):
super(Inception3, self).__init__()
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
assert len(inception_blocks) == 7
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
inception_b = inception_blocks[2]
inception_c = inception_blocks[3]
inception_d = inception_blocks[4]
inception_e = inception_blocks[5]
inception_aux = inception_blocks[6]
self.aux_logits = aux_logits
self.align_tf = align_tf
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
self.Mixed_6a = inception_b(288)
self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
if aux_logits:
self.AuxLogits = inception_aux(768, num_classes)
self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
self.fc = nn.Linear(2048, num_classes)
if init_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@staticmethod
def _transform_input(x, transform_input=False):
if transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
def _forward(self,
x,
output_logits=False,
remove_logits_bias=False,
output_predictions=False):
# Upsample if necessary.
if x.shape[2] != 299 or x.shape[3] != 299:
if self.align_tf:
theta = torch.eye(2, 3).to(x)
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
grid = F.affine_grid(theta,
size=(x.shape[0], x.shape[1], 299, 299),
align_corners=False)
x = F.grid_sample(x, grid,
mode='bilinear',
padding_mode='border',
align_corners=False)
else:
x = F.interpolate(
x, size=(299, 299), mode='bilinear', align_corners=False)
if x.shape[1] == 1:
x = x.repeat((1, 3, 1, 1))
if self.align_tf:
x = (x * 127.5 + 127.5 - 128) / 128
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
if self.training and self.aux_logits:
aux = self.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 2048 x 1 x 1
x = F.dropout(x, training=self.training)
# N x 2048 x 1 x 1
x = torch.flatten(x, 1)
# N x 2048
if output_logits or output_predictions:
x = self.fc(x)
# N x 1000 (num_classes)
if remove_logits_bias:
x = x - self.fc.bias.view(1, -1)
if output_predictions:
x = F.softmax(x, dim=1)
return x, aux
def forward(self,
x,
transform_input=False,
output_logits=False,
remove_logits_bias=False,
output_predictions=False):
x = self._transform_input(x, transform_input)
x, aux = self._forward(
x, output_logits, remove_logits_bias, output_predictions)
if self.training and self.aux_logits:
return x, aux
else:
return x
class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
super(InceptionA, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
self.pool_include_padding = not align_tf
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=self.pool_include_padding)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionB(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionB, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
def _forward(self, x):
branch3x3 = self.branch3x3(x)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionC(nn.Module):
def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
super(InceptionC, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
c7 = channels_7x7
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
self.pool_include_padding = not align_tf
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=self.pool_include_padding)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionD(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionD, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
def _forward(self, x):
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch7x7x3, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionE(nn.Module):
def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
super(InceptionE, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
self.pool_include_padding = not align_tf
self.use_max_pool = use_max_pool
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
if self.use_max_pool:
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
else:
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=self.pool_include_padding)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes, conv_block=None):
super(InceptionAux, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes)
self.fc.stddev = 0.001
def forward(self, x):
# N x 768 x 17 x 17
x = F.avg_pool2d(x, kernel_size=5, stride=3)
# N x 768 x 5 x 5
x = self.conv0(x)
# N x 128 x 5 x 5
x = self.conv1(x)
# N x 768 x 1 x 1
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 768 x 1 x 1
x = torch.flatten(x, 1)
# N x 768
x = self.fc(x)
# N x 1000
return x
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
# pylint: enable=line-too-long
# pylint: enable=missing-function-docstring
# pylint: enable=missing-class-docstring
# pylint: enable=super-with-arguments
# pylint: enable=consider-merging-isinstance
# pylint: enable=import-outside-toplevel
# pylint: enable=no-else-return