from lib.kits.basic import * from concurrent.futures import ThreadPoolExecutor from lightning_fabric.utilities.rank_zero import _get_rank from lib.data.augmentation.skel import rot_skel_on_plane from lib.utils.data import to_tensor, to_numpy from lib.utils.camera import perspective_projection, estimate_camera_trans from lib.utils.vis import Wis3D from lib.modeling.optim.skelify.skelify import SKELify from lib.body_models.common import make_SKEL DEBUG = False DEBUG_ROUND = False class SKELifySPIN(pl.Callback): ''' Call SKELify to optimize the prediction results. Here we have several concepts of data: gt, opgt, pd, (res), bpgt. 1. `gt`: Static Ground Truth: they are loaded from static training datasets, they might be real ground truth (like 2D keypoints), or pseudo ground truth (like SKEL parameters). They will be gradually replaced by the better pseudo ground truth through iterations in anticipation. 2. `opgt`: Old Pseudo-Ground Truth: they are the better ground truth among static datasets (those called gt), and the dynamic datasets (maintained in the callbacks), and will serve as the labels for training the network. 3. `pd`: Predicted Results: they are from the network outputs and will be optimized later. After being optimized, they will be called as `res`(Results from optimization). 4. `bpgt`: Better Pseudo Ground Truth: they are the optimized results stored in extra file and in the memory. These data are the highest quality data picked among the static ground truth, or cached better pseudo ground truth, or the predicted & optimized data. ''' # TODO: Now I only consider to use kp2d to evaluate the performance. (Because not all data provide kp3d.) # TODO: But we need to consider the kp3d in the future, which is if we have, than use it. def __init__( self, cfg : DictConfig, skelify : DictConfig, **kwargs, ): super().__init__() self.interval = cfg.interval self.B = cfg.batch_size self.kb_pr = cfg.get('max_batches_per_round', None) # latest k batches per round are SPINed self.better_pgt_fn = Path(cfg.better_pgt_fn) # load it before training self.skip_warm_up_steps = cfg.skip_warm_up_steps self.update_better_pgt = cfg.update_better_pgt self.skelify_cfg = skelify # The threshold to determine if the result is valid. (In case some data # don't have parameters at first but was updated to a bad parameters.) self.valid_betas_threshold = cfg.valid_betas_threshold self._init_pd_dict() self.better_pgt = None def on_train_batch_start(self, trainer, pl_module, raw_batch, batch_idx): # Lazy initialization for better pgt. if self.better_pgt is None: self._init_better_pgt() # GPU_monitor.snapshot('GPU-Mem-Before-Train-Before-SPIN-Update') device = pl_module.device batch = raw_batch['img_ds'] if not self.update_better_pgt: return # 1. Compose the data from batches. seq_key_list = batch['__key__'] batch_do_flip_list = batch['augm_args']['do_flip'] sample_uid_list = [ f'{seq_key}_flip' if do_flip else f'{seq_key}_orig' for seq_key, do_flip in zip(seq_key_list, batch_do_flip_list) ] # 2. Update the labels from better_pgt. for i, sample_uid in enumerate(sample_uid_list): if sample_uid in self.better_pgt['poses'].keys(): batch['raw_skel_params']['poses'][i] = to_tensor(self.better_pgt['poses'][sample_uid], device=device) # (46,) batch['raw_skel_params']['betas'][i] = to_tensor(self.better_pgt['betas'][sample_uid], device=device) # (10,) batch['has_skel_params']['poses'][i] = self.better_pgt['has_poses'][sample_uid] # 0 or 1 batch['has_skel_params']['betas'][i] = self.better_pgt['has_betas'][sample_uid] # 0 or 1 batch['updated_by_spin'][i] = True # add information for inspection # get_logger().trace(f'Update the pseudo-gt for {sample_uid}.') # GPU_monitor.snapshot('GPU-Mem-Before-Train-After-SPIN-Update') def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # GPU_monitor.snapshot('GPU-Mem-After-Train-Before-SPIN-Update') # Since the prediction from network might be far from the ground truth before well-trained, # we can skip some steps to avoid meaningless optimization. if trainer.global_step > self.skip_warm_up_steps or DEBUG_ROUND: # Collect the prediction results. self._save_pd(batch['img_ds'], outputs) if self.interval > 0 and trainer.global_step % self.interval == 0 or DEBUG_ROUND: torch.cuda.empty_cache() with PM.time_monitor('SPIN'): self._spin(trainer.logger, pl_module.device) torch.cuda.empty_cache() # GPU_monitor.snapshot('GPU-Mem-After-Train-After-SPIN-Update') # GPU_monitor.report_latest(k=4) def _init_pd_dict(self): ''' Memory clean up for each SPIN. ''' ''' Use numpy to store the value to save GPU memory. ''' self.cache = { # Things to identify one sample. 'seq_key_list' : [], # Things for comparison. 'opgt_poses_list' : [], 'opgt_betas_list' : [], 'opgt_has_poses_list' : [], 'opgt_has_betas_list' : [], # Things for optimization and self-iteration. 'gt_kp2d_list' : [], 'pd_poses_list' : [], 'pd_betas_list' : [], 'pd_cam_t_list' : [], 'do_flip_list' : [], 'rot_deg_list' : [], 'do_extreme_crop_list': [], # if the extreme crop is applied, we don't update the pseudo-gt # Things for visualization. 'img_patch': [], # No gt_cam_t_list. } def _format_pd(self): ''' Format the cache to numpy. ''' if self.kb_pr is None: last_k = len(self.cache['seq_key_list']) else: last_k = self.kb_pr * self.B # the latest k samples to be optimized. self.cache['seq_key_list'] = to_numpy(self.cache['seq_key_list'])[-last_k:] self.cache['opgt_poses_list'] = to_numpy(self.cache['opgt_poses_list'])[-last_k:] self.cache['opgt_betas_list'] = to_numpy(self.cache['opgt_betas_list'])[-last_k:] self.cache['opgt_has_poses_list'] = to_numpy(self.cache['opgt_has_poses_list'])[-last_k:] self.cache['opgt_has_betas_list'] = to_numpy(self.cache['opgt_has_betas_list'])[-last_k:] self.cache['gt_kp2d_list'] = to_numpy(self.cache['gt_kp2d_list'])[-last_k:] self.cache['pd_poses_list'] = to_numpy(self.cache['pd_poses_list'])[-last_k:] self.cache['pd_betas_list'] = to_numpy(self.cache['pd_betas_list'])[-last_k:] self.cache['pd_cam_t_list'] = to_numpy(self.cache['pd_cam_t_list'])[-last_k:] self.cache['do_flip_list'] = to_numpy(self.cache['do_flip_list'])[-last_k:] self.cache['rot_deg_list'] = to_numpy(self.cache['rot_deg_list'])[-last_k:] self.cache['do_extreme_crop_list'] = to_numpy(self.cache['do_extreme_crop_list'])[-last_k:] if DEBUG: self.cache['img_patch'] = to_numpy(self.cache['img_patch'])[-last_k:] def _save_pd(self, batch, outputs): ''' Save all the prediction results and related labels from the outputs. ''' B = len(batch['__key__']) self.cache['seq_key_list'].extend(batch['__key__']) # (NS,) self.cache['opgt_poses_list'].extend(to_numpy(batch['raw_skel_params']['poses'])) # (NS, 46) self.cache['opgt_betas_list'].extend(to_numpy(batch['raw_skel_params']['betas'])) # (NS, 10) self.cache['opgt_has_poses_list'].extend(to_numpy(batch['has_skel_params']['poses'])) # (NS,) 0 or 1 self.cache['opgt_has_betas_list'].extend(to_numpy(batch['has_skel_params']['betas'])) # (NS,) 0 or 1 self.cache['gt_kp2d_list'].extend(to_numpy(batch['kp2d'])) # (NS, 44, 3) self.cache['pd_poses_list'].extend(to_numpy(outputs['pd_params']['poses'])) self.cache['pd_betas_list'].extend(to_numpy(outputs['pd_params']['betas'])) self.cache['pd_cam_t_list'].extend(to_numpy(outputs['pd_cam_t'])) self.cache['do_flip_list'].extend(to_numpy(batch['augm_args']['do_flip'])) self.cache['rot_deg_list'].extend(to_numpy(batch['augm_args']['rot_deg'])) self.cache['do_extreme_crop_list'].extend(to_numpy(batch['augm_args']['do_extreme_crop'])) if DEBUG: img_patch = batch['img_patch'].clone().permute(0, 2, 3, 1) # (NS, 256, 256, 3) mean = torch.tensor([0.485, 0.456, 0.406], device=img_patch.device).reshape(1, 1, 1, 3) std = torch.tensor([0.229, 0.224, 0.225], device=img_patch.device).reshape(1, 1, 1, 3) img_patch = 255 * (img_patch * std + mean) self.cache['img_patch'].extend(to_numpy(img_patch).astype(np.uint8)) # (NS, 256, 256, 3) def _init_better_pgt(self): ''' DDP adaptable initialization. ''' self.rank = _get_rank() get_logger().info(f'Initializing better pgt cache @ rank {self.rank}') if self.rank is not None: self.better_pgt_fn = Path(f'{self.better_pgt_fn}_r{self.rank}') get_logger().info(f'Redirecting better pgt cache to {self.better_pgt_fn}') if self.better_pgt_fn.exists(): better_pgt_z = np.load(self.better_pgt_fn, allow_pickle=True) self.better_pgt = {k: better_pgt_z[k].item() for k in better_pgt_z.files} else: self.better_pgt = {'poses': {}, 'betas': {}, 'has_poses': {}, 'has_betas': {}} def _spin(self, tb_logger, device): skelify : SKELify = instantiate(self.skelify_cfg, tb_logger=tb_logger, device=device, _recursive_=False) skel_model = skelify.skel_model self._format_pd() # 1. Make up the cache to run SKELify. with PM.time_monitor('preparation'): sample_uid_list = [ f'{seq_key}_flip' if do_flip else f'{seq_key}_orig' for seq_key, do_flip in zip(self.cache['seq_key_list'], self.cache['do_flip_list']) ] all_gt_kp2d = self.cache['gt_kp2d_list'] # (NS, 44, 2) all_init_poses = self.cache['pd_poses_list'] # (NS, 46) all_init_betas = self.cache['pd_betas_list'] # (NS, 10) all_init_cam_t = self.cache['pd_cam_t_list'] # (NS, 3) all_do_extreme_crop = self.cache['do_extreme_crop_list'] # (NS,) all_res_poses = [] all_res_betas = [] all_res_cam_t = [] all_res_kp2d_err = [] # the evaluation of the keypoints 2D error # 2. Run SKELify optimization here to get better results. with PM.time_monitor('SKELify') as tm: get_logger().info(f'Start to run SKELify optimization. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') n_samples = len(self.cache['seq_key_list']) n_round = (n_samples - 1) // self.B + 1 get_logger().info(f'Running SKELify optimization for {n_samples} samples in {n_round} rounds.') for rid in range(n_round): sid = rid * self.B eid = min(sid + self.B, n_samples) gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) init_poses = to_tensor(all_init_poses[sid:eid], device=device) init_betas = to_tensor(all_init_betas[sid:eid], device=device) init_cam_t = to_tensor(all_init_cam_t[sid:eid], device=device) # Run the SKELify optimization. outputs = skelify( gt_kp2d = gt_kp2d_with_conf, init_poses = init_poses, init_betas = init_betas, init_cam_t = init_cam_t, img_patch = self.cache['img_patch'][sid:eid] if DEBUG else None, ) # Store the results. all_res_poses.extend(to_numpy(outputs['poses'])) # (~NS, 46) all_res_betas.extend(to_numpy(outputs['betas'])) # (~NS, 10) all_res_cam_t.extend(to_numpy(outputs['cam_t'])) # (~NS, 3) all_res_kp2d_err.extend(to_numpy(outputs['kp2d_err'])) # (~NS,) tm.tick(f'SKELify round {rid} finished.') # 3. Initialize the uninitialized better pseudo-gt with old ground truth. with PM.time_monitor('init_bpgt'): get_logger().info(f'Initializing bgbt. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') for i in range(n_samples): sample_uid = sample_uid_list[i] if sample_uid not in self.better_pgt.keys(): self.better_pgt['poses'][sample_uid] = self.cache['opgt_poses_list'][i] self.better_pgt['betas'][sample_uid] = self.cache['opgt_betas_list'][i] self.better_pgt['has_poses'][sample_uid] = self.cache['opgt_has_poses_list'][i] self.better_pgt['has_betas'][sample_uid] = self.cache['opgt_has_betas_list'][i] # 4. Update the results. with PM.time_monitor('upd_bpgt'): upd_cnt = 0 # Count the number of updated samples. get_logger().info(f'Update the results. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') for rid in range(n_round): torch.cuda.empty_cache() sid = rid * self.B eid = min(sid + self.B, n_samples) focal_length = np.ones(2) * 5000 / 256 # TODO: These data should be loaded from configuration files. focal_length = focal_length.reshape(1, 2).repeat(eid - sid, 1) # (B, 2) gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) # (B, 44, 3) rot_deg = to_tensor(self.cache['rot_deg_list'][sid:eid], device=device) # 4.1. Prepare the better pseudo-gt and the results. res_betas = to_tensor(all_res_betas[sid:eid], device=device) # (B, 10) res_poses_after_augm = to_tensor(all_res_poses[sid:eid], device=device) # (B, 46) res_poses_before_augm = rot_skel_on_plane(res_poses_after_augm, -rot_deg) # recover the augmentation rotation res_kp2d_err = to_tensor(all_res_kp2d_err[sid:eid], device=device) # (B,) cur_do_extreme_crop = all_do_extreme_crop[sid:eid] # 4.2. Evaluate the quality of the existing better pseudo-gt. uids = sample_uid_list[sid:eid] # [sid ~ eid] -> sample_uids bpgt_betas = to_tensor([self.better_pgt['betas'][uid] for uid in uids], device=device) bpgt_poses_before_augm = to_tensor([self.better_pgt['poses'][uid] for uid in uids], device=device) bpgt_poses_after_augm = rot_skel_on_plane(bpgt_poses_before_augm.clone(), rot_deg) # recover the augmentation rotation skel_outputs = skel_model(poses=bpgt_poses_after_augm, betas=bpgt_betas, skelmesh=False) bpgt_kp3d = skel_outputs.joints.detach() # (B, 44, 3) bpgt_est_cam_t = estimate_camera_trans( S = bpgt_kp3d, joints_2d = gt_kp2d_with_conf.clone(), focal_length = 5000, img_size = 256, ) # estimate camera translation from inference 3D keypoints and GT 2D keypoints bpgt_reproj_kp2d = perspective_projection( points = to_tensor(bpgt_kp3d, device=device), translation = to_tensor(bpgt_est_cam_t, device=device), focal_length = to_tensor(focal_length, device=device), ) bpgt_kp2d_err = SKELify.eval_kp2d_err(gt_kp2d_with_conf, bpgt_reproj_kp2d) # (B, 44) valid_betas_mask = res_betas.abs().max(dim=-1)[0] < self.valid_betas_threshold # (B,) better_mask = res_kp2d_err < bpgt_kp2d_err # (B,) upd_mask = torch.logical_and(valid_betas_mask, better_mask) # (B,) upd_ids = torch.arange(eid-sid, device=device)[upd_mask] # uids -> ids # Update one by one. for upd_id in upd_ids: # `uid` for dynamic dataset unique id, `id` for in-round batch data. # Notes: id starts from zeros, it should be applied to [sid ~ eid] directly. # Either `all_res_poses[upd_id]` or `res_poses[upd_id - sid]` is wrong. if cur_do_extreme_crop[upd_id]: # Skip the extreme crop data. continue sample_uid = uids[upd_id] self.better_pgt['poses'][sample_uid] = to_numpy(res_poses_before_augm[upd_id]) self.better_pgt['betas'][sample_uid] = to_numpy(res_betas[upd_id]) self.better_pgt['has_poses'][sample_uid] = 1. # If updated, then must have. self.better_pgt['has_betas'][sample_uid] = 1. # If updated, then must have. upd_cnt += 1 get_logger().info(f'Update {upd_cnt} samples among all {n_samples} samples.') # 5. [Async] Save the results. with PM.time_monitor('async_dumping'): # TODO: Use lock and other techniques to achieve a better submission system. # TODO: We need to design a better way to solve the synchronization problem. if hasattr(self, 'dump_thread'): self.dump_thread.result() # Wait for the previous dump to finish. with ThreadPoolExecutor() as executor: self.dump_thread = executor.submit(lambda: np.savez(self.better_pgt_fn, **self.better_pgt)) # 5. Clean up the memory. del skelify, skel_model self._init_pd_dict()