Spaces:
Runtime error
Runtime error
from __future__ import division, absolute_import | |
import warnings | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
__all__ = [ | |
'osnet_ain_x1_0', 'osnet_ain_x0_75', 'osnet_ain_x0_5', 'osnet_ain_x0_25' | |
] | |
pretrained_urls = { | |
'osnet_ain_x1_0': | |
'https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo', | |
'osnet_ain_x0_75': | |
'https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM', | |
'osnet_ain_x0_5': | |
'https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l', | |
'osnet_ain_x0_25': | |
'https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt' | |
} | |
########## | |
# Basic layers | |
########## | |
class ConvLayer(nn.Module): | |
"""Convolution layer (conv + bn + relu).""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=0, | |
groups=1, | |
IN=False | |
): | |
super(ConvLayer, self).__init__() | |
self.conv = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=False, | |
groups=groups | |
) | |
if IN: | |
self.bn = nn.InstanceNorm2d(out_channels, affine=True) | |
else: | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
return self.relu(x) | |
class Conv1x1(nn.Module): | |
"""1x1 convolution + bn + relu.""" | |
def __init__(self, in_channels, out_channels, stride=1, groups=1): | |
super(Conv1x1, self).__init__() | |
self.conv = nn.Conv2d( | |
in_channels, | |
out_channels, | |
1, | |
stride=stride, | |
padding=0, | |
bias=False, | |
groups=groups | |
) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
return self.relu(x) | |
class Conv1x1Linear(nn.Module): | |
"""1x1 convolution + bn (w/o non-linearity).""" | |
def __init__(self, in_channels, out_channels, stride=1, bn=True): | |
super(Conv1x1Linear, self).__init__() | |
self.conv = nn.Conv2d( | |
in_channels, out_channels, 1, stride=stride, padding=0, bias=False | |
) | |
self.bn = None | |
if bn: | |
self.bn = nn.BatchNorm2d(out_channels) | |
def forward(self, x): | |
x = self.conv(x) | |
if self.bn is not None: | |
x = self.bn(x) | |
return x | |
class Conv3x3(nn.Module): | |
"""3x3 convolution + bn + relu.""" | |
def __init__(self, in_channels, out_channels, stride=1, groups=1): | |
super(Conv3x3, self).__init__() | |
self.conv = nn.Conv2d( | |
in_channels, | |
out_channels, | |
3, | |
stride=stride, | |
padding=1, | |
bias=False, | |
groups=groups | |
) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
return self.relu(x) | |
class LightConv3x3(nn.Module): | |
"""Lightweight 3x3 convolution. | |
1x1 (linear) + dw 3x3 (nonlinear). | |
""" | |
def __init__(self, in_channels, out_channels): | |
super(LightConv3x3, self).__init__() | |
self.conv1 = nn.Conv2d( | |
in_channels, out_channels, 1, stride=1, padding=0, bias=False | |
) | |
self.conv2 = nn.Conv2d( | |
out_channels, | |
out_channels, | |
3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=out_channels | |
) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.bn(x) | |
return self.relu(x) | |
class LightConvStream(nn.Module): | |
"""Lightweight convolution stream.""" | |
def __init__(self, in_channels, out_channels, depth): | |
super(LightConvStream, self).__init__() | |
assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format( | |
depth | |
) | |
layers = [] | |
layers += [LightConv3x3(in_channels, out_channels)] | |
for i in range(depth - 1): | |
layers += [LightConv3x3(out_channels, out_channels)] | |
self.layers = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.layers(x) | |
########## | |
# Building blocks for omni-scale feature learning | |
########## | |
class ChannelGate(nn.Module): | |
"""A mini-network that generates channel-wise gates conditioned on input tensor.""" | |
def __init__( | |
self, | |
in_channels, | |
num_gates=None, | |
return_gates=False, | |
gate_activation='sigmoid', | |
reduction=16, | |
layer_norm=False | |
): | |
super(ChannelGate, self).__init__() | |
if num_gates is None: | |
num_gates = in_channels | |
self.return_gates = return_gates | |
self.global_avgpool = nn.AdaptiveAvgPool2d(1) | |
self.fc1 = nn.Conv2d( | |
in_channels, | |
in_channels // reduction, | |
kernel_size=1, | |
bias=True, | |
padding=0 | |
) | |
self.norm1 = None | |
if layer_norm: | |
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) | |
self.relu = nn.ReLU() | |
self.fc2 = nn.Conv2d( | |
in_channels // reduction, | |
num_gates, | |
kernel_size=1, | |
bias=True, | |
padding=0 | |
) | |
if gate_activation == 'sigmoid': | |
self.gate_activation = nn.Sigmoid() | |
elif gate_activation == 'relu': | |
self.gate_activation = nn.ReLU() | |
elif gate_activation == 'linear': | |
self.gate_activation = None | |
else: | |
raise RuntimeError( | |
"Unknown gate activation: {}".format(gate_activation) | |
) | |
def forward(self, x): | |
input = x | |
x = self.global_avgpool(x) | |
x = self.fc1(x) | |
if self.norm1 is not None: | |
x = self.norm1(x) | |
x = self.relu(x) | |
x = self.fc2(x) | |
if self.gate_activation is not None: | |
x = self.gate_activation(x) | |
if self.return_gates: | |
return x | |
return input * x | |
class OSBlock(nn.Module): | |
"""Omni-scale feature learning block.""" | |
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): | |
super(OSBlock, self).__init__() | |
assert T >= 1 | |
assert out_channels >= reduction and out_channels % reduction == 0 | |
mid_channels = out_channels // reduction | |
self.conv1 = Conv1x1(in_channels, mid_channels) | |
self.conv2 = nn.ModuleList() | |
for t in range(1, T + 1): | |
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] | |
self.gate = ChannelGate(mid_channels) | |
self.conv3 = Conv1x1Linear(mid_channels, out_channels) | |
self.downsample = None | |
if in_channels != out_channels: | |
self.downsample = Conv1x1Linear(in_channels, out_channels) | |
def forward(self, x): | |
identity = x | |
x1 = self.conv1(x) | |
x2 = 0 | |
for conv2_t in self.conv2: | |
x2_t = conv2_t(x1) | |
x2 = x2 + self.gate(x2_t) | |
x3 = self.conv3(x2) | |
if self.downsample is not None: | |
identity = self.downsample(identity) | |
out = x3 + identity | |
return F.relu(out) | |
class OSBlockINin(nn.Module): | |
"""Omni-scale feature learning block with instance normalization.""" | |
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): | |
super(OSBlockINin, self).__init__() | |
assert T >= 1 | |
assert out_channels >= reduction and out_channels % reduction == 0 | |
mid_channels = out_channels // reduction | |
self.conv1 = Conv1x1(in_channels, mid_channels) | |
self.conv2 = nn.ModuleList() | |
for t in range(1, T + 1): | |
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] | |
self.gate = ChannelGate(mid_channels) | |
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False) | |
self.downsample = None | |
if in_channels != out_channels: | |
self.downsample = Conv1x1Linear(in_channels, out_channels) | |
self.IN = nn.InstanceNorm2d(out_channels, affine=True) | |
def forward(self, x): | |
identity = x | |
x1 = self.conv1(x) | |
x2 = 0 | |
for conv2_t in self.conv2: | |
x2_t = conv2_t(x1) | |
x2 = x2 + self.gate(x2_t) | |
x3 = self.conv3(x2) | |
x3 = self.IN(x3) # IN inside residual | |
if self.downsample is not None: | |
identity = self.downsample(identity) | |
out = x3 + identity | |
return F.relu(out) | |
########## | |
# Network architecture | |
########## | |
class OSNet(nn.Module): | |
"""Omni-Scale Network. | |
Reference: | |
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019. | |
- Zhou et al. Learning Generalisable Omni-Scale Representations | |
for Person Re-Identification. TPAMI, 2021. | |
""" | |
def __init__( | |
self, | |
num_classes, | |
blocks, | |
layers, | |
channels, | |
feature_dim=512, | |
loss='softmax', | |
conv1_IN=False, | |
**kwargs | |
): | |
super(OSNet, self).__init__() | |
num_blocks = len(blocks) | |
assert num_blocks == len(layers) | |
assert num_blocks == len(channels) - 1 | |
self.loss = loss | |
self.feature_dim = feature_dim | |
# convolutional backbone | |
self.conv1 = ConvLayer( | |
3, channels[0], 7, stride=2, padding=3, IN=conv1_IN | |
) | |
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) | |
self.conv2 = self._make_layer( | |
blocks[0], layers[0], channels[0], channels[1] | |
) | |
self.pool2 = nn.Sequential( | |
Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2) | |
) | |
self.conv3 = self._make_layer( | |
blocks[1], layers[1], channels[1], channels[2] | |
) | |
self.pool3 = nn.Sequential( | |
Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2) | |
) | |
self.conv4 = self._make_layer( | |
blocks[2], layers[2], channels[2], channels[3] | |
) | |
self.conv5 = Conv1x1(channels[3], channels[3]) | |
self.global_avgpool = nn.AdaptiveAvgPool2d(1) | |
# fully connected layer | |
self.fc = self._construct_fc_layer( | |
self.feature_dim, channels[3], dropout_p=None | |
) | |
# identity classification layer | |
self.classifier = nn.Linear(self.feature_dim, num_classes) | |
self._init_params() | |
def _make_layer(self, blocks, layer, in_channels, out_channels): | |
layers = [] | |
layers += [blocks[0](in_channels, out_channels)] | |
for i in range(1, len(blocks)): | |
layers += [blocks[i](out_channels, out_channels)] | |
return nn.Sequential(*layers) | |
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): | |
if fc_dims is None or fc_dims < 0: | |
self.feature_dim = input_dim | |
return None | |
if isinstance(fc_dims, int): | |
fc_dims = [fc_dims] | |
layers = [] | |
for dim in fc_dims: | |
layers.append(nn.Linear(input_dim, dim)) | |
layers.append(nn.BatchNorm1d(dim)) | |
layers.append(nn.ReLU()) | |
if dropout_p is not None: | |
layers.append(nn.Dropout(p=dropout_p)) | |
input_dim = dim | |
self.feature_dim = fc_dims[-1] | |
return nn.Sequential(*layers) | |
def _init_params(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_( | |
m.weight, mode='fan_out', nonlinearity='relu' | |
) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.BatchNorm1d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.InstanceNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Linear): | |
nn.init.normal_(m.weight, 0, 0.01) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def featuremaps(self, x): | |
x = self.conv1(x) | |
x = self.maxpool(x) | |
x = self.conv2(x) | |
x = self.pool2(x) | |
x = self.conv3(x) | |
x = self.pool3(x) | |
x = self.conv4(x) | |
x = self.conv5(x) | |
return x | |
def forward(self, x, return_featuremaps=False): | |
x = self.featuremaps(x) | |
if return_featuremaps: | |
return x | |
v = self.global_avgpool(x) | |
v = v.view(v.size(0), -1) | |
if self.fc is not None: | |
v = self.fc(v) | |
if not self.training: | |
return v | |
y = self.classifier(v) | |
if self.loss == 'softmax': | |
return y | |
elif self.loss == 'triplet': | |
return y, v | |
else: | |
raise KeyError("Unsupported loss: {}".format(self.loss)) | |
def init_pretrained_weights(model, key=''): | |
"""Initializes model with pretrained weights. | |
Layers that don't match with pretrained layers in name or size are kept unchanged. | |
""" | |
import os | |
import errno | |
import gdown | |
from collections import OrderedDict | |
def _get_torch_home(): | |
ENV_TORCH_HOME = 'TORCH_HOME' | |
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' | |
DEFAULT_CACHE_DIR = '~/.cache' | |
torch_home = os.path.expanduser( | |
os.getenv( | |
ENV_TORCH_HOME, | |
os.path.join( | |
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch' | |
) | |
) | |
) | |
return torch_home | |
torch_home = _get_torch_home() | |
model_dir = os.path.join(torch_home, 'checkpoints') | |
try: | |
os.makedirs(model_dir) | |
except OSError as e: | |
if e.errno == errno.EEXIST: | |
# Directory already exists, ignore. | |
pass | |
else: | |
# Unexpected OSError, re-raise. | |
raise | |
filename = key + '_imagenet.pth' | |
cached_file = os.path.join(model_dir, filename) | |
if not os.path.exists(cached_file): | |
gdown.download(pretrained_urls[key], cached_file, quiet=False) | |
state_dict = torch.load(cached_file) | |
model_dict = model.state_dict() | |
new_state_dict = OrderedDict() | |
matched_layers, discarded_layers = [], [] | |
for k, v in state_dict.items(): | |
if k.startswith('module.'): | |
k = k[7:] # discard module. | |
if k in model_dict and model_dict[k].size() == v.size(): | |
new_state_dict[k] = v | |
matched_layers.append(k) | |
else: | |
discarded_layers.append(k) | |
model_dict.update(new_state_dict) | |
model.load_state_dict(model_dict) | |
if len(matched_layers) == 0: | |
warnings.warn( | |
'The pretrained weights from "{}" cannot be loaded, ' | |
'please check the key names manually ' | |
'(** ignored and continue **)'.format(cached_file) | |
) | |
else: | |
print( | |
'Successfully loaded imagenet pretrained weights from "{}"'. | |
format(cached_file) | |
) | |
if len(discarded_layers) > 0: | |
print( | |
'** The following layers are discarded ' | |
'due to unmatched keys or layer size: {}'. | |
format(discarded_layers) | |
) | |
########## | |
# Instantiation | |
########## | |
def osnet_ain_x1_0( | |
num_classes=1000, pretrained=True, loss='softmax', **kwargs | |
): | |
model = OSNet( | |
num_classes, | |
blocks=[ | |
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], | |
[OSBlockINin, OSBlock] | |
], | |
layers=[2, 2, 2], | |
channels=[64, 256, 384, 512], | |
loss=loss, | |
conv1_IN=True, | |
**kwargs | |
) | |
if pretrained: | |
init_pretrained_weights(model, key='osnet_ain_x1_0') | |
return model | |
def osnet_ain_x0_75( | |
num_classes=1000, pretrained=True, loss='softmax', **kwargs | |
): | |
model = OSNet( | |
num_classes, | |
blocks=[ | |
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], | |
[OSBlockINin, OSBlock] | |
], | |
layers=[2, 2, 2], | |
channels=[48, 192, 288, 384], | |
loss=loss, | |
conv1_IN=True, | |
**kwargs | |
) | |
if pretrained: | |
init_pretrained_weights(model, key='osnet_ain_x0_75') | |
return model | |
def osnet_ain_x0_5( | |
num_classes=1000, pretrained=True, loss='softmax', **kwargs | |
): | |
model = OSNet( | |
num_classes, | |
blocks=[ | |
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], | |
[OSBlockINin, OSBlock] | |
], | |
layers=[2, 2, 2], | |
channels=[32, 128, 192, 256], | |
loss=loss, | |
conv1_IN=True, | |
**kwargs | |
) | |
if pretrained: | |
init_pretrained_weights(model, key='osnet_ain_x0_5') | |
return model | |
def osnet_ain_x0_25( | |
num_classes=1000, pretrained=True, loss='softmax', **kwargs | |
): | |
model = OSNet( | |
num_classes, | |
blocks=[ | |
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], | |
[OSBlockINin, OSBlock] | |
], | |
layers=[2, 2, 2], | |
channels=[16, 64, 96, 128], | |
loss=loss, | |
conv1_IN=True, | |
**kwargs | |
) | |
if pretrained: | |
init_pretrained_weights(model, key='osnet_ain_x0_25') | |
return model | |