| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torchvision import models |
| from torchvision.ops.misc import FrozenBatchNorm2d |
|
|
|
|
| class Backbone(nn.Module): |
|
|
| def __init__( |
| self, |
| name: str, |
| pretrained: bool, |
| dilation: bool, |
| reduction: int, |
| swav: bool, |
| requires_grad: bool |
| ): |
|
|
| super(Backbone, self).__init__() |
|
|
| resnet = getattr(models, name)( |
| replace_stride_with_dilation=[False, False, dilation], |
| pretrained=pretrained, norm_layer=FrozenBatchNorm2d |
| ) |
|
|
| self.backbone = resnet |
| self.reduction = reduction |
|
|
| if name == 'resnet50' and swav: |
| checkpoint = torch.hub.load_state_dict_from_url( |
| 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar', |
| map_location="cpu" |
| ) |
| state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()} |
| self.backbone.load_state_dict(state_dict, strict=False) |
|
|
| |
| self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584 |
|
|
| for n, param in self.backbone.named_parameters(): |
| if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n: |
| param.requires_grad_(False) |
| else: |
| param.requires_grad_(requires_grad) |
|
|
| def forward(self, x): |
| size = x.size(-2) // self.reduction, x.size(-1) // self.reduction |
| x = self.backbone.conv1(x) |
| x = self.backbone.bn1(x) |
| x = self.backbone.relu(x) |
| x = self.backbone.maxpool(x) |
|
|
| x = self.backbone.layer1(x) |
| x = layer2 = self.backbone.layer2(x) |
| x = layer3 = self.backbone.layer3(x) |
| x = layer4 = self.backbone.layer4(x) |
|
|
| x = torch.cat([ |
| F.interpolate(f, size=size, mode='bilinear', align_corners=True) |
| for f in [layer2, layer3, layer4] |
| ], dim=1) |
|
|
| return x |
|
|