from __future__ import division, absolute_import import torch.nn as nn import torch.nn.functional as F import torch.utils.model_zoo as model_zoo __all__ = ['xception'] pretrained_settings = { 'xception': { 'imagenet': { 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', 'input_space': 'RGB', 'input_size': [3, 299, 299], 'input_range': [0, 1], 'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'num_classes': 1000, 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 } } } class SeparableConv2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False ): super(SeparableConv2d, self).__init__() self.conv1 = nn.Conv2d( in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias ) self.pointwise = nn.Conv2d( in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias ) def forward(self, x): x = self.conv1(x) x = self.pointwise(x) return x class Block(nn.Module): def __init__( self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True ): super(Block, self).__init__() if out_filters != in_filters or strides != 1: self.skip = nn.Conv2d( in_filters, out_filters, 1, stride=strides, bias=False ) self.skipbn = nn.BatchNorm2d(out_filters) else: self.skip = None self.relu = nn.ReLU(inplace=True) rep = [] filters = in_filters if grow_first: rep.append(self.relu) rep.append( SeparableConv2d( in_filters, out_filters, 3, stride=1, padding=1, bias=False ) ) rep.append(nn.BatchNorm2d(out_filters)) filters = out_filters for i in range(reps - 1): rep.append(self.relu) rep.append( SeparableConv2d( filters, filters, 3, stride=1, padding=1, bias=False ) ) rep.append(nn.BatchNorm2d(filters)) if not grow_first: rep.append(self.relu) rep.append( SeparableConv2d( in_filters, out_filters, 3, stride=1, padding=1, bias=False ) ) rep.append(nn.BatchNorm2d(out_filters)) if not start_with_relu: rep = rep[1:] else: rep[0] = nn.ReLU(inplace=False) if strides != 1: rep.append(nn.MaxPool2d(3, strides, 1)) self.rep = nn.Sequential(*rep) def forward(self, inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x += skip return x class Xception(nn.Module): """Xception. Reference: Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. CVPR 2017. Public keys: - ``xception``: Xception. """ def __init__( self, num_classes, loss, fc_dims=None, dropout_p=None, **kwargs ): super(Xception, self).__init__() self.loss = loss self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 64, 3, bias=False) self.bn2 = nn.BatchNorm2d(64) self.block1 = Block( 64, 128, 2, 2, start_with_relu=False, grow_first=True ) self.block2 = Block( 128, 256, 2, 2, start_with_relu=True, grow_first=True ) self.block3 = Block( 256, 728, 2, 2, start_with_relu=True, grow_first=True ) self.block4 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block5 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block6 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block7 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block8 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block9 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block10 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block11 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block12 = Block( 728, 1024, 2, 2, start_with_relu=True, grow_first=False ) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) self.bn3 = nn.BatchNorm2d(1536) self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) self.bn4 = nn.BatchNorm2d(2048) self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.feature_dim = 2048 self.fc = self._construct_fc_layer(fc_dims, 2048, dropout_p) self.classifier = nn.Linear(self.feature_dim, num_classes) self._init_params() def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): """Constructs fully connected layer. Args: fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed input_dim (int): input dimension dropout_p (float): dropout probability, if None, dropout is unused """ if fc_dims is None: self.feature_dim = input_dim return None assert isinstance( fc_dims, (list, tuple) ), 'fc_dims must be either list or tuple, but got {}'.format( type(fc_dims) ) layers = [] for dim in fc_dims: layers.append(nn.Linear(input_dim, dim)) layers.append(nn.BatchNorm1d(dim)) layers.append(nn.ReLU(inplace=True)) if dropout_p is not None: layers.append(nn.Dropout(p=dropout_p)) input_dim = dim self.feature_dim = fc_dims[-1] return nn.Sequential(*layers) def _init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu' ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def featuremaps(self, input): x = self.conv1(input) x = self.bn1(x) x = F.relu(x, inplace=True) x = self.conv2(x) x = self.bn2(x) x = F.relu(x, inplace=True) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) x = self.block12(x) x = self.conv3(x) x = self.bn3(x) x = F.relu(x, inplace=True) x = self.conv4(x) x = self.bn4(x) x = F.relu(x, inplace=True) return x def forward(self, x): f = self.featuremaps(x) v = self.global_avgpool(f) v = v.view(v.size(0), -1) if self.fc is not None: v = self.fc(v) if not self.training: return v y = self.classifier(v) if self.loss == 'softmax': return y elif self.loss == 'triplet': return y, v else: raise KeyError('Unsupported loss: {}'.format(self.loss)) def init_pretrained_weights(model, model_url): """Initialize models with pretrained weights. Layers that don't match with pretrained layers in name or size are kept unchanged. """ pretrain_dict = model_zoo.load_url(model_url) model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) def xception(num_classes, loss='softmax', pretrained=True, **kwargs): model = Xception(num_classes, loss, fc_dims=None, dropout_p=None, **kwargs) if pretrained: model_url = pretrained_settings['xception']['imagenet']['url'] init_pretrained_weights(model, model_url) return model