# Copyright (c) Facebook, Inc. and its affiliates. # pyre-unsafe import torch from torch import nn from detectron2.config import CfgNode from detectron2.layers import ConvTranspose2d, interpolate from ...structures import DensePoseChartPredictorOutput from ..utils import initialize_module_params from .registry import DENSEPOSE_PREDICTOR_REGISTRY @DENSEPOSE_PREDICTOR_REGISTRY.register() class DensePoseChartPredictor(nn.Module): """ Predictor (last layers of a DensePose model) that takes DensePose head outputs as an input and produces 4 tensors which represent DensePose results for predefined body parts (patches / charts): * coarse segmentation, a tensor of shape [N, K, Hout, Wout] * fine segmentation, a tensor of shape [N, C, Hout, Wout] * U coordinates, a tensor of shape [N, C, Hout, Wout] * V coordinates, a tensor of shape [N, C, Hout, Wout] where - N is the number of instances - K is the number of coarse segmentation channels ( 2 = foreground / background, 15 = one of 14 body parts / background) - C is the number of fine segmentation channels ( 24 fine body parts / background) - Hout and Wout are height and width of predictions """ def __init__(self, cfg: CfgNode, input_channels: int): """ Initialize predictor using configuration options Args: cfg (CfgNode): configuration options input_channels (int): input tensor size along the channel dimension """ super().__init__() dim_in = input_channels n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS dim_out_patches = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1 kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL # coarse segmentation self.ann_index_lowres = ConvTranspose2d( dim_in, n_segm_chan, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) ) # fine segmentation self.index_uv_lowres = ConvTranspose2d( dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) ) # U self.u_lowres = ConvTranspose2d( dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) ) # V self.v_lowres = ConvTranspose2d( dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) ) self.scale_factor = cfg.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE initialize_module_params(self) def interp2d(self, tensor_nchw: torch.Tensor): """ Bilinear interpolation method to be used for upscaling Args: tensor_nchw (tensor): tensor of shape (N, C, H, W) Return: tensor of shape (N, C, Hout, Wout), where Hout and Wout are computed by applying the scale factor to H and W """ return interpolate( tensor_nchw, scale_factor=self.scale_factor, mode="bilinear", align_corners=False ) def forward(self, head_outputs: torch.Tensor): """ Perform forward step on DensePose head outputs Args: head_outputs (tensor): DensePose head outputs, tensor of shape [N, D, H, W] Return: An instance of DensePoseChartPredictorOutput """ return DensePoseChartPredictorOutput( coarse_segm=self.interp2d(self.ann_index_lowres(head_outputs)), fine_segm=self.interp2d(self.index_uv_lowres(head_outputs)), u=self.interp2d(self.u_lowres(head_outputs)), v=self.interp2d(self.v_lowres(head_outputs)), )