|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
from torchvision.ops.misc import FrozenBatchNorm2d |
|
|
|
from torchvision.models import resnet, detection, segmentation |
|
|
|
import timm |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def convert_frozen_batchnorm(module): |
|
bn_module = ( |
|
nn.modules.batchnorm.BatchNorm2d, |
|
nn.modules.batchnorm.SyncBatchNorm |
|
) |
|
res = module |
|
if isinstance(module, bn_module): |
|
res = FrozenBatchNorm2d(module.num_features) |
|
if module.affine: |
|
res.weight.data = module.weight.data.clone().detach() |
|
res.bias.data = module.bias.data.clone().detach() |
|
res.running_mean.data = module.running_mean.data |
|
res.running_var.data = module.running_var.data |
|
res.eps = module.eps |
|
else: |
|
for name, child in module.named_children(): |
|
new_child = convert_frozen_batchnorm(child) |
|
if new_child is not child: |
|
res.add_module(name, new_child) |
|
return res |
|
|
|
|
|
def get_backbone(backbone, pretrained=True): |
|
if backbone in ('resnet18', 'resnet34', 'resnet50', 'resnet101'): |
|
|
|
model = resnet.__dict__[backbone]( |
|
pretrained=pretrained, norm_layer=FrozenBatchNorm2d |
|
) |
|
elif backbone == 'resnet50d': |
|
|
|
model = convert_frozen_batchnorm( |
|
detection.fasterrcnn_resnet50_fpn(pretrained=pretrained).backbone.body |
|
) |
|
elif backbone == 'resnet50s': |
|
|
|
model = convert_frozen_batchnorm( |
|
segmentation.deeplabv3_resnet50(pretrained=pretrained).backbone |
|
) |
|
elif backbone == 'resnet101s': |
|
|
|
model = convert_frozen_batchnorm( |
|
segmentation.deeplabv3_resnet101(pretrained=pretrained).backbone |
|
) |
|
|
|
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = convert_frozen_batchnorm( |
|
timm.create_model( |
|
backbone.replace('-', '_'), |
|
pretrained=pretrained, |
|
num_classes=0, |
|
global_pool='' |
|
) |
|
) |
|
|
|
else: |
|
raise RuntimeError(f'{backbone} is not a valid backbone') |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
return model |
|
|