Spaces:
Sleeping
Sleeping
from torch import nn | |
from torch.nn import functional as F | |
from maskrcnn_benchmark import layers | |
class KeypointRCNNPredictor(nn.Module): | |
def __init__(self, cfg): | |
super(KeypointRCNNPredictor, self).__init__() | |
input_features = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS[-1] | |
num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES | |
deconv_kernel = 4 | |
self.kps_score_lowres = layers.ConvTranspose2d( | |
input_features, | |
num_keypoints, | |
deconv_kernel, | |
stride=2, | |
padding=deconv_kernel // 2 - 1, | |
) | |
nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu") | |
nn.init.constant_(self.kps_score_lowres.bias, 0) | |
self.up_scale = 2 | |
def forward(self, x): | |
x = self.kps_score_lowres(x) | |
x = layers.interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False) | |
return x | |
_ROI_KEYPOINT_PREDICTOR = {"KeypointRCNNPredictor": KeypointRCNNPredictor} | |
def make_roi_keypoint_predictor(cfg): | |
func = _ROI_KEYPOINT_PREDICTOR[cfg.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR] | |
return func(cfg) | |