zdou0830's picture
desco
749745d
raw
history blame
No virus
1.21 kB
from torch import nn
from torch.nn import functional as F
from maskrcnn_benchmark import layers
class KeypointRCNNPredictor(nn.Module):
def __init__(self, cfg):
super(KeypointRCNNPredictor, self).__init__()
input_features = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS[-1]
num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES
deconv_kernel = 4
self.kps_score_lowres = layers.ConvTranspose2d(
input_features,
num_keypoints,
deconv_kernel,
stride=2,
padding=deconv_kernel // 2 - 1,
)
nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(self.kps_score_lowres.bias, 0)
self.up_scale = 2
def forward(self, x):
x = self.kps_score_lowres(x)
x = layers.interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False)
return x
_ROI_KEYPOINT_PREDICTOR = {"KeypointRCNNPredictor": KeypointRCNNPredictor}
def make_roi_keypoint_predictor(cfg):
func = _ROI_KEYPOINT_PREDICTOR[cfg.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR]
return func(cfg)