Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import fvcore.nn.weight_init as weight_init | |
| import torch.nn.functional as F | |
| from detectron2.layers import CNNBlockBase, Conv2d, get_norm | |
| from detectron2.modeling import BACKBONE_REGISTRY | |
| from detectron2.modeling.backbone.resnet import ( | |
| BasicStem, | |
| BottleneckBlock, | |
| DeformBottleneckBlock, | |
| ResNet, | |
| ) | |
| class DeepLabStem(CNNBlockBase): | |
| """ | |
| The DeepLab ResNet stem (layers before the first residual block). | |
| """ | |
| def __init__(self, in_channels=3, out_channels=128, norm="BN"): | |
| """ | |
| Args: | |
| norm (str or callable): norm after the first conv layer. | |
| See :func:`layers.get_norm` for supported format. | |
| """ | |
| super().__init__(in_channels, out_channels, 4) | |
| self.in_channels = in_channels | |
| self.conv1 = Conv2d( | |
| in_channels, | |
| out_channels // 2, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| norm=get_norm(norm, out_channels // 2), | |
| ) | |
| self.conv2 = Conv2d( | |
| out_channels // 2, | |
| out_channels // 2, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| norm=get_norm(norm, out_channels // 2), | |
| ) | |
| self.conv3 = Conv2d( | |
| out_channels // 2, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| norm=get_norm(norm, out_channels), | |
| ) | |
| weight_init.c2_msra_fill(self.conv1) | |
| weight_init.c2_msra_fill(self.conv2) | |
| weight_init.c2_msra_fill(self.conv3) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = F.relu_(x) | |
| x = self.conv2(x) | |
| x = F.relu_(x) | |
| x = self.conv3(x) | |
| x = F.relu_(x) | |
| x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) | |
| return x | |
| def build_resnet_deeplab_backbone(cfg, input_shape): | |
| """ | |
| Create a ResNet instance from config. | |
| Returns: | |
| ResNet: a :class:`ResNet` instance. | |
| """ | |
| # need registration of new blocks/stems? | |
| norm = cfg.MODEL.RESNETS.NORM | |
| if cfg.MODEL.RESNETS.STEM_TYPE == "basic": | |
| stem = BasicStem( | |
| in_channels=input_shape.channels, | |
| out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, | |
| norm=norm, | |
| ) | |
| elif cfg.MODEL.RESNETS.STEM_TYPE == "deeplab": | |
| stem = DeepLabStem( | |
| in_channels=input_shape.channels, | |
| out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, | |
| norm=norm, | |
| ) | |
| else: | |
| raise ValueError("Unknown stem type: {}".format(cfg.MODEL.RESNETS.STEM_TYPE)) | |
| # fmt: off | |
| freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT | |
| out_features = cfg.MODEL.RESNETS.OUT_FEATURES | |
| depth = cfg.MODEL.RESNETS.DEPTH | |
| num_groups = cfg.MODEL.RESNETS.NUM_GROUPS | |
| width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP | |
| bottleneck_channels = num_groups * width_per_group | |
| in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS | |
| out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS | |
| stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 | |
| res4_dilation = cfg.MODEL.RESNETS.RES4_DILATION | |
| res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION | |
| deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE | |
| deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED | |
| deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS | |
| res5_multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID | |
| # fmt: on | |
| assert res4_dilation in {1, 2}, "res4_dilation cannot be {}.".format(res4_dilation) | |
| assert res5_dilation in {1, 2, 4}, "res5_dilation cannot be {}.".format(res5_dilation) | |
| if res4_dilation == 2: | |
| # Always dilate res5 if res4 is dilated. | |
| assert res5_dilation == 4 | |
| num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] | |
| stages = [] | |
| # Avoid creating variables without gradients | |
| # It consumes extra memory and may cause allreduce to fail | |
| out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] | |
| max_stage_idx = max(out_stage_idx) | |
| for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): | |
| if stage_idx == 4: | |
| dilation = res4_dilation | |
| elif stage_idx == 5: | |
| dilation = res5_dilation | |
| else: | |
| dilation = 1 | |
| first_stride = 1 if idx == 0 or dilation > 1 else 2 | |
| stage_kargs = { | |
| "num_blocks": num_blocks_per_stage[idx], | |
| "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), | |
| "in_channels": in_channels, | |
| "out_channels": out_channels, | |
| "norm": norm, | |
| } | |
| stage_kargs["bottleneck_channels"] = bottleneck_channels | |
| stage_kargs["stride_in_1x1"] = stride_in_1x1 | |
| stage_kargs["dilation"] = dilation | |
| stage_kargs["num_groups"] = num_groups | |
| if deform_on_per_stage[idx]: | |
| stage_kargs["block_class"] = DeformBottleneckBlock | |
| stage_kargs["deform_modulated"] = deform_modulated | |
| stage_kargs["deform_num_groups"] = deform_num_groups | |
| else: | |
| stage_kargs["block_class"] = BottleneckBlock | |
| if stage_idx == 5: | |
| stage_kargs.pop("dilation") | |
| stage_kargs["dilation_per_block"] = [dilation * mg for mg in res5_multi_grid] | |
| blocks = ResNet.make_stage(**stage_kargs) | |
| in_channels = out_channels | |
| out_channels *= 2 | |
| bottleneck_channels *= 2 | |
| stages.append(blocks) | |
| return ResNet(stem, stages, out_features=out_features).freeze(freeze_at) | |