Robert001's picture
first commit
b334e29
raw
history blame
No virus
3.23 kB
import torch.nn as nn
from mmcv.cnn import ConvModule, Linear, constant_init, xavier_init
from mmcv.runner import auto_fp16
from mmdet.models.builder import HEADS
from .fcn_mask_head import FCNMaskHead
@HEADS.register_module()
class CoarseMaskHead(FCNMaskHead):
"""Coarse mask head used in PointRend.
Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample
the input feature map instead of upsample it.
Args:
num_convs (int): Number of conv layers in the head. Default: 0.
num_fcs (int): Number of fc layers in the head. Default: 2.
fc_out_channels (int): Number of output channels of fc layer.
Default: 1024.
downsample_factor (int): The factor that feature map is downsampled by.
Default: 2.
"""
def __init__(self,
num_convs=0,
num_fcs=2,
fc_out_channels=1024,
downsample_factor=2,
*arg,
**kwarg):
super(CoarseMaskHead, self).__init__(
*arg, num_convs=num_convs, upsample_cfg=dict(type=None), **kwarg)
self.num_fcs = num_fcs
assert self.num_fcs > 0
self.fc_out_channels = fc_out_channels
self.downsample_factor = downsample_factor
assert self.downsample_factor >= 1
# remove conv_logit
delattr(self, 'conv_logits')
if downsample_factor > 1:
downsample_in_channels = (
self.conv_out_channels
if self.num_convs > 0 else self.in_channels)
self.downsample_conv = ConvModule(
downsample_in_channels,
self.conv_out_channels,
kernel_size=downsample_factor,
stride=downsample_factor,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
else:
self.downsample_conv = None
self.output_size = (self.roi_feat_size[0] // downsample_factor,
self.roi_feat_size[1] // downsample_factor)
self.output_area = self.output_size[0] * self.output_size[1]
last_layer_dim = self.conv_out_channels * self.output_area
self.fcs = nn.ModuleList()
for i in range(num_fcs):
fc_in_channels = (
last_layer_dim if i == 0 else self.fc_out_channels)
self.fcs.append(Linear(fc_in_channels, self.fc_out_channels))
last_layer_dim = self.fc_out_channels
output_channels = self.num_classes * self.output_area
self.fc_logits = Linear(last_layer_dim, output_channels)
def init_weights(self):
for m in self.fcs.modules():
if isinstance(m, nn.Linear):
xavier_init(m)
constant_init(self.fc_logits, 0.001)
@auto_fp16()
def forward(self, x):
for conv in self.convs:
x = conv(x)
if self.downsample_conv is not None:
x = self.downsample_conv(x)
x = x.flatten(1)
for fc in self.fcs:
x = self.relu(fc(x))
mask_pred = self.fc_logits(x).view(
x.size(0), self.num_classes, *self.output_size)
return mask_pred