Spaces:
Sleeping
Sleeping
from lib.kits.basic import * | |
from lib.utils.camera import perspective_projection | |
from lib.modeling.losses import compute_poses_angle_prior_loss | |
from .utils import ( | |
gmof, | |
guess_cam_z, | |
estimate_kp2d_scale, | |
get_kp_active_jids, | |
get_params_active_q_masks, | |
) | |
def build_closure( | |
self, | |
cfg, | |
optimizer, | |
inputs, | |
focal_length : float, | |
gt_kp2d, | |
log_data, | |
): | |
B = len(gt_kp2d) | |
act_parts = instantiate(cfg.parts) | |
act_q_masks = None | |
if not (act_parts == 'all' or 'all' in act_parts): | |
act_q_masks = get_params_active_q_masks(act_parts) | |
# Shortcuts for the inference of the skeleton model. | |
def inference_skel(inputs): | |
poses_active = torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1) # (B, 46) | |
if act_q_masks is not None: | |
poses_hidden = poses_active.clone().detach() # (B, 46) | |
poses = poses_active * act_q_masks + poses_hidden * (1 - act_q_masks) # (B, 46) | |
else: | |
poses = poses_active | |
skel_params = { | |
'poses' : poses, # (B, 46) | |
'betas' : inputs['betas'], # (B, 10) | |
} | |
skel_output = self.skel_model(**skel_params, skelmesh=False) | |
return skel_params, skel_output | |
# Estimate the camera depth as an initialization if depth loss is enabled. | |
gs_cam_z = None | |
if 'w_depth' in cfg.losses: | |
with torch.no_grad(): | |
_, skel_output = inference_skel(inputs) | |
gs_cam_z = guess_cam_z( | |
pd_kp3d = skel_output.joints, | |
gt_kp2d = gt_kp2d, | |
focal_length = focal_length, | |
) | |
# Prepare the focal length for the perspective projection. | |
focal_length_xy = np.ones((B, 2)) * focal_length # (B, 2) | |
def closure(): | |
optimizer.zero_grad() | |
# 📦 Data preparation. | |
with PM.time_monitor('SKEL-forward'): | |
skel_params, skel_output = inference_skel(inputs) | |
with PM.time_monitor('reproj'): | |
pd_kp2d = perspective_projection( | |
points = to_tensor(skel_output.joints, device=self.device), | |
translation = to_tensor(inputs['cam_t'], device=self.device), | |
focal_length = to_tensor(focal_length_xy, device=self.device), | |
) | |
with PM.time_monitor('compute_losses'): | |
loss, losses = compute_losses( | |
# Loss configuration. | |
loss_cfg = instantiate(cfg.losses), | |
parts = act_parts, | |
# Data inputs. | |
gt_kp2d = gt_kp2d, | |
pd_kp2d = pd_kp2d, | |
pd_params = skel_params, | |
pd_cam_z = inputs['cam_t'][:, 2], | |
gs_cam_z = gs_cam_z, | |
) | |
with PM.time_monitor('visualization'): | |
VISUALIZE = True | |
if VISUALIZE: | |
# For visualize the optimization process. | |
kp2d_err = torch.sum((pd_kp2d - gt_kp2d[..., :2]) ** 2, dim=-1) * gt_kp2d[..., 2] # (B, J) | |
kp2d_err = kp2d_err.sum(dim=-1) / (torch.sum(gt_kp2d[..., 2], dim=-1) + 1e-6) # (B,) | |
# Store logging data. | |
if self.tb_logger is not None: | |
log_data.update({ | |
'losses' : losses, | |
'pd_kp2d' : pd_kp2d[:self.n_samples].detach().clone(), | |
'pd_verts' : skel_output.skin_verts[:self.n_samples].detach().clone(), | |
'cam_t' : inputs['cam_t'][:self.n_samples].detach().clone(), | |
'optim_betas' : inputs['betas'][:self.n_samples].detach().clone(), | |
'kp2d_err' : kp2d_err[:self.n_samples].detach().clone(), | |
}) | |
with PM.time_monitor('backwards'): | |
loss.backward() | |
return loss.item() | |
return closure | |
def compute_losses( | |
loss_cfg : Dict[str, Union[bool, float]], | |
parts : List[str], | |
gt_kp2d : torch.Tensor, | |
pd_kp2d : torch.Tensor, | |
pd_params : Dict[str, torch.Tensor], | |
pd_cam_z : torch.Tensor, | |
gs_cam_z : Optional[torch.Tensor] = None, | |
): | |
''' | |
### Args | |
- loss_cfg: Dict[str, Union[bool, float]] | |
- Special option flags (`f_xxx`) or loss weights (`w_xxx`). | |
- parts: List[str] | |
- The list of the active joint parts groups. | |
- Among {'all', 'torso', 'torso-lite', 'limbs', 'head', 'limbs_proximal', 'limbs_distal'}. | |
- gt_kp2d: torch.Tensor (B, 44, 3) | |
- The ground-truth 2D keypoints with confidence. | |
- pd_kp2d: torch.Tensor (B, 44, 2) | |
- The predicted 2D keypoints. | |
- pd_params: Dict[str, torch.Tensor] | |
- poses: torch.Tensor (B, 46) | |
- betas: torch.Tensor (B, 10) | |
- pd_cam_z: torch.Tensor (B,) | |
- The predicted camera depth translation. | |
- gs_cam_z: Optional[torch.Tensor] (B,) | |
- The guessed camera depth translation. | |
### Returns | |
- loss: torch.Tensor (,) | |
- The weighted loss value for optimization. | |
- losses: Dict[str, float] | |
- The dictionary of the loss values for logging. | |
''' | |
losses = {} | |
loss = torch.tensor(0.0, device=gt_kp2d.device) | |
kp2d_conf = gt_kp2d[:, :, 2] # (B, J) | |
gt_kp2d = gt_kp2d[:, :, :2] # (B, J, 2) | |
# Special option flags. | |
f_normalize_kp2d = loss_cfg.get('f_normalize_kp2d', False) | |
if f_normalize_kp2d: | |
scale2mean = loss_cfg.get('f_normalize_kp2d_to_mean', False) | |
scale2d = estimate_kp2d_scale(gt_kp2d) # (B,) | |
pd_kp2d = pd_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2) | |
gt_kp2d = gt_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2) | |
if scale2mean: | |
scale2d_mean = scale2d.mean() | |
pd_kp2d = pd_kp2d * scale2d_mean # (B, J, 2) | |
gt_kp2d = gt_kp2d * scale2d_mean # (B, J, 2) | |
# Mask the keypoints. | |
act_jids = get_kp_active_jids(parts) | |
kp2d_conf = kp2d_conf[:, act_jids] # (B, J) | |
gt_kp2d = gt_kp2d[:, act_jids, :] # (B, J, 2) | |
pd_kp2d = pd_kp2d[:, act_jids, :] # (B, J, 2) | |
# Calculate weighted losses. | |
w_depth = loss_cfg.get('w_depth', None) | |
w_reprojection = loss_cfg.get('w_reprojection', None) | |
w_shape_prior = loss_cfg.get('w_shape_prior', None) | |
w_angle_prior = loss_cfg.get('w_angle_prior', None) | |
if w_depth: | |
assert gs_cam_z is not None, 'The guessed camera depth is required for the depth loss.' | |
depth_loss = (gs_cam_z - pd_cam_z).pow(2) # (B,) | |
loss += (w_depth ** 2) * depth_loss.mean() # (,) | |
losses['depth'] = (w_depth ** 2) * depth_loss.mean().item() # float | |
if w_reprojection: | |
reproj_err_j = gmof(pd_kp2d - gt_kp2d).sum(dim=-1) # (B, J) | |
reproj_err_j = kp2d_conf.pow(2) * reproj_err_j # (B, J) | |
reproj_loss = reproj_err_j.sum(-1) # (B,) | |
loss += (w_reprojection ** 2) * reproj_loss.mean() # (,) | |
losses['reprojection'] = (w_reprojection ** 2) * reproj_loss.mean().item() # float | |
if w_shape_prior: | |
shape_prior_loss = pd_params['betas'].pow(2).sum(dim=-1) # (B,) | |
loss += (w_shape_prior ** 2) * shape_prior_loss.mean() # (,) | |
losses['shape_prior'] = (w_shape_prior ** 2) * shape_prior_loss.mean().item() # float | |
if w_angle_prior: | |
w_angle_prior *= loss_cfg.get('w_angle_prior_scale', 1.0) | |
angle_prior_loss = compute_poses_angle_prior_loss(pd_params['poses']) # (B,) | |
loss += (w_angle_prior ** 2) * angle_prior_loss.mean() # (,) | |
losses['angle_prior'] = (w_angle_prior ** 2) * angle_prior_loss.mean().item() # float | |
losses['weighted_sum'] = loss.item() # float | |
return loss, losses |