from lib.kits.basic import * from lib.utils.vis import Wis3D from lib.utils.vis.py_renderer import render_mesh_overlay_img from lib.utils.data import to_tensor from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img from lib.utils.camera import perspective_projection from lib.body_models.abstract_skeletons import Skeleton_OpenPose25 from lib.modeling.losses import * from lib.modeling.networks.discriminators import HSMRDiscriminator from lib.platform.config_utils import get_PM_info_dict def build_inference_pipeline( model_root: Union[Path, str], ckpt_fn : Optional[Union[Path, str]] = None, tuned_bcb : bool = True, device : str = 'cpu', ): # 1.1. Load the config file. if isinstance(model_root, str): model_root = Path(model_root) cfg_path = model_root / '.hydra' / 'config.yaml' cfg = OmegaConf.load(cfg_path) # 1.2. Override PM info dict. PM_overrides = get_PM_info_dict()._pm_ cfg._pm_ = PM_overrides get_logger(brief=True).info(f'Building inference pipeline of {cfg.exp_name}') # 2.1. Instantiate the pipeline. init_bcb = not tuned_bcb pipeline = instantiate(cfg.pipeline, init_backbone=init_bcb, _recursive_=False) pipeline.set_data_adaption(data_module_name='IMG_PATCHES') # 2.2. Load the checkpoint. if ckpt_fn is None: ckpt_fn = model_root / 'checkpoints' / 'hsmr.ckpt' pipeline.load_state_dict(torch.load(ckpt_fn, map_location=device)['state_dict']) get_logger(brief=True).info(f'Load checkpoint from {ckpt_fn}.') pipeline.eval() return pipeline.to(device) class HSMRPipeline(pl.LightningModule): def __init__(self, cfg:DictConfig, name:str, init_backbone=True): super(HSMRPipeline, self).__init__() self.name = name self.skel_model = instantiate(cfg.SKEL) self.backbone = instantiate(cfg.backbone) self.head = instantiate(cfg.head) self.cfg = cfg if init_backbone: # For inference mode with tuned backbone checkpoints, we don't need to initialize the backbone here. self._init_backbone() # Loss layers. self.kp_3d_loss = Keypoint3DLoss(loss_type='l1') self.kp_2d_loss = Keypoint2DLoss(loss_type='l1') self.params_loss = ParameterLoss() # Discriminator. self.enable_disc = self.cfg.loss_weights.get('adversarial', 0) > 0 if self.enable_disc: self.discriminator = HSMRDiscriminator() get_logger().warning(f'Discriminator enabled, the global_steps will be doubled. Use the checkpoints carefully.') else: self.discriminator = None self.cfg.loss_weights.pop('adversarial', None) # pop the adversarial term if not enabled # Manually control the optimization since we have an adversarial process. self.automatic_optimization = False self.set_data_adaption() # For visualization debug. if False: self.wis3d = Wis3D(seq_name=PM.cfg.exp_name) else: self.wis3d = None def set_data_adaption(self, data_module_name:Optional[str]=None): if data_module_name is None: # get_logger().warning('Data adapter schema is not defined. The input will be regarded as image patches.') self.adapt_batch = self._adapt_img_inference elif data_module_name == 'IMG_PATCHES': self.adapt_batch = self._adapt_img_inference elif data_module_name.startswith('SKEL_HSMR_V1'): self.adapt_batch = self._adapt_hsmr_v1 else: raise ValueError(f'Unknown data module: {data_module_name}') def print_summary(self, max_depth=1): from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary print(ModelSummary(self, max_depth=max_depth)) def configure_optimizers(self): optimizers = [] params_main = filter(lambda p: p.requires_grad, self._params_main()) optimizer_main = instantiate(self.cfg.optimizer, params=params_main) optimizers.append(optimizer_main) if len(self._params_disc()) > 0: params_disc = filter(lambda p: p.requires_grad, self._params_disc()) optimizer_disc = instantiate(self.cfg.optimizer, params=params_disc) optimizers.append(optimizer_disc) return optimizers def training_step(self, raw_batch, batch_idx): with PM.time_monitor('training_step'): return self._training_step(raw_batch, batch_idx) def _training_step(self, raw_batch, batch_idx): # GPU_monitor = GPUMonitor() # GPU_monitor.snapshot('HSMR training start') batch = self.adapt_batch(raw_batch['img_ds']) # GPU_monitor.snapshot('HSMR adapt batch') # Get the optimizer. optimizers = self.optimizers(use_pl_optimizer=True) if isinstance(optimizers, List): optimizer_main, optimizer_disc = optimizers else: optimizer_main = optimizers # GPU_monitor.snapshot('HSMR get optimizer') # 1. Main parts forward pass. with PM.time_monitor('forward_step'): img_patch = to_tensor(batch['img_patch'], self.device) # (B, C, H, W) B = len(img_patch) outputs = self.forward_step(img_patch) # {...} # GPU_monitor.snapshot('HSMR forward') pd_skel_params = HSMRPipeline._adapt_skel_params(outputs['pd_params']) # GPU_monitor.snapshot('HSMR adapt SKEL params') # 2. [Optional] Discriminator forward pass in main training step. if self.enable_disc: with PM.time_monitor('disc_forward'): pd_poses_mat, _ = self.skel_model.pose_params_to_rot(pd_skel_params['poses']) # (B, J=24, 3, 3) pd_poses_body_mat = pd_poses_mat[:, 1:, :, :] # (B, J=23, 3, 3) pd_betas = pd_skel_params['betas'] # (B, 10) disc_out = self.discriminator( poses_body = pd_poses_body_mat, # (B, J=23, 3, 3) betas = pd_betas, # (B, 10) ) else: disc_out = None # 3. Prepare the secondary products with PM.time_monitor('Secondary Products Preparation'): # 3.1. Body model outputs. with PM.time_monitor('SKEL Forward'): skel_outputs = self.skel_model(**pd_skel_params, skelmesh=False) pd_kp3d = skel_outputs.joints # (B, Q=44, 3) pd_skin = skel_outputs.skin_verts # (B, V=6890, 3) # 3.2. Reproject the 3D joints to 2D plain. with PM.time_monitor('Reprojection'): pd_kp2d = perspective_projection( points = pd_kp3d, # (B, K=Q=44, 3) translation = outputs['pd_cam_t'], # (B, 3) focal_length = outputs['focal_length'] / self.cfg.policy.img_patch_size, # (B, 2) ) # (B, 44, 2) # 3.3. Extract G.T. from inputs. gt_kp2d_with_conf = batch['kp2d'].clone() # (B, 44, 3) gt_kp3d_with_conf = batch['kp3d'].clone() # (B, 44, 4) # 3.4. Extract G.T. skin mesh only for visualization. gt_skel_params = HSMRPipeline._adapt_skel_params(batch['gt_params']) # {poses, betas} gt_skel_params = {k: v[:self.cfg.logger.samples_per_record] for k, v in gt_skel_params.items()} skel_outputs = self.skel_model(**gt_skel_params, skelmesh=False) gt_skin = skel_outputs.skin_verts # (B', V=6890, 3) gt_valid_body = batch['has_gt_params']['poses_body'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas} gt_valid_orient = batch['has_gt_params']['poses_orient'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas} gt_valid_betas = batch['has_gt_params']['betas'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas} gt_valid = torch.logical_and(torch.logical_and(gt_valid_body, gt_valid_orient), gt_valid_betas) # GPU_monitor.snapshot('HSMR secondary products') # 4. Compute losses. with PM.time_monitor('Compute Loss'): loss_main, losses_main = self._compute_losses_main( self.cfg.loss_weights, pd_kp3d, # (B, 44, 3) gt_kp3d_with_conf, # (B, 44, 4) pd_kp2d, # (B, 44, 2) gt_kp2d_with_conf, # (B, 44, 3) outputs['pd_params'], # {'poses_orient':..., 'poses_body':..., 'betas':...} batch['gt_params'], # {'poses_orient':..., 'poses_body':..., 'betas':...} batch['has_gt_params'], disc_out, ) # GPU_monitor.snapshot('HSMR compute losses') if torch.isnan(loss_main): get_logger().error(f'NaN detected in loss computation. Losses: {losses}') # 5. Main parts backward pass. with PM.time_monitor('Backward Step'): optimizer_main.zero_grad() self.manual_backward(loss_main) optimizer_main.step() # GPU_monitor.snapshot('HSMR backwards') # 6. [Optional] Discriminator training part. if self.enable_disc: with PM.time_monitor('Train Discriminator'): losses_disc = self._train_discriminator( mocap_batch = raw_batch['mocap_ds'], pd_poses_body_mat = pd_poses_body_mat, pd_betas = pd_betas, optimizer = optimizer_disc, ) else: losses_disc = {} # 7. Logging. with PM.time_monitor('Tensorboard Logging'): vis_data = { 'img_patch' : to_numpy(img_patch[:self.cfg.logger.samples_per_record]).transpose((0, 2, 3, 1)).copy(), 'pd_kp2d' : pd_kp2d[:self.cfg.logger.samples_per_record].clone(), 'pd_kp3d' : pd_kp3d[:self.cfg.logger.samples_per_record].clone(), 'gt_kp2d_with_conf' : gt_kp2d_with_conf[:self.cfg.logger.samples_per_record].clone(), 'gt_kp3d_with_conf' : gt_kp3d_with_conf[:self.cfg.logger.samples_per_record].clone(), 'pd_skin' : pd_skin[:self.cfg.logger.samples_per_record].clone(), 'gt_skin' : gt_skin.clone(), 'gt_skin_valid' : gt_valid, 'cam_t' : outputs['pd_cam_t'][:self.cfg.logger.samples_per_record].clone(), 'img_key' : batch['__key__'][:self.cfg.logger.samples_per_record], } self._tb_log(losses_main=losses_main, losses_disc=losses_disc, vis_data=vis_data) # GPU_monitor.snapshot('HSMR logging') self.log('_/loss_main', losses_main['weighted_sum'], on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=B) # GPU_monitor.report_all() return outputs def forward(self, batch): ''' ### Returns - outputs: Dict - pd_kp3d: torch.Tensor, shape (B, Q=44, 3) - pd_kp2d: torch.Tensor, shape (B, Q=44, 2) - pred_keypoints_2d: torch.Tensor, shape (B, Q=44, 2) - pred_keypoints_3d: torch.Tensor, shape (B, Q=44, 3) - pd_params: Dict - poses: torch.Tensor, shape (B, 46) - betas: torch.Tensor, shape (B, 10) - pd_cam: torch.Tensor, shape (B, 3) - pd_cam_t: torch.Tensor, shape (B, 3) - focal_length: torch.Tensor, shape (B, 2) ''' batch = self.adapt_batch(batch) # 1. Main parts forward pass. img_patch = to_tensor(batch['img_patch'], self.device) # (B, C, H, W) outputs = self.forward_step(img_patch) # {...} # 2. Prepare the secondary products # 2.1. Body model outputs. pd_skel_params = HSMRPipeline._adapt_skel_params(outputs['pd_params']) skel_outputs = self.skel_model(**pd_skel_params, skelmesh=False) pd_kp3d = skel_outputs.joints # (B, Q=44, 3) pd_skin_verts = skel_outputs.skin_verts.detach().cpu().clone() # (B, V=6890, 3) # 2.2. Reproject the 3D joints to 2D plain. pd_kp2d = perspective_projection( points = to_tensor(pd_kp3d, device=self.device), # (B, K=Q=44, 3) translation = to_tensor(outputs['pd_cam_t'], device=self.device), # (B, 3) focal_length = to_tensor(outputs['focal_length'], device=self.device) / self.cfg.policy.img_patch_size, # (B, 2) ) outputs['pd_kp3d'] = pd_kp3d outputs['pd_kp2d'] = pd_kp2d outputs['pred_keypoints_2d'] = pd_kp2d # adapt HMR2.0's script outputs['pred_keypoints_3d'] = pd_kp3d # adapt HMR2.0's script outputs['pd_params'] = pd_skel_params outputs['pd_skin_verts'] = pd_skin_verts return outputs def forward_step(self, x:torch.Tensor): ''' Run an inference step on the model. ### Args - x: torch.Tensor, shape (B, C, H, W) - The input image patch. ### Returns - outputs: Dict - 'pd_cam': torch.Tensor, shape (B, 3) - The predicted camera parameters. - 'pd_params': Dict - The predicted body model parameters. - 'focal_length': float ''' # GPU_monitor = GPUMonitor() B = len(x) # 1. Extract features from image. # The input size is 256*256, but ViT needs 256*192. TODO: make this more elegant. with PM.time_monitor('Backbone Forward'): feats = self.backbone(x[:, :, :, 32:-32]) # GPU_monitor.snapshot('HSMR forward backbone') # 2. Run the head to predict the body model parameters. with PM.time_monitor('Predict Head Forward'): pd_params, pd_cam = self.head(feats) # GPU_monitor.snapshot('HSMR forward head') # 3. Transform the camera parameters to camera translation. focal_length = self.cfg.policy.focal_length * torch.ones(B, 2, device=self.device, dtype=pd_cam.dtype) # (B, 2) pd_cam_t = torch.stack([ pd_cam[:, 1], pd_cam[:, 2], 2 * focal_length[:, 0] / (self.cfg.policy.img_patch_size * pd_cam[:, 0] + 1e-9) ], dim=-1) # (B, 3) # 4. Store the results. outputs = { 'pd_cam' : pd_cam, 'pd_cam_t' : pd_cam_t, 'pd_params' : pd_params, # 'pd_params' : {k: v.clone() for k, v in pd_params.items()}, 'focal_length' : focal_length, # (B, 2) } # GPU_monitor.report_all() return outputs # ========== Internal Functions ========== def _params_main(self): return list(self.head.parameters()) + list(self.backbone.parameters()) def _params_disc(self): if self.discriminator is None: return [] else: return list(self.discriminator.parameters()) @staticmethod def _adapt_skel_params(params:Dict): ''' Change the parameters formed like [pose_orient, pose_body, betas, trans] to [poses, betas, trans]. ''' adapted_params = {} if 'poses' in params.keys(): adapted_params['poses'] = params['poses'] elif 'poses_orient' in params.keys() and 'poses_body' in params.keys(): poses_orient = params['poses_orient'] # (B, 3) poses_body = params['poses_body'] # (B, 43) adapted_params['poses'] = torch.cat([poses_orient, poses_body], dim=1) # (B, 46) else: raise ValueError(f'Cannot find the poses parameters among {list(params.keys())}.') if 'betas' in params.keys(): adapted_params['betas'] = params['betas'] # (B, 10) else: raise ValueError(f'Cannot find the betas parameters among {list(params.keys())}.') return adapted_params def _init_backbone(self): # 1. Loading the backbone weights. get_logger().info(f'Loading backbone weights from {self.cfg.backbone_ckpt}') state_dict = torch.load(self.cfg.backbone_ckpt, map_location='cpu')['state_dict'] if 'backbone.cls_token' in state_dict.keys(): state_dict = {k: v for k, v in state_dict.items() if 'backbone' in k and 'cls_token' not in k} state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()} missing, unexpected = self.backbone.load_state_dict(state_dict) if len(missing) > 0: get_logger().warning(f'Missing keys in backbone: {missing}') if len(unexpected) > 0: get_logger().warning(f'Unexpected keys in backbone: {unexpected}') # 2. Freeze the backbone if needed. if self.cfg.get('freeze_backbone', False): self.backbone.eval() self.backbone.requires_grad_(False) def _compute_losses_main( self, loss_weights : Dict, pd_kp3d : torch.Tensor, gt_kp3d : torch.Tensor, pd_kp2d : torch.Tensor, gt_kp2d : torch.Tensor, pd_params : Dict, gt_params : Dict, has_params : Dict, disc_out : Optional[torch.Tensor]=None, *args, **kwargs, ) -> Tuple[torch.Tensor, Dict]: ''' Compute the weighted losses according to the config file. ''' # 1. Preparation. with PM.time_monitor('Preparation'): B = len(pd_kp3d) gt_skel_params = HSMRPipeline._adapt_skel_params(gt_params) # {poses, betas} pd_skel_params = HSMRPipeline._adapt_skel_params(pd_params) # {poses, betas} gt_betas = gt_skel_params['betas'].reshape(-1, 10) pd_betas = pd_skel_params['betas'].reshape(-1, 10) gt_poses = gt_skel_params['poses'].reshape(-1, 46) pd_poses = pd_skel_params['poses'].reshape(-1, 46) # 2. Keypoints losses. with PM.time_monitor('kp2d & kp3d Loss'): kp2d_loss = self.kp_2d_loss(pd_kp2d, gt_kp2d) / B kp3d_loss = self.kp_3d_loss(pd_kp3d, gt_kp3d) / B # 3. Prior losses. with PM.time_monitor('Prior Loss'): prior_loss = compute_poses_angle_prior_loss(pd_poses).mean() # (,) # 4. Parameters losses. if self.cfg.sp_poses_repr == 'rotation_matrix': with PM.time_monitor('q2mat'): gt_poses_mat, _ = self.skel_model.pose_params_to_rot(gt_poses) # (B, J=24, 3, 3) pd_poses_mat, _ = self.skel_model.pose_params_to_rot(pd_poses) # (B, J=24, 3, 3) gt_poses = gt_poses_mat.reshape(-1, 24*3*3) # (B, 24*3*3) pd_poses = pd_poses_mat.reshape(-1, 24*3*3) # (B, 24*3*3) with PM.time_monitor('Parameters Loss'): poses_orient_loss = self.params_loss(pd_poses[:, :9], gt_poses[:, :9], has_params['poses_orient']) / B poses_body_loss = self.params_loss(pd_poses[:, 9:], gt_poses[:, 9:], has_params['poses_body']) / B betas_loss = self.params_loss(pd_betas, gt_betas, has_params['betas']) / B # 5. Collect main losses. with PM.time_monitor('Accumulate'): losses = { 'kp3d' : kp3d_loss, # (,) 'kp2d' : kp2d_loss, # (,) 'prior' : prior_loss, # (,) 'poses_orient' : poses_orient_loss, # (,) 'poses_body' : poses_body_loss, # (,) 'betas' : betas_loss, # (,) } # 6. Consider adversarial loss. if disc_out is not None: with PM.time_monitor('Adversarial Loss'): adversarial_loss = ((disc_out - 1.0) ** 2).sum() / B # (,) losses['adversarial'] = adversarial_loss with PM.time_monitor('Accumulate'): loss = torch.tensor(0., device=self.device) for k, v in losses.items(): loss += v * loss_weights[k] losses = {k: v.item() for k, v in losses.items()} losses['weighted_sum'] = loss.item() return loss, losses def _train_discriminator(self, mocap_batch, pd_poses_body_mat, pd_betas, optimizer): ''' Train the discriminator using the regressed body model parameters and the realistic MoCap data. ### Args - mocap_batch: Dict - 'poses_body': torch.Tensor, shape (B, 43) - 'betas': torch.Tensor, shape (B, 10) - pd_poses_body_mat: torch.Tensor, shape (B, J=23, 3, 3) - pd_betas: torch.Tensor, shape (B, 10) - optimizer: torch.optim.Optimizer ### Returns - losses: Dict - 'pd_disc': float - 'mc_disc': float ''' pd_B = len(pd_poses_body_mat) mc_B = len(mocap_batch['poses_body']) get_logger().warning(f'pd_B: {pd_B} != mc_B: {mc_B}') # 1. Extract the realistic 3D MoCap label. mc_poses_body = mocap_batch['poses_body'] # (B, 43) padding_zeros = mc_poses_body.new_zeros(mc_B, 3) # (B, 3) mc_poses = torch.cat([padding_zeros, mc_poses_body], dim=1) # (B, 46) mc_poses_mat, _ = self.skel_model.pose_params_to_rot(mc_poses) # (B, J=24, 3, 3) mc_poses_body_mat = mc_poses_mat[:, 1:, :, :] # (B, J=23, 3, 3) mc_betas = mocap_batch['betas'] # (B, 10) # 2. Forward pass. # Discriminator forward pass for the predicted data. pd_disc_out = self.discriminator(pd_poses_body_mat.detach(), pd_betas.detach()) pd_disc_loss = ((pd_disc_out - 0.0) ** 2).sum() / pd_B # (,) # Discriminator forward pass for the realistic MoCap data. mc_disc_out = self.discriminator(mc_poses_body_mat, mc_betas) mc_disc_loss = ((mc_disc_out - 1.0) ** 2).sum() / pd_B # (,) TODO: This 'pd_B' is from HMR2, not sure if it's a bug. # 3. Backward pass. disc_loss = self.cfg.loss_weights.adversarial * (pd_disc_loss + mc_disc_loss) optimizer.zero_grad() self.manual_backward(disc_loss) optimizer.step() return { 'pd_disc': pd_disc_loss.item(), 'mc_disc': mc_disc_loss.item(), } @rank_zero_only def _tb_log(self, losses_main:Dict, losses_disc:Dict, vis_data:Dict, mode:str='train'): ''' Write the logging information to the TensorBoard. ''' if self.logger is None: return if self.global_step != 1 and self.global_step % self.cfg.logger.interval != 0: return # 1. Losses. summary_writer = self.logger.experiment for loss_name, loss_val in losses_main.items(): summary_writer.add_scalar(f'{mode}/losses_main/{loss_name}', loss_val, self.global_step) for loss_name, loss_val in losses_disc.items(): summary_writer.add_scalar(f'{mode}/losses_disc/{loss_name}', loss_val, self.global_step) # 2. Visualization. try: pelvis_id = 39 # 2.1. Visualize 3D information. self.wis3d.add_motion_mesh( verts = vis_data['pd_skin'] - vis_data['pd_kp3d'][:, pelvis_id:pelvis_id+1], # center the mesh faces = self.skel_model.skin_f, name = 'pd_skin', ) self.wis3d.add_motion_mesh( verts = vis_data['gt_skin'] - vis_data['gt_kp3d_with_conf'][:, pelvis_id:pelvis_id+1, :3], # center the mesh faces = self.skel_model.skin_f, name = 'gt_skin', ) self.wis3d.add_motion_skel( joints = vis_data['pd_kp3d'] - vis_data['pd_kp3d'][:, pelvis_id:pelvis_id+1], bones = Skeleton_OpenPose25.bones, colors = Skeleton_OpenPose25.bone_colors, name = 'pd_kp3d', ) aligned_gt_kp3d = vis_data['gt_kp3d_with_conf'] aligned_gt_kp3d[..., :3] -= vis_data['gt_kp3d_with_conf'][:, pelvis_id:pelvis_id+1, :3] self.wis3d.add_motion_skel( joints = aligned_gt_kp3d, bones = Skeleton_OpenPose25.bones, colors = Skeleton_OpenPose25.bone_colors, name = 'gt_kp3d', ) except Exception as e: get_logger().error(f'Failed to visualize the current performance on wis3d: {e}') try: # 2.2. Visualize 2D information. if vis_data['img_patch'] is not None: # Overlay the skin mesh of the results on the original image. imgs_spliced = [] for i, img_patch in enumerate(vis_data['img_patch']): # TODO: make this more elegant. img_mean = to_numpy(OmegaConf.to_container(self.cfg.policy.img_mean))[None, None] # (1, 1, 3) img_std = to_numpy(OmegaConf.to_container(self.cfg.policy.img_std))[None, None] # (1, 1, 3) img_patch = ((img_mean + img_patch * img_std) * 255).astype(np.uint8) img_patch_raw = annotate_img(img_patch, 'raw') img_with_mesh = render_mesh_overlay_img( faces = self.skel_model.skin_f, verts = vis_data['pd_skin'][i].float(), K4 = [self.cfg.policy.focal_length, self.cfg.policy.focal_length, 128, 128], img = img_patch, Rt = [torch.eye(3).float(), vis_data['cam_t'][i].float()], mesh_color = 'pink', ) img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh') img_with_gt_mesh = render_mesh_overlay_img( faces = self.skel_model.skin_f, verts = vis_data['gt_skin'][i].float(), K4 = [self.cfg.policy.focal_length, self.cfg.policy.focal_length, 128, 128], img = img_patch, Rt = [torch.eye(3).float(), vis_data['cam_t'][i].float()], mesh_color = 'pink', ) valid = 'valid' if vis_data['gt_skin_valid'][i] else 'invalid' img_with_gt_mesh = annotate_img(img_with_gt_mesh, f'gt_mesh_{valid}') img_with_gt = annotate_img(img_patch, 'gt_kp2d') gt_kp2d_with_conf = vis_data['gt_kp2d_with_conf'][i] gt_kp2d_with_conf[:, :2] = (gt_kp2d_with_conf[:, :2] + 0.5) * self.cfg.policy.img_patch_size img_with_gt = draw_kp2d_on_img( img_with_gt, gt_kp2d_with_conf, Skeleton_OpenPose25.bones, Skeleton_OpenPose25.bone_colors, ) img_with_pd = annotate_img(img_patch, 'pd_kp2d') pd_kp2d_vis = vis_data['pd_kp2d'][i] pd_kp2d_vis = (pd_kp2d_vis + 0.5) * self.cfg.policy.img_patch_size img_with_pd = draw_kp2d_on_img( img_with_pd, (vis_data['pd_kp2d'][i] + 0.5) * self.cfg.policy.img_patch_size, Skeleton_OpenPose25.bones, Skeleton_OpenPose25.bone_colors, ) img_spliced = splice_img([img_patch_raw, img_with_gt, img_with_pd, img_with_mesh, img_with_gt_mesh], grid_ids=[[0, 1, 2, 3, 4]]) img_spliced = annotate_img(img_spliced, vis_data['img_key'][i], pos='tl') imgs_spliced.append(img_spliced) try: self.wis3d.set_scene_id(i) self.wis3d.add_image( image = img_spliced, name = 'image', ) except Exception as e: get_logger().error(f'Failed to visualize the current performance on wis3d: {e}') img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(vis_data['img_patch']))]) img_final = to_tensor(img_final, device=None).permute(2, 0, 1) summary_writer.add_image(f'{mode}/visualization', img_final, self.global_step) except Exception as e: get_logger().error(f'Failed to visualize the current performance: {e}') # traceback.print_exc() def _adapt_hsmr_v1(self, batch): from lib.data.augmentation.skel import rot_skel_on_plane rot_deg = batch['augm_args']['rot_deg'] # (B,) skel_params = rot_skel_on_plane(batch['raw_skel_params'], rot_deg) batch['gt_params'] = {} batch['gt_params']['poses_orient'] = skel_params['poses'][:, :3] batch['gt_params']['poses_body'] = skel_params['poses'][:, 3:] batch['gt_params']['betas'] = skel_params['betas'] has_skel_params = batch['has_skel_params'] batch['has_gt_params'] = {} batch['has_gt_params']['poses_orient'] = has_skel_params['poses'] batch['has_gt_params']['poses_body'] = has_skel_params['poses'] batch['has_gt_params']['betas'] = has_skel_params['betas'] return batch def _adapt_img_inference(self, img_patches): return {'img_patch': img_patches}