Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
from torch import nn | |
from torch.nn import functional as F | |
from .hourglass import Hourglass | |
from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor | |
from maskrcnn_benchmark.modeling.poolers import Pooler | |
from maskrcnn_benchmark.layers import Conv2d | |
from maskrcnn_benchmark.modeling.make_layers import make_conv3x3 | |
class MaskRCNNFPNFeatureExtractor(nn.Module): | |
""" | |
Heads for FPN for classification | |
""" | |
def __init__(self, cfg): | |
""" | |
Arguments: | |
num_classes (int): number of output classes | |
input_size (int): number of channels of the input once it's flattened | |
representation_size (int): size of the intermediate representation | |
""" | |
super(MaskRCNNFPNFeatureExtractor, self).__init__() | |
resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION | |
scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES | |
sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
self.pooler = pooler | |
use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN | |
layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS | |
dilation = cfg.MODEL.ROI_MASK_HEAD.DILATION | |
next_feature = input_size | |
self.blocks = [] | |
for layer_idx, layer_features in enumerate(layers, 1): | |
layer_name = "mask_fcn{}".format(layer_idx) | |
module = make_conv3x3(next_feature, layer_features, dilation=dilation, stride=1, use_gn=use_gn) | |
self.add_module(layer_name, module) | |
next_feature = layer_features | |
self.blocks.append(layer_name) | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
for layer_name in self.blocks: | |
x = F.relu(getattr(self, layer_name)(x)) | |
return x | |
class HourglassFPNFeatureExtractor(nn.Module): | |
""" | |
Heads for FPN for classification | |
""" | |
def __init__(self, cfg): | |
""" | |
Arguments: | |
num_classes (int): number of output classes | |
input_size (int): number of channels of the input once it's flattened | |
representation_size (int): size of the intermediate representation | |
""" | |
super(HourglassFPNFeatureExtractor, self).__init__() | |
resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION | |
scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES | |
sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
self.pooler = pooler | |
use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN | |
layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS | |
scale = cfg.MODEL.ROI_MASK_HEAD.HG_SCALE | |
assert input_size == layers[0] | |
self.blocks = [] | |
for layer_idx, layer_features in enumerate(layers, 1): | |
layer_name = "mask_hg{}".format(layer_idx) | |
module = Hourglass(scale, layer_features, gn=use_gn) | |
self.add_module(layer_name, module) | |
self.blocks.append(layer_name) | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
for layer_name in self.blocks: | |
x = F.relu(getattr(self, layer_name)(x)) | |
return x | |
_ROI_MASK_FEATURE_EXTRACTORS = { | |
"ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor, | |
"MaskRCNNFPNFeatureExtractor": MaskRCNNFPNFeatureExtractor, | |
"HourglassFPNFeatureExtractor": HourglassFPNFeatureExtractor, | |
} | |
def make_roi_mask_feature_extractor(cfg): | |
func = _ROI_MASK_FEATURE_EXTRACTORS[cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR] | |
return func(cfg) | |