Spaces:
Sleeping
Sleeping
from torch import nn | |
from torch.nn import functional as F | |
from maskrcnn_benchmark.modeling.poolers import Pooler | |
from maskrcnn_benchmark.layers import Conv2d | |
from maskrcnn_benchmark.layers import ConvTranspose2d | |
class KeypointRCNNFeatureExtractor(nn.Module): | |
def __init__(self, cfg): | |
super(KeypointRCNNFeatureExtractor, self).__init__() | |
resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION | |
scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES | |
sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
self.pooler = pooler | |
input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS | |
next_feature = input_features | |
self.blocks = [] | |
for layer_idx, layer_features in enumerate(layers, 1): | |
layer_name = "conv_fcn{}".format(layer_idx) | |
module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1) | |
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") | |
nn.init.constant_(module.bias, 0) | |
self.add_module(layer_name, module) | |
next_feature = layer_features | |
self.blocks.append(layer_name) | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
for layer_name in self.blocks: | |
x = F.relu(getattr(self, layer_name)(x)) | |
return x | |
class KeypointRCNNFeature2XZoomExtractor(nn.Module): | |
def __init__(self, cfg): | |
super(KeypointRCNNFeature2XZoomExtractor, self).__init__() | |
resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION | |
scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES | |
sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO | |
pooler = Pooler( | |
output_size=(resolution, resolution), | |
scales=scales, | |
sampling_ratio=sampling_ratio, | |
) | |
self.pooler = pooler | |
input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS | |
next_feature = input_features | |
self.blocks = [] | |
for layer_idx, layer_features in enumerate(layers, 1): | |
layer_name = "conv_fcn{}".format(layer_idx) | |
module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1) | |
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") | |
nn.init.constant_(module.bias, 0) | |
self.add_module(layer_name, module) | |
if layer_idx == len(layers) // 2: | |
deconv_kernel = 4 | |
kps_upsacle = ConvTranspose2d( | |
layer_features, layer_features, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1 | |
) | |
nn.init.kaiming_normal_(kps_upsacle.weight, mode="fan_out", nonlinearity="relu") | |
nn.init.constant_(kps_upsacle.bias, 0) | |
self.add_module("conv_fcn_upscale", kps_upsacle) | |
self.blocks.append("conv_fcn_upscale") | |
next_feature = layer_features | |
self.blocks.append(layer_name) | |
def forward(self, x, proposals): | |
x = self.pooler(x, proposals) | |
for layer_name in self.blocks: | |
x = F.relu(getattr(self, layer_name)(x)) | |
return x | |
_ROI_KEYPOINT_FEATURE_EXTRACTORS = { | |
"KeypointRCNNFeatureExtractor": KeypointRCNNFeatureExtractor, | |
"KeypointRCNNFeature2XZoomExtractor": KeypointRCNNFeature2XZoomExtractor, | |
} | |
def make_roi_keypoint_feature_extractor(cfg): | |
func = _ROI_KEYPOINT_FEATURE_EXTRACTORS[cfg.MODEL.ROI_KEYPOINT_HEAD.FEATURE_EXTRACTOR] | |
return func(cfg) | |