import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp from mmcv.cnn import build_conv_layer, build_norm_layer, kaiming_init from torch.nn.modules.utils import _pair from mmdet.models.backbones.resnet import Bottleneck, ResNet from mmdet.models.builder import BACKBONES class TridentConv(nn.Module): """Trident Convolution Module. Args: in_channels (int): Number of channels in input. out_channels (int): Number of channels in output. kernel_size (int): Size of convolution kernel. stride (int, optional): Convolution stride. Default: 1. trident_dilations (tuple[int, int, int], optional): Dilations of different trident branch. Default: (1, 2, 3). test_branch_idx (int, optional): In inference, all 3 branches will be used if `test_branch_idx==-1`, otherwise only branch with index `test_branch_idx` will be used. Default: 1. bias (bool, optional): Whether to use bias in convolution or not. Default: False. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, trident_dilations=(1, 2, 3), test_branch_idx=1, bias=False): super(TridentConv, self).__init__() self.num_branch = len(trident_dilations) self.with_bias = bias self.test_branch_idx = test_branch_idx self.stride = _pair(stride) self.kernel_size = _pair(kernel_size) self.paddings = _pair(trident_dilations) self.dilations = trident_dilations self.in_channels = in_channels self.out_channels = out_channels self.bias = bias self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels, *self.kernel_size)) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.bias = None self.init_weights() def init_weights(self): kaiming_init(self, distribution='uniform', mode='fan_in') def extra_repr(self): tmpstr = f'in_channels={self.in_channels}' tmpstr += f', out_channels={self.out_channels}' tmpstr += f', kernel_size={self.kernel_size}' tmpstr += f', num_branch={self.num_branch}' tmpstr += f', test_branch_idx={self.test_branch_idx}' tmpstr += f', stride={self.stride}' tmpstr += f', paddings={self.paddings}' tmpstr += f', dilations={self.dilations}' tmpstr += f', bias={self.bias}' return tmpstr def forward(self, inputs): if self.training or self.test_branch_idx == -1: outputs = [ F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation) for input, dilation, padding in zip( inputs, self.dilations, self.paddings) ] else: assert len(inputs) == 1 outputs = [ F.conv2d(inputs[0], self.weight, self.bias, self.stride, self.paddings[self.test_branch_idx], self.dilations[self.test_branch_idx]) ] return outputs # Since TridentNet is defined over ResNet50 and ResNet101, here we # only support TridentBottleneckBlock. class TridentBottleneck(Bottleneck): """BottleBlock for TridentResNet. Args: trident_dilations (tuple[int, int, int]): Dilations of different trident branch. test_branch_idx (int): In inference, all 3 branches will be used if `test_branch_idx==-1`, otherwise only branch with index `test_branch_idx` will be used. concat_output (bool): Whether to concat the output list to a Tensor. `True` only in the last Block. """ def __init__(self, trident_dilations, test_branch_idx, concat_output, **kwargs): super(TridentBottleneck, self).__init__(**kwargs) self.trident_dilations = trident_dilations self.num_branch = len(trident_dilations) self.concat_output = concat_output self.test_branch_idx = test_branch_idx self.conv2 = TridentConv( self.planes, self.planes, kernel_size=3, stride=self.conv2_stride, bias=False, trident_dilations=self.trident_dilations, test_branch_idx=test_branch_idx) def forward(self, x): def _inner_forward(x): num_branch = ( self.num_branch if self.training or self.test_branch_idx == -1 else 1) identity = x if not isinstance(x, list): x = (x, ) * num_branch identity = x if self.downsample is not None: identity = [self.downsample(b) for b in x] out = [self.conv1(b) for b in x] out = [self.norm1(b) for b in out] out = [self.relu(b) for b in out] if self.with_plugins: for k in range(len(out)): out[k] = self.forward_plugin(out[k], self.after_conv1_plugin_names) out = self.conv2(out) out = [self.norm2(b) for b in out] out = [self.relu(b) for b in out] if self.with_plugins: for k in range(len(out)): out[k] = self.forward_plugin(out[k], self.after_conv2_plugin_names) out = [self.conv3(b) for b in out] out = [self.norm3(b) for b in out] if self.with_plugins: for k in range(len(out)): out[k] = self.forward_plugin(out[k], self.after_conv3_plugin_names) out = [ out_b + identity_b for out_b, identity_b in zip(out, identity) ] return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = [self.relu(b) for b in out] if self.concat_output: out = torch.cat(out, dim=0) return out def make_trident_res_layer(block, inplanes, planes, num_blocks, stride=1, trident_dilations=(1, 2, 3), style='pytorch', with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), dcn=None, plugins=None, test_branch_idx=-1): """Build Trident Res Layers.""" downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = [] conv_stride = stride downsample.extend([ build_conv_layer( conv_cfg, inplanes, planes * block.expansion, kernel_size=1, stride=conv_stride, bias=False), build_norm_layer(norm_cfg, planes * block.expansion)[1] ]) downsample = nn.Sequential(*downsample) layers = [] for i in range(num_blocks): layers.append( block( inplanes=inplanes, planes=planes, stride=stride if i == 0 else 1, trident_dilations=trident_dilations, downsample=downsample if i == 0 else None, style=style, with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, dcn=dcn, plugins=plugins, test_branch_idx=test_branch_idx, concat_output=True if i == num_blocks - 1 else False)) inplanes = planes * block.expansion return nn.Sequential(*layers) @BACKBONES.register_module() class TridentResNet(ResNet): """The stem layer, stage 1 and stage 2 in Trident ResNet are identical to ResNet, while in stage 3, Trident BottleBlock is utilized to replace the normal BottleBlock to yield trident output. Different branch shares the convolution weight but uses different dilations to achieve multi-scale output. / stage3(b0) \ x - stem - stage1 - stage2 - stage3(b1) - output \ stage3(b2) / Args: depth (int): Depth of resnet, from {50, 101, 152}. num_branch (int): Number of branches in TridentNet. test_branch_idx (int): In inference, all 3 branches will be used if `test_branch_idx==-1`, otherwise only branch with index `test_branch_idx` will be used. trident_dilations (tuple[int]): Dilations of different trident branch. len(trident_dilations) should be equal to num_branch. """ # noqa def __init__(self, depth, num_branch, test_branch_idx, trident_dilations, **kwargs): assert num_branch == len(trident_dilations) assert depth in (50, 101, 152) super(TridentResNet, self).__init__(depth, **kwargs) assert self.num_stages == 3 self.test_branch_idx = test_branch_idx self.num_branch = num_branch last_stage_idx = self.num_stages - 1 stride = self.strides[last_stage_idx] dilation = trident_dilations dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None if self.plugins is not None: stage_plugins = self.make_stage_plugins(self.plugins, last_stage_idx) else: stage_plugins = None planes = self.base_channels * 2**last_stage_idx res_layer = make_trident_res_layer( TridentBottleneck, inplanes=(self.block.expansion * self.base_channels * 2**(last_stage_idx - 1)), planes=planes, num_blocks=self.stage_blocks[last_stage_idx], stride=stride, trident_dilations=dilation, style=self.style, with_cp=self.with_cp, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, dcn=dcn, plugins=stage_plugins, test_branch_idx=self.test_branch_idx) layer_name = f'layer{last_stage_idx + 1}' self.__setattr__(layer_name, res_layer) self.res_layers.pop(last_stage_idx) self.res_layers.insert(last_stage_idx, layer_name) self._freeze_stages()