File size: 1,205 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from torch import nn
from torch.nn import functional as F

from maskrcnn_benchmark import layers


class KeypointRCNNPredictor(nn.Module):
    def __init__(self, cfg):
        super(KeypointRCNNPredictor, self).__init__()
        input_features = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS[-1]
        num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES
        deconv_kernel = 4
        self.kps_score_lowres = layers.ConvTranspose2d(
            input_features,
            num_keypoints,
            deconv_kernel,
            stride=2,
            padding=deconv_kernel // 2 - 1,
        )
        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
        nn.init.constant_(self.kps_score_lowres.bias, 0)
        self.up_scale = 2

    def forward(self, x):
        x = self.kps_score_lowres(x)
        x = layers.interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False)
        return x


_ROI_KEYPOINT_PREDICTOR = {"KeypointRCNNPredictor": KeypointRCNNPredictor}


def make_roi_keypoint_predictor(cfg):
    func = _ROI_KEYPOINT_PREDICTOR[cfg.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR]
    return func(cfg)