| from mmpose.models import HeatmapHead |
| from mmpose.registry import MODELS |
|
|
|
|
| |
| @MODELS.register_module() |
| class ExampleHead(HeatmapHead): |
| """Implements an example head. |
| |
| Implement the model head just like a normal pytorch module. |
| """ |
|
|
| def __init__(self, **kwargs) -> None: |
| print('Initializing ExampleHead...') |
| super().__init__(**kwargs) |
|
|
| def forward(self, feats): |
| """Forward the network. The input is multi scale feature maps and the |
| output is the coordinates. |
| |
| Args: |
| feats (Tuple[Tensor]): Multi scale feature maps. |
| |
| Returns: |
| Tensor: output coordinates or heatmaps. |
| """ |
| return super().forward(feats) |
|
|
| def predict(self, feats, batch_data_samples, test_cfg={}): |
| """Predict results from outputs. The behaviour of head during testing |
| should be defined in this function. |
| |
| Args: |
| feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage |
| features (or multiple multi-stage features in TTA) |
| batch_data_samples (List[:obj:`PoseDataSample`]): A list of |
| data samples for instances in a batch |
| test_cfg (dict): The runtime config for testing process. Defaults |
| to {} |
| |
| Returns: |
| Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If |
| ``test_cfg['output_heatmap']==True``, return both pose and heatmap |
| prediction; otherwise only return the pose prediction. |
| |
| The pose prediction is a list of ``InstanceData``, each contains |
| the following fields: |
| |
| - keypoints (np.ndarray): predicted keypoint coordinates in |
| shape (num_instances, K, D) where K is the keypoint number |
| and D is the keypoint dimension |
| - keypoint_scores (np.ndarray): predicted keypoint scores in |
| shape (num_instances, K) |
| |
| The heatmap prediction is a list of ``PixelData``, each contains |
| the following fields: |
| |
| - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) |
| """ |
| return super().predict(feats, batch_data_samples, test_cfg) |
|
|
| def loss(self, feats, batch_data_samples, train_cfg={}) -> dict: |
| """Calculate losses from a batch of inputs and data samples. The |
| behaviour of head during training should be defined in this function. |
| |
| Args: |
| feats (Tuple[Tensor]): The multi-stage features |
| batch_data_samples (List[:obj:`PoseDataSample`]): A list of |
| data samples for instances in a batch |
| train_cfg (dict): The runtime config for training process. |
| Defaults to {} |
| |
| Returns: |
| dict: A dictionary of losses. |
| """ |
|
|
| return super().loss(feats, batch_data_samples, train_cfg) |
|
|