Spaces:
Sleeping
Sleeping
from lib.kits.basic import * | |
import traceback | |
from tqdm import tqdm | |
from lib.body_models.common import make_SKEL | |
from lib.body_models.skel_wrapper import SKELWrapper, SKELOutput | |
from lib.body_models.abstract_skeletons import Skeleton_OpenPose25 | |
from lib.utils.data import to_tensor, to_list | |
from lib.utils.camera import perspective_projection | |
from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img | |
from lib.utils.vis import render_mesh_overlay_img | |
from lib.modeling.losses import compute_poses_angle_prior_loss | |
from .skelify.utils import get_kp_active_j_masks | |
def compute_rel_change(prev_val: float, curr_val: float) -> float: | |
''' | |
Compute the relative change between two values. | |
Copied: from https://github.com/vchoutas/smplify-x | |
### Args: | |
- prev_val: float | |
- curr_val: float | |
### Returns: | |
- float | |
''' | |
return np.abs(prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1]) | |
def gmof(x, sigma): | |
''' | |
Geman-McClure error function, to be used as a robust loss function. | |
''' | |
x_squared = x ** 2 | |
sigma_squared = sigma ** 2 | |
return (sigma_squared * x_squared) / (sigma_squared + x_squared) | |
class SKELifyRefiner(): | |
def __init__(self, cfg, name='SKELify', tb_logger=None, device='cuda:0'): | |
self.cfg = cfg | |
self.name = name | |
self.eq_thre = cfg.early_quit_thresholds | |
self.tb_logger = tb_logger | |
self.device = device | |
self.skel_model = instantiate(cfg.skel_model).to(device) | |
# Dirty implementation for visualization. | |
self.render_frames = [] | |
def __call__( | |
self, | |
gt_kp2d : Union[torch.Tensor, np.ndarray], | |
init_poses : Union[torch.Tensor, np.ndarray], | |
init_betas : Union[torch.Tensor, np.ndarray], | |
init_cam_t : Union[torch.Tensor, np.ndarray], | |
img_patch : Optional[np.ndarray] = None, | |
**kwargs | |
): | |
''' | |
Use optimization to fit the SKEL parameters to the 2D keypoints. | |
### Args: | |
- gt_kp2d : torch.Tensor or np.ndarray, (B, J, 3) | |
- The last three dim means [x, y, conf]. | |
- The 2D keypoints to fit, they are defined in [-0.5, 0.5], zero-centered space. | |
- init_poses : torch.Tensor or np.ndarray, (B, 46) | |
- init_betas : torch.Tensor or np.ndarray, (B, 10) | |
- init_cam_t : torch.Tensor or np.ndarray, (B, 3) | |
- img_patch : np.ndarray or None, (B, H, W, 3) | |
- The image patch for visualization. H, W are defined in normalized bounding box space. | |
- If None, the visualization will simply use a black image. | |
### Returns: | |
- TODO: | |
''' | |
# ⛩️ Prepare the input data. | |
gt_kp2d = to_tensor(gt_kp2d, device=self.device).detach().float().clone() # (B, J, 3) | |
init_poses = to_tensor(init_poses, device=self.device).detach().float().clone() # (B, 46) | |
init_betas = to_tensor(init_betas, device=self.device).detach().float().clone() # (B, 10) | |
init_cam_t = to_tensor(init_cam_t, device=self.device).detach().float().clone() # (B, 3) | |
inputs = { | |
'poses_orient': init_poses[:, :3], # (B, 3) | |
'poses_body' : init_poses[:, 3:], # (B, 43) | |
'betas' : init_betas, # (B, 10) | |
'cam_t' : init_cam_t, # (B, 3) | |
} | |
focal_length = np.ones(2) * self.cfg.focal_length / self.cfg.img_patch_size | |
focal_length = focal_length.reshape(1, 2).repeat(inputs['cam_t'].shape[0], 1) | |
# ⛩️ Optimization phases, controlled by config file. | |
prev_phase_steps = 0 # accumulate the steps are *supposed* to be done in the previous phases | |
for phase_id, phase_name in enumerate(self.cfg.phases): | |
phase_cfg = self.cfg.phases[phase_name] | |
# Preparation. | |
optim_params = [] | |
for k in inputs.keys(): | |
if k in phase_cfg.params_keys: | |
inputs[k].requires_grad = True | |
optim_params.append(inputs[k]) # (B, D) | |
else: | |
inputs[k].requires_grad = False | |
optimizer = instantiate(phase_cfg.optimizer, optim_params, _recursive_=True) | |
def closure(): | |
optimizer.zero_grad() | |
# Data preparation. | |
cam_t = inputs['cam_t'] | |
skel_params = { | |
'poses' : torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1), # (B, 46) | |
'betas' : inputs['betas'], # (B, 10) | |
'skelmesh' : False, | |
} | |
# Optimize steps. | |
skel_output = self.skel_model(**skel_params) | |
pd_kp2d = perspective_projection( | |
points = to_tensor(skel_output.joints, device=self.device), | |
translation = to_tensor(cam_t, device=self.device), | |
focal_length = to_tensor(focal_length, device=self.device), | |
) | |
loss, losses = self._compute_losses( | |
act_losses = phase_cfg.losses, | |
act_parts = phase_cfg.get('parts', 'all'), | |
gt_kp2d = gt_kp2d, | |
pd_kp2d = pd_kp2d, | |
pd_params = skel_params, | |
**phase_cfg.get('weights', {}), | |
) | |
# For visualize the optimization process. | |
_conf = gt_kp2d[..., 2] # (B, J) | |
metric = torch.sum((pd_kp2d - gt_kp2d[..., :2]) ** 2, dim=-1) * _conf # (B, J) | |
metric = metric.sum(dim=-1) / (torch.sum(_conf, dim=-1) + 1e-6) # (B,) | |
# Store logging data. | |
if self.tb_logger is not None: | |
log_data.update({ | |
'losses' : losses, | |
'pd_kp2d' : pd_kp2d[:self.cfg.logger.samples_per_record].detach().clone(), | |
'pd_verts' : skel_output.skin_verts[:self.cfg.logger.samples_per_record].detach().clone(), | |
'cam_t' : cam_t[:self.cfg.logger.samples_per_record].detach().clone(), | |
'metric' : metric[:self.cfg.logger.samples_per_record].detach().clone(), | |
'optim_betas' : inputs['betas'][:self.cfg.logger.samples_per_record].detach().clone(), | |
}) | |
loss.backward() | |
return loss.item() | |
# Optimization loop. | |
prev_loss = None | |
with tqdm(range(phase_cfg.max_loop)) as bar: | |
bar.set_description(f'[{phase_name}] Loss: ???') | |
for i in bar: | |
log_data = {} | |
curr_loss = optimizer.step(closure) | |
# Logging. | |
if self.tb_logger is not None: | |
log_data.update({ | |
'img_patch' : img_patch[:self.cfg.logger.samples_per_record] if img_patch is not None else None, | |
'gt_kp2d' : gt_kp2d[:self.cfg.logger.samples_per_record].detach().clone(), | |
}) | |
self._tb_log(prev_phase_steps + i, log_data) | |
# self._tb_log_for_report(prev_phase_steps + i, log_data) | |
bar.set_description(f'[{phase_name}] Loss: {curr_loss:.4f}') | |
if self._can_early_quit(optim_params, prev_loss, curr_loss): | |
break | |
prev_loss = curr_loss | |
prev_phase_steps += phase_cfg.max_loop | |
# ⛩️ Prepare the output data. | |
outputs = { | |
'poses': torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1).detach().clone(), # (B, 46) | |
'betas': inputs['betas'].detach().clone(), # (B, 10) | |
'cam_t': inputs['cam_t'].detach().clone(), # (B, 3) | |
} | |
return outputs | |
def _compute_losses( | |
self, | |
act_losses : List[str], | |
act_parts : List[str], | |
gt_kp2d : torch.Tensor, | |
pd_kp2d : torch.Tensor, | |
pd_params : Dict, | |
robust_sigma : float = 100, | |
shape_prior_weight : float = 5, | |
angle_prior_weight : float = 15.2, | |
*args, **kwargs, | |
): | |
''' | |
Compute the weighted losses according to the config file. | |
Follow: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/smplify/losses.py#L26-L58s | |
''' | |
B = len(gt_kp2d) | |
act_j_masks = get_kp_active_j_masks(act_parts, device=gt_kp2d.device) # (44,) | |
# Reproject the 3D keypoints to image and compare the L2 error with the g.t. 2D keypoints. | |
kp_conf = gt_kp2d[..., 2] # (B, J) | |
gt_kp2d = gt_kp2d[..., :2] # (B, J, 2) | |
reproj_err = gmof(pd_kp2d - gt_kp2d, robust_sigma) # (B, J, 2) | |
reproj_loss = ((kp_conf ** 2) * reproj_err.sum(dim=-1) * act_j_masks[None]).sum(-1) # (B,) | |
# Regularize the shape parameters. | |
shape_prior_loss = (shape_prior_weight ** 2) * (pd_params['betas'] ** 2).sum(dim=-1) # (B,) | |
# Use the SKEL angle prior knowledge (e.g., rotation limitation) to regularize the optimization process. | |
# TODO: Is that necessary? | |
angle_prior_loss = (angle_prior_weight ** 2) * compute_poses_angle_prior_loss(pd_params['poses']).mean() # (,) | |
losses = { | |
'reprojection' : reproj_loss.mean(), # (,) | |
'shape_prior' : shape_prior_loss.mean(), # (,) | |
'angle_prior' : angle_prior_loss, # (,) | |
} | |
loss = torch.tensor(0., device=gt_kp2d.device) | |
for k in act_losses: | |
loss += losses[k] | |
losses = {k: v.detach() for k, v in losses.items()} | |
losses['sum'] = loss.detach() # (,) | |
return loss, losses | |
def _can_early_quit(self, opt_params, prev_loss, curr_loss): | |
''' Judge whether to early quit the optimization process. If yes, return True, otherwise False.''' | |
if self.cfg.early_quit_thresholds is None: | |
# Never early quit. | |
return False | |
# Relative change test. | |
if prev_loss is not None: | |
loss_rel_change = compute_rel_change(prev_loss, curr_loss) | |
if loss_rel_change < self.cfg.early_quit_thresholds.rel: | |
get_logger().info(f'Early quit due to relative change: {loss_rel_change:.4f} = rel({prev_loss}, {curr_loss})') | |
return True | |
# Absolute change test. | |
if all([ | |
torch.abs(param.grad.max()).item() < self.cfg.early_quit_thresholds.abs | |
for param in opt_params if param.grad is not None | |
]): | |
get_logger().info(f'Early quit due to absolute change.') | |
return True | |
return False | |
def _tb_log(self, step_cnt:int, log_data:Dict, *args, **kwargs): | |
''' Write the logging information to the TensorBoard. ''' | |
if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval != 0: | |
return | |
summary_writer = self.tb_logger.experiment | |
# Save losses. | |
for loss_name, loss_val in log_data['losses'].items(): | |
summary_writer.add_scalar(f'skelify/{loss_name}', loss_val.detach().item(), step_cnt) | |
# Visualization of the optimization process. TODO: Maybe we can make this more elegant. | |
if log_data['img_patch'] is None: | |
log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \ | |
* len(log_data['gt_kp2d']) | |
if len(self.render_frames) < 1: | |
self.init_v = log_data['pd_verts'] | |
self.init_metric = log_data['metric'] | |
self.init_ct = log_data['cam_t'] | |
# Overlay the skin mesh of the results on the original image. | |
try: | |
imgs_spliced = [] | |
for i, img_patch in enumerate(log_data['img_patch']): | |
metric = log_data['metric'][i].item() | |
img_with_init = render_mesh_overlay_img( | |
faces = self.skel_model.skin_f, | |
verts = self.init_v[i], | |
K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], | |
img = img_patch, | |
Rt = [torch.eye(3), self.init_ct[i]], | |
mesh_color = 'pink', | |
) | |
img_with_init = annotate_img(img_with_init, 'init') | |
img_with_init = annotate_img(img_with_init, f'Quality: {self.init_metric[i].item()*1000:.3f}/1e3', pos='tl') | |
img_with_mesh = render_mesh_overlay_img( | |
faces = self.skel_model.skin_f, | |
verts = log_data['pd_verts'][i], | |
K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], | |
img = img_patch, | |
Rt = [torch.eye(3), log_data['cam_t'][i]], | |
mesh_color = 'pink', | |
) | |
img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh') | |
betas_max = log_data['optim_betas'][i].abs().max().item() | |
img_with_mesh = annotate_img(img_with_mesh, f'Quality: {metric*1000:.3f}/1e3\nbetas_max: {betas_max:.3f}', pos='tl') | |
img_patch_raw = annotate_img(img_patch, 'raw') | |
log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size | |
img_with_gt = annotate_img(img_patch, 'gt_kp2d') | |
img_with_gt = draw_kp2d_on_img( | |
img_with_gt, | |
log_data['gt_kp2d'][i], | |
Skeleton_OpenPose25.bones, | |
Skeleton_OpenPose25.bone_colors, | |
) | |
log_data['pd_kp2d'][i] = (log_data['pd_kp2d'][i] + 0.5) * self.cfg.img_patch_size | |
img_with_pd = annotate_img(img_patch, 'pd_kp2d') | |
img_with_pd = draw_kp2d_on_img( | |
img_with_pd, | |
log_data['pd_kp2d'][i], | |
Skeleton_OpenPose25.bones, | |
Skeleton_OpenPose25.bone_colors, | |
) | |
img_spliced = splice_img( | |
img_grids = [img_patch_raw, img_with_gt, img_with_pd, img_with_init, img_with_mesh], | |
# grid_ids = [[0, 1, 2, 3, 4]], | |
grid_ids = [[1, 2, 3, 4]], | |
) | |
imgs_spliced.append(img_spliced) | |
img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))]) | |
img_final = to_tensor(img_final, device=None).permute(2, 0, 1) # (3, H, W) | |
summary_writer.add_image('skelify/visualization', img_final, step_cnt) | |
self.render_frames.append(img_final) | |
except Exception as e: | |
get_logger().error(f'Failed to visualize the optimization process: {e}') | |
# traceback.print_exc() | |
def _tb_log_for_report(self, step_cnt:int, log_data:Dict, *args, **kwargs): | |
''' Write the logging information to the TensorBoard. ''' | |
get_logger().warning(f'This logging functions is just for presentation.') | |
if len(self.render_frames) < 1: | |
self.init_v = log_data['pd_verts'] | |
self.init_ct = log_data['cam_t'] | |
if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval != 0: | |
return | |
summary_writer = self.tb_logger.experiment | |
# Save losses. | |
for loss_name, loss_val in log_data['losses'].items(): | |
summary_writer.add_scalar(f'losses/{loss_name}', loss_val.detach().item(), step_cnt) | |
# Visualization of the optimization process. TODO: Maybe we can make this more elegant. | |
if log_data['img_patch'] is None: | |
log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \ | |
* len(log_data['gt_kp2d']) | |
# Overlay the skin mesh of the results on the original image. | |
try: | |
imgs_spliced = [] | |
for i, img_patch in enumerate(log_data['img_patch']): | |
img_with_init = render_mesh_overlay_img( | |
faces = self.skel_model.skin_f, | |
verts = self.init_v[i], | |
K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], | |
img = img_patch, | |
Rt = [torch.eye(3), self.init_ct[i]], | |
mesh_color = 'pink', | |
) | |
img_with_init = annotate_img(img_with_init, 'init') | |
img_with_mesh = render_mesh_overlay_img( | |
faces = self.skel_model.skin_f, | |
verts = log_data['pd_verts'][i], | |
K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], | |
img = img_patch, | |
Rt = [torch.eye(3), log_data['cam_t'][i]], | |
mesh_color = 'pink', | |
) | |
img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh') | |
img_patch_raw = annotate_img(img_patch, 'raw') | |
log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size | |
img_with_gt = annotate_img(img_patch, 'gt_kp2d') | |
img_with_gt = draw_kp2d_on_img( | |
img_with_gt, | |
log_data['gt_kp2d'][i], | |
Skeleton_OpenPose25.bones, | |
Skeleton_OpenPose25.bone_colors, | |
) | |
img_spliced = splice_img([img_patch_raw, img_with_gt, img_with_init, img_with_mesh], grid_ids=[[0, 1, 2, 3]]) | |
imgs_spliced.append(img_spliced) | |
img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))]) | |
img_final = to_tensor(img_final, device=None).permute(2, 0, 1) | |
summary_writer.add_image('visualization', img_final, step_cnt) | |
self.render_frames.append(img_final) | |
except Exception as e: | |
get_logger().error(f'Failed to visualize the optimization process: {e}') | |
traceback.print_exc() | |