# Copyright (c) 2019 Western Digital Corporation or its affiliates. import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from ..builder import NECKS class DetectionBlock(nn.Module): """Detection block in YOLO neck. Let out_channels = n, the DetectionBlock contains: Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer. The first 6 ConvLayers are formed the following way: 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n. The Conv2D layer is 1x1x255. Some block will have branch after the fifth ConvLayer. The input channel is arbitrary (in_channels) Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN', requires_grad=True) act_cfg (dict): Config dict for activation layer. Default: dict(type='LeakyReLU', negative_slope=0.1). """ def __init__(self, in_channels, out_channels, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): super(DetectionBlock, self).__init__() double_out_channels = out_channels * 2 # shortcut cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.conv1 = ConvModule(in_channels, out_channels, 1, **cfg) self.conv2 = ConvModule( out_channels, double_out_channels, 3, padding=1, **cfg) self.conv3 = ConvModule(double_out_channels, out_channels, 1, **cfg) self.conv4 = ConvModule( out_channels, double_out_channels, 3, padding=1, **cfg) self.conv5 = ConvModule(double_out_channels, out_channels, 1, **cfg) def forward(self, x): tmp = self.conv1(x) tmp = self.conv2(tmp) tmp = self.conv3(tmp) tmp = self.conv4(tmp) out = self.conv5(tmp) return out @NECKS.register_module() class YOLOV3Neck(nn.Module): """The neck of YOLOV3. It can be treated as a simplified version of FPN. It will take the result from Darknet backbone and do some upsampling and concatenation. It will finally output the detection result. Note: The input feats should be from top to bottom. i.e., from high-lvl to low-lvl But YOLOV3Neck will process them in reversed order. i.e., from bottom (high-lvl) to top (low-lvl) Args: num_scales (int): The number of scales / stages. in_channels (int): The number of input channels. out_channels (int): The number of output channels. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN', requires_grad=True) act_cfg (dict): Config dict for activation layer. Default: dict(type='LeakyReLU', negative_slope=0.1). """ def __init__(self, num_scales, in_channels, out_channels, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): super(YOLOV3Neck, self).__init__() assert (num_scales == len(in_channels) == len(out_channels)) self.num_scales = num_scales self.in_channels = in_channels self.out_channels = out_channels # shortcut cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) # To support arbitrary scales, the code looks awful, but it works. # Better solution is welcomed. self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg) for i in range(1, self.num_scales): in_c, out_c = self.in_channels[i], self.out_channels[i] self.add_module(f'conv{i}', ConvModule(in_c, out_c, 1, **cfg)) # in_c + out_c : High-lvl feats will be cat with low-lvl feats self.add_module(f'detect{i+1}', DetectionBlock(in_c + out_c, out_c, **cfg)) def forward(self, feats): assert len(feats) == self.num_scales # processed from bottom (high-lvl) to top (low-lvl) outs = [] out = self.detect1(feats[-1]) outs.append(out) for i, x in enumerate(reversed(feats[:-1])): conv = getattr(self, f'conv{i+1}') tmp = conv(out) # Cat with low-lvl feats tmp = F.interpolate(tmp, scale_factor=2) tmp = torch.cat((tmp, x), 1) detect = getattr(self, f'detect{i+2}') out = detect(tmp) outs.append(out) return tuple(outs) def init_weights(self): """Initialize the weights of module.""" # init is done in ConvModule pass