IsshikiHugh's picture
feat: CPU demo
5ac1897
from lib.kits.basic import *
from lib.utils.camera import perspective_projection
from lib.modeling.losses import compute_poses_angle_prior_loss
from .utils import (
gmof,
guess_cam_z,
estimate_kp2d_scale,
get_kp_active_jids,
get_params_active_q_masks,
)
def build_closure(
self,
cfg,
optimizer,
inputs,
focal_length : float,
gt_kp2d,
log_data,
):
B = len(gt_kp2d)
act_parts = instantiate(cfg.parts)
act_q_masks = None
if not (act_parts == 'all' or 'all' in act_parts):
act_q_masks = get_params_active_q_masks(act_parts)
# Shortcuts for the inference of the skeleton model.
def inference_skel(inputs):
poses_active = torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1) # (B, 46)
if act_q_masks is not None:
poses_hidden = poses_active.clone().detach() # (B, 46)
poses = poses_active * act_q_masks + poses_hidden * (1 - act_q_masks) # (B, 46)
else:
poses = poses_active
skel_params = {
'poses' : poses, # (B, 46)
'betas' : inputs['betas'], # (B, 10)
}
skel_output = self.skel_model(**skel_params, skelmesh=False)
return skel_params, skel_output
# Estimate the camera depth as an initialization if depth loss is enabled.
gs_cam_z = None
if 'w_depth' in cfg.losses:
with torch.no_grad():
_, skel_output = inference_skel(inputs)
gs_cam_z = guess_cam_z(
pd_kp3d = skel_output.joints,
gt_kp2d = gt_kp2d,
focal_length = focal_length,
)
# Prepare the focal length for the perspective projection.
focal_length_xy = np.ones((B, 2)) * focal_length # (B, 2)
def closure():
optimizer.zero_grad()
# 📦 Data preparation.
with PM.time_monitor('SKEL-forward'):
skel_params, skel_output = inference_skel(inputs)
with PM.time_monitor('reproj'):
pd_kp2d = perspective_projection(
points = to_tensor(skel_output.joints, device=self.device),
translation = to_tensor(inputs['cam_t'], device=self.device),
focal_length = to_tensor(focal_length_xy, device=self.device),
)
with PM.time_monitor('compute_losses'):
loss, losses = compute_losses(
# Loss configuration.
loss_cfg = instantiate(cfg.losses),
parts = act_parts,
# Data inputs.
gt_kp2d = gt_kp2d,
pd_kp2d = pd_kp2d,
pd_params = skel_params,
pd_cam_z = inputs['cam_t'][:, 2],
gs_cam_z = gs_cam_z,
)
with PM.time_monitor('visualization'):
VISUALIZE = True
if VISUALIZE:
# For visualize the optimization process.
kp2d_err = torch.sum((pd_kp2d - gt_kp2d[..., :2]) ** 2, dim=-1) * gt_kp2d[..., 2] # (B, J)
kp2d_err = kp2d_err.sum(dim=-1) / (torch.sum(gt_kp2d[..., 2], dim=-1) + 1e-6) # (B,)
# Store logging data.
if self.tb_logger is not None:
log_data.update({
'losses' : losses,
'pd_kp2d' : pd_kp2d[:self.n_samples].detach().clone(),
'pd_verts' : skel_output.skin_verts[:self.n_samples].detach().clone(),
'cam_t' : inputs['cam_t'][:self.n_samples].detach().clone(),
'optim_betas' : inputs['betas'][:self.n_samples].detach().clone(),
'kp2d_err' : kp2d_err[:self.n_samples].detach().clone(),
})
with PM.time_monitor('backwards'):
loss.backward()
return loss.item()
return closure
def compute_losses(
loss_cfg : Dict[str, Union[bool, float]],
parts : List[str],
gt_kp2d : torch.Tensor,
pd_kp2d : torch.Tensor,
pd_params : Dict[str, torch.Tensor],
pd_cam_z : torch.Tensor,
gs_cam_z : Optional[torch.Tensor] = None,
):
'''
### Args
- loss_cfg: Dict[str, Union[bool, float]]
- Special option flags (`f_xxx`) or loss weights (`w_xxx`).
- parts: List[str]
- The list of the active joint parts groups.
- Among {'all', 'torso', 'torso-lite', 'limbs', 'head', 'limbs_proximal', 'limbs_distal'}.
- gt_kp2d: torch.Tensor (B, 44, 3)
- The ground-truth 2D keypoints with confidence.
- pd_kp2d: torch.Tensor (B, 44, 2)
- The predicted 2D keypoints.
- pd_params: Dict[str, torch.Tensor]
- poses: torch.Tensor (B, 46)
- betas: torch.Tensor (B, 10)
- pd_cam_z: torch.Tensor (B,)
- The predicted camera depth translation.
- gs_cam_z: Optional[torch.Tensor] (B,)
- The guessed camera depth translation.
### Returns
- loss: torch.Tensor (,)
- The weighted loss value for optimization.
- losses: Dict[str, float]
- The dictionary of the loss values for logging.
'''
losses = {}
loss = torch.tensor(0.0, device=gt_kp2d.device)
kp2d_conf = gt_kp2d[:, :, 2] # (B, J)
gt_kp2d = gt_kp2d[:, :, :2] # (B, J, 2)
# Special option flags.
f_normalize_kp2d = loss_cfg.get('f_normalize_kp2d', False)
if f_normalize_kp2d:
scale2mean = loss_cfg.get('f_normalize_kp2d_to_mean', False)
scale2d = estimate_kp2d_scale(gt_kp2d) # (B,)
pd_kp2d = pd_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2)
gt_kp2d = gt_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2)
if scale2mean:
scale2d_mean = scale2d.mean()
pd_kp2d = pd_kp2d * scale2d_mean # (B, J, 2)
gt_kp2d = gt_kp2d * scale2d_mean # (B, J, 2)
# Mask the keypoints.
act_jids = get_kp_active_jids(parts)
kp2d_conf = kp2d_conf[:, act_jids] # (B, J)
gt_kp2d = gt_kp2d[:, act_jids, :] # (B, J, 2)
pd_kp2d = pd_kp2d[:, act_jids, :] # (B, J, 2)
# Calculate weighted losses.
w_depth = loss_cfg.get('w_depth', None)
w_reprojection = loss_cfg.get('w_reprojection', None)
w_shape_prior = loss_cfg.get('w_shape_prior', None)
w_angle_prior = loss_cfg.get('w_angle_prior', None)
if w_depth:
assert gs_cam_z is not None, 'The guessed camera depth is required for the depth loss.'
depth_loss = (gs_cam_z - pd_cam_z).pow(2) # (B,)
loss += (w_depth ** 2) * depth_loss.mean() # (,)
losses['depth'] = (w_depth ** 2) * depth_loss.mean().item() # float
if w_reprojection:
reproj_err_j = gmof(pd_kp2d - gt_kp2d).sum(dim=-1) # (B, J)
reproj_err_j = kp2d_conf.pow(2) * reproj_err_j # (B, J)
reproj_loss = reproj_err_j.sum(-1) # (B,)
loss += (w_reprojection ** 2) * reproj_loss.mean() # (,)
losses['reprojection'] = (w_reprojection ** 2) * reproj_loss.mean().item() # float
if w_shape_prior:
shape_prior_loss = pd_params['betas'].pow(2).sum(dim=-1) # (B,)
loss += (w_shape_prior ** 2) * shape_prior_loss.mean() # (,)
losses['shape_prior'] = (w_shape_prior ** 2) * shape_prior_loss.mean().item() # float
if w_angle_prior:
w_angle_prior *= loss_cfg.get('w_angle_prior_scale', 1.0)
angle_prior_loss = compute_poses_angle_prior_loss(pd_params['poses']) # (B,)
loss += (w_angle_prior ** 2) * angle_prior_loss.mean() # (,)
losses['angle_prior'] = (w_angle_prior ** 2) * angle_prior_loss.mean().item() # float
losses['weighted_sum'] = loss.item() # float
return loss, losses