Spaces:
Running
on
Zero
Running
on
Zero
| # ---------------------------------------------------------------------------- | |
| # Adapted from https://github.com/IDEA-Research/ED-Pose/ \ | |
| # tree/master/models/edpose | |
| # Original licence: IDEA License 1.0 | |
| # ---------------------------------------------------------------------------- | |
| import copy | |
| import math | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from mmcv.ops import MultiScaleDeformableAttention | |
| from mmengine.model import BaseModule, ModuleList, constant_init | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor, nn | |
| from mmpose.models.utils import inverse_sigmoid | |
| from mmpose.registry import KEYPOINT_CODECS, MODELS | |
| from mmpose.utils.tensor_utils import to_numpy | |
| from mmpose.utils.typing import (ConfigType, Features, OptConfigType, | |
| OptSampleList, Predictions) | |
| from .base_transformer_head import TransformerHead | |
| from .transformers.deformable_detr_layers import ( | |
| DeformableDetrTransformerDecoderLayer, DeformableDetrTransformerEncoder) | |
| from .transformers.utils import FFN, PositionEmbeddingSineHW | |
| class EDPoseDecoder(BaseModule): | |
| """Transformer decoder of EDPose: `Explicit Box Detection Unifies End-to- | |
| End Multi-Person Pose Estimation. | |
| Args: | |
| layer_cfg (ConfigDict): the config of each encoder | |
| layer. All the layers will share the same config. | |
| num_layers (int): Number of decoder layers. | |
| return_intermediate (bool, optional): Whether to return outputs of | |
| intermediate layers. Defaults to `True`. | |
| embed_dims (int): Dims of embed. | |
| query_dim (int): Dims of queries. | |
| num_feature_levels (int): Number of feature levels. | |
| num_box_decoder_layers (int): Number of box decoder layers. | |
| num_keypoints (int): Number of datasets' body keypoints. | |
| num_dn (int): Number of denosing points. | |
| num_group (int): Number of decoder layers. | |
| """ | |
| def __init__(self, | |
| layer_cfg, | |
| num_layers, | |
| return_intermediate, | |
| embed_dims: int = 256, | |
| query_dim=4, | |
| num_feature_levels=1, | |
| num_box_decoder_layers=2, | |
| num_keypoints=17, | |
| num_dn=100, | |
| num_group=100): | |
| super().__init__() | |
| self.layer_cfg = layer_cfg | |
| self.num_layers = num_layers | |
| self.embed_dims = embed_dims | |
| assert return_intermediate, 'support return_intermediate only' | |
| self.return_intermediate = return_intermediate | |
| assert query_dim in [ | |
| 2, 4 | |
| ], 'query_dim should be 2/4 but {}'.format(query_dim) | |
| self.query_dim = query_dim | |
| self.num_feature_levels = num_feature_levels | |
| self.layers = ModuleList([ | |
| DeformableDetrTransformerDecoderLayer(**self.layer_cfg) | |
| for _ in range(self.num_layers) | |
| ]) | |
| self.norm = nn.LayerNorm(self.embed_dims) | |
| self.ref_point_head = FFN(self.query_dim // 2 * self.embed_dims, | |
| self.embed_dims, self.embed_dims, 2) | |
| self.num_keypoints = num_keypoints | |
| self.query_scale = None | |
| self.bbox_embed = None | |
| self.class_embed = None | |
| self.pose_embed = None | |
| self.pose_hw_embed = None | |
| self.num_box_decoder_layers = num_box_decoder_layers | |
| self.box_pred_damping = None | |
| self.num_group = num_group | |
| self.rm_detach = None | |
| self.num_dn = num_dn | |
| self.hw = nn.Embedding(self.num_keypoints, 2) | |
| self.keypoint_embed = nn.Embedding(self.num_keypoints, embed_dims) | |
| self.kpt_index = [ | |
| x for x in range(self.num_group * (self.num_keypoints + 1)) | |
| if x % (self.num_keypoints + 1) != 0 | |
| ] | |
| def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, | |
| reference_points: Tensor, spatial_shapes: Tensor, | |
| level_start_index: Tensor, valid_ratios: Tensor, | |
| humandet_attn_mask: Tensor, human2pose_attn_mask: Tensor, | |
| **kwargs) -> Tuple[Tensor]: | |
| """Forward function of decoder | |
| Args: | |
| query (Tensor): The input queries, has shape (bs, num_queries, | |
| dim). | |
| value (Tensor): The input values, has shape (bs, num_value, dim). | |
| key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` | |
| input. ByteTensor, has shape (bs, num_value). | |
| reference_points (Tensor): The initial reference, has shape | |
| (bs, num_queries, 4) with the last dimension arranged as | |
| (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has | |
| shape (bs, num_queries, 2) with the last dimension arranged | |
| as (cx, cy). | |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
| has shape (num_levels, 2), last dimension represents (h, w). | |
| level_start_index (Tensor): The start index of each level. | |
| A tensor has shape (num_levels, ) and can be represented | |
| as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. | |
| valid_ratios (Tensor): The ratios of the valid width and the valid | |
| height relative to the width and the height of features in all | |
| levels, has shape (bs, num_levels, 2). | |
| reg_branches: (obj:`nn.ModuleList`, optional): Used for refining | |
| the regression results. | |
| Returns: | |
| Tuple[Tuple[Tensor]]: Outputs of Deformable Transformer Decoder. | |
| - output (Tuple[Tensor]): Output embeddings of the last decoder, | |
| each has shape (num_decoder_layers, num_queries, bs, embed_dims) | |
| - reference_points (Tensor): The reference of the last decoder | |
| layer, each has shape (num_decoder_layers, bs, num_queries, 4). | |
| The coordinates are arranged as (cx, cy, w, h) | |
| """ | |
| output = query | |
| attn_mask = humandet_attn_mask | |
| intermediate = [] | |
| intermediate_reference_points = [reference_points] | |
| effect_num_dn = self.num_dn if self.training else 0 | |
| inter_select_number = self.num_group | |
| for layer_id, layer in enumerate(self.layers): | |
| if reference_points.shape[-1] == 4: | |
| reference_points_input = \ | |
| reference_points[:, :, None] * \ | |
| torch.cat([valid_ratios, valid_ratios], -1)[None, :] | |
| else: | |
| assert reference_points.shape[-1] == 2 | |
| reference_points_input = \ | |
| reference_points[:, :, None] * \ | |
| valid_ratios[None, :] | |
| query_sine_embed = self.get_proposal_pos_embed( | |
| reference_points_input[:, :, 0, :]) # nq, bs, 256*2 | |
| query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 | |
| output = layer( | |
| output.transpose(0, 1), | |
| query_pos=query_pos.transpose(0, 1), | |
| value=value.transpose(0, 1), | |
| key_padding_mask=key_padding_mask, | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index, | |
| valid_ratios=valid_ratios, | |
| reference_points=reference_points_input.transpose( | |
| 0, 1).contiguous(), | |
| self_attn_mask=attn_mask, | |
| **kwargs) | |
| output = output.transpose(0, 1) | |
| intermediate.append(self.norm(output)) | |
| # human update | |
| if layer_id < self.num_box_decoder_layers: | |
| delta_unsig = self.bbox_embed[layer_id](output) | |
| new_reference_points = delta_unsig + inverse_sigmoid( | |
| reference_points) | |
| new_reference_points = new_reference_points.sigmoid() | |
| # query expansion | |
| if layer_id == self.num_box_decoder_layers - 1: | |
| dn_output = output[:effect_num_dn] | |
| dn_new_reference_points = new_reference_points[:effect_num_dn] | |
| class_unselected = self.class_embed[layer_id]( | |
| output)[effect_num_dn:] | |
| topk_proposals = torch.topk( | |
| class_unselected.max(-1)[0], inter_select_number, dim=0)[1] | |
| new_reference_points_for_box = torch.gather( | |
| new_reference_points[effect_num_dn:], 0, | |
| topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) | |
| new_output_for_box = torch.gather( | |
| output[effect_num_dn:], 0, | |
| topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) | |
| bs = new_output_for_box.shape[1] | |
| new_output_for_keypoint = new_output_for_box[:, None, :, :] \ | |
| + self.keypoint_embed.weight[None, :, None, :] | |
| if self.num_keypoints == 17: | |
| delta_xy = self.pose_embed[-1](new_output_for_keypoint)[ | |
| ..., :2] | |
| else: | |
| delta_xy = self.pose_embed[0](new_output_for_keypoint)[ | |
| ..., :2] | |
| keypoint_xy = (inverse_sigmoid( | |
| new_reference_points_for_box[..., :2][:, None]) + | |
| delta_xy).sigmoid() | |
| num_queries, _, bs, _ = keypoint_xy.shape | |
| keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze( | |
| -2).repeat(num_queries, 1, bs, 1).sigmoid() | |
| keypoint_wh = keypoint_wh_weight * \ | |
| new_reference_points_for_box[..., 2:][:, None] | |
| new_reference_points_for_keypoint = torch.cat( | |
| (keypoint_xy, keypoint_wh), dim=-1) | |
| new_reference_points = torch.cat( | |
| (new_reference_points_for_box.unsqueeze(1), | |
| new_reference_points_for_keypoint), | |
| dim=1).flatten(0, 1) | |
| output = torch.cat( | |
| (new_output_for_box.unsqueeze(1), new_output_for_keypoint), | |
| dim=1).flatten(0, 1) | |
| new_reference_points = torch.cat( | |
| (dn_new_reference_points, new_reference_points), dim=0) | |
| output = torch.cat((dn_output, output), dim=0) | |
| attn_mask = human2pose_attn_mask | |
| # human-to-keypoints update | |
| if layer_id >= self.num_box_decoder_layers: | |
| effect_num_dn = self.num_dn if self.training else 0 | |
| inter_select_number = self.num_group | |
| ref_before_sigmoid = inverse_sigmoid(reference_points) | |
| output_bbox_dn = output[:effect_num_dn] | |
| output_bbox_norm = output[effect_num_dn:][0::( | |
| self.num_keypoints + 1)] | |
| ref_before_sigmoid_bbox_dn = \ | |
| ref_before_sigmoid[:effect_num_dn] | |
| ref_before_sigmoid_bbox_norm = \ | |
| ref_before_sigmoid[effect_num_dn:][0::( | |
| self.num_keypoints + 1)] | |
| delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn) | |
| delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm) | |
| outputs_unsig_dn = delta_unsig_dn + ref_before_sigmoid_bbox_dn | |
| outputs_unsig_norm = delta_unsig_norm + \ | |
| ref_before_sigmoid_bbox_norm | |
| new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid() | |
| new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid( | |
| ) | |
| output_kpt = output[effect_num_dn:].index_select( | |
| 0, torch.tensor(self.kpt_index, device=output.device)) | |
| delta_xy_unsig = self.pose_embed[layer_id - | |
| self.num_box_decoder_layers]( | |
| output_kpt) | |
| outputs_unsig = ref_before_sigmoid[ | |
| effect_num_dn:].index_select( | |
| 0, torch.tensor(self.kpt_index, | |
| device=output.device)).clone() | |
| delta_hw_unsig = self.pose_hw_embed[ | |
| layer_id - self.num_box_decoder_layers]( | |
| output_kpt) | |
| outputs_unsig[..., :2] += delta_xy_unsig[..., :2] | |
| outputs_unsig[..., 2:] += delta_hw_unsig | |
| new_reference_points_for_keypoint = outputs_unsig.sigmoid() | |
| bs = new_reference_points_for_box_norm.shape[1] | |
| new_reference_points_norm = torch.cat( | |
| (new_reference_points_for_box_norm.unsqueeze(1), | |
| new_reference_points_for_keypoint.view( | |
| -1, self.num_keypoints, bs, 4)), | |
| dim=1).flatten(0, 1) | |
| new_reference_points = torch.cat( | |
| (new_reference_points_for_box_dn, | |
| new_reference_points_norm), | |
| dim=0) | |
| reference_points = new_reference_points.detach() | |
| intermediate_reference_points.append(reference_points) | |
| decoder_outputs = [itm_out.transpose(0, 1) for itm_out in intermediate] | |
| reference_points = [ | |
| itm_refpoint.transpose(0, 1) | |
| for itm_refpoint in intermediate_reference_points | |
| ] | |
| return decoder_outputs, reference_points | |
| def get_proposal_pos_embed(pos_tensor: Tensor, | |
| temperature: int = 10000, | |
| num_pos_feats: int = 128) -> Tensor: | |
| """Get the position embedding of the proposal. | |
| Args: | |
| pos_tensor (Tensor): Not normalized proposals, has shape | |
| (bs, num_queries, 4) with the last dimension arranged as | |
| (cx, cy, w, h). | |
| temperature (int, optional): The temperature used for scaling the | |
| position embedding. Defaults to 10000. | |
| num_pos_feats (int, optional): The feature dimension for each | |
| position along x, y, w, and h-axis. Note the final returned | |
| dimension for each position is 4 times of num_pos_feats. | |
| Default to 128. | |
| Returns: | |
| Tensor: The position embedding of proposal, has shape | |
| (bs, num_queries, num_pos_feats * 4), with the last dimension | |
| arranged as (cx, cy, w, h) | |
| """ | |
| scale = 2 * math.pi | |
| dim_t = torch.arange( | |
| num_pos_feats, dtype=torch.float32, device=pos_tensor.device) | |
| dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) | |
| x_embed = pos_tensor[:, :, 0] * scale | |
| y_embed = pos_tensor[:, :, 1] * scale | |
| pos_x = x_embed[:, :, None] / dim_t | |
| pos_y = y_embed[:, :, None] / dim_t | |
| pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| if pos_tensor.size(-1) == 2: | |
| pos = torch.cat((pos_y, pos_x), dim=2) | |
| elif pos_tensor.size(-1) == 4: | |
| w_embed = pos_tensor[:, :, 2] * scale | |
| pos_w = w_embed[:, :, None] / dim_t | |
| pos_w = torch.stack( | |
| (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| h_embed = pos_tensor[:, :, 3] * scale | |
| pos_h = h_embed[:, :, None] / dim_t | |
| pos_h = torch.stack( | |
| (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) | |
| else: | |
| raise ValueError('Unknown pos_tensor shape(-1):{}'.format( | |
| pos_tensor.size(-1))) | |
| return pos | |
| class EDPoseOutHead(BaseModule): | |
| """Final Head of EDPose: `Explicit Box Detection Unifies End-to-End Multi- | |
| Person Pose Estimation. | |
| Args: | |
| num_classes (int): The number of classes. | |
| num_keypoints (int): The number of datasets' body keypoints. | |
| num_queries (int): The number of queries. | |
| cls_no_bias (bool): Weather add the bias to class embed. | |
| embed_dims (int): The dims of embed. | |
| as_two_stage (bool, optional): Whether to generate the proposal | |
| from the outputs of encoder. Defaults to `False`. | |
| refine_queries_num (int): The number of refines queries after | |
| decoders. | |
| num_box_decoder_layers (int): The number of bbox decoder layer. | |
| num_group (int): The number of groups. | |
| num_pred_layer (int): The number of the prediction layers. | |
| Defaults to 6. | |
| dec_pred_class_embed_share (bool): Whether to share parameters | |
| for all the class prediction layers. Defaults to `False`. | |
| dec_pred_bbox_embed_share (bool): Whether to share parameters | |
| for all the bbox prediction layers. Defaults to `False`. | |
| dec_pred_pose_embed_share (bool): Whether to share parameters | |
| for all the pose prediction layers. Defaults to `False`. | |
| """ | |
| def __init__(self, | |
| num_classes, | |
| num_keypoints: int = 17, | |
| num_queries: int = 900, | |
| cls_no_bias: bool = False, | |
| embed_dims: int = 256, | |
| as_two_stage: bool = False, | |
| refine_queries_num: int = 100, | |
| num_box_decoder_layers: int = 2, | |
| num_group: int = 100, | |
| num_pred_layer: int = 6, | |
| dec_pred_class_embed_share: bool = False, | |
| dec_pred_bbox_embed_share: bool = False, | |
| dec_pred_pose_embed_share: bool = False, | |
| **kwargs): | |
| super().__init__() | |
| self.embed_dims = embed_dims | |
| self.as_two_stage = as_two_stage | |
| self.num_classes = num_classes | |
| self.refine_queries_num = refine_queries_num | |
| self.num_box_decoder_layers = num_box_decoder_layers | |
| self.num_keypoints = num_keypoints | |
| self.num_queries = num_queries | |
| # prepare pred layers | |
| self.dec_pred_class_embed_share = dec_pred_class_embed_share | |
| self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share | |
| self.dec_pred_pose_embed_share = dec_pred_pose_embed_share | |
| # prepare class & box embed | |
| _class_embed = nn.Linear( | |
| self.embed_dims, self.num_classes, bias=(not cls_no_bias)) | |
| if not cls_no_bias: | |
| prior_prob = 0.01 | |
| bias_value = -math.log((1 - prior_prob) / prior_prob) | |
| _class_embed.bias.data = torch.ones(self.num_classes) * bias_value | |
| _bbox_embed = FFN(self.embed_dims, self.embed_dims, 4, 3) | |
| _pose_embed = FFN(self.embed_dims, self.embed_dims, 2, 3) | |
| _pose_hw_embed = FFN(self.embed_dims, self.embed_dims, 2, 3) | |
| self.num_group = num_group | |
| if dec_pred_bbox_embed_share: | |
| box_embed_layerlist = [_bbox_embed for i in range(num_pred_layer)] | |
| else: | |
| box_embed_layerlist = [ | |
| copy.deepcopy(_bbox_embed) for i in range(num_pred_layer) | |
| ] | |
| if dec_pred_class_embed_share: | |
| class_embed_layerlist = [ | |
| _class_embed for i in range(num_pred_layer) | |
| ] | |
| else: | |
| class_embed_layerlist = [ | |
| copy.deepcopy(_class_embed) for i in range(num_pred_layer) | |
| ] | |
| if num_keypoints == 17: | |
| if dec_pred_pose_embed_share: | |
| pose_embed_layerlist = [ | |
| _pose_embed | |
| for i in range(num_pred_layer - num_box_decoder_layers + 1) | |
| ] | |
| else: | |
| pose_embed_layerlist = [ | |
| copy.deepcopy(_pose_embed) | |
| for i in range(num_pred_layer - num_box_decoder_layers + 1) | |
| ] | |
| else: | |
| if dec_pred_pose_embed_share: | |
| pose_embed_layerlist = [ | |
| _pose_embed | |
| for i in range(num_pred_layer - num_box_decoder_layers) | |
| ] | |
| else: | |
| pose_embed_layerlist = [ | |
| copy.deepcopy(_pose_embed) | |
| for i in range(num_pred_layer - num_box_decoder_layers) | |
| ] | |
| pose_hw_embed_layerlist = [ | |
| _pose_hw_embed | |
| for i in range(num_pred_layer - num_box_decoder_layers) | |
| ] | |
| self.bbox_embed = nn.ModuleList(box_embed_layerlist) | |
| self.class_embed = nn.ModuleList(class_embed_layerlist) | |
| self.pose_embed = nn.ModuleList(pose_embed_layerlist) | |
| self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) | |
| def init_weights(self) -> None: | |
| """Initialize weights of the Deformable DETR head.""" | |
| for m in self.bbox_embed: | |
| constant_init(m[-1], 0, bias=0) | |
| for m in self.pose_embed: | |
| constant_init(m[-1], 0, bias=0) | |
| def forward(self, hidden_states: List[Tensor], references: List[Tensor], | |
| mask_dict: Dict, hidden_states_enc: Tensor, | |
| referens_enc: Tensor, batch_data_samples) -> Dict: | |
| """Forward function. | |
| Args: | |
| hidden_states (Tensor): Hidden states output from each decoder | |
| layer, has shape (num_decoder_layers, bs, num_queries, dim). | |
| references (list[Tensor]): List of the reference from the decoder. | |
| Returns: | |
| tuple[Tensor]: results of head containing the following tensor. | |
| - pred_logits (Tensor): Outputs from the | |
| classification head, the socres of every bboxes. | |
| - pred_boxes (Tensor): The output boxes. | |
| - pred_keypoints (Tensor): The output keypoints. | |
| """ | |
| # update human boxes | |
| effec_dn_num = self.refine_queries_num if self.training else 0 | |
| outputs_coord_list = [] | |
| outputs_class = [] | |
| for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed, | |
| layer_hs) in enumerate( | |
| zip(references[:-1], self.bbox_embed, | |
| self.class_embed, hidden_states)): | |
| if dec_lid < self.num_box_decoder_layers: | |
| layer_delta_unsig = layer_bbox_embed(layer_hs) | |
| layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid( | |
| layer_ref_sig) | |
| layer_outputs_unsig = layer_outputs_unsig.sigmoid() | |
| layer_cls = layer_cls_embed(layer_hs) | |
| outputs_coord_list.append(layer_outputs_unsig) | |
| outputs_class.append(layer_cls) | |
| else: | |
| layer_hs_bbox_dn = layer_hs[:, :effec_dn_num, :] | |
| layer_hs_bbox_norm = \ | |
| layer_hs[:, effec_dn_num:, :][:, 0::( | |
| self.num_keypoints + 1), :] | |
| bs = layer_ref_sig.shape[0] | |
| ref_before_sigmoid_bbox_dn = \ | |
| layer_ref_sig[:, : effec_dn_num, :] | |
| ref_before_sigmoid_bbox_norm = \ | |
| layer_ref_sig[:, effec_dn_num:, :][:, 0::( | |
| self.num_keypoints + 1), :] | |
| layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn) | |
| layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm) | |
| layer_outputs_unsig_dn = layer_delta_unsig_dn + \ | |
| inverse_sigmoid(ref_before_sigmoid_bbox_dn) | |
| layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid() | |
| layer_outputs_unsig_norm = layer_delta_unsig_norm + \ | |
| inverse_sigmoid(ref_before_sigmoid_bbox_norm) | |
| layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid() | |
| layer_outputs_unsig = torch.cat( | |
| (layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1) | |
| layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn) | |
| layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm) | |
| layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
| outputs_class.append(layer_cls) | |
| outputs_coord_list.append(layer_outputs_unsig) | |
| # update keypoints boxes | |
| outputs_keypoints_list = [] | |
| kpt_index = [ | |
| x for x in range(self.num_group * (self.num_keypoints + 1)) | |
| if x % (self.num_keypoints + 1) != 0 | |
| ] | |
| for dec_lid, (layer_ref_sig, layer_hs) in enumerate( | |
| zip(references[:-1], hidden_states)): | |
| if dec_lid < self.num_box_decoder_layers: | |
| assert isinstance(layer_hs, torch.Tensor) | |
| bs = layer_hs.shape[0] | |
| layer_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, self.num_keypoints * 3)) | |
| outputs_keypoints_list.append(layer_res) | |
| else: | |
| bs = layer_ref_sig.shape[0] | |
| layer_hs_kpt = \ | |
| layer_hs[:, effec_dn_num:, :].index_select( | |
| 1, torch.tensor(kpt_index, device=layer_hs.device)) | |
| delta_xy_unsig = self.pose_embed[dec_lid - | |
| self.num_box_decoder_layers]( | |
| layer_hs_kpt) | |
| layer_ref_sig_kpt = \ | |
| layer_ref_sig[:, effec_dn_num:, :].index_select( | |
| 1, torch.tensor(kpt_index, device=layer_hs.device)) | |
| layer_outputs_unsig_keypoints = delta_xy_unsig + \ | |
| inverse_sigmoid(layer_ref_sig_kpt[..., :2]) | |
| vis_xy_unsig = torch.ones_like( | |
| layer_outputs_unsig_keypoints, | |
| device=layer_outputs_unsig_keypoints.device) | |
| xyv = torch.cat((layer_outputs_unsig_keypoints, | |
| vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
| dim=-1) | |
| xyv = xyv.sigmoid() | |
| layer_res = xyv.reshape( | |
| (bs, self.num_group, self.num_keypoints, 3)).flatten(2, 3) | |
| layer_res = self.keypoint_xyzxyz_to_xyxyzz(layer_res) | |
| outputs_keypoints_list.append(layer_res) | |
| dn_mask_dict = mask_dict | |
| if self.refine_queries_num > 0 and dn_mask_dict is not None: | |
| outputs_class, outputs_coord_list, outputs_keypoints_list = \ | |
| self.dn_post_process2( | |
| outputs_class, outputs_coord_list, | |
| outputs_keypoints_list, dn_mask_dict | |
| ) | |
| for _out_class, _out_bbox, _out_keypoint in zip( | |
| outputs_class, outputs_coord_list, outputs_keypoints_list): | |
| assert _out_class.shape[1] == \ | |
| _out_bbox.shape[1] == _out_keypoint.shape[1] | |
| return outputs_class[-1], outputs_coord_list[ | |
| -1], outputs_keypoints_list[-1] | |
| def keypoint_xyzxyz_to_xyxyzz(self, keypoints: torch.Tensor): | |
| """ | |
| Args: | |
| keypoints (torch.Tensor): ..., 51 | |
| """ | |
| res = torch.zeros_like(keypoints) | |
| num_points = keypoints.shape[-1] // 3 | |
| res[..., 0:2 * num_points:2] = keypoints[..., 0::3] | |
| res[..., 1:2 * num_points:2] = keypoints[..., 1::3] | |
| res[..., 2 * num_points:] = keypoints[..., 2::3] | |
| return res | |
| class EDPoseHead(TransformerHead): | |
| """Head introduced in `Explicit Box Detection Unifies End-to-End Multi- | |
| Person Pose Estimation`_ by J Yang1 et al (2023). The head is composed of | |
| Encoder, Decoder and Out_head. | |
| Code is modified from the `official github repo | |
| <https://github.com/IDEA-Research/ED-Pose>`_. | |
| More details can be found in the `paper | |
| <https://arxiv.org/pdf/2302.01593.pdf>`_ . | |
| Args: | |
| num_queries (int): Number of query in Transformer. | |
| num_feature_levels (int): Number of feature levels. Defaults to 4. | |
| num_keypoints (int): Number of keypoints. Defaults to 4. | |
| as_two_stage (bool, optional): Whether to generate the proposal | |
| from the outputs of encoder. Defaults to `False`. | |
| encoder (:obj:`ConfigDict` or dict, optional): Config of the | |
| Transformer encoder. Defaults to None. | |
| decoder (:obj:`ConfigDict` or dict, optional): Config of the | |
| Transformer decoder. Defaults to None. | |
| out_head (:obj:`ConfigDict` or dict, optional): Config for the | |
| bounding final out head module. Defaults to None. | |
| positional_encoding (:obj:`ConfigDict` or dict): Config for | |
| transformer position encoding. Defaults None. | |
| denosing_cfg (:obj:`ConfigDict` or dict, optional): Config of the | |
| human query denoising training strategy. | |
| data_decoder (:obj:`ConfigDict` or dict, optional): Config of the | |
| data decoder which transform the results from output space to | |
| input space. | |
| dec_pred_class_embed_share (bool): Whether to share the class embed | |
| layer. Default False. | |
| dec_pred_bbox_embed_share (bool): Whether to share the bbox embed | |
| layer. Default False. | |
| refine_queries_num (int): Number of refined human content queries | |
| and their position queries . | |
| two_stage_keep_all_tokens (bool): Whether to keep all tokens. | |
| """ | |
| def __init__(self, | |
| num_queries: int = 100, | |
| num_feature_levels: int = 4, | |
| num_keypoints: int = 17, | |
| as_two_stage: bool = False, | |
| encoder: OptConfigType = None, | |
| decoder: OptConfigType = None, | |
| out_head: OptConfigType = None, | |
| positional_encoding: OptConfigType = None, | |
| data_decoder: OptConfigType = None, | |
| denosing_cfg: OptConfigType = None, | |
| dec_pred_class_embed_share: bool = False, | |
| dec_pred_bbox_embed_share: bool = False, | |
| refine_queries_num: int = 100, | |
| two_stage_keep_all_tokens: bool = False) -> None: | |
| self.as_two_stage = as_two_stage | |
| self.num_feature_levels = num_feature_levels | |
| self.refine_queries_num = refine_queries_num | |
| self.dec_pred_class_embed_share = dec_pred_class_embed_share | |
| self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share | |
| self.two_stage_keep_all_tokens = two_stage_keep_all_tokens | |
| self.num_heads = decoder['layer_cfg']['self_attn_cfg']['num_heads'] | |
| self.num_group = decoder['num_group'] | |
| self.num_keypoints = num_keypoints | |
| self.denosing_cfg = denosing_cfg | |
| if data_decoder is not None: | |
| self.data_decoder = KEYPOINT_CODECS.build(data_decoder) | |
| else: | |
| self.data_decoder = None | |
| super().__init__( | |
| encoder=encoder, | |
| decoder=decoder, | |
| out_head=out_head, | |
| positional_encoding=positional_encoding, | |
| num_queries=num_queries) | |
| self.positional_encoding = PositionEmbeddingSineHW( | |
| **self.positional_encoding_cfg) | |
| self.encoder = DeformableDetrTransformerEncoder(**self.encoder_cfg) | |
| self.decoder = EDPoseDecoder( | |
| num_keypoints=num_keypoints, **self.decoder_cfg) | |
| self.out_head = EDPoseOutHead( | |
| num_keypoints=num_keypoints, | |
| as_two_stage=as_two_stage, | |
| refine_queries_num=refine_queries_num, | |
| **self.out_head_cfg, | |
| **self.decoder_cfg) | |
| self.embed_dims = self.encoder.embed_dims | |
| self.label_enc = nn.Embedding( | |
| self.denosing_cfg['dn_labelbook_size'] + 1, self.embed_dims) | |
| if not self.as_two_stage: | |
| self.query_embedding = nn.Embedding(self.num_queries, | |
| self.embed_dims) | |
| self.refpoint_embedding = nn.Embedding(self.num_queries, 4) | |
| self.level_embed = nn.Parameter( | |
| torch.Tensor(self.num_feature_levels, self.embed_dims)) | |
| self.decoder.bbox_embed = self.out_head.bbox_embed | |
| self.decoder.pose_embed = self.out_head.pose_embed | |
| self.decoder.pose_hw_embed = self.out_head.pose_hw_embed | |
| self.decoder.class_embed = self.out_head.class_embed | |
| if self.as_two_stage: | |
| self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) | |
| self.memory_trans_norm = nn.LayerNorm(self.embed_dims) | |
| if dec_pred_class_embed_share and dec_pred_bbox_embed_share: | |
| self.enc_out_bbox_embed = self.out_head.bbox_embed[0] | |
| else: | |
| self.enc_out_bbox_embed = copy.deepcopy( | |
| self.out_head.bbox_embed[0]) | |
| if dec_pred_class_embed_share and dec_pred_bbox_embed_share: | |
| self.enc_out_class_embed = self.out_head.class_embed[0] | |
| else: | |
| self.enc_out_class_embed = copy.deepcopy( | |
| self.out_head.class_embed[0]) | |
| def init_weights(self) -> None: | |
| """Initialize weights for Transformer and other components.""" | |
| super().init_weights() | |
| for coder in self.encoder, self.decoder: | |
| for p in coder.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| for m in self.modules(): | |
| if isinstance(m, MultiScaleDeformableAttention): | |
| m.init_weights() | |
| if self.as_two_stage: | |
| nn.init.xavier_uniform_(self.memory_trans_fc.weight) | |
| nn.init.normal_(self.level_embed) | |
| def pre_transformer(self, | |
| img_feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList = None | |
| ) -> Tuple[Dict]: | |
| """Process image features before feeding them to the transformer. | |
| Args: | |
| img_feats (tuple[Tensor]): Multi-level features that may have | |
| different resolutions, output from neck. Each feature has | |
| shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. | |
| batch_data_samples (list[:obj:`DetDataSample`], optional): The | |
| batch data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Defaults to None. | |
| Returns: | |
| tuple[dict]: The first dict contains the inputs of encoder and the | |
| second dict contains the inputs of decoder. | |
| - encoder_inputs_dict (dict): The keyword args dictionary of | |
| `self.encoder()`. | |
| - decoder_inputs_dict (dict): The keyword args dictionary of | |
| `self.forward_decoder()`, which includes 'memory_mask'. | |
| """ | |
| batch_size = img_feats[0].size(0) | |
| # construct binary masks for the transformer. | |
| assert batch_data_samples is not None | |
| batch_input_shape = batch_data_samples[0].batch_input_shape | |
| img_shape_list = [sample.img_shape for sample in batch_data_samples] | |
| input_img_h, input_img_w = batch_input_shape | |
| masks = img_feats[0].new_ones((batch_size, input_img_h, input_img_w)) | |
| for img_id in range(batch_size): | |
| img_h, img_w = img_shape_list[img_id] | |
| masks[img_id, :img_h, :img_w] = 0 | |
| # NOTE following the official DETR repo, non-zero values representing | |
| # ignored positions, while zero values means valid positions. | |
| mlvl_masks = [] | |
| mlvl_pos_embeds = [] | |
| for feat in img_feats: | |
| mlvl_masks.append( | |
| F.interpolate(masks[None], | |
| size=feat.shape[-2:]).to(torch.bool).squeeze(0)) | |
| mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) | |
| feat_flatten = [] | |
| lvl_pos_embed_flatten = [] | |
| mask_flatten = [] | |
| spatial_shapes = [] | |
| for lvl, (feat, mask, pos_embed) in enumerate( | |
| zip(img_feats, mlvl_masks, mlvl_pos_embeds)): | |
| batch_size, c, h, w = feat.shape | |
| # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] | |
| feat = feat.view(batch_size, c, -1).permute(0, 2, 1) | |
| pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) | |
| lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) | |
| # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] | |
| mask = mask.flatten(1) | |
| spatial_shape = (h, w) | |
| feat_flatten.append(feat) | |
| lvl_pos_embed_flatten.append(lvl_pos_embed) | |
| mask_flatten.append(mask) | |
| spatial_shapes.append(spatial_shape) | |
| # (bs, num_feat_points, dim) | |
| feat_flatten = torch.cat(feat_flatten, 1) | |
| lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) | |
| # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) | |
| mask_flatten = torch.cat(mask_flatten, 1) | |
| spatial_shapes = torch.as_tensor( # (num_level, 2) | |
| spatial_shapes, | |
| dtype=torch.long, | |
| device=feat_flatten.device) | |
| level_start_index = torch.cat(( | |
| spatial_shapes.new_zeros((1, )), # (num_level) | |
| spatial_shapes.prod(1).cumsum(0)[:-1])) | |
| valid_ratios = torch.stack( # (bs, num_level, 2) | |
| [self.get_valid_ratio(m) for m in mlvl_masks], 1) | |
| if self.refine_queries_num > 0 or batch_data_samples is not None: | |
| input_query_label, input_query_bbox, humandet_attn_mask, \ | |
| human2pose_attn_mask, mask_dict =\ | |
| self.prepare_for_denosing( | |
| batch_data_samples, | |
| device=img_feats[0].device) | |
| else: | |
| assert batch_data_samples is None | |
| input_query_bbox = input_query_label = \ | |
| humandet_attn_mask = human2pose_attn_mask = mask_dict = None | |
| encoder_inputs_dict = dict( | |
| query=feat_flatten, | |
| query_pos=lvl_pos_embed_flatten, | |
| key_padding_mask=mask_flatten, | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index, | |
| valid_ratios=valid_ratios) | |
| decoder_inputs_dict = dict( | |
| memory_mask=mask_flatten, | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index, | |
| valid_ratios=valid_ratios, | |
| humandet_attn_mask=humandet_attn_mask, | |
| human2pose_attn_mask=human2pose_attn_mask, | |
| input_query_bbox=input_query_bbox, | |
| input_query_label=input_query_label, | |
| mask_dict=mask_dict) | |
| return encoder_inputs_dict, decoder_inputs_dict | |
| def forward_encoder(self, | |
| img_feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList = None) -> Dict: | |
| """Forward with Transformer encoder. | |
| The forward procedure is defined as: | |
| 'pre_transformer' -> 'encoder' | |
| Args: | |
| img_feats (tuple[Tensor]): Multi-level features that may have | |
| different resolutions, output from neck. Each feature has | |
| shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. | |
| batch_data_samples (list[:obj:`DetDataSample`], optional): The | |
| batch data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Defaults to None. | |
| Returns: | |
| dict: The dictionary of encoder outputs, which includes the | |
| `memory` of the encoder output. | |
| """ | |
| encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( | |
| img_feats, batch_data_samples) | |
| memory = self.encoder(**encoder_inputs_dict) | |
| encoder_outputs_dict = dict(memory=memory, **decoder_inputs_dict) | |
| return encoder_outputs_dict | |
| def pre_decoder(self, memory: Tensor, memory_mask: Tensor, | |
| spatial_shapes: Tensor, input_query_bbox: Tensor, | |
| input_query_label: Tensor) -> Tuple[Dict, Dict]: | |
| """Prepare intermediate variables before entering Transformer decoder, | |
| such as `query` and `reference_points`. | |
| Args: | |
| memory (Tensor): The output embeddings of the Transformer encoder, | |
| has shape (bs, num_feat_points, dim). | |
| memory_mask (Tensor): ByteTensor, the padding mask of the memory, | |
| has shape (bs, num_feat_points). It will only be used when | |
| `as_two_stage` is `True`. | |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
| has shape (num_levels, 2), last dimension represents (h, w). | |
| It will only be used when `as_two_stage` is `True`. | |
| input_query_bbox (Tensor): Denosing bbox query for training. | |
| input_query_label (Tensor): Denosing label query for training. | |
| Returns: | |
| tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict. | |
| - decoder_inputs_dict (dict): The keyword dictionary args of | |
| `self.decoder()`. | |
| - head_inputs_dict (dict): The keyword dictionary args of the | |
| bbox_head functions. | |
| """ | |
| bs, _, c = memory.shape | |
| if self.as_two_stage: | |
| output_memory, output_proposals = \ | |
| self.gen_encoder_output_proposals( | |
| memory, memory_mask, spatial_shapes) | |
| enc_outputs_class = self.enc_out_class_embed(output_memory) | |
| enc_outputs_coord_unact = self.enc_out_bbox_embed( | |
| output_memory) + output_proposals | |
| topk_proposals = torch.topk( | |
| enc_outputs_class.max(-1)[0], self.num_queries, dim=1)[1] | |
| topk_coords_undetach = torch.gather( | |
| enc_outputs_coord_unact, 1, | |
| topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) | |
| topk_coords_unact = topk_coords_undetach.detach() | |
| reference_points = topk_coords_unact.sigmoid() | |
| query_undetach = torch.gather( | |
| output_memory, 1, | |
| topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) | |
| query = query_undetach.detach() | |
| if input_query_bbox is not None: | |
| reference_points = torch.cat( | |
| [input_query_bbox, topk_coords_unact], dim=1).sigmoid() | |
| query = torch.cat([input_query_label, query], dim=1) | |
| if self.two_stage_keep_all_tokens: | |
| hidden_states_enc = output_memory.unsqueeze(0) | |
| referens_enc = enc_outputs_coord_unact.unsqueeze(0) | |
| else: | |
| hidden_states_enc = query_undetach.unsqueeze(0) | |
| referens_enc = topk_coords_undetach.sigmoid().unsqueeze(0) | |
| else: | |
| hidden_states_enc, referens_enc = None, None | |
| query = self.query_embedding.weight[:, None, :].repeat( | |
| 1, bs, 1).transpose(0, 1) | |
| reference_points = \ | |
| self.refpoint_embedding.weight[:, None, :].repeat(1, bs, 1) | |
| if input_query_bbox is not None: | |
| reference_points = torch.cat( | |
| [input_query_bbox, reference_points], dim=1) | |
| query = torch.cat([input_query_label, query], dim=1) | |
| reference_points = reference_points.sigmoid() | |
| decoder_inputs_dict = dict( | |
| query=query, reference_points=reference_points) | |
| head_inputs_dict = dict( | |
| hidden_states_enc=hidden_states_enc, referens_enc=referens_enc) | |
| return decoder_inputs_dict, head_inputs_dict | |
| def forward_decoder(self, memory: Tensor, memory_mask: Tensor, | |
| spatial_shapes: Tensor, level_start_index: Tensor, | |
| valid_ratios: Tensor, humandet_attn_mask: Tensor, | |
| human2pose_attn_mask: Tensor, input_query_bbox: Tensor, | |
| input_query_label: Tensor, mask_dict: Dict) -> Dict: | |
| """Forward with Transformer decoder. | |
| The forward procedure is defined as: | |
| 'pre_decoder' -> 'decoder' | |
| Args: | |
| memory (Tensor): The output embeddings of the Transformer encoder, | |
| has shape (bs, num_feat_points, dim). | |
| memory_mask (Tensor): ByteTensor, the padding mask of the memory, | |
| has shape (bs, num_feat_points). | |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
| has shape (num_levels, 2), last dimension represents (h, w). | |
| level_start_index (Tensor): The start index of each level. | |
| A tensor has shape (num_levels, ) and can be represented | |
| as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. | |
| valid_ratios (Tensor): The ratios of the valid width and the valid | |
| height relative to the width and the height of features in all | |
| levels, has shape (bs, num_levels, 2). | |
| humandet_attn_mask (Tensor): Human attention mask. | |
| human2pose_attn_mask (Tensor): Human to pose attention mask. | |
| input_query_bbox (Tensor): Denosing bbox query for training. | |
| input_query_label (Tensor): Denosing label query for training. | |
| Returns: | |
| dict: The dictionary of decoder outputs, which includes the | |
| `hidden_states` of the decoder output and `references` including | |
| the initial and intermediate reference_points. | |
| """ | |
| decoder_in, head_in = self.pre_decoder(memory, memory_mask, | |
| spatial_shapes, | |
| input_query_bbox, | |
| input_query_label) | |
| inter_states, inter_references = self.decoder( | |
| query=decoder_in['query'].transpose(0, 1), | |
| value=memory.transpose(0, 1), | |
| key_padding_mask=memory_mask, # for cross_attn | |
| reference_points=decoder_in['reference_points'].transpose(0, 1), | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index, | |
| valid_ratios=valid_ratios, | |
| humandet_attn_mask=humandet_attn_mask, | |
| human2pose_attn_mask=human2pose_attn_mask) | |
| references = inter_references | |
| decoder_outputs_dict = dict( | |
| hidden_states=inter_states, | |
| references=references, | |
| mask_dict=mask_dict) | |
| decoder_outputs_dict.update(head_in) | |
| return decoder_outputs_dict | |
| def forward_out_head(self, batch_data_samples: OptSampleList, | |
| hidden_states: List[Tensor], references: List[Tensor], | |
| mask_dict: Dict, hidden_states_enc: Tensor, | |
| referens_enc: Tensor) -> Tuple[Tensor]: | |
| """Forward function.""" | |
| out = self.out_head(hidden_states, references, mask_dict, | |
| hidden_states_enc, referens_enc, | |
| batch_data_samples) | |
| return out | |
| def predict(self, | |
| feats: Features, | |
| batch_data_samples: OptSampleList, | |
| test_cfg: ConfigType = {}) -> Predictions: | |
| """Predict results from features.""" | |
| input_shapes = np.array( | |
| [d.metainfo['input_size'] for d in batch_data_samples]) | |
| if test_cfg.get('flip_test', False): | |
| assert NotImplementedError( | |
| 'flip_test is currently not supported ' | |
| 'for EDPose. Please set `model.test_cfg.flip_test=False`') | |
| else: | |
| pred_logits, pred_boxes, pred_keypoints = self.forward( | |
| feats, batch_data_samples) # (B, K, D) | |
| pred = self.decode( | |
| input_shapes, | |
| pred_logits=pred_logits, | |
| pred_boxes=pred_boxes, | |
| pred_keypoints=pred_keypoints) | |
| return pred | |
| def decode(self, input_shapes: np.ndarray, pred_logits: Tensor, | |
| pred_boxes: Tensor, pred_keypoints: Tensor): | |
| """Select the final top-k keypoints, and decode the results from | |
| normalize size to origin input size. | |
| Args: | |
| input_shapes (Tensor): The size of input image. | |
| pred_logits (Tensor): The result of score. | |
| pred_boxes (Tensor): The result of bbox. | |
| pred_keypoints (Tensor): The result of keypoints. | |
| Returns: | |
| """ | |
| if self.data_decoder is None: | |
| raise RuntimeError(f'The data decoder has not been set in \ | |
| {self.__class__.__name__}. ' | |
| 'Please set the data decoder configs in \ | |
| the init parameters to ' | |
| 'enable head methods `head.predict()` and \ | |
| `head.decode()`') | |
| preds = [] | |
| pred_logits = pred_logits.sigmoid() | |
| pred_logits, pred_boxes, pred_keypoints = to_numpy( | |
| [pred_logits, pred_boxes, pred_keypoints]) | |
| for input_shape, pred_logit, pred_bbox, pred_kpts in zip( | |
| input_shapes, pred_logits, pred_boxes, pred_keypoints): | |
| bboxes, keypoints, keypoint_scores = self.data_decoder.decode( | |
| input_shape, pred_logit, pred_bbox, pred_kpts) | |
| # pack outputs | |
| preds.append( | |
| InstanceData( | |
| keypoints=keypoints, | |
| keypoint_scores=keypoint_scores, | |
| bboxes=bboxes)) | |
| return preds | |
| def gen_encoder_output_proposals(self, memory: Tensor, memory_mask: Tensor, | |
| spatial_shapes: Tensor | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Generate proposals from encoded memory. The function will only be | |
| used when `as_two_stage` is `True`. | |
| Args: | |
| memory (Tensor): The output embeddings of the Transformer encoder, | |
| has shape (bs, num_feat_points, dim). | |
| memory_mask (Tensor): ByteTensor, the padding mask of the memory, | |
| has shape (bs, num_feat_points). | |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
| has shape (num_levels, 2), last dimension represents (h, w). | |
| Returns: | |
| tuple: A tuple of transformed memory and proposals. | |
| - output_memory (Tensor): The transformed memory for obtaining | |
| top-k proposals, has shape (bs, num_feat_points, dim). | |
| - output_proposals (Tensor): The inverse-normalized proposal, has | |
| shape (batch_size, num_keys, 4) with the last dimension arranged | |
| as (cx, cy, w, h). | |
| """ | |
| bs = memory.size(0) | |
| proposals = [] | |
| _cur = 0 # start index in the sequence of the current level | |
| for lvl, (H, W) in enumerate(spatial_shapes): | |
| mask_flatten_ = memory_mask[:, | |
| _cur:(_cur + H * W)].view(bs, H, W, 1) | |
| valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1) | |
| valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1) | |
| grid_y, grid_x = torch.meshgrid( | |
| torch.linspace( | |
| 0, H - 1, H, dtype=torch.float32, device=memory.device), | |
| torch.linspace( | |
| 0, W - 1, W, dtype=torch.float32, device=memory.device)) | |
| grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) | |
| scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2) | |
| grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale | |
| wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) | |
| proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) | |
| proposals.append(proposal) | |
| _cur += (H * W) | |
| output_proposals = torch.cat(proposals, 1) | |
| output_proposals_valid = ((output_proposals > 0.01) & | |
| (output_proposals < 0.99)).all( | |
| -1, keepdim=True) | |
| output_proposals = inverse_sigmoid(output_proposals) | |
| output_proposals = output_proposals.masked_fill( | |
| memory_mask.unsqueeze(-1), float('inf')) | |
| output_proposals = output_proposals.masked_fill( | |
| ~output_proposals_valid, float('inf')) | |
| output_memory = memory | |
| output_memory = output_memory.masked_fill( | |
| memory_mask.unsqueeze(-1), float(0)) | |
| output_memory = output_memory.masked_fill(~output_proposals_valid, | |
| float(0)) | |
| output_memory = self.memory_trans_fc(output_memory) | |
| output_memory = self.memory_trans_norm(output_memory) | |
| # [bs, sum(hw), 2] | |
| return output_memory, output_proposals | |
| def default_init_cfg(self): | |
| init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] | |
| return init_cfg | |
| def prepare_for_denosing(self, targets: OptSampleList, device): | |
| """prepare for dn components in forward function.""" | |
| if not self.training: | |
| bs = len(targets) | |
| attn_mask_infere = torch.zeros( | |
| bs, | |
| self.num_heads, | |
| self.num_group * (self.num_keypoints + 1), | |
| self.num_group * (self.num_keypoints + 1), | |
| device=device, | |
| dtype=torch.bool) | |
| group_bbox_kpt = (self.num_keypoints + 1) | |
| kpt_index = [ | |
| x for x in range(self.num_group * (self.num_keypoints + 1)) | |
| if x % (self.num_keypoints + 1) == 0 | |
| ] | |
| for matchj in range(self.num_group * (self.num_keypoints + 1)): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt | |
| if sj > 0: | |
| attn_mask_infere[:, :, matchj, :sj] = True | |
| if ej < self.num_group * (self.num_keypoints + 1): | |
| attn_mask_infere[:, :, matchj, ej:] = True | |
| for match_x in range(self.num_group * (self.num_keypoints + 1)): | |
| if match_x % group_bbox_kpt == 0: | |
| attn_mask_infere[:, :, match_x, kpt_index] = False | |
| attn_mask_infere = attn_mask_infere.flatten(0, 1) | |
| return None, None, None, attn_mask_infere, None | |
| # targets, dn_scalar, noise_scale = dn_args | |
| device = targets[0]['boxes'].device | |
| bs = len(targets) | |
| refine_queries_num = self.refine_queries_num | |
| # gather gt boxes and labels | |
| gt_boxes = [t['boxes'] for t in targets] | |
| gt_labels = [t['labels'] for t in targets] | |
| gt_keypoints = [t['keypoints'] for t in targets] | |
| # repeat them | |
| def get_indices_for_repeat(now_num, target_num, device='cuda'): | |
| """ | |
| Input: | |
| - now_num: int | |
| - target_num: int | |
| Output: | |
| - indices: tensor[target_num] | |
| """ | |
| out_indice = [] | |
| base_indice = torch.arange(now_num).to(device) | |
| multiplier = target_num // now_num | |
| out_indice.append(base_indice.repeat(multiplier)) | |
| residue = target_num % now_num | |
| out_indice.append(base_indice[torch.randint( | |
| 0, now_num, (residue, ), device=device)]) | |
| return torch.cat(out_indice) | |
| gt_boxes_expand = [] | |
| gt_labels_expand = [] | |
| gt_keypoints_expand = [] | |
| for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( | |
| zip(gt_boxes, gt_labels, gt_keypoints)): | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i > 0: | |
| indices = get_indices_for_repeat(num_gt_i, refine_queries_num, | |
| device) | |
| gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 | |
| gt_labels_expand_i = gt_labels_i[indices] | |
| gt_keypoints_expand_i = gt_keypoint_i[indices] | |
| else: | |
| # all negative samples when no gt boxes | |
| gt_boxes_expand_i = torch.rand( | |
| refine_queries_num, 4, device=device) | |
| gt_labels_expand_i = torch.ones( | |
| refine_queries_num, dtype=torch.int64, | |
| device=device) * int(self.num_classes) | |
| gt_keypoints_expand_i = torch.rand( | |
| refine_queries_num, self.num_keypoints * 3, device=device) | |
| gt_boxes_expand.append(gt_boxes_expand_i) | |
| gt_labels_expand.append(gt_labels_expand_i) | |
| gt_keypoints_expand.append(gt_keypoints_expand_i) | |
| gt_boxes_expand = torch.stack(gt_boxes_expand) | |
| gt_labels_expand = torch.stack(gt_labels_expand) | |
| gt_keypoints_expand = torch.stack(gt_keypoints_expand) | |
| knwon_boxes_expand = gt_boxes_expand.clone() | |
| knwon_labels_expand = gt_labels_expand.clone() | |
| # add noise | |
| if self.denosing_cfg['dn_label_noise_ratio'] > 0: | |
| prob = torch.rand_like(knwon_labels_expand.float()) | |
| chosen_indice = prob < self.denosing_cfg['dn_label_noise_ratio'] | |
| new_label = torch.randint_like( | |
| knwon_labels_expand[chosen_indice], 0, | |
| self.dn_labelbook_size) # randomly put a new one here | |
| knwon_labels_expand[chosen_indice] = new_label | |
| if self.denosing_cfg['dn_box_noise_scale'] > 0: | |
| diff = torch.zeros_like(knwon_boxes_expand) | |
| diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 | |
| diff[..., 2:] = knwon_boxes_expand[..., 2:] | |
| knwon_boxes_expand += torch.mul( | |
| (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), | |
| diff) * self.denosing_cfg['dn_box_noise_scale'] | |
| knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) | |
| input_query_label = self.label_enc(knwon_labels_expand) | |
| input_query_bbox = inverse_sigmoid(knwon_boxes_expand) | |
| # prepare mask | |
| if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: | |
| attn_mask = torch.zeros( | |
| bs, | |
| self.num_heads, | |
| refine_queries_num + self.num_queries, | |
| refine_queries_num + self.num_queries, | |
| device=device, | |
| dtype=torch.bool) | |
| attn_mask[:, :, refine_queries_num:, :refine_queries_num] = True | |
| for idx, (gt_boxes_i, | |
| gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(refine_queries_num): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask[idx, :, matchi, :si] = True | |
| if ei < refine_queries_num: | |
| attn_mask[idx, :, matchi, ei:refine_queries_num] = True | |
| attn_mask = attn_mask.flatten(0, 1) | |
| if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: | |
| attn_mask2 = torch.zeros( | |
| bs, | |
| self.num_heads, | |
| refine_queries_num + self.num_group * (self.num_keypoints + 1), | |
| refine_queries_num + self.num_group * (self.num_keypoints + 1), | |
| device=device, | |
| dtype=torch.bool) | |
| attn_mask2[:, :, refine_queries_num:, :refine_queries_num] = True | |
| group_bbox_kpt = (self.num_keypoints + 1) | |
| kpt_index = [ | |
| x for x in range(self.num_group * (self.num_keypoints + 1)) | |
| if x % (self.num_keypoints + 1) == 0 | |
| ] | |
| for matchj in range(self.num_group * (self.num_keypoints + 1)): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt | |
| if sj > 0: | |
| attn_mask2[:, :, refine_queries_num:, | |
| refine_queries_num:][:, :, matchj, :sj] = True | |
| if ej < self.num_group * (self.num_keypoints + 1): | |
| attn_mask2[:, :, refine_queries_num:, | |
| refine_queries_num:][:, :, matchj, ej:] = True | |
| for match_x in range(self.num_group * (self.num_keypoints + 1)): | |
| if match_x % group_bbox_kpt == 0: | |
| attn_mask2[:, :, refine_queries_num:, | |
| refine_queries_num:][:, :, match_x, | |
| kpt_index] = False | |
| for idx, (gt_boxes_i, | |
| gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(refine_queries_num): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask2[idx, :, matchi, :si] = True | |
| if ei < refine_queries_num: | |
| attn_mask2[idx, :, matchi, | |
| ei:refine_queries_num] = True | |
| attn_mask2 = attn_mask2.flatten(0, 1) | |
| mask_dict = { | |
| 'pad_size': refine_queries_num, | |
| 'known_bboxs': gt_boxes_expand, | |
| 'known_labels': gt_labels_expand, | |
| 'known_keypoints': gt_keypoints_expand | |
| } | |
| return input_query_label, input_query_bbox, \ | |
| attn_mask, attn_mask2, mask_dict | |
| def loss(self, | |
| feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList, | |
| train_cfg: OptConfigType = {}) -> dict: | |
| """Calculate losses from a batch of inputs and data samples.""" | |
| assert NotImplementedError( | |
| 'the training of EDPose has not been ' | |
| 'supported. Please stay tuned for further update.') | |