Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from maskrcnn_benchmark.modeling import registry | |
from maskrcnn_benchmark.modeling.backbone import resnet | |
from maskrcnn_benchmark.modeling.poolers import Pooler | |
from maskrcnn_benchmark.modeling.make_layers import group_norm | |
from maskrcnn_benchmark.modeling.make_layers import make_fc | |
class LightheadFeatureExtractor(nn.Module): | |
def __init__(self, cfg): | |
super(LightheadFeatureExtractor, self).__init__() | |
resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION | |
scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES | |
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
input_size = 10 * resolution**2 | |
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM | |
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN | |
C_in, C_mid, C_out = cfg.MODEL.BACKBONE.OUT_CHANNELS, 256, input_size | |
self.separable_conv_11 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0)) | |
self.separable_conv_12 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7)) | |
self.separable_conv_21 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0)) | |
self.separable_conv_22 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7)) | |
for module in [self.separable_conv_11, self.separable_conv_12, self.separable_conv_21, self.separable_conv_22]: | |
# Caffe2 implementation uses XavierFill, which in fact | |
# corresponds to kaiming_uniform_ in PyTorch | |
nn.init.kaiming_uniform_(module.weight, a=1) | |
self.pooler = pooler | |
self.fc6 = make_fc( | |
input_size * resolution**2, representation_size, use_gn | |
) # <TODO> wait official repo to support psroi | |
def forward(self, x, proposals): | |
light = [] | |
for feat in x: | |
sc11 = self.separable_conv_11(feat) | |
sc12 = self.separable_conv_12(sc11) | |
sc21 = self.separable_conv_21(feat) | |
sc22 = self.separable_conv_22(sc21) | |
out = sc12 + sc22 | |
light.append(out) | |
x = self.pooler(light, proposals) | |
x = x.view(x.size(0), -1) | |
x = F.relu(self.fc6(x)) | |
return x | |
class ResNet50Conv5ROIFeatureExtractor(nn.Module): | |
def __init__(self, config): | |
super(ResNet50Conv5ROIFeatureExtractor, self).__init__() | |
resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION | |
scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES | |
sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
stage = resnet.StageSpec(index=4, block_count=3, return_features=False) | |
head = resnet.ResNetHead( | |
block_module=config.MODEL.RESNETS.TRANS_FUNC, | |
stages=(stage,), | |
num_groups=config.MODEL.RESNETS.NUM_GROUPS, | |
width_per_group=config.MODEL.RESNETS.WIDTH_PER_GROUP, | |
stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1, | |
stride_init=None, | |
res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS, | |
dilation=config.MODEL.RESNETS.RES5_DILATION, | |
) | |
self.pooler = pooler | |
self.head = head | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
x = self.head(x) | |
return x | |
class FPN2MLPFeatureExtractor(nn.Module): | |
""" | |
Heads for FPN for classification | |
""" | |
def __init__(self, cfg): | |
super(FPN2MLPFeatureExtractor, self).__init__() | |
resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION | |
scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES | |
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution**2 | |
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM | |
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN | |
self.pooler = pooler | |
self.fc6 = make_fc(input_size, representation_size, use_gn) | |
self.fc7 = make_fc(representation_size, representation_size, use_gn) | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
x = x.view(x.size(0), -1) | |
x = F.relu(self.fc6(x)) | |
x = F.relu(self.fc7(x)) | |
return x | |
class FPNXconv1fcFeatureExtractor(nn.Module): | |
""" | |
Heads for FPN for classification | |
""" | |
def __init__(self, cfg): | |
super(FPNXconv1fcFeatureExtractor, self).__init__() | |
resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION | |
scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES | |
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
self.pooler = pooler | |
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN | |
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
conv_head_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM | |
num_stacked_convs = cfg.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS | |
dilation = cfg.MODEL.ROI_BOX_HEAD.DILATION | |
xconvs = [] | |
for ix in range(num_stacked_convs): | |
xconvs.append( | |
nn.Conv2d( | |
in_channels, | |
conv_head_dim, | |
kernel_size=3, | |
stride=1, | |
padding=dilation, | |
dilation=dilation, | |
bias=False if use_gn else True, | |
) | |
) | |
in_channels = conv_head_dim | |
if use_gn: | |
xconvs.append(group_norm(in_channels)) | |
xconvs.append(nn.ReLU(inplace=True)) | |
self.add_module("xconvs", nn.Sequential(*xconvs)) | |
for modules in [ | |
self.xconvs, | |
]: | |
for l in modules.modules(): | |
if isinstance(l, nn.Conv2d): | |
torch.nn.init.normal_(l.weight, std=0.01) | |
if not use_gn: | |
torch.nn.init.constant_(l.bias, 0) | |
input_size = conv_head_dim * resolution**2 | |
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM | |
self.fc6 = make_fc(input_size, representation_size, use_gn=False) | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
x = self.xconvs(x) | |
x = x.view(x.size(0), -1) | |
x = F.relu(self.fc6(x)) | |
return x | |
def make_roi_box_feature_extractor(cfg): | |
func = registry.ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR] | |
return func(cfg) | |