desco / maskrcnn_benchmark /modeling /roi_heads /box_head /roi_box_feature_extractors.py
zdou0830's picture
desco
749745d
raw
history blame
No virus
7.51 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch.nn import functional as F
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.backbone import resnet
from maskrcnn_benchmark.modeling.poolers import Pooler
from maskrcnn_benchmark.modeling.make_layers import group_norm
from maskrcnn_benchmark.modeling.make_layers import make_fc
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("LightheadFeatureExtractor")
class LightheadFeatureExtractor(nn.Module):
def __init__(self, cfg):
super(LightheadFeatureExtractor, self).__init__()
resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
input_size = 10 * resolution**2
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
C_in, C_mid, C_out = cfg.MODEL.BACKBONE.OUT_CHANNELS, 256, input_size
self.separable_conv_11 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0))
self.separable_conv_12 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7))
self.separable_conv_21 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0))
self.separable_conv_22 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7))
for module in [self.separable_conv_11, self.separable_conv_12, self.separable_conv_21, self.separable_conv_22]:
# Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch
nn.init.kaiming_uniform_(module.weight, a=1)
self.pooler = pooler
self.fc6 = make_fc(
input_size * resolution**2, representation_size, use_gn
) # <TODO> wait official repo to support psroi
def forward(self, x, proposals):
light = []
for feat in x:
sc11 = self.separable_conv_11(feat)
sc12 = self.separable_conv_12(sc11)
sc21 = self.separable_conv_21(feat)
sc22 = self.separable_conv_22(sc21)
out = sc12 + sc22
light.append(out)
x = self.pooler(light, proposals)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
return x
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
class ResNet50Conv5ROIFeatureExtractor(nn.Module):
def __init__(self, config):
super(ResNet50Conv5ROIFeatureExtractor, self).__init__()
resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES
sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
stage = resnet.StageSpec(index=4, block_count=3, return_features=False)
head = resnet.ResNetHead(
block_module=config.MODEL.RESNETS.TRANS_FUNC,
stages=(stage,),
num_groups=config.MODEL.RESNETS.NUM_GROUPS,
width_per_group=config.MODEL.RESNETS.WIDTH_PER_GROUP,
stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1,
stride_init=None,
res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS,
dilation=config.MODEL.RESNETS.RES5_DILATION,
)
self.pooler = pooler
self.head = head
def forward(self, x, proposals):
x = self.pooler(x, proposals)
x = self.head(x)
return x
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor")
class FPN2MLPFeatureExtractor(nn.Module):
"""
Heads for FPN for classification
"""
def __init__(self, cfg):
super(FPN2MLPFeatureExtractor, self).__init__()
resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution**2
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
self.pooler = pooler
self.fc6 = make_fc(input_size, representation_size, use_gn)
self.fc7 = make_fc(representation_size, representation_size, use_gn)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
return x
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPNXconv1fcFeatureExtractor")
class FPNXconv1fcFeatureExtractor(nn.Module):
"""
Heads for FPN for classification
"""
def __init__(self, cfg):
super(FPNXconv1fcFeatureExtractor, self).__init__()
resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
self.pooler = pooler
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
conv_head_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM
num_stacked_convs = cfg.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS
dilation = cfg.MODEL.ROI_BOX_HEAD.DILATION
xconvs = []
for ix in range(num_stacked_convs):
xconvs.append(
nn.Conv2d(
in_channels,
conv_head_dim,
kernel_size=3,
stride=1,
padding=dilation,
dilation=dilation,
bias=False if use_gn else True,
)
)
in_channels = conv_head_dim
if use_gn:
xconvs.append(group_norm(in_channels))
xconvs.append(nn.ReLU(inplace=True))
self.add_module("xconvs", nn.Sequential(*xconvs))
for modules in [
self.xconvs,
]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
if not use_gn:
torch.nn.init.constant_(l.bias, 0)
input_size = conv_head_dim * resolution**2
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
self.fc6 = make_fc(input_size, representation_size, use_gn=False)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
x = self.xconvs(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
return x
def make_roi_box_feature_extractor(cfg):
func = registry.ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR]
return func(cfg)