desco / maskrcnn_benchmark /modeling /roi_heads /mask_head /roi_mask_feature_extractors.py
zdou0830's picture
desco
749745d
raw
history blame
No virus
4.11 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from torch import nn
from torch.nn import functional as F
from .hourglass import Hourglass
from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor
from maskrcnn_benchmark.modeling.poolers import Pooler
from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.modeling.make_layers import make_conv3x3
class MaskRCNNFPNFeatureExtractor(nn.Module):
"""
Heads for FPN for classification
"""
def __init__(self, cfg):
"""
Arguments:
num_classes (int): number of output classes
input_size (int): number of channels of the input once it's flattened
representation_size (int): size of the intermediate representation
"""
super(MaskRCNNFPNFeatureExtractor, self).__init__()
resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS
self.pooler = pooler
use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN
layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS
dilation = cfg.MODEL.ROI_MASK_HEAD.DILATION
next_feature = input_size
self.blocks = []
for layer_idx, layer_features in enumerate(layers, 1):
layer_name = "mask_fcn{}".format(layer_idx)
module = make_conv3x3(next_feature, layer_features, dilation=dilation, stride=1, use_gn=use_gn)
self.add_module(layer_name, module)
next_feature = layer_features
self.blocks.append(layer_name)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
for layer_name in self.blocks:
x = F.relu(getattr(self, layer_name)(x))
return x
class HourglassFPNFeatureExtractor(nn.Module):
"""
Heads for FPN for classification
"""
def __init__(self, cfg):
"""
Arguments:
num_classes (int): number of output classes
input_size (int): number of channels of the input once it's flattened
representation_size (int): size of the intermediate representation
"""
super(HourglassFPNFeatureExtractor, self).__init__()
resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS
self.pooler = pooler
use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN
layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS
scale = cfg.MODEL.ROI_MASK_HEAD.HG_SCALE
assert input_size == layers[0]
self.blocks = []
for layer_idx, layer_features in enumerate(layers, 1):
layer_name = "mask_hg{}".format(layer_idx)
module = Hourglass(scale, layer_features, gn=use_gn)
self.add_module(layer_name, module)
self.blocks.append(layer_name)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
for layer_name in self.blocks:
x = F.relu(getattr(self, layer_name)(x))
return x
_ROI_MASK_FEATURE_EXTRACTORS = {
"ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor,
"MaskRCNNFPNFeatureExtractor": MaskRCNNFPNFeatureExtractor,
"HourglassFPNFeatureExtractor": HourglassFPNFeatureExtractor,
}
def make_roi_mask_feature_extractor(cfg):
func = _ROI_MASK_FEATURE_EXTRACTORS[cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR]
return func(cfg)