|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import timm |
|
|
|
from transformers import PreTrainedModel |
|
|
|
from .heads import ArcMarginProduct, ElasticArcFace, ArcFaceSubCenterDynamic |
|
from .configuration_miewid import MiewIdNetConfig |
|
|
|
def weights_init_kaiming(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Linear') != -1: |
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') |
|
nn.init.constant_(m.bias, 0.0) |
|
elif classname.find('Conv') != -1: |
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0.0) |
|
elif classname.find('BatchNorm') != -1: |
|
if m.affine: |
|
nn.init.constant_(m.weight, 1.0) |
|
nn.init.constant_(m.bias, 0.0) |
|
|
|
|
|
def weights_init_classifier(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Linear') != -1: |
|
nn.init.normal_(m.weight, std=0.001) |
|
if m.bias: |
|
nn.init.constant_(m.bias, 0.0) |
|
|
|
class GeM(nn.Module): |
|
def __init__(self, p=3, eps=1e-6): |
|
super(GeM, self).__init__() |
|
self.p = nn.Parameter(torch.ones(1)*p) |
|
self.eps = eps |
|
|
|
def forward(self, x): |
|
return self.gem(x, p=self.p, eps=self.eps) |
|
|
|
def gem(self, x, p=3, eps=1e-6): |
|
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + \ |
|
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \ |
|
', ' + 'eps=' + str(self.eps) + ')' |
|
|
|
class MiewIdNet(PreTrainedModel): |
|
config_class = MiewIdNetConfig |
|
|
|
def __init__(self, config): |
|
""" |
|
""" |
|
super(MiewIdNet, self).__init__(config) |
|
print('Building Model Backbone for {} model'.format(config.model_name)) |
|
print('config.model_name', config.model_name) |
|
|
|
n_classes=config.n_classes |
|
model_name=config.model_name |
|
use_fc=False |
|
fc_dim=512 |
|
dropout=0.0 |
|
loss_module=config.loss_module |
|
s=30.0 |
|
margin=0.50 |
|
ls_eps=0.0 |
|
theta_zero=0.785 |
|
pretrained=True |
|
margins=config.k |
|
k=config.k |
|
|
|
print('model_name', model_name) |
|
|
|
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) |
|
final_in_features = 2152 |
|
|
|
print('final_in_features', final_in_features) |
|
|
|
|
|
self.backbone.global_pool = GeM() |
|
|
|
|
|
self.bn = nn.BatchNorm1d(final_in_features) |
|
self.use_fc = use_fc |
|
if use_fc: |
|
self.dropout = nn.Dropout(p=dropout) |
|
self.bn = nn.BatchNorm1d(fc_dim) |
|
self.bn.bias.requires_grad_(False) |
|
self.fc = nn.Linear(final_in_features, n_classes, bias = False) |
|
self.bn.apply(weights_init_kaiming) |
|
self.fc.apply(weights_init_classifier) |
|
final_in_features = fc_dim |
|
|
|
self.loss_module = loss_module |
|
if loss_module == 'arcface': |
|
self.final = ElasticArcFace(final_in_features, n_classes, |
|
s=s, m=margin) |
|
elif loss_module == 'arcface_subcenter_dynamic': |
|
if margins is None: |
|
margins = [0.3] * n_classes |
|
print(final_in_features, n_classes) |
|
self.final = ArcFaceSubCenterDynamic( |
|
embedding_dim=final_in_features, |
|
output_classes=n_classes, |
|
margins=margins, |
|
s=s, |
|
k=k ) |
|
|
|
|
|
|
|
|
|
else: |
|
self.final = nn.Linear(final_in_features, n_classes) |
|
|
|
def _init_params(self): |
|
nn.init.xavier_normal_(self.fc.weight) |
|
nn.init.constant_(self.fc.bias, 0) |
|
nn.init.constant_(self.bn.weight, 1) |
|
nn.init.constant_(self.bn.bias, 0) |
|
|
|
def forward(self, x, label=None): |
|
feature = self.extract_feat(x) |
|
|
|
return feature |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_feat(self, x): |
|
batch_size = x.shape[0] |
|
x = self.backbone(x).view(batch_size, -1) |
|
|
|
x = self.bn(x) |
|
if self.use_fc: |
|
x1 = self.dropout(x) |
|
x1 = self.bn(x1) |
|
x1 = self.fc(x1) |
|
|
|
return x |
|
|
|
def extract_logits(self, x, label=None): |
|
feature = self.extract_feat(x) |
|
assert label is not None |
|
if self.loss_module in ('arcface', 'arcface_subcenter_dynamic'): |
|
logits = self.final(feature, label) |
|
else: |
|
logits = self.final(feature) |
|
|
|
return logits |