| | |
| |
|
| | from typing import Optional |
| | from torch import nn |
| |
|
| | from detectron2.config import CfgNode |
| |
|
| | from .cse.embedder import Embedder |
| | from .filter import DensePoseDataFilter |
| |
|
| |
|
| | def build_densepose_predictor(cfg: CfgNode, input_channels: int): |
| | """ |
| | Create an instance of DensePose predictor based on configuration options. |
| | |
| | Args: |
| | cfg (CfgNode): configuration options |
| | input_channels (int): input tensor size along the channel dimension |
| | Return: |
| | An instance of DensePose predictor |
| | """ |
| | from .predictors import DENSEPOSE_PREDICTOR_REGISTRY |
| |
|
| | predictor_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME |
| | return DENSEPOSE_PREDICTOR_REGISTRY.get(predictor_name)(cfg, input_channels) |
| |
|
| |
|
| | def build_densepose_data_filter(cfg: CfgNode): |
| | """ |
| | Build DensePose data filter which selects data for training |
| | |
| | Args: |
| | cfg (CfgNode): configuration options |
| | |
| | Return: |
| | Callable: list(Tensor), list(Instances) -> list(Tensor), list(Instances) |
| | An instance of DensePose filter, which takes feature tensors and proposals |
| | as an input and returns filtered features and proposals |
| | """ |
| | dp_filter = DensePoseDataFilter(cfg) |
| | return dp_filter |
| |
|
| |
|
| | def build_densepose_head(cfg: CfgNode, input_channels: int): |
| | """ |
| | Build DensePose head based on configurations options |
| | |
| | Args: |
| | cfg (CfgNode): configuration options |
| | input_channels (int): input tensor size along the channel dimension |
| | Return: |
| | An instance of DensePose head |
| | """ |
| | from .roi_heads.registry import ROI_DENSEPOSE_HEAD_REGISTRY |
| |
|
| | head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME |
| | return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels) |
| |
|
| |
|
| | def build_densepose_losses(cfg: CfgNode): |
| | """ |
| | Build DensePose loss based on configurations options |
| | |
| | Args: |
| | cfg (CfgNode): configuration options |
| | Return: |
| | An instance of DensePose loss |
| | """ |
| | from .losses import DENSEPOSE_LOSS_REGISTRY |
| |
|
| | loss_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME |
| | return DENSEPOSE_LOSS_REGISTRY.get(loss_name)(cfg) |
| |
|
| |
|
| | def build_densepose_embedder(cfg: CfgNode) -> Optional[nn.Module]: |
| | """ |
| | Build embedder used to embed mesh vertices into an embedding space. |
| | Embedder contains sub-embedders, one for each mesh ID. |
| | |
| | Args: |
| | cfg (cfgNode): configuration options |
| | Return: |
| | Embedding module |
| | """ |
| | if cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS: |
| | return Embedder(cfg) |
| | return None |
| |
|