miewid-imagenet / modeling_miewid.py
lashao's picture
Upload MiewIdNet
143c69d verified
raw
history blame contribute delete
No virus
5.34 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from transformers import PreTrainedModel
from .heads import ArcMarginProduct, ElasticArcFace, ArcFaceSubCenterDynamic
from .configuration_miewid import MiewIdNetConfig
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
nn.init.constant_(m.bias, 0.0)
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if m.bias:
nn.init.constant_(m.bias, 0.0)
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1)*p)
self.eps = eps
def forward(self, x):
return self.gem(x, p=self.p, eps=self.eps)
def gem(self, x, p=3, eps=1e-6):
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
def __repr__(self):
return self.__class__.__name__ + \
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
', ' + 'eps=' + str(self.eps) + ')'
class MiewIdNet(PreTrainedModel):
config_class = MiewIdNetConfig
def __init__(self, config):
"""
"""
super(MiewIdNet, self).__init__(config)
print('Building Model Backbone for {} model'.format(config.model_name))
print('config.model_name', config.model_name)
n_classes=config.n_classes
model_name=config.model_name
use_fc=False
fc_dim=512
dropout=0.0
loss_module=config.loss_module
s=30.0
margin=0.50
ls_eps=0.0
theta_zero=0.785
pretrained=True
margins=config.k
k=config.k
print('model_name', model_name)
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
final_in_features = 2152#self.backbone.classifier.in_features
print('final_in_features', final_in_features)
# self.backbone.classifier = nn.Identity()
self.backbone.global_pool = GeM()#nn.Identity()
# self.pooling = GeM()
self.bn = nn.BatchNorm1d(final_in_features)
self.use_fc = use_fc
if use_fc:
self.dropout = nn.Dropout(p=dropout)
self.bn = nn.BatchNorm1d(fc_dim)
self.bn.bias.requires_grad_(False)
self.fc = nn.Linear(final_in_features, n_classes, bias = False)
self.bn.apply(weights_init_kaiming)
self.fc.apply(weights_init_classifier)
final_in_features = fc_dim
self.loss_module = loss_module
if loss_module == 'arcface':
self.final = ElasticArcFace(final_in_features, n_classes,
s=s, m=margin)
elif loss_module == 'arcface_subcenter_dynamic':
if margins is None:
margins = [0.3] * n_classes
print(final_in_features, n_classes)
self.final = ArcFaceSubCenterDynamic(
embedding_dim=final_in_features,
output_classes=n_classes,
margins=margins,
s=s,
k=k )
# elif loss_module == 'cosface':
# self.final = AddMarginProduct(final_in_features, n_classes, s=s, m=margin)
# elif loss_module == 'adacos':
# self.final = AdaCos(final_in_features, n_classes, m=margin, theta_zero=theta_zero)
else:
self.final = nn.Linear(final_in_features, n_classes)
def _init_params(self):
nn.init.xavier_normal_(self.fc.weight)
nn.init.constant_(self.fc.bias, 0)
nn.init.constant_(self.bn.weight, 1)
nn.init.constant_(self.bn.bias, 0)
def forward(self, x, label=None):
feature = self.extract_feat(x)
return feature
# if not self.training:
# return feature
# else:
# assert label is not None
# if self.loss_module in ('arcface', 'arcface_subcenter_dynamic'):
# logits = self.final(feature, label)
# else:
# logits = self.final(feature)
# return logits
def extract_feat(self, x):
batch_size = x.shape[0]
x = self.backbone(x).view(batch_size, -1)
# x = self.pooling(x).view(batch_size, -1)
x = self.bn(x)
if self.use_fc:
x1 = self.dropout(x)
x1 = self.bn(x1)
x1 = self.fc(x1)
return x
def extract_logits(self, x, label=None):
feature = self.extract_feat(x)
assert label is not None
if self.loss_module in ('arcface', 'arcface_subcenter_dynamic'):
logits = self.final(feature, label)
else:
logits = self.final(feature)
return logits