# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import fvcore.nn.weight_init as weight_init import torch import torch.nn.functional as F from detectron2.layers import Conv2d, FrozenBatchNorm2d, get_norm from detectron2.modeling import BACKBONE_REGISTRY, ResNet, ResNetBlockBase, make_stage from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock from .trident_conv import TridentConv __all__ = ["TridentBottleneckBlock", "make_trident_stage", "build_trident_resnet_backbone"] class TridentBottleneckBlock(ResNetBlockBase): def __init__( self, in_channels, out_channels, *, bottleneck_channels, stride=1, num_groups=1, norm="BN", stride_in_1x1=False, num_branch=3, dilations=(1, 2, 3), concat_output=False, test_branch_idx=-1, ): """ Args: num_branch (int): the number of branches in TridentNet. dilations (tuple): the dilations of multiple branches in TridentNet. concat_output (bool): if concatenate outputs of multiple branches in TridentNet. Use 'True' for the last trident block. """ super().__init__(in_channels, out_channels, stride) assert num_branch == len(dilations) self.num_branch = num_branch self.concat_output = concat_output self.test_branch_idx = test_branch_idx if in_channels != out_channels: self.shortcut = Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=False, norm=get_norm(norm, out_channels), ) else: self.shortcut = None stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) self.conv1 = Conv2d( in_channels, bottleneck_channels, kernel_size=1, stride=stride_1x1, bias=False, norm=get_norm(norm, bottleneck_channels), ) self.conv2 = TridentConv( bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride_3x3, paddings=dilations, bias=False, groups=num_groups, dilations=dilations, num_branch=num_branch, test_branch_idx=test_branch_idx, norm=get_norm(norm, bottleneck_channels), ) self.conv3 = Conv2d( bottleneck_channels, out_channels, kernel_size=1, bias=False, norm=get_norm(norm, out_channels), ) for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: if layer is not None: # shortcut can be None weight_init.c2_msra_fill(layer) def forward(self, x): num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 if not isinstance(x, list): x = [x] * num_branch out = [self.conv1(b) for b in x] out = [F.relu_(b) for b in out] out = self.conv2(out) out = [F.relu_(b) for b in out] out = [self.conv3(b) for b in out] if self.shortcut is not None: shortcut = [self.shortcut(b) for b in x] else: shortcut = x out = [out_b + shortcut_b for out_b, shortcut_b in zip(out, shortcut)] out = [F.relu_(b) for b in out] if self.concat_output: out = torch.cat(out) return out def make_trident_stage(block_class, num_blocks, first_stride, **kwargs): """ Create a resnet stage by creating many blocks for TridentNet. """ blocks = [] for i in range(num_blocks - 1): blocks.append(block_class(stride=first_stride if i == 0 else 1, **kwargs)) kwargs["in_channels"] = kwargs["out_channels"] blocks.append(block_class(stride=1, concat_output=True, **kwargs)) return blocks @BACKBONE_REGISTRY.register() def build_trident_resnet_backbone(cfg, input_shape): """ Create a ResNet instance from config for TridentNet. Returns: ResNet: a :class:`ResNet` instance. """ # need registration of new blocks/stems? norm = cfg.MODEL.RESNETS.NORM stem = BasicStem( in_channels=input_shape.channels, out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, norm=norm, ) freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT if freeze_at >= 1: for p in stem.parameters(): p.requires_grad = False stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem) # fmt: off out_features = cfg.MODEL.RESNETS.OUT_FEATURES depth = cfg.MODEL.RESNETS.DEPTH num_groups = cfg.MODEL.RESNETS.NUM_GROUPS width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP bottleneck_channels = num_groups * width_per_group in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX # fmt: on assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] stages = [] res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5} out_stage_idx = [res_stage_idx[f] for f in out_features] trident_stage_idx = res_stage_idx[trident_stage] max_stage_idx = max(out_stage_idx) for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): dilation = res5_dilation if stage_idx == 5 else 1 first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 stage_kargs = { "num_blocks": num_blocks_per_stage[idx], "first_stride": first_stride, "in_channels": in_channels, "bottleneck_channels": bottleneck_channels, "out_channels": out_channels, "num_groups": num_groups, "norm": norm, "stride_in_1x1": stride_in_1x1, "dilation": dilation, } if stage_idx == trident_stage_idx: assert not deform_on_per_stage[ idx ], "Not support deformable conv in Trident blocks yet." stage_kargs["block_class"] = TridentBottleneckBlock stage_kargs["num_branch"] = num_branch stage_kargs["dilations"] = branch_dilations stage_kargs["test_branch_idx"] = test_branch_idx stage_kargs.pop("dilation") elif deform_on_per_stage[idx]: stage_kargs["block_class"] = DeformBottleneckBlock stage_kargs["deform_modulated"] = deform_modulated stage_kargs["deform_num_groups"] = deform_num_groups else: stage_kargs["block_class"] = BottleneckBlock blocks = ( make_trident_stage(**stage_kargs) if stage_idx == trident_stage_idx else make_stage(**stage_kargs) ) in_channels = out_channels out_channels *= 2 bottleneck_channels *= 2 if freeze_at >= stage_idx: for block in blocks: block.freeze() stages.append(blocks) return ResNet(stem, stages, out_features=out_features)