''' Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py Original author cavalleria ''' import torch.nn as nn from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module import torch class Flatten(Module): def forward(self, x): return x.view(x.size(0), -1) class ConvBlock(Module): def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): super(ConvBlock, self).__init__() self.layers = nn.Sequential( Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), BatchNorm2d(num_features=out_c), PReLU(num_parameters=out_c) ) def forward(self, x): return self.layers(x) class LinearBlock(Module): def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): super(LinearBlock, self).__init__() self.layers = nn.Sequential( Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), BatchNorm2d(num_features=out_c) ) def forward(self, x): return self.layers(x) class DepthWise(Module): def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): super(DepthWise, self).__init__() self.residual = residual self.layers = nn.Sequential( ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) ) def forward(self, x): short_cut = None if self.residual: short_cut = x x = self.layers(x) if self.residual: output = short_cut + x else: output = x return output class Residual(Module): def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): super(Residual, self).__init__() modules = [] for _ in range(num_block): modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) self.layers = Sequential(*modules) def forward(self, x): return self.layers(x) class GDC(Module): def __init__(self, embedding_size): super(GDC, self).__init__() self.layers = nn.Sequential( LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), Flatten(), Linear(512, embedding_size, bias=False), BatchNorm1d(embedding_size)) def forward(self, x): return self.layers(x) class MobileFaceNet(Module): def __init__(self, fp16=False, num_features=512): super(MobileFaceNet, self).__init__() scale = 2 self.fp16 = fp16 self.layers = nn.Sequential( ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), ) self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) self.features = GDC(num_features) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() def forward(self, x): with torch.cuda.amp.autocast(self.fp16): x = self.layers(x) x = self.conv_sep(x.float() if self.fp16 else x) x = self.features(x) return x def get_mbf(fp16, num_features): return MobileFaceNet(fp16, num_features)