Spaces:
Build error
Build error
import torch.nn as nn | |
from mmcv.cnn import ConvModule, xavier_init | |
from ..builder import NECKS | |
class ChannelMapper(nn.Module): | |
r"""Channel Mapper to reduce/increase channels of backbone features. | |
This is used to reduce/increase channels of backbone features. | |
Args: | |
in_channels (List[int]): Number of input channels per scale. | |
out_channels (int): Number of output channels (used at each scale). | |
kernel_size (int, optional): kernel_size for reducing channels (used | |
at each scale). Default: 3. | |
conv_cfg (dict, optional): Config dict for convolution layer. | |
Default: None. | |
norm_cfg (dict, optional): Config dict for normalization layer. | |
Default: None. | |
act_cfg (dict, optional): Config dict for activation layer in | |
ConvModule. Default: dict(type='ReLU'). | |
Example: | |
>>> import torch | |
>>> in_channels = [2, 3, 5, 7] | |
>>> scales = [340, 170, 84, 43] | |
>>> inputs = [torch.rand(1, c, s, s) | |
... for c, s in zip(in_channels, scales)] | |
>>> self = ChannelMapper(in_channels, 11, 3).eval() | |
>>> outputs = self.forward(inputs) | |
>>> for i in range(len(outputs)): | |
... print(f'outputs[{i}].shape = {outputs[i].shape}') | |
outputs[0].shape = torch.Size([1, 11, 340, 340]) | |
outputs[1].shape = torch.Size([1, 11, 170, 170]) | |
outputs[2].shape = torch.Size([1, 11, 84, 84]) | |
outputs[3].shape = torch.Size([1, 11, 43, 43]) | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=dict(type='ReLU')): | |
super(ChannelMapper, self).__init__() | |
assert isinstance(in_channels, list) | |
self.convs = nn.ModuleList() | |
for in_channel in in_channels: | |
self.convs.append( | |
ConvModule( | |
in_channel, | |
out_channels, | |
kernel_size, | |
padding=(kernel_size - 1) // 2, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
# default init_weights for conv(msra) and norm in ConvModule | |
def init_weights(self): | |
"""Initialize the weights of ChannelMapper module.""" | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
xavier_init(m, distribution='uniform') | |
def forward(self, inputs): | |
"""Forward function.""" | |
assert len(inputs) == len(self.convs) | |
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] | |
return tuple(outs) | |