desco / maskrcnn_benchmark /modeling /roi_heads /keypoint_head /roi_keypoint_feature_extractors.py
zdou0830's picture
desco
749745d
raw
history blame
No virus
3.91 kB
from torch import nn
from torch.nn import functional as F
from maskrcnn_benchmark.modeling.poolers import Pooler
from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.layers import ConvTranspose2d
class KeypointRCNNFeatureExtractor(nn.Module):
def __init__(self, cfg):
super(KeypointRCNNFeatureExtractor, self).__init__()
resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
self.pooler = pooler
input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS
layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS
next_feature = input_features
self.blocks = []
for layer_idx, layer_features in enumerate(layers, 1):
layer_name = "conv_fcn{}".format(layer_idx)
module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1)
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(module.bias, 0)
self.add_module(layer_name, module)
next_feature = layer_features
self.blocks.append(layer_name)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
for layer_name in self.blocks:
x = F.relu(getattr(self, layer_name)(x))
return x
class KeypointRCNNFeature2XZoomExtractor(nn.Module):
def __init__(self, cfg):
super(KeypointRCNNFeature2XZoomExtractor, self).__init__()
resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
self.pooler = pooler
input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS
layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS
next_feature = input_features
self.blocks = []
for layer_idx, layer_features in enumerate(layers, 1):
layer_name = "conv_fcn{}".format(layer_idx)
module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1)
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(module.bias, 0)
self.add_module(layer_name, module)
if layer_idx == len(layers) // 2:
deconv_kernel = 4
kps_upsacle = ConvTranspose2d(
layer_features, layer_features, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1
)
nn.init.kaiming_normal_(kps_upsacle.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(kps_upsacle.bias, 0)
self.add_module("conv_fcn_upscale", kps_upsacle)
self.blocks.append("conv_fcn_upscale")
next_feature = layer_features
self.blocks.append(layer_name)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
for layer_name in self.blocks:
x = F.relu(getattr(self, layer_name)(x))
return x
_ROI_KEYPOINT_FEATURE_EXTRACTORS = {
"KeypointRCNNFeatureExtractor": KeypointRCNNFeatureExtractor,
"KeypointRCNNFeature2XZoomExtractor": KeypointRCNNFeature2XZoomExtractor,
}
def make_roi_keypoint_feature_extractor(cfg):
func = _ROI_KEYPOINT_FEATURE_EXTRACTORS[cfg.MODEL.ROI_KEYPOINT_HEAD.FEATURE_EXTRACTOR]
return func(cfg)