# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from torch import nn from torch.nn import functional as F from .hourglass import Hourglass from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor from maskrcnn_benchmark.modeling.poolers import Pooler from maskrcnn_benchmark.layers import Conv2d from maskrcnn_benchmark.modeling.make_layers import make_conv3x3 class MaskRCNNFPNFeatureExtractor(nn.Module): """ Heads for FPN for classification """ def __init__(self, cfg): """ Arguments: num_classes (int): number of output classes input_size (int): number of channels of the input once it's flattened representation_size (int): size of the intermediate representation """ super(MaskRCNNFPNFeatureExtractor, self).__init__() resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO pooler = Pooler( output_size=(resolution, resolution), scales=scales, sampling_ratio=sampling_ratio, ) input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS self.pooler = pooler use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS dilation = cfg.MODEL.ROI_MASK_HEAD.DILATION next_feature = input_size self.blocks = [] for layer_idx, layer_features in enumerate(layers, 1): layer_name = "mask_fcn{}".format(layer_idx) module = make_conv3x3(next_feature, layer_features, dilation=dilation, stride=1, use_gn=use_gn) self.add_module(layer_name, module) next_feature = layer_features self.blocks.append(layer_name) def forward(self, x, proposals): x = self.pooler(x, proposals) for layer_name in self.blocks: x = F.relu(getattr(self, layer_name)(x)) return x class HourglassFPNFeatureExtractor(nn.Module): """ Heads for FPN for classification """ def __init__(self, cfg): """ Arguments: num_classes (int): number of output classes input_size (int): number of channels of the input once it's flattened representation_size (int): size of the intermediate representation """ super(HourglassFPNFeatureExtractor, self).__init__() resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO pooler = Pooler( output_size=(resolution, resolution), scales=scales, sampling_ratio=sampling_ratio, ) input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS self.pooler = pooler use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS scale = cfg.MODEL.ROI_MASK_HEAD.HG_SCALE assert input_size == layers[0] self.blocks = [] for layer_idx, layer_features in enumerate(layers, 1): layer_name = "mask_hg{}".format(layer_idx) module = Hourglass(scale, layer_features, gn=use_gn) self.add_module(layer_name, module) self.blocks.append(layer_name) def forward(self, x, proposals): x = self.pooler(x, proposals) for layer_name in self.blocks: x = F.relu(getattr(self, layer_name)(x)) return x _ROI_MASK_FEATURE_EXTRACTORS = { "ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor, "MaskRCNNFPNFeatureExtractor": MaskRCNNFPNFeatureExtractor, "HourglassFPNFeatureExtractor": HourglassFPNFeatureExtractor, } def make_roi_mask_feature_extractor(cfg): func = _ROI_MASK_FEATURE_EXTRACTORS[cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR] return func(cfg)