# Copyright (c) Facebook, Inc. and its affiliates. from typing import Optional from torch import nn from detectron2.config import CfgNode from .cse.embedder import Embedder from .filter import DensePoseDataFilter def build_densepose_predictor(cfg: CfgNode, input_channels: int): """ Create an instance of DensePose predictor based on configuration options. Args: cfg (CfgNode): configuration options input_channels (int): input tensor size along the channel dimension Return: An instance of DensePose predictor """ from .predictors import DENSEPOSE_PREDICTOR_REGISTRY predictor_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME return DENSEPOSE_PREDICTOR_REGISTRY.get(predictor_name)(cfg, input_channels) def build_densepose_data_filter(cfg: CfgNode): """ Build DensePose data filter which selects data for training Args: cfg (CfgNode): configuration options Return: Callable: list(Tensor), list(Instances) -> list(Tensor), list(Instances) An instance of DensePose filter, which takes feature tensors and proposals as an input and returns filtered features and proposals """ dp_filter = DensePoseDataFilter(cfg) return dp_filter def build_densepose_head(cfg: CfgNode, input_channels: int): """ Build DensePose head based on configurations options Args: cfg (CfgNode): configuration options input_channels (int): input tensor size along the channel dimension Return: An instance of DensePose head """ from .roi_heads.registry import ROI_DENSEPOSE_HEAD_REGISTRY head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels) def build_densepose_losses(cfg: CfgNode): """ Build DensePose loss based on configurations options Args: cfg (CfgNode): configuration options Return: An instance of DensePose loss """ from .losses import DENSEPOSE_LOSS_REGISTRY loss_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME return DENSEPOSE_LOSS_REGISTRY.get(loss_name)(cfg) def build_densepose_embedder(cfg: CfgNode) -> Optional[nn.Module]: """ Build embedder used to embed mesh vertices into an embedding space. Embedder contains sub-embedders, one for each mesh ID. Args: cfg (cfgNode): configuration options Return: Embedding module """ if cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS: return Embedder(cfg) return None