Spaces:
Sleeping
Sleeping
import copy | |
import pdb | |
import os | |
import math | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch import Tensor | |
from util import box_ops | |
from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz | |
from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, | |
get_world_size, interpolate, | |
is_dist_avail_and_initialized, inverse_sigmoid) | |
from .backbones import build_backbone | |
from .matcher import build_matcher | |
from .transformer import build_transformer | |
from .utils import PoseProjector, sigmoid_focal_loss, MLP | |
from .postprocesses import PostProcess_SMPLX, PostProcess_aios | |
from .postprocesses import PostProcess_SMPLX_Multi as PostProcess_SMPLX | |
from .postprocesses import PostProcess_SMPLX_Multi_Box | |
from .postprocesses import PostProcess_SMPLX_Multi_Infer, PostProcess_SMPLX_Multi_Infer_Box | |
from .criterion_smplx import SetCriterion, SetCriterion_Box | |
from ..registry import MODULE_BUILD_FUNCS | |
from detrsmpl.core.conventions.keypoints_mapping import convert_kps | |
from detrsmpl.models.body_models.builder import build_body_model | |
from util.human_models import smpl_x | |
from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idxs_by_part | |
import numpy as np | |
import random | |
from detrsmpl.utils.geometry import (rot6d_to_rotmat) | |
from detrsmpl.utils.transforms import rotmat_to_aa | |
import cv2 | |
from config.config import cfg | |
class AiOSSMPLX(nn.Module): | |
def __init__( | |
self, | |
backbone, | |
transformer, | |
num_classes, | |
num_queries, | |
aux_loss=False, | |
iter_update=True, | |
query_dim=4, | |
random_refpoints_xy=False, | |
fix_refpoints_hw=-1, | |
num_feature_levels=1, | |
nheads=8, | |
two_stage_type='no', | |
dec_pred_class_embed_share=False, | |
dec_pred_bbox_embed_share=False, | |
dec_pred_pose_embed_share=False, | |
two_stage_class_embed_share=True, | |
two_stage_bbox_embed_share=True, | |
dn_number=100, | |
dn_box_noise_scale=0.4, | |
dn_label_noise_ratio=0.5, | |
dn_batch_gt_fuse=False, | |
dn_labelbook_size=100, | |
dn_attn_mask_type_list=['group2group'], | |
cls_no_bias=False, | |
num_group=100, | |
num_body_points=17, | |
num_hand_points=10, | |
num_face_points=10, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
body_model=dict( | |
type='smplx', | |
keypoint_src='smplx', | |
num_expression_coeffs=10, | |
keypoint_dst='smplx_137', | |
model_path='data/body_models/smplx', | |
use_pca=False, | |
use_face_contour=True), | |
train=True, | |
inference=False, | |
focal_length=[5000., 5000.], | |
camera_3d_size=2.5 | |
): | |
super().__init__() | |
self.num_queries = num_queries | |
self.transformer = transformer | |
self.num_classes = num_classes | |
self.hidden_dim = hidden_dim = transformer.d_model | |
self.num_feature_levels = num_feature_levels | |
self.nheads = nheads | |
self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) | |
self.num_body_points = num_body_points | |
self.num_hand_points = num_hand_points | |
self.num_face_points = num_face_points | |
self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points | |
self.num_box_decoder_layers = num_box_decoder_layers | |
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.focal_length = focal_length | |
self.camera_3d_size=camera_3d_size | |
self.inference = inference | |
if train: | |
self.smpl_convention = 'smplx' | |
else: | |
self.smpl_convention = 'h36m' | |
# setting query dim | |
self.query_dim = query_dim | |
assert query_dim == 4 | |
self.random_refpoints_xy = random_refpoints_xy # False | |
self.fix_refpoints_hw = fix_refpoints_hw # -1 | |
# for dn training | |
self.dn_number = dn_number | |
self.dn_box_noise_scale = dn_box_noise_scale | |
self.dn_label_noise_ratio = dn_label_noise_ratio | |
self.dn_batch_gt_fuse = dn_batch_gt_fuse | |
self.dn_labelbook_size = dn_labelbook_size | |
self.dn_attn_mask_type_list = dn_attn_mask_type_list | |
assert all([ | |
i in ['match2dn', 'dn2dn', 'group2group'] | |
for i in dn_attn_mask_type_list | |
]) | |
assert not dn_batch_gt_fuse | |
# build human body | |
# if train: | |
# self.body_model = build_body_model(body_model) | |
if inference: | |
body_model=dict( | |
type='smplx', | |
keypoint_src='smplx', | |
num_expression_coeffs=10, | |
num_betas=10, | |
keypoint_dst='smplx', | |
model_path='data/body_models/smplx', | |
use_pca=False, | |
use_face_contour=True) | |
self.body_model = build_body_model(body_model) | |
for param in self.body_model.parameters(): | |
param.requires_grad = False | |
# prepare input projection layers | |
if num_feature_levels > 1: | |
num_backbone_outs = len(backbone.num_channels) # 3 | |
input_proj_list = [] | |
for _ in range(num_backbone_outs): | |
in_channels = backbone.num_channels[_] | |
input_proj_list.append( | |
nn.Sequential( | |
nn.Conv2d(in_channels, hidden_dim, kernel_size=1), | |
nn.GroupNorm(32, hidden_dim), | |
)) | |
for _ in range(num_feature_levels - num_backbone_outs): | |
input_proj_list.append( | |
nn.Sequential( | |
nn.Conv2d(in_channels, | |
hidden_dim, | |
kernel_size=3, | |
stride=2, | |
padding=1), | |
nn.GroupNorm(32, hidden_dim), | |
)) | |
in_channels = hidden_dim | |
self.input_proj = nn.ModuleList(input_proj_list) | |
else: | |
assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!' | |
self.input_proj = nn.ModuleList([ | |
nn.Sequential( | |
nn.Conv2d(backbone.num_channels[-1], | |
hidden_dim, | |
kernel_size=1), | |
nn.GroupNorm(32, hidden_dim), | |
) | |
]) | |
self.backbone = backbone | |
self.aux_loss = aux_loss | |
self.box_pred_damping = box_pred_damping = None | |
self.iter_update = iter_update | |
assert iter_update, 'Why not iter_update?' | |
# prepare pred layers | |
self.dec_pred_class_embed_share = dec_pred_class_embed_share # false | |
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false | |
# 1.1 prepare class & box embed | |
_class_embed = nn.Linear(hidden_dim, | |
num_classes, | |
bias=(not cls_no_bias)) | |
if not cls_no_bias: | |
prior_prob = 0.01 | |
bias_value = -math.log((1 - prior_prob) / prior_prob) | |
_class_embed.bias.data = torch.ones(self.num_classes) * bias_value | |
# 1.2 box embed layer list | |
if dec_pred_class_embed_share: | |
class_embed_layerlist = [ | |
_class_embed for i in range(transformer.num_decoder_layers) | |
] | |
else: | |
class_embed_layerlist = [ | |
copy.deepcopy(_class_embed) | |
for i in range(transformer.num_decoder_layers) | |
] | |
########################################################################### | |
# body bbox + l/r hand box + face box | |
########################################################################### | |
# 1.1 body bbox embed | |
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) | |
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) | |
# 1.2 body bbox embed layer list | |
self.num_group = num_group | |
if dec_pred_bbox_embed_share: | |
box_body_embed_layerlist = [ | |
_bbox_embed for i in range(transformer.num_decoder_layers) | |
] | |
else: | |
box_body_embed_layerlist = [ | |
copy.deepcopy(_bbox_embed) | |
for i in range(transformer.num_decoder_layers) | |
] | |
# 2.1 lhand bbox embed | |
_bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4 | |
nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0) | |
_bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0) | |
# 2.2 lhand bbox embed layer list | |
if dec_pred_pose_embed_share: | |
box_hand_embed_layerlist = \ | |
[_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] | |
else: | |
box_hand_embed_layerlist = [ | |
copy.deepcopy(_bbox_hand_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers + 1) | |
] | |
if dec_pred_pose_embed_share: | |
box_hand_hw_embed_layerlist = [ | |
_bbox_hand_hw_embed for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers) | |
] | |
else: | |
box_hand_hw_embed_layerlist = [ | |
copy.deepcopy(_bbox_hand_hw_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 4.1 face bbox embed | |
_bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0) | |
_bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0) | |
# 4.2 face bbox embed layer list | |
if dec_pred_pose_embed_share: | |
box_face_embed_layerlist = [ | |
_bbox_face_embed for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers + 1) | |
] | |
else: | |
box_face_embed_layerlist = [ | |
copy.deepcopy(_bbox_face_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers + 1) | |
] | |
if dec_pred_pose_embed_share: | |
box_face_hw_embed_layerlist = [ | |
_bbox_face_hw_embed for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers)] | |
else: | |
box_face_hw_embed_layerlist = [ | |
copy.deepcopy(_bbox_face_hw_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
########################################################################### | |
# body kp2d + l/r hand kp2d + face kp2d | |
########################################################################### | |
######## body ####### | |
# 1.1 body kp2d embed | |
_pose_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_pose_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_pose_embed.layers[-1].bias.data, 0) | |
# 1.2 body kp2d embed layer list | |
if num_body_points == 17: | |
if dec_pred_pose_embed_share: | |
pose_embed_layerlist = \ | |
[_pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] | |
else: | |
pose_embed_layerlist = [ | |
copy.deepcopy(_pose_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers + 1) | |
] | |
else: | |
if dec_pred_pose_embed_share: | |
pose_embed_layerlist = [ | |
_pose_embed for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
pose_embed_layerlist = [ | |
copy.deepcopy(_pose_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 1.3 body kp bbox embed | |
_pose_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
# 1.4 body kp bbox embed layer list | |
pose_hw_embed_layerlist = [ | |
_pose_hw_embed for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
######## lhand ####### | |
# 2.1 lhand kp2d embed | |
_pose_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_pose_hand_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_pose_hand_embed.layers[-1].bias.data, 0) | |
# 2.2 lhand kp2d embed layer list | |
if dec_pred_pose_embed_share: | |
pose_hand_embed_layerlist = \ | |
[_pose_hand_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)] | |
else: | |
pose_hand_embed_layerlist = [ | |
copy.deepcopy(_pose_hand_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_hand_face_decoder_layers + 1) | |
] | |
# 2.3 lhand kp bbox embed | |
_pose_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
# 2.4 lhand kp bbox embed layer list | |
pose_hand_hw_embed_layerlist = [ | |
_pose_hand_hw_embed for i in range(transformer.num_decoder_layers - | |
num_hand_face_decoder_layers) | |
] | |
######## face ####### | |
# 4.1 face kp2d embed | |
_pose_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_pose_face_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_pose_face_embed.layers[-1].bias.data, 0) | |
# 4.2 face kp2d embed layer list | |
if dec_pred_pose_embed_share: | |
pose_face_embed_layerlist = \ | |
[_pose_face_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)] | |
else: | |
pose_face_embed_layerlist = [ | |
copy.deepcopy(_pose_face_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_hand_face_decoder_layers + 1) | |
] | |
# 4.3 face kp bbox embed | |
_pose_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
# 4.4 face kp bbox embed layer list | |
pose_face_hw_embed_layerlist = [ | |
_pose_face_hw_embed for i in range(transformer.num_decoder_layers - | |
num_hand_face_decoder_layers) | |
] | |
########################################################################### | |
# smpl pose + betas + kp2d + kp3d + cam | |
########################################################################### | |
# 1. smpl pose embed | |
if body_model['type'].upper()=='SMPL': | |
self.body_model_joint_num = 24 | |
elif body_model['type'].upper()=='SMPLX': | |
self.body_model_joint_num = 22 | |
else: | |
raise ValueError( | |
f'Only supports SMPL or SMPLX, but get {body_model.type}') | |
#TODO: | |
_smpl_pose_embed = MLP(hidden_dim * (self.num_body_points + 4), | |
hidden_dim, self.body_model_joint_num * 6, 3) | |
nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smpl_pose_embed_layerlist = [ | |
_smpl_pose_embed | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smpl_pose_embed_layerlist = [ | |
copy.deepcopy(_smpl_pose_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 2. smpl betas embed | |
_smpl_beta_embed = MLP(hidden_dim * (self.num_body_points + 4), | |
hidden_dim, 10, 3) | |
nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smpl_beta_embed_layerlist = [ | |
_smpl_beta_embed | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smpl_beta_embed_layerlist = [ | |
copy.deepcopy(_smpl_beta_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 3. smpl cam embed | |
_cam_embed = MLP(hidden_dim * (self.num_body_points + 4), hidden_dim, | |
3, 3) | |
nn.init.constant_(_cam_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_cam_embed.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
cam_embed_layerlist = [ | |
_cam_embed for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
cam_embed_layerlist = [ | |
copy.deepcopy(_cam_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
########################################################################### | |
# smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam | |
########################################################################### | |
# 1. smplx body pose embed | |
# _smplx_pose_embed = MLP(hidden_dim * (self.num_body_points + 1), | |
# hidden_dim, 23 * 6, 3) | |
# nn.init.constant_(_smplx_pose_embed.layers[-1].weight.data, 0) | |
# nn.init.constant_(_smplx_pose_embed.layers[-1].bias.data, 0) | |
# if dec_pred_bbox_embed_share: | |
# smplx_pose_embed_layerlist = [ | |
# _smplx_pose_embed | |
# for i in range(transformer.num_decoder_layers - | |
# num_box_decoder_layers + 1) | |
# ] | |
# else: | |
# smplx_pose_embed_layerlist = [ | |
# copy.deepcopy(_smplx_pose_embed) | |
# for i in range(transformer.num_decoder_layers - | |
# num_box_decoder_layers + 1) | |
# ] | |
# 2. smplx hand pose embed | |
_smplx_hand_pose_embed_layer_2_3 = \ | |
MLP(hidden_dim, hidden_dim, 15 * 6, 3) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0) | |
_smplx_hand_pose_embed_layer_4_5 = \ | |
MLP(hidden_dim * (self.num_hand_points + 3), hidden_dim, 15 * 6, 3) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smplx_hand_pose_embed_layerlist = [ | |
_smplx_hand_pose_embed_layer_2_3 | |
if i<2 else _smplx_hand_pose_embed_layer_4_5 | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smplx_hand_pose_embed_layerlist = [ | |
copy.deepcopy(_smplx_hand_pose_embed_layer_2_3) | |
if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 3. smplx face expression | |
_smplx_expression_embed_layer_2_3 = \ | |
MLP(hidden_dim, hidden_dim, 10, 3) | |
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0) | |
_smplx_expression_embed_layer_4_5 = \ | |
MLP(hidden_dim * (self.num_hand_points + 2), hidden_dim, 10, 3) | |
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smplx_expression_embed_layerlist = [ | |
_smplx_expression_embed_layer_2_3 | |
if i<2 else _smplx_expression_embed_layer_4_5 | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smplx_expression_embed_layerlist = [ | |
copy.deepcopy(_smplx_expression_embed_layer_2_3) | |
if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# _smplx_expression_embed = MLP(hidden_dim * (self.num_face_points + 2), | |
# hidden_dim, 10, 3) | |
# nn.init.constant_(_smplx_expression_embed.layers[-1].weight.data, 0) | |
# nn.init.constant_(_smplx_expression_embed.layers[-1].bias.data, 0) | |
# if dec_pred_bbox_embed_share: | |
# smplx_expression_embed_layerlist = [ | |
# _smplx_expression_embed | |
# for i in range(transformer.num_decoder_layers - | |
# num_hand_face_decoder_layers) | |
# ] | |
# else: | |
# smplx_expression_embed_layerlist = [ | |
# copy.deepcopy(_smplx_expression_embed) | |
# for i in range(transformer.num_decoder_layers - | |
# num_hand_face_decoder_layers) | |
# ] | |
# 4. smplx jaw pose embed | |
_smplx_jaw_embed_2_3 = MLP(hidden_dim * 1, | |
hidden_dim, 6, 3) | |
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0) | |
_smplx_jaw_embed_4_5 = MLP(hidden_dim * (self.num_face_points + 2), | |
hidden_dim, 6, 3) | |
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smplx_jaw_embed_layerlist = [ | |
_smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5 | |
for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers) | |
] | |
else: | |
smplx_jaw_embed_layerlist = [ | |
copy.deepcopy(_smplx_jaw_embed_2_3) | |
if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5) | |
for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers) | |
] | |
############### | |
self.bbox_embed = nn.ModuleList(box_body_embed_layerlist) | |
self.class_embed = nn.ModuleList(class_embed_layerlist) | |
self.pose_embed = nn.ModuleList(pose_embed_layerlist) | |
self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) | |
self.transformer.decoder.bbox_embed = self.bbox_embed | |
self.transformer.decoder.pose_embed = self.pose_embed | |
self.transformer.decoder.pose_hw_embed = self.pose_hw_embed | |
self.transformer.decoder.class_embed = self.class_embed | |
# smpl | |
self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist) | |
self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist) | |
self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist) | |
# self.smpl_cam_f_embed = nn.ModuleList(f_embed_layerlist) | |
# self.transformer.decoder.smpl_pose_embed = self.smpl_pose_embed | |
# self.transformer.decoder.smpl_beta_embed = self.smpl_beta_embed | |
# self.transformer.decoder.smpl_cam_embed = self.smpl_cam_embed | |
# smplx lhand kp | |
self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist) | |
self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist) | |
self.pose_hand_embed = nn.ModuleList(pose_hand_embed_layerlist) | |
self.pose_hand_hw_embed = nn.ModuleList(pose_hand_hw_embed_layerlist) | |
self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed | |
self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed | |
self.transformer.decoder.pose_hand_embed = self.pose_hand_embed | |
self.transformer.decoder.pose_hand_hw_embed = self.pose_hand_hw_embed | |
# smplx face kp | |
self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist) | |
self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist) | |
self.pose_face_embed = nn.ModuleList(pose_face_embed_layerlist) | |
self.pose_face_hw_embed = nn.ModuleList(pose_face_hw_embed_layerlist) | |
self.transformer.decoder.bbox_face_embed = self.bbox_face_embed | |
self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed | |
self.transformer.decoder.pose_face_embed = self.pose_face_embed | |
self.transformer.decoder.pose_face_hw_embed = self.pose_face_hw_embed | |
# smplx | |
self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist) | |
# self.smplx_rhand_pose_embed = nn.ModuleList(smplx_rhand_pose_embed_layerlist) | |
self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist) | |
self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist) | |
# self.transformer.decoder.smplx_hand_pose_embed = self.smplx_hand_pose_embed | |
# self.transformer.decoder.smplx_rhand_pose_embed = self.smplx_rhand_pose_embed | |
# self.transformer.decoder.num_whole_bosmpl_expr_embeddy_points = self.smplx_expression_embed | |
# self.transformer.decoder.smpl_jaw_embed = self.smplx_jaw_embed | |
######### | |
self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers | |
self.transformer.decoder.num_body_points = num_body_points | |
self.transformer.decoder.num_hand_points = num_hand_points | |
self.transformer.decoder.num_face_points = num_face_points | |
# two stage | |
self.two_stage_type = two_stage_type | |
assert two_stage_type in [ | |
'no', 'standard' | |
], 'unknown param {} of two_stage_type'.format(two_stage_type) | |
if two_stage_type != 'no': | |
if two_stage_bbox_embed_share: | |
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
self.transformer.enc_out_bbox_embed = _bbox_embed | |
else: | |
self.transformer.enc_out_bbox_embed = copy.deepcopy( | |
_bbox_embed) | |
if two_stage_class_embed_share: | |
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
self.transformer.enc_out_class_embed = _class_embed | |
else: | |
self.transformer.enc_out_class_embed = copy.deepcopy( | |
_class_embed) | |
self.refpoint_embed = None | |
self._reset_parameters() | |
def get_camera_trans(self, cam_param, input_body_shape): | |
# camera translation | |
t_xy = cam_param[:, :2] | |
gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive | |
k_value = torch.FloatTensor( | |
[ | |
math.sqrt( | |
self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size / | |
(input_body_shape[0] * input_body_shape[1]) | |
) | |
] | |
).cuda().view(-1) | |
t_z = k_value * gamma | |
cam_trans = torch.cat((t_xy, t_z[:, None]), 1) | |
return cam_trans | |
def _reset_parameters(self): | |
# init input_proj | |
for proj in self.input_proj: | |
nn.init.xavier_uniform_(proj[0].weight, gain=1) | |
nn.init.constant_(proj[0].bias, 0) | |
def prepare_for_dn2(self, targets): | |
if not self.training: | |
device = targets[0]['boxes'].device | |
bs = len(targets) | |
num_points = self.num_body_points + 4 | |
attn_mask2 = torch.zeros( | |
bs, | |
self.nheads, | |
self.num_group * num_points, | |
self.num_group * num_points, | |
device=device, | |
dtype=torch.bool) | |
group_bbox_kpt = num_points | |
group_nobbox_kpt = self.num_body_points | |
kpt_index = [ | |
x for x in range(self.num_group * num_points) | |
if x % num_points in [ | |
0, | |
self.num_body_points+1, | |
self.num_body_points+2, | |
self.num_body_points+3 | |
] | |
] | |
for matchj in range(self.num_group * num_points): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
if sj > 0: | |
attn_mask2[:, :, matchj, :sj] = True | |
if ej < self.num_group * num_points: | |
attn_mask2[:, :, matchj, ej:] = True | |
for match_x in range(self.num_group * num_points): | |
if match_x % group_bbox_kpt in [0, | |
self.num_body_points+1, | |
self.num_body_points+2, | |
self.num_body_points+3]: | |
attn_mask2[:,:,match_x,kpt_index]=False | |
num_points = self.num_whole_body_points + 4 | |
attn_mask3 = torch.zeros( | |
bs, | |
self.nheads, | |
self.num_group * (num_points), | |
self.num_group * (num_points), | |
device=device, | |
dtype=torch.bool) | |
group_bbox_kpt = (num_points) | |
# group_nobbox_kpt = self.num_body_points | |
kpt_index = [ | |
x for x in range(self.num_group * (num_points)) if x % (num_points) in | |
[0, | |
1+self.num_body_points, | |
2+self.num_body_points+self.num_hand_points, | |
3+self.num_body_points+self.num_hand_points*2 | |
] | |
] | |
for matchj in range(self.num_group * num_points): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
if sj > 0: | |
attn_mask3[:, :, matchj, :sj] = True | |
if ej < self.num_group * num_points: | |
attn_mask3[:, :, matchj, ej:] = True | |
for match_x in range(self.num_group * num_points): | |
if match_x % group_bbox_kpt in [ | |
0, | |
1 + self.num_body_points, | |
2 + self.num_body_points + self.num_hand_points, | |
3 + self.num_body_points + self.num_hand_points * 2]: | |
attn_mask3[:, :, match_x, kpt_index] = False | |
# num_points = self.num_whole_body_points + 4 | |
# device = targets[0]['boxes'].device | |
# bs = len(targets) | |
# attn_mask_infere = torch.zeros( | |
# bs, | |
# self.nheads, | |
# self.num_group * num_points, | |
# self.num_group * num_points, | |
# device=device, | |
# dtype=torch.bool) | |
# group_bbox_kpt = num_points | |
# group_nobbox_kpt = self.num_body_points | |
# kpt_index = [ | |
# x for x in range(self.num_group * num_points) | |
# if x % num_points == 0 | |
# ] | |
# for matchj in range(self.num_group * num_points): | |
# sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
# ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt | |
# if sj > 0: | |
# attn_mask_infere[:, :, matchj, :sj] = True | |
# if ej < self.num_group * num_points: | |
# attn_mask_infere[:, :, matchj, ej:] = True | |
# for match_x in range(self.num_group * num_points): | |
# if match_x % group_bbox_kpt == 0: | |
# attn_mask_infere[:, :, match_x, kpt_index] = False | |
# attn_mask_infere = attn_mask_infere.flatten(0, 1) | |
attn_mask2 = attn_mask2.flatten(0, 1) | |
attn_mask3 = attn_mask3.flatten(0, 1) | |
return None, None, None, attn_mask2, attn_mask3, None | |
# targets, dn_scalar, noise_scale = dn_args | |
device = targets[0]['boxes'].device | |
bs = len(targets) | |
dn_number = self.dn_number # 100 | |
dn_box_noise_scale = self.dn_box_noise_scale # 0.4 | |
dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5 | |
# gather gt boxes and labels | |
gt_boxes = [t['boxes'] for t in targets] | |
gt_labels = [t['labels'] for t in targets] | |
gt_keypoints = [t['keypoints'] for t in targets] | |
# repeat them | |
def get_indices_for_repeat(now_num, target_num, device='cuda'): | |
""" | |
Input: | |
- now_num: int | |
- target_num: int | |
Output: | |
- indices: tensor[target_num] | |
""" | |
out_indice = [] | |
base_indice = torch.arange(now_num).to(device) | |
multiplier = target_num // now_num | |
out_indice.append(base_indice.repeat(multiplier)) | |
residue = target_num % now_num | |
out_indice.append(base_indice[torch.randint(0, | |
now_num, (residue, ), | |
device=device)]) | |
return torch.cat(out_indice) | |
if self.dn_batch_gt_fuse: | |
raise NotImplementedError | |
gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4 | |
gt_labels_bsall = torch.cat(gt_labels) | |
num_gt_bsall = gt_boxes_bsall.shape[0] | |
if num_gt_bsall > 0: | |
indices = get_indices_for_repeat(num_gt_bsall, dn_number, | |
device) | |
gt_boxes_expand = gt_boxes_bsall[indices][None].repeat( | |
bs, 1, 1) # bs, num_dn, 4 | |
gt_labels_expand = gt_labels_bsall[indices][None].repeat( | |
bs, 1) # bs, num_dn | |
else: | |
# all negative samples when no gt boxes | |
gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device) | |
gt_labels_expand = torch.ones( | |
bs, dn_number, dtype=torch.int64, device=device) * int( | |
self.num_classes) | |
else: | |
gt_boxes_expand = [] | |
gt_labels_expand = [] | |
gt_keypoints_expand = [] # here | |
for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( | |
zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id | |
num_gt_i = gt_boxes_i.shape[0] # instance num | |
if num_gt_i > 0: | |
indices = get_indices_for_repeat(num_gt_i, dn_number, | |
device) | |
gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 | |
gt_labels_expand_i = gt_labels_i[indices] # add smpl | |
gt_keypoints_expand_i = gt_keypoint_i[indices] | |
else: | |
# all negative samples when no gt boxes | |
gt_boxes_expand_i = torch.rand(dn_number, 4, device=device) | |
gt_labels_expand_i = torch.ones( | |
dn_number, dtype=torch.int64, device=device) * int( | |
self.num_classes) | |
gt_keypoints_expand_i = torch.rand(dn_number, | |
self.num_body_points * | |
3, | |
device=device) | |
gt_boxes_expand.append(gt_boxes_expand_i) # add smpl | |
gt_labels_expand.append(gt_labels_expand_i) | |
gt_keypoints_expand.append(gt_keypoints_expand_i) | |
gt_boxes_expand = torch.stack(gt_boxes_expand) | |
gt_labels_expand = torch.stack(gt_labels_expand) | |
gt_keypoints_expand = torch.stack(gt_keypoints_expand) | |
knwon_boxes_expand = gt_boxes_expand.clone() | |
knwon_labels_expand = gt_labels_expand.clone() | |
# add noise | |
if dn_label_noise_ratio > 0: | |
prob = torch.rand_like(knwon_labels_expand.float()) | |
chosen_indice = prob < dn_label_noise_ratio | |
new_label = torch.randint_like( | |
knwon_labels_expand[chosen_indice], 0, | |
self.dn_labelbook_size) # randomly put a new one here | |
knwon_labels_expand[chosen_indice] = new_label | |
if dn_box_noise_scale > 0: | |
diff = torch.zeros_like(knwon_boxes_expand) | |
diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 | |
diff[..., 2:] = knwon_boxes_expand[..., 2:] | |
knwon_boxes_expand += torch.mul( | |
(torch.rand_like(knwon_boxes_expand) * 2 - 1.0), | |
diff) * dn_box_noise_scale | |
knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) | |
input_query_label = self.label_enc(knwon_labels_expand) | |
input_query_bbox = inverse_sigmoid(knwon_boxes_expand) | |
# prepare mask | |
body_mask, body_kps_mask, lhand_mask, lhand_kps_mask, rhand_mask, \ | |
rhand_kps_mask, face_mask, face_kps_mask = \ | |
False, False, False, False, False, False, False, False | |
if random.random() < 0.2: | |
body_mask = True | |
if random.random() < 0.5: | |
body_kps_mask = True | |
if random.random() < 0.2: | |
lhand_mask = True | |
if random.random() < 0.5: | |
lhand_kps_mask = True | |
if random.random() < 0.2: | |
rhand_mask = True | |
if random.random() < 0.5: | |
rhand_kps_mask = True | |
if random.random() < 0.2: | |
face_mask = True | |
if random.random() < 0.5: | |
face_kps_mask = True | |
if 'group2group' in self.dn_attn_mask_type_list: | |
attn_mask = torch.zeros(bs, | |
self.nheads, | |
dn_number + self.num_queries, | |
dn_number + self.num_queries, | |
device=device, | |
dtype=torch.bool) | |
attn_mask[:, :, dn_number:, :dn_number] = True | |
for idx, (gt_boxes_i, | |
gt_labels_i) in enumerate(zip(gt_boxes, | |
gt_labels)): # for batch | |
num_gt_i = gt_boxes_i.shape[0] | |
if num_gt_i == 0: | |
continue | |
for matchi in range(dn_number): | |
si = (matchi // num_gt_i) * num_gt_i | |
ei = (matchi // num_gt_i + 1) * num_gt_i | |
if si > 0: | |
attn_mask[idx, :, matchi, :si] = True | |
if ei < dn_number: | |
attn_mask[idx, :, matchi, ei:dn_number] = True | |
attn_mask = attn_mask.flatten(0, 1) | |
if 'group2group' in self.dn_attn_mask_type_list: | |
# self.num_body_points = self.num_body_points +3 | |
inter_body_mask = [] | |
if body_mask: | |
inter_body_mask.append(0) | |
if body_kps_mask: | |
indices = sorted(random.sample(range(1, self.num_body_points+1), k=6)) | |
inter_body_mask.extend(indices) | |
if lhand_mask: | |
inter_body_mask.append(self.num_body_points+1) | |
if rhand_mask: | |
inter_body_mask.append(self.num_body_points+2) | |
if face_mask: | |
inter_body_mask.append(self.num_body_points+3) | |
num_points = self.num_body_points + 4 | |
attn_mask2 = torch.zeros( | |
bs, | |
self.nheads, | |
dn_number + self.num_group * num_points, | |
dn_number + self.num_group * num_points, | |
device=device, | |
dtype=torch.bool) | |
attn_mask2[:, :, dn_number:, :dn_number] = True | |
group_bbox_kpt = num_points | |
# group_nobbox_kpt = self.num_body_points | |
kpt_index = [x for x in range(self.num_group * num_points) | |
if x % num_points in [ | |
0, self.num_body_points+1, self.num_body_points+2, self.num_body_points+3]] | |
for matchj in range(self.num_group * num_points): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
if sj > 0: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
if ej < self.num_group * num_points: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
if (matchj // group_bbox_kpt) == 0: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, sj:ej][..., inter_body_mask] = True | |
for match_x in range(self.num_group * num_points): | |
if match_x % group_bbox_kpt == 0 and body_mask != False: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, ::num_points]=False | |
if match_x % group_bbox_kpt == self.num_body_points + 1 and lhand_mask != False: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 1::num_points]=False | |
if match_x % group_bbox_kpt == self.num_body_points + 2 and rhand_mask != False: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 2::num_points]=False | |
if match_x % group_bbox_kpt == self.num_body_points + 3 and face_mask != False: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 3::num_points]=False | |
# if match_x % group_bbox_kpt in [0, | |
# self.num_body_points+1, | |
# self.num_body_points+2, | |
# self.num_body_points+3]: | |
# attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False | |
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
num_gt_i = gt_boxes_i.shape[0] | |
if num_gt_i == 0: | |
continue | |
for matchi in range(dn_number): | |
si = (matchi // num_gt_i) * num_gt_i | |
ei = (matchi // num_gt_i + 1) * num_gt_i | |
if si > 0: | |
attn_mask2[idx, :, matchi, :si] = True | |
if ei < dn_number: | |
attn_mask2[idx, :, matchi, ei:dn_number] = True | |
attn_mask2 = attn_mask2.flatten(0, 1) | |
if 'group2group' in self.dn_attn_mask_type_list: | |
inter_body_mask = [] | |
if body_mask: | |
inter_body_mask.append(0) | |
if body_kps_mask: | |
indices = sorted(random.sample(range(1, self.num_body_points+1), k=6)) | |
inter_body_mask.extend(indices) | |
if lhand_mask: | |
inter_body_mask.append(self.num_body_points+1) | |
if lhand_kps_mask: | |
indices = sorted(random.sample(range(self.num_body_points+2, self.num_body_points+8), k=3)) | |
inter_body_mask.extend(indices) | |
if rhand_mask: | |
inter_body_mask.append(self.num_body_points+8) | |
if rhand_kps_mask: | |
indices = sorted(random.sample(range(self.num_body_points+9, self.num_body_points+15), k=3)) | |
inter_body_mask.extend(indices) | |
if face_mask: | |
inter_body_mask.append(self.num_body_points+15) | |
if face_kps_mask: | |
indices = sorted(random.sample(range(self.num_body_points+16, self.num_body_points+22), k=3) ) | |
inter_body_mask.extend(indices) | |
# self.num_body_points = self.num_body_points +3 | |
num_points = self.num_whole_body_points + 4 | |
attn_mask3 = torch.zeros( | |
bs, | |
self.nheads, | |
dn_number + self.num_group * (num_points), dn_number + self.num_group * (num_points), | |
device=device, dtype=torch.bool) | |
attn_mask3[:, :, dn_number:, :dn_number] = True | |
group_bbox_kpt = (num_points) | |
# group_nobbox_kpt = self.num_body_points | |
kpt_index = [ | |
x for x in range(self.num_group * (num_points)) if x % (num_points) in | |
[0, | |
1+self.num_body_points, | |
2+self.num_body_points+self.num_hand_points, | |
3+self.num_body_points+self.num_hand_points*2 | |
] | |
] | |
for matchj in range(self.num_group * num_points): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
if sj > 0: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
if ej < self.num_group * num_points: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
if (matchj // group_bbox_kpt) == 0: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, sj:ej][..., inter_body_mask] = True | |
for match_x in range(self.num_group * num_points): | |
if match_x % group_bbox_kpt == 0 and body_mask != False: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, ::num_points]=False | |
if match_x % group_bbox_kpt == 1 + self.num_body_points and lhand_mask != False: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 1::num_points]=False | |
if match_x % group_bbox_kpt == 2 + self.num_body_points + self.num_hand_points and rhand_mask != False: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 2::num_points]=False | |
if match_x % group_bbox_kpt == 3 + self.num_body_points + self.num_hand_points * 2 and face_mask != False: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 3::num_points]=False | |
# if match_x % group_bbox_kpt in [0, | |
# 1 + self.num_body_points, | |
# 2 + self.num_body_points + self.num_hand_points, | |
# 3 + self.num_body_points + self.num_hand_points * 2]: | |
# attn_mask3[:, :, dn_number:, dn_number:][:,:,match_x,kpt_index]=False | |
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
num_gt_i = gt_boxes_i.shape[0] | |
if num_gt_i == 0: | |
continue | |
for matchi in range(dn_number): | |
si = (matchi // num_gt_i) * num_gt_i | |
ei = (matchi // num_gt_i + 1) * num_gt_i | |
if si > 0: | |
attn_mask3[idx, :, matchi, :si] = True | |
if ei < dn_number: | |
attn_mask3[idx, :, matchi, ei:dn_number] = True | |
attn_mask3 = attn_mask3.flatten(0, 1) | |
mask_dict = { | |
'pad_size': dn_number, | |
'known_bboxs': gt_boxes_expand, | |
'known_labels': gt_labels_expand, | |
'known_keypoints': gt_keypoints_expand | |
} | |
return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict | |
def dn_post_process2(self, outputs_class, outputs_coord, | |
outputs_body_keypoints_list, mask_dict): | |
if mask_dict and mask_dict['pad_size'] > 0: | |
output_known_class = [ | |
outputs_class_i[:, :mask_dict['pad_size'], :] | |
for outputs_class_i in outputs_class | |
] | |
output_known_coord = [ | |
outputs_coord_i[:, :mask_dict['pad_size'], :] | |
for outputs_coord_i in outputs_coord | |
] | |
outputs_class = [ | |
outputs_class_i[:, mask_dict['pad_size']:, :] | |
for outputs_class_i in outputs_class | |
] | |
outputs_coord = [ | |
outputs_coord_i[:, mask_dict['pad_size']:, :] | |
for outputs_coord_i in outputs_coord | |
] | |
outputs_keypoint = outputs_body_keypoints_list | |
mask_dict.update({ | |
'output_known_coord': output_known_coord, | |
'output_known_class': output_known_class | |
}) | |
return outputs_class, outputs_coord, outputs_keypoint | |
def forward(self, data_batch: NestedTensor, targets: List = None): | |
"""The forward expects a NestedTensor, which consists of: | |
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] | |
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
It returns a dict with the following elements: | |
- "pred_logits": the classification logits (including no-object) for all queries. | |
Shape= [batch_size x num_queries x num_classes] | |
- "pred_boxes": The normalized boxes coordinates for all queries, represented as | |
(center_x, center_y, width, height). These values are normalized in [0, 1], | |
relative to the size of each individual image (disregarding possible padding). | |
See PostProcess for information on how to retrieve the unnormalized bounding box. | |
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
dictionnaries containing the two above keys for each decoder layer. | |
""" | |
if isinstance(data_batch, dict): | |
samples, targets = self.prepare_targets(data_batch) | |
# import pdb; pdb.set_trace() | |
elif isinstance(data_batch, (list, torch.Tensor)): | |
samples = nested_tensor_from_tensor_list(data_batch) | |
else: | |
samples = data_batch | |
# print(samples.data['img'].shape) | |
# exit() | |
features, poss = self.backbone(samples) | |
srcs = [] | |
masks = [] | |
for l, feat in enumerate(features): # len(features=3) | |
src, mask = feat.decompose() | |
srcs.append(self.input_proj[l](src)) | |
masks.append(mask) | |
assert mask is not None | |
if self.num_feature_levels > len(srcs): | |
_len_srcs = len(srcs) | |
for l in range(_len_srcs, self.num_feature_levels): | |
if l == _len_srcs: | |
src = self.input_proj[l](features[-1].tensors) | |
else: | |
src = self.input_proj[l](srcs[-1]) | |
m = samples.mask | |
mask = F.interpolate(m[None].float(), | |
size=src.shape[-2:]).to(torch.bool)[0] | |
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
srcs.append(src) | |
masks.append(mask) | |
poss.append(pos_l) | |
if self.dn_number > 0 or targets is not None: | |
input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\ | |
self.prepare_for_dn2(targets) | |
else: | |
assert targets is None | |
input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None | |
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( | |
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, | |
attn_mask2, attn_mask3) | |
# update human boxes | |
effective_dn_number = self.dn_number if self.training else 0 | |
outputs_body_bbox_list = [] | |
outputs_class = [] | |
for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed, | |
layer_hs) in enumerate( | |
zip(reference[:-1], self.bbox_embed, | |
self.class_embed, hs)): | |
if dec_lid < self.num_box_decoder_layers: | |
# human det | |
layer_delta_unsig = layer_body_bbox_embed(layer_hs) | |
layer_body_box_outputs_unsig = \ | |
layer_delta_unsig + inverse_sigmoid(layer_ref_sig) | |
layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid() | |
layer_cls = layer_cls_embed(layer_hs) | |
# import mmcv | |
# import cv2 | |
# img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy() | |
# bbox = (box_ops.box_cxcywh_to_xyxy(layer_body_box_outputs_unsig[0][0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
# img = mmcv.imshow_bboxes(img.copy(), bbox, show=False) | |
# cv2.imwrite('test.png',img) | |
outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
outputs_class.append(layer_cls) | |
elif dec_lid < self.num_box_decoder_layers + 2: | |
bs = layer_ref_sig.shape[0] | |
# dn body bbox | |
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
# norm body bbox | |
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
:, 0::(self.num_body_points + 4), :] # norm content query | |
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][ | |
:, 0::(self.num_body_points+ 4), :] # norm position query | |
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
layer_body_box_outputs_unsig = torch.cat( | |
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
# classfication | |
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
outputs_class.append(layer_cls) | |
outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
else: | |
bs = layer_ref_sig.shape[0] | |
# dn body bbox | |
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
# norm body bbox | |
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
:, 0::(self.num_whole_body_points + 4), :] # norm content query | |
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][ | |
:, 0::(self.num_whole_body_points + 4), :] # norm position query | |
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
layer_body_box_outputs_unsig = torch.cat( | |
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
# classfication | |
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
outputs_class.append(layer_cls) | |
outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
# 找query | |
q_index = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[1] | |
q_value = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[0] | |
# update hand and face boxes | |
outputs_lhand_bbox_list = [] | |
outputs_rhand_bbox_list = [] | |
outputs_face_bbox_list = [] | |
# update keypoints boxes | |
outputs_body_keypoints_list = [] | |
outputs_body_keypoints_hw = [] | |
outputs_lhand_keypoints_list = [] | |
outputs_lhand_keypoints_hw = [] | |
outputs_rhand_keypoints_list = [] | |
outputs_rhand_keypoints_hw = [] | |
outputs_face_keypoints_list = [] | |
outputs_face_keypoints_hw = [] | |
outputs_smpl_pose_list = [] | |
outputs_smpl_lhand_pose_list = [] | |
outputs_smpl_rhand_pose_list = [] | |
outputs_smpl_expr_list = [] | |
outputs_smpl_jaw_pose_list = [] | |
outputs_smpl_beta_list = [] | |
outputs_smpl_cam_list = [] | |
# outputs_smpl_cam_f_list = [] | |
outputs_smpl_kp2d_list = [] | |
outputs_smpl_kp3d_list = [] | |
outputs_smpl_verts_list = [] | |
body_kpt_index = [ | |
x for x in range(self.num_group * (self.num_body_points + 4)) | |
if x % (self.num_body_points + 4) in range(1,self.num_body_points+1) | |
] | |
body_kpt_index_2 = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in range(1,self.num_body_points+1)) | |
] | |
lhand_bbox_index = [ | |
x for x in range(self.num_group * (self.num_body_points + 4)) | |
if x % (self.num_body_points + 4) != 1 | |
] | |
lhand_kpt_index = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in range( | |
self.num_body_points+2, | |
self.num_body_points+self.num_hand_points+2))] | |
rhand_bbox_index = [ | |
x for x in range(self.num_group * (self.num_body_points + 4)) | |
if x % (self.num_body_points + 4) != 2 | |
] | |
rhand_kpt_index = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in range( | |
self.num_body_points+self.num_hand_points+3, | |
self.num_body_points+self.num_hand_points*2+3)) | |
] | |
face_bbox_index = [ | |
x for x in range(self.num_group * (self.num_body_points + 4)) | |
if x % (self.num_body_points + 4) != 3 | |
] | |
face_kpt_index = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in range( | |
self.num_body_points+self.num_hand_points*2+4, | |
self.num_body_points+self.num_hand_points*2+self.num_face_points+4)) | |
] | |
# smpl pose | |
# body box, kps, lhand box | |
body_index = list(range(0,self.num_body_points+2)) | |
# rhand box and face box | |
body_index.extend( | |
[self.num_body_points + self.num_hand_points + 2, self.num_body_points + 2 * self.num_hand_points + 3] | |
) | |
smpl_pose_index = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in body_index) | |
] | |
# smpl lhand | |
lhand_index = list(range(self.num_body_points+1, self.num_body_points+self.num_hand_points+3)) | |
# body box | |
lhand_index.insert(0, 0) | |
smpl_lhand_pose_index = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in lhand_index)] | |
# smpl rhand | |
rhand_index = list(range(self.num_body_points + self.num_hand_points + 2, self.num_body_points + self.num_hand_points * 2 +3)) | |
rhand_index.insert(0,self.num_body_points+1) | |
rhand_index.insert(0,0) | |
smpl_rhand_pose_index = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in rhand_index)] | |
# smpl face | |
face_index = list(range(self.num_body_points + self.num_hand_points * 2 + 3, self.num_body_points + self.num_hand_points * 2 + self.num_face_points + 4)) | |
face_index.insert(0,0) | |
smpl_face_pose_index = [ | |
x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
if (x % (self.num_whole_body_points + 4) in face_index)] | |
for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): | |
if dec_lid < self.num_box_decoder_layers: | |
assert isinstance(layer_hs, torch.Tensor) | |
bs = layer_hs.shape[0] | |
layer_body_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_body_points * 3)) # [-, 900, 42] | |
outputs_body_keypoints_list.append(layer_body_kps_res) | |
# lhand | |
layer_lhand_bbox_res = layer_hs.new_zeros( | |
(bs, self.num_queries, 4)) # [-, 900, 42] | |
outputs_lhand_bbox_list.append(layer_lhand_bbox_res) | |
layer_lhand_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_hand_points * 3)) # [-, 900, 42] | |
outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
# rhand | |
layer_rhand_bbox_res = layer_hs.new_zeros( | |
(bs, self.num_queries, 4)) # [-, 900, 42] | |
outputs_rhand_bbox_list.append(layer_rhand_bbox_res) | |
layer_rhand_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_hand_points * 3)) # [-, 900, 42] | |
outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
# face | |
layer_face_bbox_res = layer_hs.new_zeros( | |
(bs, self.num_queries, 4)) # [-, 900, 42] | |
outputs_face_bbox_list.append(layer_face_bbox_res) | |
layer_face_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_face_points * 3)) # [-, 900, 42] | |
outputs_face_keypoints_list.append(layer_face_kps_res) | |
# smpl or smplx | |
smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3)) | |
smpl_rhand_pose = layer_hs.new_zeros( | |
(bs, self.num_queries, 15 * 3)) | |
smpl_lhand_pose = layer_hs.new_zeros( | |
(bs, self.num_queries, 15 * 3)) | |
smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6)) | |
smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3)) | |
# smpl_cam_f = layer_hs.new_zeros((bs, self.num_queries, 1)) | |
# smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3)) | |
smpl_kp3d = layer_hs.new_zeros( | |
(bs, self.num_queries, self.num_body_points, 4)) | |
outputs_smpl_pose_list.append(smpl_pose) | |
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
outputs_smpl_expr_list.append(smpl_expr) | |
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
outputs_smpl_beta_list.append(smpl_beta) | |
outputs_smpl_cam_list.append(smpl_cam) | |
# outputs_smpl_cam_f_list.append(smpl_cam_f) | |
# outputs_smpl_kp2d_list.append(smpl_kp2d) | |
outputs_smpl_kp3d_list.append(smpl_kp3d) | |
elif dec_lid < self.num_box_decoder_layers +2: | |
bs = layer_ref_sig.shape[0] | |
layer_hs_body_kpt = \ | |
layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(body_kpt_index, device=layer_hs.device)) | |
# body kp2d | |
delta_body_kp_xy_unsig = \ | |
self.pose_embed[dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt) | |
layer_ref_sig_body_kpt = \ | |
layer_ref_sig[:,effective_dn_number:, :].index_select(1,torch.tensor(body_kpt_index,device=layer_hs.device)) | |
layer_outputs_unsig_body_keypoints = delta_body_kp_xy_unsig + inverse_sigmoid( | |
layer_ref_sig_body_kpt[..., :2]) | |
vis_xy_unsig = torch.ones_like( | |
layer_outputs_unsig_body_keypoints, | |
device=layer_outputs_unsig_body_keypoints.device) | |
xyv = torch.cat((layer_outputs_unsig_body_keypoints, | |
vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
dim=-1) | |
xyv = xyv.sigmoid() | |
# from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
# img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy() | |
# gt_kp2d = xyv[0,:17] | |
# coco_kps = gt_kp2d[:,:2].reshape(17,2).detach().cpu().numpy() * data_batch['img_shape'].cpu().numpy()[0,None,None,::-1] | |
# visualize_kp2d( | |
# coco_kps, | |
# output_path='.', | |
# image_array=img.copy()[None], | |
# data_source='coco', | |
# overwrite=True) | |
layer_res = xyv.reshape( | |
(bs, self.num_group, self.num_body_points, | |
3)).flatten(2, 3) | |
layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape( | |
bs, self.num_group, self.num_body_points, 2).flatten(2, 3) | |
layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) | |
outputs_body_keypoints_list.append(layer_res) | |
outputs_body_keypoints_hw.append(layer_hw) | |
# lhand bbox | |
layer_hs_lhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][:, (self.num_body_points + 1)::(self.num_body_points + 4), :] | |
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_sig_lhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][ | |
:, (self.num_body_points + 1)::(self.num_body_points + 4), :].clone() | |
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
layer_lhand_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_hand_points * 3)) # [-, 900, 42] | |
outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
# rhand bbox | |
layer_hs_rhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + 2)::(self.num_body_points + 4), :] | |
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_sig_rhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][ | |
:, (self.num_body_points + 2)::(self.num_body_points + 4), :].clone() | |
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
# rhand kps | |
layer_rhand_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_hand_points * 3)) # [-, 900, 42] | |
outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
# face bbox | |
layer_hs_face_bbox = \ | |
layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + 3)::(self.num_body_points + 4), :] | |
delta_face_bbox_xy_unsig = self.bbox_face_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_sig_face_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][ | |
:, (self.num_body_points + 3)::(self.num_body_points + 4), :].clone() | |
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
# face kps | |
layer_face_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_face_points * 3)) # [-, 900, 42] | |
outputs_face_keypoints_list.append(layer_face_kps_res) | |
# smpl or smplx | |
bs, _, feat_dim = layer_hs.shape | |
smpl_feats = layer_hs[:, effective_dn_number:, :].reshape( | |
bs, -1, feat_dim * (self.num_body_points + 4)) | |
smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + 1):: (self.num_body_points + 4), :].reshape( | |
bs, -1, feat_dim) | |
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + 2):: (self.num_body_points + 4), :].reshape( | |
bs, -1, feat_dim) | |
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + 3):: (self.num_body_points + 4), :].reshape( | |
bs, -1, feat_dim) | |
smpl_pose = self.smpl_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_feats) | |
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, self.body_model_joint_num, 3, 3) | |
smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_jaw_pose = self.smpl_jaw_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, 1, 3, 3) | |
smpl_beta = self.smpl_beta_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_feats) | |
smpl_cam = self.smpl_cam_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_feats) | |
# smpl_cam_f = self.smpl_cam_f_embed[ | |
# dec_lid - self.num_box_decoder_layers](smpl_feats) | |
# zero | |
# smpl_lhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3) | |
# smpl_rhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3) | |
# smpl_expr = layer_hs.new_zeros(bs, self.num_group, 10) | |
smpl_expr = self.smpl_expr_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
# smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3) | |
leye_pose = torch.zeros_like(smpl_jaw_pose) | |
reye_pose = torch.zeros_like(smpl_jaw_pose) | |
if self.body_model is not None: | |
smpl_pose_ = rotmat_to_aa(smpl_pose) | |
# smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
# smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
leye_pose_ = rotmat_to_aa(leye_pose) | |
reye_pose_ = rotmat_to_aa(reye_pose) | |
pred_output = self.body_model( | |
betas=smpl_beta.reshape(-1, 10), | |
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
global_orient=smpl_pose_[:, :, 0].reshape( | |
-1, 3).unsqueeze(1), | |
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
leye_pose=leye_pose_, | |
reye_pose=reye_pose_, | |
jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
# expression=smpl_expr.reshape(-1, 10), | |
expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10) | |
) | |
smpl_kp3d = pred_output['joints'].reshape( | |
bs, self.num_group, -1, 3) | |
smpl_verts = pred_output['vertices'].reshape( | |
bs, self.num_group, -1, 3) | |
# pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) | |
outputs_smpl_pose_list.append(smpl_pose) | |
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
outputs_smpl_expr_list.append(smpl_expr) | |
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
outputs_smpl_beta_list.append(smpl_beta) | |
outputs_smpl_cam_list.append(smpl_cam) | |
# outputs_smpl_cam_f_list.append(smpl_cam_f) | |
outputs_smpl_kp3d_list.append(smpl_kp3d) | |
else: | |
bs = layer_ref_sig.shape[0] | |
layer_hs_body_kpt = \ | |
layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(body_kpt_index_2, device=layer_hs.device)) | |
# body kp2d | |
delta_body_kp_xy_unsig = \ | |
self.pose_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt) | |
layer_ref_sig_body_kpt = \ | |
layer_ref_sig[:,effective_dn_number:, :].index_select( | |
1,torch.tensor(body_kpt_index_2,device=layer_hs.device)) | |
layer_outputs_unsig_body_keypoints = \ | |
delta_body_kp_xy_unsig + inverse_sigmoid( | |
layer_ref_sig_body_kpt[..., :2]) | |
vis_xy_unsig = torch.ones_like( | |
layer_outputs_unsig_body_keypoints, | |
device=layer_outputs_unsig_body_keypoints.device) | |
xyv = torch.cat((layer_outputs_unsig_body_keypoints, | |
vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
dim=-1) | |
xyv = xyv.sigmoid() | |
layer_res = xyv.reshape( | |
(bs, self.num_group, self.num_body_points, | |
3)).flatten(2, 3) | |
layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape( | |
bs, self.num_group, self.num_body_points, 2).flatten(2, 3) | |
layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) | |
outputs_body_keypoints_list.append(layer_res) | |
outputs_body_keypoints_hw.append(layer_hw) | |
# lhand bbox | |
layer_hs_lhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :] | |
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_sig_lhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][ | |
:, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :].clone() | |
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
# lhand kps | |
layer_hs_lhand_kps_res = \ | |
layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(lhand_kpt_index, device=layer_hs.device)) | |
delta_lhand_kp_xy_unsig = \ | |
self.pose_hand_embed[ | |
dec_lid - self.num_hand_face_decoder_layers](layer_hs_lhand_kps_res) | |
layer_ref_sig_lhand_kpt = \ | |
layer_ref_sig[:,effective_dn_number:, :].index_select( | |
1,torch.tensor(lhand_kpt_index,device=layer_hs.device)) | |
layer_outputs_unsig_lhand_keypoints = delta_lhand_kp_xy_unsig + inverse_sigmoid( | |
layer_ref_sig_lhand_kpt[..., :2]) | |
lhand_vis_xy_unsig = torch.ones_like( | |
layer_outputs_unsig_lhand_keypoints, | |
device=layer_outputs_unsig_lhand_keypoints.device) | |
lhand_xyv = torch.cat((layer_outputs_unsig_lhand_keypoints, | |
lhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
dim=-1) | |
lhand_xyv = lhand_xyv.sigmoid() | |
layer_lhand_kps_res = lhand_xyv.reshape( | |
(bs, self.num_group, self.num_hand_points, | |
3)).flatten(2, 3) | |
layer_lhand_hw = layer_ref_sig_lhand_kpt[..., 2:].reshape( | |
bs, self.num_group, self.num_hand_points, 2).flatten(2, 3) | |
layer_lhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_lhand_kps_res) | |
outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
outputs_lhand_keypoints_hw.append(layer_lhand_hw) | |
# rhand bbox | |
layer_hs_rhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :] | |
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_sig_rhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][ | |
:, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :].clone() | |
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
# rhand kps | |
layer_hs_rhand_kps_res = \ | |
layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(rhand_kpt_index, device=layer_hs.device)) | |
delta_rhand_kp_xy_unsig = \ | |
self.pose_hand_embed[ | |
dec_lid - self.num_hand_face_decoder_layers](layer_hs_rhand_kps_res) | |
layer_ref_sig_rhand_kpt = \ | |
layer_ref_sig[:,effective_dn_number:, :].index_select( | |
1,torch.tensor(rhand_kpt_index,device=layer_hs.device)) | |
layer_outputs_unsig_rhand_keypoints = delta_rhand_kp_xy_unsig + inverse_sigmoid( | |
layer_ref_sig_rhand_kpt[..., :2]) | |
rhand_vis_xy_unsig = torch.ones_like( | |
layer_outputs_unsig_rhand_keypoints, | |
device=layer_outputs_unsig_rhand_keypoints.device) | |
rhand_xyv = torch.cat((layer_outputs_unsig_rhand_keypoints, | |
rhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
dim=-1) | |
rhand_xyv = rhand_xyv.sigmoid() | |
layer_rhand_kps_res = rhand_xyv.reshape( | |
(bs, self.num_group, self.num_hand_points, | |
3)).flatten(2, 3) | |
layer_rhand_hw = layer_ref_sig_rhand_kpt[..., 2:].reshape( | |
bs, self.num_group, self.num_hand_points, 2).flatten(2, 3) | |
layer_rhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_rhand_kps_res) | |
outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
outputs_rhand_keypoints_hw.append(layer_rhand_hw) | |
# face bbox | |
layer_hs_face_bbox = \ | |
layer_hs[:, effective_dn_number:, :][ | |
:, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :] | |
delta_face_bbox_xy_unsig = self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_sig_face_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][ | |
:, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :].clone() | |
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
# face kps | |
layer_hs_face_kps_res = \ | |
layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(face_kpt_index, device=layer_hs.device)) | |
delta_face_kp_xy_unsig = \ | |
self.pose_face_embed[ | |
dec_lid - self.num_hand_face_decoder_layers](layer_hs_face_kps_res) | |
layer_ref_sig_face_kpt = \ | |
layer_ref_sig[:,effective_dn_number:, :].index_select( | |
1,torch.tensor(face_kpt_index,device=layer_hs.device)) | |
layer_outputs_unsig_face_keypoints = delta_face_kp_xy_unsig + inverse_sigmoid( | |
layer_ref_sig_face_kpt[..., :2]) | |
face_vis_xy_unsig = torch.ones_like( | |
layer_outputs_unsig_face_keypoints, | |
device=layer_outputs_unsig_face_keypoints.device) | |
face_xyv = torch.cat((layer_outputs_unsig_face_keypoints, | |
face_vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
dim=-1) | |
face_xyv = face_xyv.sigmoid() | |
layer_face_kps_res = face_xyv.reshape( | |
(bs, self.num_group, self.num_face_points, | |
3)).flatten(2, 3) | |
layer_face_hw = layer_ref_sig_face_kpt[..., 2:].reshape( | |
bs, self.num_group, self.num_face_points, 2).flatten(2, 3) | |
layer_face_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_face_kps_res) | |
outputs_face_keypoints_list.append(layer_face_kps_res) | |
outputs_face_keypoints_hw.append(layer_face_hw) | |
# pdb.set_trace() | |
bs, _, feat_dim = layer_hs.shape | |
smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * (self.num_body_points + 4)) | |
smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * (self.num_hand_points + 3)) | |
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * (self.num_hand_points + 3)) | |
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * (self.num_face_points + 2)) | |
smpl_pose = self.smpl_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, self.body_model_joint_num, 3, 3) | |
smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_expr = self.smpl_expr_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
smpl_jaw_pose = self.smpl_jaw_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, 1, 3, 3) | |
smpl_beta = self.smpl_beta_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
smpl_cam = self.smpl_cam_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
# smpl_cam_f = self.smpl_cam_f_embed[ | |
# dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
num_samples = smpl_beta.reshape(-1, 10).shape[0] | |
device = smpl_beta.device | |
leye_pose = torch.zeros_like(smpl_jaw_pose) | |
reye_pose = torch.zeros_like(smpl_jaw_pose) | |
if self.body_model is not None: | |
# print(smpl_pose) | |
# exit() | |
smpl_pose_ = rotmat_to_aa(smpl_pose) | |
smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
leye_pose_ = rotmat_to_aa(leye_pose) | |
reye_pose_ = rotmat_to_aa(reye_pose) | |
pred_output = self.body_model( | |
betas=smpl_beta.reshape(-1, 10), | |
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
global_orient=smpl_pose_[:, :, 0].reshape( | |
-1, 3).unsqueeze(1), | |
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
leye_pose=leye_pose_, | |
reye_pose=reye_pose_, | |
jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
expression=smpl_expr.reshape(-1, 10), | |
# expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10), | |
) | |
smpl_kp3d = pred_output['joints'].reshape( | |
bs, self.num_group, -1, 3) | |
smpl_verts = pred_output['vertices'].reshape( | |
bs, self.num_group, -1, 3) | |
# pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) | |
# from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d | |
# visualize_kp3d(smpl_kp3d[0,:100].detach().cpu().numpy(), | |
# output_path='./figs/pred3d', | |
# data_source='smplx_137') | |
# import numpy as np | |
# from pytorch3d.io import save_obj | |
# save_obj( | |
# '1.obj', | |
# torch.tensor(pred_output['vertices'][0]), | |
# torch.tensor(self.body_model.faces.astype(np.float))) | |
# exit() | |
outputs_smpl_pose_list.append(smpl_pose) | |
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
outputs_smpl_expr_list.append(smpl_expr) | |
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
outputs_smpl_beta_list.append(smpl_beta) | |
outputs_smpl_cam_list.append(smpl_cam) | |
# outputs_smpl_cam_f_list.append(smpl_cam_f) | |
outputs_smpl_kp3d_list.append(smpl_kp3d) | |
if not self.training: | |
outputs_smpl_verts_list.append(smpl_verts) | |
dn_mask_dict = mask_dict | |
if self.dn_number > 0 and dn_mask_dict is not None: | |
outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list = self.dn_post_process2( | |
outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list, | |
dn_mask_dict) | |
dn_class_input = dn_mask_dict['known_labels'] | |
dn_bbox_input = dn_mask_dict['known_bboxs'] | |
dn_class_pred = dn_mask_dict['output_known_class'] | |
dn_bbox_pred = dn_mask_dict['output_known_coord'] | |
for idx, (_out_class, _out_bbox, _out_keypoint) in enumerate( | |
zip(outputs_class, outputs_body_bbox_list, | |
outputs_body_keypoints_list)): | |
assert _out_class.shape[1] == _out_bbox.shape[ | |
1] == _out_keypoint.shape[1] | |
out = { | |
'pred_logits': outputs_class[-1], | |
'pred_boxes': outputs_body_bbox_list[-1], | |
'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
'pred_face_boxes': outputs_face_bbox_list[-1], | |
'pred_keypoints': outputs_body_keypoints_list[-1], | |
'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1], | |
'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1], | |
'pred_face_keypoints': outputs_face_keypoints_list[-1], | |
'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
# 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1], | |
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1] | |
} | |
if not self.training: | |
full_pose = torch.cat((outputs_smpl_pose_list[-1], | |
outputs_smpl_lhand_pose_list[-1], | |
outputs_smpl_rhand_pose_list[-1], | |
outputs_smpl_jaw_pose_list[-1]),dim=2) | |
bs,num_q,_,_,_ = full_pose.shape | |
full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3) | |
out = { | |
'pred_logits': outputs_class[-1], | |
'pred_boxes': outputs_body_bbox_list[-1], | |
'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
'pred_face_boxes': outputs_face_bbox_list[-1], | |
'pred_keypoints': outputs_body_keypoints_list[-1], | |
'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1], | |
'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1], | |
'pred_face_keypoints': outputs_face_keypoints_list[-1], | |
'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
# 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1], | |
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1], | |
'pred_smpl_verts': outputs_smpl_verts_list[-1], | |
'pred_smpl_fullpose': full_pose | |
} | |
if self.dn_number > 0 and dn_mask_dict is not None: | |
out.update({ | |
'dn_class_input': dn_class_input, | |
'dn_bbox_input': dn_bbox_input, | |
'dn_class_pred': dn_class_pred[-1], | |
'dn_bbox_pred': dn_bbox_pred[-1], | |
'num_tgt': dn_mask_dict['pad_size'] | |
}) | |
if self.aux_loss: | |
out['aux_outputs'] = \ | |
self._set_aux_loss( | |
outputs_class, | |
outputs_body_bbox_list, | |
outputs_lhand_bbox_list, | |
outputs_rhand_bbox_list, | |
outputs_face_bbox_list, | |
outputs_body_keypoints_list, | |
outputs_lhand_keypoints_list, | |
outputs_rhand_keypoints_list, | |
outputs_face_keypoints_list, | |
outputs_smpl_pose_list, | |
outputs_smpl_rhand_pose_list, | |
outputs_smpl_lhand_pose_list, | |
outputs_smpl_jaw_pose_list, | |
outputs_smpl_expr_list, | |
outputs_smpl_beta_list, | |
outputs_smpl_cam_list, | |
# outputs_smpl_cam_f_list, | |
outputs_smpl_kp3d_list | |
) # with key pred_logits, pred_bbox, pred_keypoints | |
if self.dn_number > 0 and dn_mask_dict is not None: | |
assert len(dn_class_pred[:-1]) == len( | |
dn_bbox_pred[:-1]) == len(out['aux_outputs']) | |
for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip( | |
out['aux_outputs'], dn_class_pred, dn_bbox_pred): | |
aux_out.update({ | |
'dn_class_input': dn_class_input, | |
'dn_bbox_input': dn_bbox_input, | |
'dn_class_pred': dn_class_pred_i, | |
'dn_bbox_pred': dn_bbox_pred_i, | |
'num_tgt': dn_mask_dict['pad_size'] | |
}) | |
# for encoder output | |
if hs_enc is not None: | |
interm_coord = ref_enc[-1] | |
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1]) | |
interm_pose = torch.zeros_like(outputs_body_keypoints_list[0]) | |
out['interm_outputs'] = { | |
'pred_logits': interm_class, | |
'pred_boxes': interm_coord, | |
'pred_keypoints': interm_pose | |
} | |
return out, targets, data_batch | |
def _set_aux_loss(self, | |
outputs_class, | |
outputs_body_coord, | |
outputs_lhand_coord, | |
outputs_rhand_coord, | |
outputs_face_coord, | |
outputs_body_keypoints, | |
outputs_lhand_keypoints, | |
outputs_rhand_keypoints, | |
outputs_face_keypoints, | |
outputs_smpl_pose, | |
outputs_smpl_rhand_pose, | |
outputs_smpl_lhand_pose, | |
outputs_smpl_jaw_pose, | |
outputs_smpl_expr, | |
outputs_smpl_beta, | |
outputs_smpl_cam, | |
# outputs_smpl_cam_f, | |
outputs_smpl_kp3d): | |
return [{ | |
'pred_logits': a, | |
'pred_boxes': b, | |
'pred_lhand_boxes': c, | |
'pred_rhand_boxes': d, | |
'pred_face_boxes': e, | |
'pred_keypoints': f, | |
'pred_lhand_keypoints': g, | |
'pred_rhand_keypoints': h, | |
'pred_face_keypoints': i, | |
'pred_smpl_pose': j, | |
'pred_smpl_rhand_pose': k, | |
'pred_smpl_lhand_pose': l, | |
'pred_smpl_jaw_pose': m, | |
'pred_smpl_expr': n, | |
'pred_smpl_beta': o, | |
'pred_smpl_cam': p, | |
# 'pred_smpl_cam_f': q, | |
'pred_smpl_kp3d': q | |
} for a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q in zip( | |
outputs_class[:-1], | |
outputs_body_coord[:-1], | |
outputs_lhand_coord[:-1], | |
outputs_rhand_coord[:-1], | |
outputs_face_coord[:-1], | |
outputs_body_keypoints[:-1], | |
outputs_lhand_keypoints[:-1], | |
outputs_rhand_keypoints[:-1], | |
outputs_face_keypoints[:-1], | |
outputs_smpl_pose[:-1], | |
outputs_smpl_rhand_pose[:-1], | |
outputs_smpl_lhand_pose[:-1], | |
outputs_smpl_jaw_pose[:-1], | |
outputs_smpl_expr[:-1], | |
outputs_smpl_beta[:-1], | |
outputs_smpl_cam[:-1], | |
outputs_smpl_kp3d[:-1])] | |
def prepare_targets(self, data_batch): | |
data_batch_coco = [] | |
instance_dict = {} | |
img_list = data_batch['img'].float() | |
# input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape'] | |
batch_size, _, input_img_h, input_img_w = img_list.shape | |
device = img_list.device | |
masks = torch.ones((batch_size, input_img_h, input_img_w), | |
dtype=torch.bool, | |
device=device) | |
# cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape | |
for img_id in range(batch_size): | |
img_h, img_w = data_batch['img_shape'][img_id] | |
masks[img_id, :img_h, :img_w] = 0 | |
if not self.inference: | |
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
data_batch['body_bbox_size'][img_id]],dim=-1) | |
instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\ | |
data_batch['face_bbox_size'][img_id]],dim=-1) | |
instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\ | |
data_batch['lhand_bbox_size'][img_id]],dim=-1) | |
instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\ | |
data_batch['rhand_bbox_size'][img_id]],dim=-1) | |
instance_kp2d = data_batch['joint_img'][img_id].clone().float() | |
instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float() | |
instance_kp2d[:,:,2:] = instance_kp2d_mask | |
body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True) | |
lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True) | |
rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True) | |
face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True) | |
# from util.vis_utils import show_bbox | |
# show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy) | |
body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1) | |
lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1) | |
rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1) | |
face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1) | |
instance_dict = {} | |
instance_dict['boxes'] = instance_body_bbox.float() | |
instance_dict['face_boxes'] = instance_face_bbox.float() | |
instance_dict['lhand_boxes'] = instance_lhand_bbox.float() | |
instance_dict['rhand_boxes'] = instance_rhand_bbox.float() | |
instance_dict['keypoints'] = body_kp2d.float() | |
instance_dict['lhand_keypoints'] = lhand_kp2d.float() | |
instance_dict['rhand_keypoints'] = rhand_kp2d.float() | |
instance_dict['face_keypoints'] = face_kp2d.float() | |
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3] | |
instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3] | |
instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3] | |
instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3] | |
instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0], | |
dtype=torch.long, | |
device=device) | |
data_batch_coco.append(instance_dict) | |
# body_bbox = data_batch['body_bbox'][img_id].clone().float().reshape(-1, 4) | |
# lhand_bbox = data_batch['lhand_bbox'][img_id].clone().float().reshape(-1, 4) | |
# rhand_bbox = data_batch['rhand_bbox'][img_id].clone().float().reshape(-1, 4) | |
# face_bbox = data_batch['face_bbox'][img_id].clone().float().reshape(-1, 4) | |
# vis = False | |
# if vis: | |
# import mmcv | |
# body_bbox[:, 0] *= img_w | |
# body_bbox[:, 1] *= img_h | |
# body_bbox[:, 2] *= img_w | |
# body_bbox[:, 3] *= img_h | |
# img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy() | |
# img = mmcv.imshow_bboxes(img.copy(), face_bbox.cpu().numpy(), show=False) | |
# cv2.imwrite('test.png', img) | |
# instance_kp2d[:,:,0] = instance_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w | |
# instance_kp2d[:,:,1] = instance_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h | |
# from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
# img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy() | |
# img1 = visualize_kp2d(instance_kp2d.cpu().detach().numpy(),image_array=img[None].copy(),return_array=True) | |
# cv2.imwrite('test.png',img1[0]) | |
# lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w | |
# lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h | |
# lhand_kp2d = convert_kps(lhand_kp2d, 'smplx_lhand', 'smplx', approximate=True)[0] | |
else: | |
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
data_batch['body_bbox_size'][img_id]],dim=-1) | |
instance_dict = {} | |
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
instance_dict['boxes'] = instance_body_bbox.float() | |
data_batch_coco.append(instance_dict) | |
input_img = NestedTensor(img_list, masks) | |
return input_img, data_batch_coco | |
def keypoints_to_scaled_bbox_bfh( | |
self, keypoints, occ=None, | |
body_scale=1.0, fh_scale=1.0, | |
convention='smplx'): | |
'''Obtain scaled bbox in xyxy format given keypoints | |
Args: | |
keypoints (np.ndarray): Keypoints | |
scale (float): Bounding Box scale | |
Returns: | |
bbox_xyxy (np.ndarray): Bounding box in xyxy format | |
''' | |
bboxs = [] | |
# supported kps.shape: (1, n, k) or (n, k), k = 2 or 3 | |
if keypoints.ndim == 3: | |
keypoints = keypoints[0] | |
if keypoints.shape[-1] != 2: | |
keypoints = keypoints[:, :2] | |
for body_part in ['body', 'head', 'left_hand', 'right_hand']: | |
if body_part == 'body': | |
scale = body_scale | |
kps = keypoints | |
else: | |
scale = fh_scale | |
kp_id = get_keypoint_idxs_by_part(body_part, convention=convention) | |
kps = keypoints[kp_id] | |
if not occ is None: | |
occ_p = occ[kp_id] | |
if np.sum(occ_p) / len(kp_id) >= 0.1: | |
conf = 0 | |
# print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip') | |
else: | |
# print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}') | |
conf = 1 | |
else: | |
conf = 1 | |
if body_part == 'body': | |
conf = 1 | |
xmin, ymin = np.amin(kps, axis=0) | |
xmax, ymax = np.amax(kps, axis=0) | |
width = (xmax - xmin) * scale | |
height = (ymax - ymin) * scale | |
x_center = 0.5 * (xmax + xmin) | |
y_center = 0.5 * (ymax + ymin) | |
xmin = x_center - 0.5 * width | |
xmax = x_center + 0.5 * width | |
ymin = y_center - 0.5 * height | |
ymax = y_center + 0.5 * height | |
bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32) | |
bboxs.append(bbox) | |
return bboxs | |
def build_aios_smplx(args, cfg): | |
# pdb.set_trace() | |
num_classes = args.num_classes # 2 | |
device = torch.device(args.device) | |
backbone = build_backbone(args) | |
transformer = build_transformer(args) | |
dn_labelbook_size = args.dn_labelbook_size | |
dec_pred_class_embed_share = args.dec_pred_class_embed_share | |
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share | |
if args.eval: | |
body_model = args.body_model_test | |
train = False | |
else: | |
body_model = args.body_model_train | |
train = True | |
model = AiOSSMPLX( | |
backbone, | |
transformer, | |
num_classes=num_classes, # 2 | |
num_queries=args.num_queries, # 900 | |
aux_loss=True, | |
iter_update=True, | |
query_dim=4, | |
random_refpoints_xy=args.random_refpoints_xy, # False | |
fix_refpoints_hw=args.fix_refpoints_hw, # -1 | |
num_feature_levels=args.num_feature_levels, # 4 | |
nheads=args.nheads, # 8 | |
dec_pred_class_embed_share=dec_pred_class_embed_share, # false | |
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False | |
# two stage | |
two_stage_type=args.two_stage_type, | |
# box_share | |
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False | |
two_stage_class_embed_share=args.two_stage_class_embed_share, # False | |
dn_number=args.dn_number if args.use_dn else 0, # 100 | |
dn_box_noise_scale=args.dn_box_noise_scale, # 0.4 | |
dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5 | |
dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false | |
dn_attn_mask_type_list=args.dn_attn_mask_type_list, | |
dn_labelbook_size=dn_labelbook_size, # 100 | |
cls_no_bias=args.cls_no_bias, # False | |
num_group=args.num_group, # 100 | |
num_body_points=args.num_body_points, # 17 | |
num_hand_points=args.num_hand_points, # 17 | |
num_face_points=args.num_face_points, # 17 | |
num_box_decoder_layers=args.num_box_decoder_layers, # 2 | |
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
# smpl_convention=convention | |
body_model=body_model, | |
train=train, | |
inference=args.inference) | |
matcher = build_matcher(args) | |
# prepare weight dict | |
weight_dict = { | |
'loss_ce': args.cls_loss_coef, # 2 | |
# bbox | |
'loss_body_bbox': args.body_bbox_loss_coef, # 5 | |
'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5 | |
'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5 | |
'loss_face_bbox': args.face_bbox_loss_coef, # 5 | |
# bbox giou | |
'loss_body_giou': args.body_giou_loss_coef, # 2 | |
'loss_rhand_giou': args.rhand_giou_loss_coef, # 2 | |
'loss_lhand_giou': args.lhand_giou_loss_coef, # 2 | |
'loss_face_giou': args.face_giou_loss_coef, # 2 | |
# 2d kp | |
'loss_keypoints': args.keypoints_loss_coef, # 10 | |
'loss_rhand_keypoints': args.rhand_keypoints_loss_coef, # 10 | |
'loss_lhand_keypoints': args.lhand_keypoints_loss_coef, # 10 | |
'loss_face_keypoints': args.face_keypoints_loss_coef, # 10 | |
# 2d kp oks | |
'loss_oks': args.oks_loss_coef, # 4 | |
'loss_rhand_oks': args.rhand_oks_loss_coef, # 4 | |
'loss_lhand_oks': args.lhand_oks_loss_coef, # 4 | |
'loss_face_oks': args.face_oks_loss_coef, # 4 | |
# smpl param | |
'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0 | |
'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0 | |
'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0 | |
'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0 | |
'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0 | |
'loss_smpl_beta': args.smpl_beta_loss_coef, # 0 | |
'loss_smpl_expr': args.smpl_expr_loss_coef, | |
# smpl kp3d ra | |
'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0 | |
'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0 | |
'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0 | |
'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0 | |
# smpl kp3d | |
'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0 | |
'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0 | |
'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0 | |
'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0 | |
# smpl kp2d | |
'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0 | |
'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0 | |
'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0 | |
'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0 | |
# smpl kp2d ba | |
'loss_smpl_body_kp2d_ba': args.smpl_body_kp2d_ba_loss_coef, | |
'loss_smpl_face_kp2d_ba': args.smpl_face_kp2d_ba_loss_coef, | |
'loss_smpl_lhand_kp2d_ba': args.smpl_lhand_kp2d_ba_loss_coef, | |
'loss_smpl_rhand_kp2d_ba': args.smpl_rhand_kp2d_ba_loss_coef, | |
} | |
clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) | |
if args.use_dn: | |
weight_dict.update({ | |
'dn_loss_ce': | |
args.dn_label_coef, # 0.3 | |
'dn_loss_bbox': | |
args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5 | |
'dn_loss_giou': | |
args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5 | |
}) | |
clean_weight_dict = copy.deepcopy(weight_dict) | |
if args.aux_loss: | |
aux_weight_dict = {} | |
for i in range(args.dec_layers - 1): # from 0 t 4 # ??? | |
for k, v in clean_weight_dict.items(): | |
if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k): | |
continue | |
if i < args.num_box_decoder_layers and k in [ | |
'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox', | |
'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']: | |
continue | |
if i < args.num_hand_face_decoder_layers and k in [ | |
'loss_rhand_keypoints', 'loss_lhand_keypoints', | |
'loss_face_keypoints', 'loss_rhand_oks', | |
'loss_lhand_oks', 'loss_face_oks']: | |
continue | |
if i < args.num_box_decoder_layers and 'smpl' in k: | |
continue | |
aux_weight_dict.update({k + f'_{i}': v}) | |
weight_dict.update(aux_weight_dict) | |
if args.two_stage_type != 'no': | |
interm_weight_dict = {} | |
try: | |
no_interm_box_loss = args.no_interm_box_loss | |
except: | |
no_interm_box_loss = False | |
_coeff_weight_dict = { | |
'loss_ce': 1.0, | |
# bbox | |
'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
# bbox giou | |
'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0, | |
# 2d kp | |
'loss_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_rhand_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_lhand_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_face_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
# 2d oks | |
'loss_oks': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_rhand_oks': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_lhand_oks': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_face_oks': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl param | |
'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl kp3d ra | |
'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl kp3d | |
'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl kp2d | |
'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl kp2d ba | |
'loss_smpl_body_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_lhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_rhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_face_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
} | |
try: | |
interm_loss_coef = args.interm_loss_coef # 1 | |
except: | |
interm_loss_coef = 1.0 | |
interm_weight_dict.update({ | |
k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] | |
for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k | |
}) | |
weight_dict.update(interm_weight_dict) | |
interm_weight_dict.update({ | |
k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k] | |
for k, v in clean_weight_dict_wo_dn.items() | |
}) # ??? | |
weight_dict.update(interm_weight_dict) | |
losses = cfg.losses | |
if args.dn_number > 0: | |
losses += ['dn_label', 'dn_bbox'] | |
losses += ['matching'] | |
criterion = SetCriterion( | |
num_classes, | |
matcher=matcher, | |
weight_dict=weight_dict, | |
focal_alpha=args.focal_alpha, | |
losses=losses, | |
num_box_decoder_layers=args.num_box_decoder_layers, | |
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
num_body_points=args.num_body_points, | |
num_hand_points=args.num_hand_points, | |
num_face_points=args.num_face_points, | |
) | |
criterion.to(device) | |
if args.inference: | |
postprocessors = { | |
'bbox': | |
PostProcess_SMPLX_Multi_Infer( | |
num_select=args.num_select, | |
nms_iou_threshold=args.nms_iou_threshold, | |
num_body_points=args.num_body_points), | |
} | |
else: | |
postprocessors = { | |
'bbox': | |
PostProcess_SMPLX( | |
num_select=args.num_select, | |
nms_iou_threshold=args.nms_iou_threshold, | |
num_body_points=args.num_body_points), | |
} | |
postprocessors_aios = { | |
'bbox': | |
PostProcess_aios(num_select=args.num_select, | |
nms_iou_threshold=args.nms_iou_threshold, | |
num_body_points=args.num_body_points), | |
} | |
# criterion_smpl=build_architecture(cfg['smpl_loss']) | |
return model, criterion, postprocessors, postprocessors_aios | |
class AiOSSMPLX_Box(nn.Module): | |
def __init__( | |
self, | |
backbone, | |
transformer, | |
num_classes, | |
num_queries, | |
aux_loss=False, | |
iter_update=True, | |
query_dim=4, | |
random_refpoints_xy=False, | |
fix_refpoints_hw=-1, | |
num_feature_levels=1, | |
nheads=8, | |
two_stage_type='no', | |
dec_pred_class_embed_share=False, | |
dec_pred_bbox_embed_share=False, | |
dec_pred_pose_embed_share=False, | |
two_stage_class_embed_share=True, | |
two_stage_bbox_embed_share=True, | |
dn_number=100, | |
dn_box_noise_scale=0.4, | |
dn_label_noise_ratio=0.5, | |
dn_batch_gt_fuse=False, | |
dn_labelbook_size=100, | |
dn_attn_mask_type_list=['group2group'], | |
cls_no_bias=False, | |
num_group=100, | |
num_body_points=0, | |
num_hand_points=0, | |
num_face_points=0, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
body_model=dict( | |
type='smplx', | |
keypoint_src='smplx', | |
num_expression_coeffs=10, | |
keypoint_dst='smplx_137', | |
model_path='data/body_models/smplx', | |
use_pca=False, | |
use_face_contour=True), | |
train=True, | |
inference=False, | |
focal_length=[5000., 5000.], | |
camera_3d_size=2.5 | |
): | |
super().__init__() | |
self.num_queries = num_queries | |
self.transformer = transformer | |
self.num_classes = num_classes | |
self.hidden_dim = hidden_dim = transformer.d_model | |
self.num_feature_levels = num_feature_levels | |
self.nheads = nheads | |
self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) | |
self.num_body_points = num_body_points | |
self.num_hand_points = num_hand_points | |
self.num_face_points = num_face_points | |
self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points | |
self.num_box_decoder_layers = num_box_decoder_layers | |
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.focal_length = focal_length | |
self.camera_3d_size=camera_3d_size | |
self.inference = inference | |
if train: | |
self.smpl_convention = 'smplx' | |
else: | |
self.smpl_convention = 'h36m' | |
# setting query dim | |
self.query_dim = query_dim | |
assert query_dim == 4 | |
self.random_refpoints_xy = random_refpoints_xy # False | |
self.fix_refpoints_hw = fix_refpoints_hw # -1 | |
# for dn training | |
self.dn_number = dn_number | |
self.dn_box_noise_scale = dn_box_noise_scale | |
self.dn_label_noise_ratio = dn_label_noise_ratio | |
self.dn_batch_gt_fuse = dn_batch_gt_fuse | |
self.dn_labelbook_size = dn_labelbook_size | |
self.dn_attn_mask_type_list = dn_attn_mask_type_list | |
assert all([ | |
i in ['match2dn', 'dn2dn', 'group2group'] | |
for i in dn_attn_mask_type_list | |
]) | |
assert not dn_batch_gt_fuse | |
# build human body | |
# if train: | |
# self.body_model = build_body_model(body_model) | |
if inference: | |
body_model=dict( | |
type='smplx', | |
keypoint_src='smplx', | |
num_expression_coeffs=10, | |
num_betas=10, | |
keypoint_dst='smplx', | |
model_path='data/body_models/smplx', | |
use_pca=False, | |
use_face_contour=True) | |
self.body_model = build_body_model(body_model) | |
for param in self.body_model.parameters(): | |
param.requires_grad = False | |
# prepare input projection layers | |
if num_feature_levels > 1: | |
num_backbone_outs = len(backbone.num_channels) # 3 | |
input_proj_list = [] | |
for _ in range(num_backbone_outs): | |
in_channels = backbone.num_channels[_] | |
input_proj_list.append( | |
nn.Sequential( | |
nn.Conv2d(in_channels, hidden_dim, kernel_size=1), | |
nn.GroupNorm(32, hidden_dim), | |
)) | |
for _ in range(num_feature_levels - num_backbone_outs): | |
input_proj_list.append( | |
nn.Sequential( | |
nn.Conv2d(in_channels, | |
hidden_dim, | |
kernel_size=3, | |
stride=2, | |
padding=1), | |
nn.GroupNorm(32, hidden_dim), | |
)) | |
in_channels = hidden_dim | |
self.input_proj = nn.ModuleList(input_proj_list) | |
else: | |
assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!' | |
self.input_proj = nn.ModuleList([ | |
nn.Sequential( | |
nn.Conv2d(backbone.num_channels[-1], | |
hidden_dim, | |
kernel_size=1), | |
nn.GroupNorm(32, hidden_dim), | |
) | |
]) | |
self.backbone = backbone | |
self.aux_loss = aux_loss | |
self.box_pred_damping = box_pred_damping = None | |
self.iter_update = iter_update | |
assert iter_update, 'Why not iter_update?' | |
# prepare pred layers | |
self.dec_pred_class_embed_share = dec_pred_class_embed_share # false | |
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false | |
# 1.1 prepare class & box embed | |
_class_embed = nn.Linear(hidden_dim, | |
num_classes, | |
bias=(not cls_no_bias)) | |
if not cls_no_bias: | |
prior_prob = 0.01 | |
bias_value = -math.log((1 - prior_prob) / prior_prob) | |
_class_embed.bias.data = torch.ones(self.num_classes) * bias_value | |
# 1.2 box embed layer list | |
if dec_pred_class_embed_share: | |
class_embed_layerlist = [ | |
_class_embed for i in range(transformer.num_decoder_layers) | |
] | |
else: | |
class_embed_layerlist = [ | |
copy.deepcopy(_class_embed) | |
for i in range(transformer.num_decoder_layers) | |
] | |
########################################################################### | |
# body bbox + l/r hand box + face box | |
########################################################################### | |
# 1.1 body bbox embed | |
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) | |
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) | |
# 1.2 body bbox embed layer list | |
self.num_group = num_group | |
if dec_pred_bbox_embed_share: | |
box_body_embed_layerlist = [ | |
_bbox_embed for i in range(transformer.num_decoder_layers) | |
] | |
else: | |
box_body_embed_layerlist = [ | |
copy.deepcopy(_bbox_embed) | |
for i in range(transformer.num_decoder_layers) | |
] | |
# 2.1 lhand bbox embed | |
_bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4 | |
nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0) | |
_bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0) | |
# 2.2 lhand bbox embed layer list | |
if dec_pred_pose_embed_share: | |
box_hand_embed_layerlist = \ | |
[_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] | |
else: | |
box_hand_embed_layerlist = [ | |
copy.deepcopy(_bbox_hand_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers + 1) | |
] | |
if dec_pred_pose_embed_share: | |
box_hand_hw_embed_layerlist = [ | |
_bbox_hand_hw_embed for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers) | |
] | |
else: | |
box_hand_hw_embed_layerlist = [ | |
copy.deepcopy(_bbox_hand_hw_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 4.1 face bbox embed | |
_bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0) | |
_bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0) | |
# 4.2 face bbox embed layer list | |
if dec_pred_pose_embed_share: | |
box_face_embed_layerlist = [ | |
_bbox_face_embed for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers + 1) | |
] | |
else: | |
box_face_embed_layerlist = [ | |
copy.deepcopy(_bbox_face_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers + 1) | |
] | |
if dec_pred_pose_embed_share: | |
box_face_hw_embed_layerlist = [ | |
_bbox_face_hw_embed for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers)] | |
else: | |
box_face_hw_embed_layerlist = [ | |
copy.deepcopy(_bbox_face_hw_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 1. smpl pose embed | |
if body_model['type'].upper()=='SMPL': | |
self.body_model_joint_num = 24 | |
elif body_model['type'].upper()=='SMPLX': | |
self.body_model_joint_num = 22 | |
else: | |
raise ValueError( | |
f'Only supports SMPL or SMPLX, but get {body_model.type}') | |
#TODO: | |
_smpl_pose_embed = MLP(hidden_dim * 4, hidden_dim, self.body_model_joint_num * 6, 3) | |
nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smpl_pose_embed_layerlist = [ | |
_smpl_pose_embed | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smpl_pose_embed_layerlist = [ | |
copy.deepcopy(_smpl_pose_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 2. smpl betas embed | |
_smpl_beta_embed = MLP(hidden_dim * 4, hidden_dim, 10, 3) | |
nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smpl_beta_embed_layerlist = [ | |
_smpl_beta_embed | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smpl_beta_embed_layerlist = [ | |
copy.deepcopy(_smpl_beta_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 3. smpl cam embed | |
_cam_embed = MLP(hidden_dim * 4, hidden_dim, 3, 3) | |
nn.init.constant_(_cam_embed.layers[-1].weight.data, 0) | |
nn.init.constant_(_cam_embed.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
cam_embed_layerlist = [ | |
_cam_embed for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
cam_embed_layerlist = [ | |
copy.deepcopy(_cam_embed) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
########################################################################### | |
# smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam | |
########################################################################### | |
# 2. smplx hand pose embed | |
_smplx_hand_pose_embed_layer_2_3 = \ | |
MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0) | |
_smplx_hand_pose_embed_layer_4_5 = \ | |
MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smplx_hand_pose_embed_layerlist = [ | |
_smplx_hand_pose_embed_layer_2_3 | |
if i<2 else _smplx_hand_pose_embed_layer_4_5 | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smplx_hand_pose_embed_layerlist = [ | |
copy.deepcopy(_smplx_hand_pose_embed_layer_2_3) | |
if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 3. smplx face expression | |
_smplx_expression_embed_layer_2_3 = \ | |
MLP(hidden_dim*2, hidden_dim, 10, 3) | |
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0) | |
_smplx_expression_embed_layer_4_5 = \ | |
MLP(hidden_dim * 2, hidden_dim, 10, 3) | |
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smplx_expression_embed_layerlist = [ | |
_smplx_expression_embed_layer_2_3 | |
if i<2 else _smplx_expression_embed_layer_4_5 | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
else: | |
smplx_expression_embed_layerlist = [ | |
copy.deepcopy(_smplx_expression_embed_layer_2_3) | |
if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5) | |
for i in range(transformer.num_decoder_layers - | |
num_box_decoder_layers) | |
] | |
# 4. smplx jaw pose embed | |
_smplx_jaw_embed_2_3 = MLP(hidden_dim * 2, hidden_dim, 6, 3) | |
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0) | |
_smplx_jaw_embed_4_5 = MLP(hidden_dim * 2, hidden_dim, 6, 3) | |
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0) | |
nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0) | |
if dec_pred_bbox_embed_share: | |
smplx_jaw_embed_layerlist = [ | |
_smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5 | |
for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers) | |
] | |
else: | |
smplx_jaw_embed_layerlist = [ | |
copy.deepcopy(_smplx_jaw_embed_2_3) | |
if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5) | |
for i in range( | |
transformer.num_decoder_layers - num_box_decoder_layers) | |
] | |
self.bbox_embed = nn.ModuleList(box_body_embed_layerlist) | |
self.class_embed = nn.ModuleList(class_embed_layerlist) | |
self.transformer.decoder.bbox_embed = self.bbox_embed | |
self.transformer.decoder.class_embed = self.class_embed | |
# smpl | |
self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist) | |
self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist) | |
self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist) | |
# smplx lhand kp | |
self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist) | |
self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist) | |
self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed | |
self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed | |
# smplx face kp | |
self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist) | |
self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist) | |
self.transformer.decoder.bbox_face_embed = self.bbox_face_embed | |
self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed | |
# smplx | |
self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist) | |
self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist) | |
self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist) | |
self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers | |
self.transformer.decoder.num_body_points = num_body_points | |
self.transformer.decoder.num_hand_points = num_hand_points | |
self.transformer.decoder.num_face_points = num_face_points | |
# two stage | |
self.two_stage_type = two_stage_type | |
assert two_stage_type in [ | |
'no', 'standard' | |
], 'unknown param {} of two_stage_type'.format(two_stage_type) | |
if two_stage_type != 'no': | |
if two_stage_bbox_embed_share: | |
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
self.transformer.enc_out_bbox_embed = _bbox_embed | |
else: | |
self.transformer.enc_out_bbox_embed = copy.deepcopy( | |
_bbox_embed) | |
if two_stage_class_embed_share: | |
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
self.transformer.enc_out_class_embed = _class_embed | |
else: | |
self.transformer.enc_out_class_embed = copy.deepcopy( | |
_class_embed) | |
self.refpoint_embed = None | |
self._reset_parameters() | |
def get_camera_trans(self, cam_param, input_body_shape): | |
# camera translation | |
t_xy = cam_param[:, :2] | |
gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive | |
k_value = torch.FloatTensor( | |
[ | |
math.sqrt( | |
self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size / | |
(input_body_shape[0] * input_body_shape[1]) | |
) | |
] | |
).cuda().view(-1) | |
t_z = k_value * gamma | |
cam_trans = torch.cat((t_xy, t_z[:, None]), 1) | |
return cam_trans | |
def _reset_parameters(self): | |
# init input_proj | |
for proj in self.input_proj: | |
nn.init.xavier_uniform_(proj[0].weight, gain=1) | |
nn.init.constant_(proj[0].bias, 0) | |
def prepare_for_dn2(self, targets): | |
if not self.training: | |
device = targets[0]['boxes'].device | |
bs = len(targets) | |
num_points = 4 | |
attn_mask2 = torch.zeros( | |
bs, | |
self.nheads, | |
self.num_group * 4, | |
self.num_group * 4, | |
device=device, | |
dtype=torch.bool) | |
group_bbox_kpt = 4 | |
# body bbox index | |
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
for matchj in range(self.num_group * 4): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
# for each instance, they should associate with their query (body hand face) | |
if sj > 0: | |
attn_mask2[:, :, matchj, :sj] = True | |
if ej < self.num_group * 4: | |
attn_mask2[:, :, matchj, ej:] = True | |
for match_x in range(self.num_group * 4): | |
if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
# each query (hand face body) should associate with all body query | |
attn_mask2[:,:,match_x, kpt_index]=False | |
num_points = 4 | |
attn_mask3 = torch.zeros( | |
bs, | |
self.nheads, | |
self.num_group * 4, | |
self.num_group * 4, | |
device=device, | |
dtype=torch.bool) | |
group_bbox_kpt = 4 | |
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
for matchj in range(self.num_group * 4): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
# for each instance, they should associate with their query (body hand face) | |
if sj > 0: | |
attn_mask3[:, :, matchj, :sj] = True | |
if ej < self.num_group * 4: | |
attn_mask3[:, :, matchj, ej:] = True | |
for match_x in range(self.num_group * 4): | |
if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
# each query (hand face body) should associate with all body query | |
attn_mask3[:, :, match_x, kpt_index] = False | |
attn_mask2 = attn_mask2.flatten(0, 1) | |
attn_mask3 = attn_mask3.flatten(0, 1) | |
return None, None, None, attn_mask2, attn_mask3, None | |
# targets, dn_scalar, noise_scale = dn_args | |
device = targets[0]['boxes'].device | |
bs = len(targets) | |
dn_number = self.dn_number # 100 | |
dn_box_noise_scale = self.dn_box_noise_scale # 0.4 | |
dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5 | |
# gather gt boxes and labels | |
gt_boxes = [t['boxes'] for t in targets] | |
gt_labels = [t['labels'] for t in targets] | |
gt_keypoints = [t['keypoints'] for t in targets] | |
# repeat them | |
def get_indices_for_repeat(now_num, target_num, device='cuda'): | |
""" | |
Input: | |
- now_num: int | |
- target_num: int | |
Output: | |
- indices: tensor[target_num] | |
""" | |
out_indice = [] | |
base_indice = torch.arange(now_num).to(device) | |
multiplier = target_num // now_num | |
out_indice.append(base_indice.repeat(multiplier)) | |
residue = target_num % now_num | |
out_indice.append(base_indice[torch.randint(0, | |
now_num, (residue, ), | |
device=device)]) | |
return torch.cat(out_indice) | |
if self.dn_batch_gt_fuse: | |
raise NotImplementedError | |
gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4 | |
gt_labels_bsall = torch.cat(gt_labels) | |
num_gt_bsall = gt_boxes_bsall.shape[0] | |
if num_gt_bsall > 0: | |
indices = get_indices_for_repeat(num_gt_bsall, dn_number, | |
device) | |
gt_boxes_expand = gt_boxes_bsall[indices][None].repeat( | |
bs, 1, 1) # bs, num_dn, 4 | |
gt_labels_expand = gt_labels_bsall[indices][None].repeat( | |
bs, 1) # bs, num_dn | |
else: | |
# all negative samples when no gt boxes | |
gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device) | |
gt_labels_expand = torch.ones( | |
bs, dn_number, dtype=torch.int64, device=device) * int( | |
self.num_classes) | |
else: | |
gt_boxes_expand = [] | |
gt_labels_expand = [] | |
gt_keypoints_expand = [] # here | |
for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( | |
zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id | |
num_gt_i = gt_boxes_i.shape[0] # instance num | |
if num_gt_i > 0: | |
indices = get_indices_for_repeat(num_gt_i, dn_number, | |
device) | |
gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 | |
gt_labels_expand_i = gt_labels_i[indices] # add smpl | |
gt_keypoints_expand_i = gt_keypoint_i[indices] | |
else: | |
# all negative samples when no gt boxes | |
gt_boxes_expand_i = torch.rand(dn_number, 4, device=device) | |
gt_labels_expand_i = torch.ones( | |
dn_number, dtype=torch.int64, device=device) * int( | |
self.num_classes) | |
gt_keypoints_expand_i = torch.rand(dn_number, | |
self.num_body_points * | |
3, | |
device=device) | |
gt_boxes_expand.append(gt_boxes_expand_i) # add smpl | |
gt_labels_expand.append(gt_labels_expand_i) | |
gt_keypoints_expand.append(gt_keypoints_expand_i) | |
gt_boxes_expand = torch.stack(gt_boxes_expand) | |
gt_labels_expand = torch.stack(gt_labels_expand) | |
gt_keypoints_expand = torch.stack(gt_keypoints_expand) | |
knwon_boxes_expand = gt_boxes_expand.clone() | |
knwon_labels_expand = gt_labels_expand.clone() | |
# add noise | |
if dn_label_noise_ratio > 0: | |
prob = torch.rand_like(knwon_labels_expand.float()) | |
chosen_indice = prob < dn_label_noise_ratio | |
new_label = torch.randint_like( | |
knwon_labels_expand[chosen_indice], 0, | |
self.dn_labelbook_size) # randomly put a new one here | |
knwon_labels_expand[chosen_indice] = new_label | |
if dn_box_noise_scale > 0: | |
diff = torch.zeros_like(knwon_boxes_expand) | |
diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 | |
diff[..., 2:] = knwon_boxes_expand[..., 2:] | |
knwon_boxes_expand += torch.mul( | |
(torch.rand_like(knwon_boxes_expand) * 2 - 1.0), | |
diff) * dn_box_noise_scale | |
knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) | |
input_query_label = self.label_enc(knwon_labels_expand) | |
input_query_bbox = inverse_sigmoid(knwon_boxes_expand) | |
# prepare mask | |
if 'group2group' in self.dn_attn_mask_type_list: | |
attn_mask = torch.zeros(bs, | |
self.nheads, | |
dn_number + self.num_queries, | |
dn_number + self.num_queries, | |
device=device, | |
dtype=torch.bool) | |
attn_mask[:, :, dn_number:, :dn_number] = True | |
for idx, (gt_boxes_i, gt_labels_i) in enumerate( | |
zip(gt_boxes, gt_labels)): # for batch | |
num_gt_i = gt_boxes_i.shape[0] | |
if num_gt_i == 0: | |
continue | |
for matchi in range(dn_number): | |
si = (matchi // num_gt_i) * num_gt_i | |
ei = (matchi // num_gt_i + 1) * num_gt_i | |
if si > 0: | |
attn_mask[idx, :, matchi, :si] = True | |
if ei < dn_number: | |
attn_mask[idx, :, matchi, ei:dn_number] = True | |
attn_mask = attn_mask.flatten(0, 1) | |
if 'group2group' in self.dn_attn_mask_type_list: | |
# self.num_body_points = self.num_body_points +3 | |
num_points = 4 | |
attn_mask2 = torch.zeros( | |
bs, | |
self.nheads, | |
dn_number + self.num_group * 4, | |
dn_number + self.num_group * 4, | |
device=device, | |
dtype=torch.bool) | |
attn_mask2[:, :, dn_number:, :dn_number] = True | |
group_bbox_kpt = 4 | |
for matchj in range(self.num_group * 4): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
# for each instance, they should associate their body, hand, and face bbox | |
if sj > 0: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
if ej < self.num_group * 4: | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
# body bbox index | |
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
for match_x in range(self.num_group * 4): | |
if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
# for each instance, they should associate their each query with | |
# other instances' body query | |
attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False | |
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
num_gt_i = gt_boxes_i.shape[0] | |
if num_gt_i == 0: | |
continue | |
for matchi in range(dn_number): | |
si = (matchi // num_gt_i) * num_gt_i | |
ei = (matchi // num_gt_i + 1) * num_gt_i | |
if si > 0: | |
attn_mask2[idx, :, matchi, :si] = True | |
if ei < dn_number: | |
attn_mask2[idx, :, matchi, ei:dn_number] = True | |
attn_mask2 = attn_mask2.flatten(0, 1) | |
if 'group2group' in self.dn_attn_mask_type_list: | |
num_points = 4 | |
attn_mask3 = torch.zeros( | |
bs, | |
self.nheads, | |
dn_number + self.num_group * 4, dn_number + self.num_group * 4, | |
device=device, dtype=torch.bool) | |
attn_mask3[:, :, dn_number:, :dn_number] = True | |
group_bbox_kpt = 4 | |
for matchj in range(self.num_group * 4): | |
sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
# for each instance, they should associate their body, hand, and face bbox | |
if sj > 0: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
if ej < self.num_group * 4: | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
for match_x in range(self.num_group * 4): | |
if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
# for each instance, they should associate their each query with | |
# other instances' body query | |
attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False | |
for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
num_gt_i = gt_boxes_i.shape[0] | |
if num_gt_i == 0: | |
continue | |
for matchi in range(dn_number): | |
si = (matchi // num_gt_i) * num_gt_i | |
ei = (matchi // num_gt_i + 1) * num_gt_i | |
if si > 0: | |
attn_mask3[idx, :, matchi, :si] = True | |
if ei < dn_number: | |
attn_mask3[idx, :, matchi, ei:dn_number] = True | |
attn_mask3 = attn_mask3.flatten(0, 1) | |
mask_dict = { | |
'pad_size': dn_number, | |
'known_bboxs': gt_boxes_expand, | |
'known_labels': gt_labels_expand, | |
'known_keypoints': gt_keypoints_expand | |
} | |
return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict | |
def dn_post_process2(self, outputs_class, outputs_coord, mask_dict): | |
if mask_dict and mask_dict['pad_size'] > 0: | |
output_known_class = [ | |
outputs_class_i[:, :mask_dict['pad_size'], :] | |
for outputs_class_i in outputs_class | |
] | |
output_known_coord = [ | |
outputs_coord_i[:, :mask_dict['pad_size'], :] | |
for outputs_coord_i in outputs_coord | |
] | |
outputs_class = [ | |
outputs_class_i[:, mask_dict['pad_size']:, :] | |
for outputs_class_i in outputs_class | |
] | |
outputs_coord = [ | |
outputs_coord_i[:, mask_dict['pad_size']:, :] | |
for outputs_coord_i in outputs_coord | |
] | |
mask_dict.update({ | |
'output_known_coord': output_known_coord, | |
'output_known_class': output_known_class | |
}) | |
return outputs_class, outputs_coord | |
def forward(self, data_batch: NestedTensor, targets: List = None): | |
"""The forward expects a NestedTensor, which consists of: | |
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] | |
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
It returns a dict with the following elements: | |
- "pred_logits": the classification logits (including no-object) for all queries. | |
Shape= [batch_size x num_queries x num_classes] | |
- "pred_boxes": The normalized boxes coordinates for all queries, represented as | |
(center_x, center_y, width, height). These values are normalized in [0, 1], | |
relative to the size of each individual image (disregarding possible padding). | |
See PostProcess for information on how to retrieve the unnormalized bounding box. | |
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
dictionnaries containing the two above keys for each decoder layer. | |
""" | |
if isinstance(data_batch, dict): | |
samples, targets = self.prepare_targets(data_batch) | |
# import pdb; pdb.set_trace() | |
elif isinstance(data_batch, (list, torch.Tensor)): | |
samples = nested_tensor_from_tensor_list(data_batch) | |
else: | |
samples = data_batch | |
features, poss = self.backbone(samples) | |
srcs = [] | |
masks = [] | |
for l, feat in enumerate(features): # len(features=3) | |
src, mask = feat.decompose() | |
srcs.append(self.input_proj[l](src)) | |
masks.append(mask) | |
assert mask is not None | |
if self.num_feature_levels > len(srcs): | |
_len_srcs = len(srcs) | |
for l in range(_len_srcs, self.num_feature_levels): | |
if l == _len_srcs: | |
src = self.input_proj[l](features[-1].tensors) | |
else: | |
src = self.input_proj[l](srcs[-1]) | |
m = samples.mask | |
mask = F.interpolate(m[None].float(), | |
size=src.shape[-2:]).to(torch.bool)[0] | |
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
srcs.append(src) | |
masks.append(mask) | |
poss.append(pos_l) | |
if self.dn_number > 0 or targets is not None: | |
input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\ | |
self.prepare_for_dn2(targets) | |
else: | |
assert targets is None | |
input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None | |
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( | |
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, | |
attn_mask2, attn_mask3) | |
# update human boxes | |
effective_dn_number = self.dn_number if self.training else 0 | |
outputs_body_bbox_list = [] | |
outputs_class = [] | |
for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed, | |
layer_hs) in enumerate( | |
zip(reference[:-1], self.bbox_embed, | |
self.class_embed, hs)): | |
if dec_lid < self.num_box_decoder_layers: | |
# human det | |
layer_delta_unsig = layer_body_bbox_embed(layer_hs) | |
layer_body_box_outputs_unsig = \ | |
layer_delta_unsig + inverse_sigmoid(layer_ref_sig) | |
layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid() | |
layer_cls = layer_cls_embed(layer_hs) | |
outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
outputs_class.append(layer_cls) | |
elif dec_lid < self.num_box_decoder_layers + 2: | |
bs = layer_ref_sig.shape[0] | |
# dn body bbox | |
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
# norm body bbox | |
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
:, 0::(self.num_body_points + 4), :] # norm content query | |
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][ | |
:, 0::(self.num_body_points+ 4), :] # norm position query | |
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
layer_body_box_outputs_unsig = torch.cat( | |
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
# classfication | |
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
outputs_class.append(layer_cls) | |
outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
else: | |
bs = layer_ref_sig.shape[0] | |
# dn body bbox | |
layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_dn) | |
layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
# norm body bbox | |
layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
:, 0::(self.num_whole_body_points + 4), :] # norm content query | |
reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][ | |
:, 0::(self.num_whole_body_points + 4), :] # norm position query | |
layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
reference_before_sigmoid_body_bbox_norm) | |
layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
layer_body_box_outputs_unsig = torch.cat( | |
(layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
# classfication | |
layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
outputs_class.append(layer_cls) | |
outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
# update hand and face boxes | |
outputs_lhand_bbox_list = [] | |
outputs_rhand_bbox_list = [] | |
outputs_face_bbox_list = [] | |
# update keypoints boxes | |
outputs_body_keypoints_list = [] | |
outputs_body_keypoints_hw = [] | |
outputs_lhand_keypoints_list = [] | |
outputs_lhand_keypoints_hw = [] | |
outputs_rhand_keypoints_list = [] | |
outputs_rhand_keypoints_hw = [] | |
outputs_face_keypoints_list = [] | |
outputs_face_keypoints_hw = [] | |
outputs_smpl_pose_list = [] | |
outputs_smpl_lhand_pose_list = [] | |
outputs_smpl_rhand_pose_list = [] | |
outputs_smpl_expr_list = [] | |
outputs_smpl_jaw_pose_list = [] | |
outputs_smpl_beta_list = [] | |
outputs_smpl_cam_list = [] | |
outputs_smpl_kp2d_list = [] | |
outputs_smpl_kp3d_list = [] | |
outputs_smpl_verts_list = [] | |
# smpl pose | |
# body box, kps, lhand box | |
body_index = [0, 1, 2, 3] | |
smpl_pose_index = [ | |
x for x in range(self.num_group * 4) if (x % 4 in body_index)] | |
# smpl lhand | |
lhand_index = [0, 1] | |
smpl_lhand_pose_index = [ | |
x for x in range(self.num_group * 4) if (x % 4 in lhand_index)] | |
# smpl rhand | |
rhand_index = [0, 2] | |
smpl_rhand_pose_index = [ | |
x for x in range(self.num_group * 4) if (x % 4 in rhand_index)] | |
# smpl face | |
face_index = [0, 3] | |
smpl_face_pose_index = [ | |
x for x in range(self.num_group * 4) if (x % 4 in face_index)] | |
for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): | |
if dec_lid < self.num_box_decoder_layers: | |
assert isinstance(layer_hs, torch.Tensor) | |
bs = layer_hs.shape[0] | |
layer_body_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_body_points * 3)) # [-, 900, 42] | |
outputs_body_keypoints_list.append(layer_body_kps_res) | |
# lhand | |
layer_lhand_bbox_res = layer_hs.new_zeros( | |
(bs, self.num_queries, 4)) # [-, 900, 42] | |
outputs_lhand_bbox_list.append(layer_lhand_bbox_res) | |
layer_lhand_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_hand_points * 3)) # [-, 900, 42] | |
outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
# rhand | |
layer_rhand_bbox_res = layer_hs.new_zeros( | |
(bs, self.num_queries, 4)) # [-, 900, 42] | |
outputs_rhand_bbox_list.append(layer_rhand_bbox_res) | |
layer_rhand_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_hand_points * 3)) # [-, 900, 42] | |
outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
# face | |
layer_face_bbox_res = layer_hs.new_zeros( | |
(bs, self.num_queries, 4)) # [-, 900, 42] | |
outputs_face_bbox_list.append(layer_face_bbox_res) | |
layer_face_kps_res = layer_hs.new_zeros( | |
(bs, self.num_queries, | |
self.num_face_points * 3)) # [-, 900, 42] | |
outputs_face_keypoints_list.append(layer_face_kps_res) | |
# smpl or smplx | |
smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3)) | |
smpl_rhand_pose = layer_hs.new_zeros( | |
(bs, self.num_queries, 15 * 3)) | |
smpl_lhand_pose = layer_hs.new_zeros( | |
(bs, self.num_queries, 15 * 3)) | |
smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6)) | |
smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3)) | |
# smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3)) | |
smpl_kp3d = layer_hs.new_zeros( | |
(bs, self.num_queries, self.num_body_points, 4)) | |
outputs_smpl_pose_list.append(smpl_pose) | |
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
outputs_smpl_expr_list.append(smpl_expr) | |
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
outputs_smpl_beta_list.append(smpl_beta) | |
outputs_smpl_cam_list.append(smpl_cam) | |
# outputs_smpl_kp2d_list.append(smpl_kp2d) | |
outputs_smpl_kp3d_list.append(smpl_kp3d) | |
elif dec_lid < self.num_box_decoder_layers +2: | |
bs = layer_ref_sig.shape[0] | |
# lhand bbox | |
layer_hs_lhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][:, 1::4, :] | |
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_sig_lhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone() | |
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
# rhand bbox | |
layer_hs_rhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][:, 2::4, :] | |
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_sig_rhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone() | |
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
# face bbox | |
layer_hs_face_bbox = \ | |
layer_hs[:, effective_dn_number:, :][:, 3::4, :] | |
delta_face_bbox_xy_unsig = self.bbox_face_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_sig_face_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone() | |
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
# smpl or smplx | |
bs, _, feat_dim = layer_hs.shape | |
smpl_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 4) | |
smpl_lhand_pose_feats = \ | |
layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 2) | |
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 2) | |
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 2) | |
smpl_pose = self.smpl_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_feats) | |
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, self.body_model_joint_num, 3, 3) | |
smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_jaw_pose = self.smpl_jaw_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, 1, 3, 3) | |
smpl_beta = self.smpl_beta_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_feats) | |
smpl_cam = self.smpl_cam_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_feats) | |
smpl_expr = self.smpl_expr_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
# smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3) | |
leye_pose = torch.zeros_like(smpl_jaw_pose) | |
reye_pose = torch.zeros_like(smpl_jaw_pose) | |
if self.body_model is not None: | |
smpl_pose_ = rotmat_to_aa(smpl_pose) | |
# smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
# smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
leye_pose_ = rotmat_to_aa(leye_pose) | |
reye_pose_ = rotmat_to_aa(reye_pose) | |
pred_output = self.body_model( | |
betas=smpl_beta.reshape(-1, 10), | |
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
global_orient=smpl_pose_[:, :, 0].reshape( | |
-1, 3).unsqueeze(1), | |
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
leye_pose=leye_pose_, | |
reye_pose=reye_pose_, | |
jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
# expression=smpl_expr.reshape(-1, 10), | |
expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10) | |
) | |
smpl_kp3d = pred_output['joints'].reshape( | |
bs, self.num_group, -1, 3) | |
smpl_verts = pred_output['vertices'].reshape( | |
bs, self.num_group, -1, 3) | |
# pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) | |
outputs_smpl_pose_list.append(smpl_pose) | |
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
outputs_smpl_expr_list.append(smpl_expr) | |
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
outputs_smpl_beta_list.append(smpl_beta) | |
outputs_smpl_cam_list.append(smpl_cam) | |
outputs_smpl_kp3d_list.append(smpl_kp3d) | |
else: | |
bs = layer_ref_sig.shape[0] | |
# lhand bbox | |
layer_hs_lhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][:, 1::4, :] | |
delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_sig_lhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone() | |
layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
# rhand bbox | |
layer_hs_rhand_bbox = \ | |
layer_hs[:, effective_dn_number:, :][:, 2::4, :] | |
delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_sig_rhand_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone() | |
layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
# face bbox | |
layer_hs_face_bbox = \ | |
layer_hs[:, effective_dn_number:, :][:, 3::4, :] | |
delta_face_bbox_xy_unsig = \ | |
self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_sig_face_bbox = \ | |
layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone() | |
layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
bs, _, feat_dim = layer_hs.shape | |
smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 4) | |
smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 2) | |
smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 2) | |
smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) | |
).reshape(bs, -1, feat_dim * 2) | |
smpl_pose = self.smpl_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, self.body_model_joint_num, 3, 3) | |
smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
-1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
smpl_expr = self.smpl_expr_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
smpl_jaw_pose = self.smpl_jaw_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
bs, self.num_group, 1, 3, 3) | |
smpl_beta = self.smpl_beta_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
smpl_cam = self.smpl_cam_embed[ | |
dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
num_samples = smpl_beta.reshape(-1, 10).shape[0] | |
device = smpl_beta.device | |
leye_pose = torch.zeros_like(smpl_jaw_pose) | |
reye_pose = torch.zeros_like(smpl_jaw_pose) | |
if self.body_model is not None: | |
smpl_pose_ = rotmat_to_aa(smpl_pose) | |
smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
leye_pose_ = rotmat_to_aa(leye_pose) | |
reye_pose_ = rotmat_to_aa(reye_pose) | |
pred_output = self.body_model( | |
betas=smpl_beta.reshape(-1, 10), | |
body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
global_orient=smpl_pose_[:, :, 0].reshape( | |
-1, 3).unsqueeze(1), | |
left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
leye_pose=leye_pose_, | |
reye_pose=reye_pose_, | |
jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
expression=smpl_expr.reshape(-1, 10), | |
# expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10), | |
) | |
smpl_kp3d = pred_output['joints'].reshape( | |
bs, self.num_group, -1, 3) | |
smpl_verts = pred_output['vertices'].reshape( | |
bs, self.num_group, -1, 3) | |
outputs_smpl_pose_list.append(smpl_pose) | |
outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
outputs_smpl_expr_list.append(smpl_expr) | |
outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
outputs_smpl_beta_list.append(smpl_beta) | |
outputs_smpl_cam_list.append(smpl_cam) | |
outputs_smpl_kp3d_list.append(smpl_kp3d) | |
if not self.training: | |
outputs_smpl_verts_list.append(smpl_verts) | |
dn_mask_dict = mask_dict | |
if self.dn_number > 0 and dn_mask_dict is not None: | |
outputs_class, outputs_body_bbox_list = self.dn_post_process2( | |
outputs_class, outputs_body_bbox_list, dn_mask_dict) | |
dn_class_input = dn_mask_dict['known_labels'] | |
dn_bbox_input = dn_mask_dict['known_bboxs'] | |
dn_class_pred = dn_mask_dict['output_known_class'] | |
dn_bbox_pred = dn_mask_dict['output_known_coord'] | |
for idx, (_out_class, _out_bbox) in enumerate(zip(outputs_class, outputs_body_bbox_list)): | |
assert _out_class.shape[1] == _out_bbox.shape[1] | |
out = { | |
'pred_logits': outputs_class[-1], | |
'pred_boxes': outputs_body_bbox_list[-1], | |
'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
'pred_face_boxes': outputs_face_bbox_list[-1], | |
'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1] | |
} | |
if not self.training: | |
full_pose = torch.cat((outputs_smpl_pose_list[-1], | |
outputs_smpl_lhand_pose_list[-1], | |
outputs_smpl_rhand_pose_list[-1], | |
outputs_smpl_jaw_pose_list[-1]),dim=2) | |
bs,num_q,_,_,_ = full_pose.shape | |
full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3) | |
out = { | |
'pred_logits': outputs_class[-1], | |
'pred_boxes': outputs_body_bbox_list[-1], | |
'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
'pred_face_boxes': outputs_face_bbox_list[-1], | |
'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1], | |
'pred_smpl_verts': outputs_smpl_verts_list[-1], | |
'pred_smpl_fullpose': full_pose | |
} | |
if self.dn_number > 0 and dn_mask_dict is not None: | |
out.update({ | |
'dn_class_input': dn_class_input, | |
'dn_bbox_input': dn_bbox_input, | |
'dn_class_pred': dn_class_pred[-1], | |
'dn_bbox_pred': dn_bbox_pred[-1], | |
'num_tgt': dn_mask_dict['pad_size'] | |
}) | |
if self.aux_loss: | |
out['aux_outputs'] = \ | |
self._set_aux_loss( | |
outputs_class, | |
outputs_body_bbox_list, | |
outputs_lhand_bbox_list, | |
outputs_rhand_bbox_list, | |
outputs_face_bbox_list, | |
outputs_smpl_pose_list, | |
outputs_smpl_rhand_pose_list, | |
outputs_smpl_lhand_pose_list, | |
outputs_smpl_jaw_pose_list, | |
outputs_smpl_expr_list, | |
outputs_smpl_beta_list, | |
outputs_smpl_cam_list, | |
outputs_smpl_kp3d_list | |
) # with key pred_logits, pred_bbox, pred_keypoints | |
if self.dn_number > 0 and dn_mask_dict is not None: | |
assert len(dn_class_pred[:-1]) == len( | |
dn_bbox_pred[:-1]) == len(out['aux_outputs']) | |
for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip( | |
out['aux_outputs'], dn_class_pred, dn_bbox_pred): | |
aux_out.update({ | |
'dn_class_input': dn_class_input, | |
'dn_bbox_input': dn_bbox_input, | |
'dn_class_pred': dn_class_pred_i, | |
'dn_bbox_pred': dn_bbox_pred_i, | |
'num_tgt': dn_mask_dict['pad_size'] | |
}) | |
# for encoder output | |
if hs_enc is not None: | |
interm_coord = ref_enc[-1] | |
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1]) | |
interm_pose = torch.zeros_like(outputs_body_keypoints_list[0]) | |
out['interm_outputs'] = { | |
'pred_logits': interm_class, | |
'pred_boxes': interm_coord, | |
'pred_keypoints': interm_pose | |
} | |
return out, targets, data_batch | |
def _set_aux_loss(self, | |
outputs_class, | |
outputs_body_coord, | |
outputs_lhand_coord, | |
outputs_rhand_coord, | |
outputs_face_coord, | |
outputs_smpl_pose, | |
outputs_smpl_rhand_pose, | |
outputs_smpl_lhand_pose, | |
outputs_smpl_jaw_pose, | |
outputs_smpl_expr, | |
outputs_smpl_beta, | |
outputs_smpl_cam, | |
outputs_smpl_kp3d): | |
return [{ | |
'pred_logits': a, | |
'pred_boxes': b, | |
'pred_lhand_boxes': c, | |
'pred_rhand_boxes': d, | |
'pred_face_boxes': e, | |
'pred_smpl_pose': j, | |
'pred_smpl_rhand_pose': k, | |
'pred_smpl_lhand_pose': l, | |
'pred_smpl_jaw_pose': m, | |
'pred_smpl_expr': n, | |
'pred_smpl_beta': o, | |
'pred_smpl_cam': p, | |
'pred_smpl_kp3d': q | |
} for a, b, c, d, e, j, k, l, m, n, o, p, q in zip( | |
outputs_class[:-1], | |
outputs_body_coord[:-1], | |
outputs_lhand_coord[:-1], | |
outputs_rhand_coord[:-1], | |
outputs_face_coord[:-1], | |
outputs_smpl_pose[:-1], | |
outputs_smpl_rhand_pose[:-1], | |
outputs_smpl_lhand_pose[:-1], | |
outputs_smpl_jaw_pose[:-1], | |
outputs_smpl_expr[:-1], | |
outputs_smpl_beta[:-1], | |
outputs_smpl_cam[:-1], | |
outputs_smpl_kp3d[:-1])] | |
def prepare_targets(self, data_batch): | |
data_batch_coco = [] | |
instance_dict = {} | |
img_list = data_batch['img'].float() | |
# input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape'] | |
batch_size, _, input_img_h, input_img_w = img_list.shape | |
device = img_list.device | |
masks = torch.ones((batch_size, input_img_h, input_img_w), | |
dtype=torch.bool, | |
device=device) | |
if self.num_body_points == 17: | |
ed_convention = 'coco' | |
elif self.num_body_points == 14: | |
ed_convention = 'crowdpose' | |
# cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape | |
for img_id in range(batch_size): | |
img_h, img_w = data_batch['img_shape'][img_id] | |
masks[img_id, :img_h, :img_w] = 0 | |
if not self.inference: | |
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
data_batch['body_bbox_size'][img_id]],dim=-1) | |
instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\ | |
data_batch['face_bbox_size'][img_id]],dim=-1) | |
instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\ | |
data_batch['lhand_bbox_size'][img_id]],dim=-1) | |
instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\ | |
data_batch['rhand_bbox_size'][img_id]],dim=-1) | |
instance_kp2d = data_batch['joint_img'][img_id].clone().float() | |
instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float() | |
instance_kp2d[:,:,2:] = instance_kp2d_mask | |
body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True) | |
lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True) | |
rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True) | |
face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True) | |
# from util.vis_utils import show_bbox | |
# show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy) | |
body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1) | |
lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1) | |
rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1) | |
face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1) | |
instance_dict = {} | |
instance_dict['boxes'] = instance_body_bbox.float() | |
instance_dict['face_boxes'] = instance_face_bbox.float() | |
instance_dict['lhand_boxes'] = instance_lhand_bbox.float() | |
instance_dict['rhand_boxes'] = instance_rhand_bbox.float() | |
instance_dict['keypoints'] = body_kp2d.float() | |
instance_dict['lhand_keypoints'] = lhand_kp2d.float() | |
instance_dict['rhand_keypoints'] = rhand_kp2d.float() | |
instance_dict['face_keypoints'] = face_kp2d.float() | |
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3] | |
instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3] | |
instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3] | |
instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3] | |
instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0], | |
dtype=torch.long, | |
device=device) | |
data_batch_coco.append(instance_dict) | |
else: | |
instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
data_batch['body_bbox_size'][img_id]],dim=-1) | |
instance_dict = {} | |
# instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
instance_dict['boxes'] = instance_body_bbox.float() | |
data_batch_coco.append(instance_dict) | |
input_img = NestedTensor(img_list, masks) | |
return input_img, data_batch_coco | |
def keypoints_to_scaled_bbox_bfh( | |
self, keypoints, occ=None, | |
body_scale=1.0, fh_scale=1.0, | |
convention='smplx'): | |
'''Obtain scaled bbox in xyxy format given keypoints | |
Args: | |
keypoints (np.ndarray): Keypoints | |
scale (float): Bounding Box scale | |
Returns: | |
bbox_xyxy (np.ndarray): Bounding box in xyxy format | |
''' | |
bboxs = [] | |
# supported kps.shape: (1, n, k) or (n, k), k = 2 or 3 | |
if keypoints.ndim == 3: | |
keypoints = keypoints[0] | |
if keypoints.shape[-1] != 2: | |
keypoints = keypoints[:, :2] | |
for body_part in ['body', 'head', 'left_hand', 'right_hand']: | |
if body_part == 'body': | |
scale = body_scale | |
kps = keypoints | |
else: | |
scale = fh_scale | |
kp_id = get_keypoint_idxs_by_part(body_part, convention=convention) | |
kps = keypoints[kp_id] | |
if not occ is None: | |
occ_p = occ[kp_id] | |
if np.sum(occ_p) / len(kp_id) >= 0.1: | |
conf = 0 | |
# print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip') | |
else: | |
# print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}') | |
conf = 1 | |
else: | |
conf = 1 | |
if body_part == 'body': | |
conf = 1 | |
xmin, ymin = np.amin(kps, axis=0) | |
xmax, ymax = np.amax(kps, axis=0) | |
width = (xmax - xmin) * scale | |
height = (ymax - ymin) * scale | |
x_center = 0.5 * (xmax + xmin) | |
y_center = 0.5 * (ymax + ymin) | |
xmin = x_center - 0.5 * width | |
xmax = x_center + 0.5 * width | |
ymin = y_center - 0.5 * height | |
ymax = y_center + 0.5 * height | |
bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32) | |
bboxs.append(bbox) | |
return bboxs | |
def build_aios_smplx_box(args, cfg): | |
# pdb.set_trace() | |
num_classes = args.num_classes # 2 | |
device = torch.device(args.device) | |
backbone = build_backbone(args) | |
transformer = build_transformer(args) | |
dn_labelbook_size = args.dn_labelbook_size | |
dec_pred_class_embed_share = args.dec_pred_class_embed_share | |
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share | |
if args.eval: | |
body_model = args.body_model_test | |
train = False | |
else: | |
body_model = args.body_model_train | |
train = True | |
model = AiOSSMPLX_Box( | |
backbone, | |
transformer, | |
num_classes=num_classes, # 2 | |
num_queries=args.num_queries, # 900 | |
aux_loss=True, | |
iter_update=True, | |
query_dim=4, | |
random_refpoints_xy=args.random_refpoints_xy, # False | |
fix_refpoints_hw=args.fix_refpoints_hw, # -1 | |
num_feature_levels=args.num_feature_levels, # 4 | |
nheads=args.nheads, # 8 | |
dec_pred_class_embed_share=dec_pred_class_embed_share, # false | |
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False | |
# two stage | |
two_stage_type=args.two_stage_type, | |
# box_share | |
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False | |
two_stage_class_embed_share=args.two_stage_class_embed_share, # False | |
dn_number=args.dn_number if args.use_dn else 0, # 100 | |
dn_box_noise_scale=args.dn_box_noise_scale, # 0.4 | |
dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5 | |
dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false | |
dn_attn_mask_type_list=args.dn_attn_mask_type_list, | |
dn_labelbook_size=dn_labelbook_size, # 100 | |
cls_no_bias=args.cls_no_bias, # False | |
num_group=args.num_group, # 100 | |
num_body_points=0, # 17 | |
num_hand_points=0, # 17 | |
num_face_points=0, # 17 | |
num_box_decoder_layers=args.num_box_decoder_layers, # 2 | |
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
# smpl_convention=convention | |
body_model=body_model, | |
train=train, | |
inference=args.inference) | |
matcher = build_matcher(args) | |
# prepare weight dict | |
weight_dict = { | |
'loss_ce': args.cls_loss_coef, # 2 | |
# bbox | |
'loss_body_bbox': args.body_bbox_loss_coef, # 5 | |
'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5 | |
'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5 | |
'loss_face_bbox': args.face_bbox_loss_coef, # 5 | |
# bbox giou | |
'loss_body_giou': args.body_giou_loss_coef, # 2 | |
'loss_rhand_giou': args.rhand_giou_loss_coef, # 2 | |
'loss_lhand_giou': args.lhand_giou_loss_coef, # 2 | |
'loss_face_giou': args.face_giou_loss_coef, # 2 | |
# smpl param | |
'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0 | |
'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0 | |
'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0 | |
'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0 | |
'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0 | |
'loss_smpl_beta': args.smpl_beta_loss_coef, # 0 | |
'loss_smpl_expr': args.smpl_expr_loss_coef, | |
# smpl kp3d ra | |
'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0 | |
'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0 | |
'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0 | |
'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0 | |
# smpl kp3d | |
'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0 | |
'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0 | |
'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0 | |
'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0 | |
# smpl kp2d | |
'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0 | |
'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0 | |
'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0 | |
'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0 | |
} | |
clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) | |
if args.use_dn: | |
weight_dict.update({ | |
'dn_loss_ce': | |
args.dn_label_coef, # 0.3 | |
'dn_loss_bbox': | |
args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5 | |
'dn_loss_giou': | |
args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5 | |
}) | |
clean_weight_dict = copy.deepcopy(weight_dict) | |
if args.aux_loss: | |
aux_weight_dict = {} | |
for i in range(args.dec_layers - 1): # from 0 t 4 # ??? | |
for k, v in clean_weight_dict.items(): | |
if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k): | |
continue | |
if i < args.num_box_decoder_layers and k in [ | |
'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox', | |
'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']: | |
continue | |
if i < args.num_hand_face_decoder_layers and k in [ | |
'loss_rhand_keypoints', 'loss_lhand_keypoints', | |
'loss_face_keypoints', 'loss_rhand_oks', | |
'loss_lhand_oks', 'loss_face_oks']: | |
continue | |
if i < args.num_box_decoder_layers and 'smpl' in k: | |
continue | |
aux_weight_dict.update({k + f'_{i}': v}) | |
weight_dict.update(aux_weight_dict) | |
if args.two_stage_type != 'no': | |
interm_weight_dict = {} | |
try: | |
no_interm_box_loss = args.no_interm_box_loss | |
except: | |
no_interm_box_loss = False | |
_coeff_weight_dict = { | |
'loss_ce': 1.0, | |
# bbox | |
'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
# bbox giou | |
'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl param | |
'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl kp3d ra | |
'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl kp3d | |
'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
# smpl kp2d | |
'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
} | |
try: | |
interm_loss_coef = args.interm_loss_coef # 1 | |
except: | |
interm_loss_coef = 1.0 | |
interm_weight_dict.update({ | |
k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] | |
for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k | |
}) | |
weight_dict.update(interm_weight_dict) | |
interm_weight_dict.update({ | |
k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k] | |
for k, v in clean_weight_dict_wo_dn.items() | |
}) # ??? | |
weight_dict.update(interm_weight_dict) | |
losses = cfg.losses | |
if args.dn_number > 0: | |
losses += ['dn_label', 'dn_bbox'] | |
losses += ['matching'] | |
criterion = SetCriterion_Box( | |
num_classes, | |
matcher=matcher, | |
weight_dict=weight_dict, | |
focal_alpha=args.focal_alpha, | |
losses=losses, | |
num_box_decoder_layers=args.num_box_decoder_layers, | |
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
num_body_points=0, | |
num_hand_points=0, | |
num_face_points=0, | |
) | |
criterion.to(device) | |
if args.inference: | |
postprocessors = { | |
'bbox': | |
PostProcess_SMPLX_Multi_Infer_Box( | |
num_select=args.num_select, | |
nms_iou_threshold=args.nms_iou_threshold, | |
num_body_points=0), | |
} | |
else: | |
postprocessors = { | |
'bbox': | |
PostProcess_SMPLX_Multi_Box( | |
num_select=args.num_select, | |
nms_iou_threshold=args.nms_iou_threshold, | |
num_body_points=0), | |
} | |
postprocessors_aios = { | |
'bbox': | |
PostProcess_aios(num_select=args.num_select, | |
nms_iou_threshold=args.nms_iou_threshold, | |
num_body_points=0), | |
} | |
return model, criterion, postprocessors, postprocessors_aios | |