Spaces:
Sleeping
Sleeping
from lib.kits.basic import * | |
from concurrent.futures import ThreadPoolExecutor | |
from lightning_fabric.utilities.rank_zero import _get_rank | |
from lib.data.augmentation.skel import rot_skel_on_plane | |
from lib.utils.data import to_tensor, to_numpy | |
from lib.utils.camera import perspective_projection, estimate_camera_trans | |
from lib.utils.vis import Wis3D | |
from lib.modeling.optim.skelify.skelify import SKELify | |
from lib.body_models.common import make_SKEL | |
DEBUG = False | |
DEBUG_ROUND = False | |
class SKELifySPIN(pl.Callback): | |
''' | |
Call SKELify to optimize the prediction results. | |
Here we have several concepts of data: gt, opgt, pd, (res), bpgt. | |
1. `gt`: Static Ground Truth: they are loaded from static training datasets, | |
they might be real ground truth (like 2D keypoints), or pseudo ground truth (like SKEL | |
parameters). They will be gradually replaced by the better pseudo ground truth through | |
iterations in anticipation. | |
2. `opgt`: Old Pseudo-Ground Truth: they are the better ground truth among static datasets | |
(those called gt), and the dynamic datasets (maintained in the callbacks), and will serve | |
as the labels for training the network. | |
3. `pd`: Predicted Results: they are from the network outputs and will be optimized later. | |
After being optimized, they will be called as `res`(Results from optimization). | |
4. `bpgt`: Better Pseudo Ground Truth: they are the optimized results stored in extra file | |
and in the memory. These data are the highest quality data picked among the static | |
ground truth, or cached better pseudo ground truth, or the predicted & optimized data. | |
''' | |
# TODO: Now I only consider to use kp2d to evaluate the performance. (Because not all data provide kp3d.) | |
# TODO: But we need to consider the kp3d in the future, which is if we have, than use it. | |
def __init__( | |
self, | |
cfg : DictConfig, | |
skelify : DictConfig, | |
**kwargs, | |
): | |
super().__init__() | |
self.interval = cfg.interval | |
self.B = cfg.batch_size | |
self.kb_pr = cfg.get('max_batches_per_round', None) # latest k batches per round are SPINed | |
self.better_pgt_fn = Path(cfg.better_pgt_fn) # load it before training | |
self.skip_warm_up_steps = cfg.skip_warm_up_steps | |
self.update_better_pgt = cfg.update_better_pgt | |
self.skelify_cfg = skelify | |
# The threshold to determine if the result is valid. (In case some data | |
# don't have parameters at first but was updated to a bad parameters.) | |
self.valid_betas_threshold = cfg.valid_betas_threshold | |
self._init_pd_dict() | |
self.better_pgt = None | |
def on_train_batch_start(self, trainer, pl_module, raw_batch, batch_idx): | |
# Lazy initialization for better pgt. | |
if self.better_pgt is None: | |
self._init_better_pgt() | |
# GPU_monitor.snapshot('GPU-Mem-Before-Train-Before-SPIN-Update') | |
device = pl_module.device | |
batch = raw_batch['img_ds'] | |
if not self.update_better_pgt: | |
return | |
# 1. Compose the data from batches. | |
seq_key_list = batch['__key__'] | |
batch_do_flip_list = batch['augm_args']['do_flip'] | |
sample_uid_list = [ | |
f'{seq_key}_flip' if do_flip else f'{seq_key}_orig' | |
for seq_key, do_flip in zip(seq_key_list, batch_do_flip_list) | |
] | |
# 2. Update the labels from better_pgt. | |
for i, sample_uid in enumerate(sample_uid_list): | |
if sample_uid in self.better_pgt['poses'].keys(): | |
batch['raw_skel_params']['poses'][i] = to_tensor(self.better_pgt['poses'][sample_uid], device=device) # (46,) | |
batch['raw_skel_params']['betas'][i] = to_tensor(self.better_pgt['betas'][sample_uid], device=device) # (10,) | |
batch['has_skel_params']['poses'][i] = self.better_pgt['has_poses'][sample_uid] # 0 or 1 | |
batch['has_skel_params']['betas'][i] = self.better_pgt['has_betas'][sample_uid] # 0 or 1 | |
batch['updated_by_spin'][i] = True # add information for inspection | |
# get_logger().trace(f'Update the pseudo-gt for {sample_uid}.') | |
# GPU_monitor.snapshot('GPU-Mem-Before-Train-After-SPIN-Update') | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
# GPU_monitor.snapshot('GPU-Mem-After-Train-Before-SPIN-Update') | |
# Since the prediction from network might be far from the ground truth before well-trained, | |
# we can skip some steps to avoid meaningless optimization. | |
if trainer.global_step > self.skip_warm_up_steps or DEBUG_ROUND: | |
# Collect the prediction results. | |
self._save_pd(batch['img_ds'], outputs) | |
if self.interval > 0 and trainer.global_step % self.interval == 0 or DEBUG_ROUND: | |
torch.cuda.empty_cache() | |
with PM.time_monitor('SPIN'): | |
self._spin(trainer.logger, pl_module.device) | |
torch.cuda.empty_cache() | |
# GPU_monitor.snapshot('GPU-Mem-After-Train-After-SPIN-Update') | |
# GPU_monitor.report_latest(k=4) | |
def _init_pd_dict(self): | |
''' Memory clean up for each SPIN. ''' | |
''' Use numpy to store the value to save GPU memory. ''' | |
self.cache = { | |
# Things to identify one sample. | |
'seq_key_list' : [], | |
# Things for comparison. | |
'opgt_poses_list' : [], | |
'opgt_betas_list' : [], | |
'opgt_has_poses_list' : [], | |
'opgt_has_betas_list' : [], | |
# Things for optimization and self-iteration. | |
'gt_kp2d_list' : [], | |
'pd_poses_list' : [], | |
'pd_betas_list' : [], | |
'pd_cam_t_list' : [], | |
'do_flip_list' : [], | |
'rot_deg_list' : [], | |
'do_extreme_crop_list': [], # if the extreme crop is applied, we don't update the pseudo-gt | |
# Things for visualization. | |
'img_patch': [], | |
# No gt_cam_t_list. | |
} | |
def _format_pd(self): | |
''' Format the cache to numpy. ''' | |
if self.kb_pr is None: | |
last_k = len(self.cache['seq_key_list']) | |
else: | |
last_k = self.kb_pr * self.B # the latest k samples to be optimized. | |
self.cache['seq_key_list'] = to_numpy(self.cache['seq_key_list'])[-last_k:] | |
self.cache['opgt_poses_list'] = to_numpy(self.cache['opgt_poses_list'])[-last_k:] | |
self.cache['opgt_betas_list'] = to_numpy(self.cache['opgt_betas_list'])[-last_k:] | |
self.cache['opgt_has_poses_list'] = to_numpy(self.cache['opgt_has_poses_list'])[-last_k:] | |
self.cache['opgt_has_betas_list'] = to_numpy(self.cache['opgt_has_betas_list'])[-last_k:] | |
self.cache['gt_kp2d_list'] = to_numpy(self.cache['gt_kp2d_list'])[-last_k:] | |
self.cache['pd_poses_list'] = to_numpy(self.cache['pd_poses_list'])[-last_k:] | |
self.cache['pd_betas_list'] = to_numpy(self.cache['pd_betas_list'])[-last_k:] | |
self.cache['pd_cam_t_list'] = to_numpy(self.cache['pd_cam_t_list'])[-last_k:] | |
self.cache['do_flip_list'] = to_numpy(self.cache['do_flip_list'])[-last_k:] | |
self.cache['rot_deg_list'] = to_numpy(self.cache['rot_deg_list'])[-last_k:] | |
self.cache['do_extreme_crop_list'] = to_numpy(self.cache['do_extreme_crop_list'])[-last_k:] | |
if DEBUG: | |
self.cache['img_patch'] = to_numpy(self.cache['img_patch'])[-last_k:] | |
def _save_pd(self, batch, outputs): | |
''' Save all the prediction results and related labels from the outputs. ''' | |
B = len(batch['__key__']) | |
self.cache['seq_key_list'].extend(batch['__key__']) # (NS,) | |
self.cache['opgt_poses_list'].extend(to_numpy(batch['raw_skel_params']['poses'])) # (NS, 46) | |
self.cache['opgt_betas_list'].extend(to_numpy(batch['raw_skel_params']['betas'])) # (NS, 10) | |
self.cache['opgt_has_poses_list'].extend(to_numpy(batch['has_skel_params']['poses'])) # (NS,) 0 or 1 | |
self.cache['opgt_has_betas_list'].extend(to_numpy(batch['has_skel_params']['betas'])) # (NS,) 0 or 1 | |
self.cache['gt_kp2d_list'].extend(to_numpy(batch['kp2d'])) # (NS, 44, 3) | |
self.cache['pd_poses_list'].extend(to_numpy(outputs['pd_params']['poses'])) | |
self.cache['pd_betas_list'].extend(to_numpy(outputs['pd_params']['betas'])) | |
self.cache['pd_cam_t_list'].extend(to_numpy(outputs['pd_cam_t'])) | |
self.cache['do_flip_list'].extend(to_numpy(batch['augm_args']['do_flip'])) | |
self.cache['rot_deg_list'].extend(to_numpy(batch['augm_args']['rot_deg'])) | |
self.cache['do_extreme_crop_list'].extend(to_numpy(batch['augm_args']['do_extreme_crop'])) | |
if DEBUG: | |
img_patch = batch['img_patch'].clone().permute(0, 2, 3, 1) # (NS, 256, 256, 3) | |
mean = torch.tensor([0.485, 0.456, 0.406], device=img_patch.device).reshape(1, 1, 1, 3) | |
std = torch.tensor([0.229, 0.224, 0.225], device=img_patch.device).reshape(1, 1, 1, 3) | |
img_patch = 255 * (img_patch * std + mean) | |
self.cache['img_patch'].extend(to_numpy(img_patch).astype(np.uint8)) # (NS, 256, 256, 3) | |
def _init_better_pgt(self): | |
''' DDP adaptable initialization. ''' | |
self.rank = _get_rank() | |
get_logger().info(f'Initializing better pgt cache @ rank {self.rank}') | |
if self.rank is not None: | |
self.better_pgt_fn = Path(f'{self.better_pgt_fn}_r{self.rank}') | |
get_logger().info(f'Redirecting better pgt cache to {self.better_pgt_fn}') | |
if self.better_pgt_fn.exists(): | |
better_pgt_z = np.load(self.better_pgt_fn, allow_pickle=True) | |
self.better_pgt = {k: better_pgt_z[k].item() for k in better_pgt_z.files} | |
else: | |
self.better_pgt = {'poses': {}, 'betas': {}, 'has_poses': {}, 'has_betas': {}} | |
def _spin(self, tb_logger, device): | |
skelify : SKELify = instantiate(self.skelify_cfg, tb_logger=tb_logger, device=device, _recursive_=False) | |
skel_model = skelify.skel_model | |
self._format_pd() | |
# 1. Make up the cache to run SKELify. | |
with PM.time_monitor('preparation'): | |
sample_uid_list = [ | |
f'{seq_key}_flip' if do_flip else f'{seq_key}_orig' | |
for seq_key, do_flip in zip(self.cache['seq_key_list'], self.cache['do_flip_list']) | |
] | |
all_gt_kp2d = self.cache['gt_kp2d_list'] # (NS, 44, 2) | |
all_init_poses = self.cache['pd_poses_list'] # (NS, 46) | |
all_init_betas = self.cache['pd_betas_list'] # (NS, 10) | |
all_init_cam_t = self.cache['pd_cam_t_list'] # (NS, 3) | |
all_do_extreme_crop = self.cache['do_extreme_crop_list'] # (NS,) | |
all_res_poses = [] | |
all_res_betas = [] | |
all_res_cam_t = [] | |
all_res_kp2d_err = [] # the evaluation of the keypoints 2D error | |
# 2. Run SKELify optimization here to get better results. | |
with PM.time_monitor('SKELify') as tm: | |
get_logger().info(f'Start to run SKELify optimization. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') | |
n_samples = len(self.cache['seq_key_list']) | |
n_round = (n_samples - 1) // self.B + 1 | |
get_logger().info(f'Running SKELify optimization for {n_samples} samples in {n_round} rounds.') | |
for rid in range(n_round): | |
sid = rid * self.B | |
eid = min(sid + self.B, n_samples) | |
gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) | |
init_poses = to_tensor(all_init_poses[sid:eid], device=device) | |
init_betas = to_tensor(all_init_betas[sid:eid], device=device) | |
init_cam_t = to_tensor(all_init_cam_t[sid:eid], device=device) | |
# Run the SKELify optimization. | |
outputs = skelify( | |
gt_kp2d = gt_kp2d_with_conf, | |
init_poses = init_poses, | |
init_betas = init_betas, | |
init_cam_t = init_cam_t, | |
img_patch = self.cache['img_patch'][sid:eid] if DEBUG else None, | |
) | |
# Store the results. | |
all_res_poses.extend(to_numpy(outputs['poses'])) # (~NS, 46) | |
all_res_betas.extend(to_numpy(outputs['betas'])) # (~NS, 10) | |
all_res_cam_t.extend(to_numpy(outputs['cam_t'])) # (~NS, 3) | |
all_res_kp2d_err.extend(to_numpy(outputs['kp2d_err'])) # (~NS,) | |
tm.tick(f'SKELify round {rid} finished.') | |
# 3. Initialize the uninitialized better pseudo-gt with old ground truth. | |
with PM.time_monitor('init_bpgt'): | |
get_logger().info(f'Initializing bgbt. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') | |
for i in range(n_samples): | |
sample_uid = sample_uid_list[i] | |
if sample_uid not in self.better_pgt.keys(): | |
self.better_pgt['poses'][sample_uid] = self.cache['opgt_poses_list'][i] | |
self.better_pgt['betas'][sample_uid] = self.cache['opgt_betas_list'][i] | |
self.better_pgt['has_poses'][sample_uid] = self.cache['opgt_has_poses_list'][i] | |
self.better_pgt['has_betas'][sample_uid] = self.cache['opgt_has_betas_list'][i] | |
# 4. Update the results. | |
with PM.time_monitor('upd_bpgt'): | |
upd_cnt = 0 # Count the number of updated samples. | |
get_logger().info(f'Update the results. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') | |
for rid in range(n_round): | |
torch.cuda.empty_cache() | |
sid = rid * self.B | |
eid = min(sid + self.B, n_samples) | |
focal_length = np.ones(2) * 5000 / 256 # TODO: These data should be loaded from configuration files. | |
focal_length = focal_length.reshape(1, 2).repeat(eid - sid, 1) # (B, 2) | |
gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) # (B, 44, 3) | |
rot_deg = to_tensor(self.cache['rot_deg_list'][sid:eid], device=device) | |
# 4.1. Prepare the better pseudo-gt and the results. | |
res_betas = to_tensor(all_res_betas[sid:eid], device=device) # (B, 10) | |
res_poses_after_augm = to_tensor(all_res_poses[sid:eid], device=device) # (B, 46) | |
res_poses_before_augm = rot_skel_on_plane(res_poses_after_augm, -rot_deg) # recover the augmentation rotation | |
res_kp2d_err = to_tensor(all_res_kp2d_err[sid:eid], device=device) # (B,) | |
cur_do_extreme_crop = all_do_extreme_crop[sid:eid] | |
# 4.2. Evaluate the quality of the existing better pseudo-gt. | |
uids = sample_uid_list[sid:eid] # [sid ~ eid] -> sample_uids | |
bpgt_betas = to_tensor([self.better_pgt['betas'][uid] for uid in uids], device=device) | |
bpgt_poses_before_augm = to_tensor([self.better_pgt['poses'][uid] for uid in uids], device=device) | |
bpgt_poses_after_augm = rot_skel_on_plane(bpgt_poses_before_augm.clone(), rot_deg) # recover the augmentation rotation | |
skel_outputs = skel_model(poses=bpgt_poses_after_augm, betas=bpgt_betas, skelmesh=False) | |
bpgt_kp3d = skel_outputs.joints.detach() # (B, 44, 3) | |
bpgt_est_cam_t = estimate_camera_trans( | |
S = bpgt_kp3d, | |
joints_2d = gt_kp2d_with_conf.clone(), | |
focal_length = 5000, | |
img_size = 256, | |
) # estimate camera translation from inference 3D keypoints and GT 2D keypoints | |
bpgt_reproj_kp2d = perspective_projection( | |
points = to_tensor(bpgt_kp3d, device=device), | |
translation = to_tensor(bpgt_est_cam_t, device=device), | |
focal_length = to_tensor(focal_length, device=device), | |
) | |
bpgt_kp2d_err = SKELify.eval_kp2d_err(gt_kp2d_with_conf, bpgt_reproj_kp2d) # (B, 44) | |
valid_betas_mask = res_betas.abs().max(dim=-1)[0] < self.valid_betas_threshold # (B,) | |
better_mask = res_kp2d_err < bpgt_kp2d_err # (B,) | |
upd_mask = torch.logical_and(valid_betas_mask, better_mask) # (B,) | |
upd_ids = torch.arange(eid-sid, device=device)[upd_mask] # uids -> ids | |
# Update one by one. | |
for upd_id in upd_ids: | |
# `uid` for dynamic dataset unique id, `id` for in-round batch data. | |
# Notes: id starts from zeros, it should be applied to [sid ~ eid] directly. | |
# Either `all_res_poses[upd_id]` or `res_poses[upd_id - sid]` is wrong. | |
if cur_do_extreme_crop[upd_id]: | |
# Skip the extreme crop data. | |
continue | |
sample_uid = uids[upd_id] | |
self.better_pgt['poses'][sample_uid] = to_numpy(res_poses_before_augm[upd_id]) | |
self.better_pgt['betas'][sample_uid] = to_numpy(res_betas[upd_id]) | |
self.better_pgt['has_poses'][sample_uid] = 1. # If updated, then must have. | |
self.better_pgt['has_betas'][sample_uid] = 1. # If updated, then must have. | |
upd_cnt += 1 | |
get_logger().info(f'Update {upd_cnt} samples among all {n_samples} samples.') | |
# 5. [Async] Save the results. | |
with PM.time_monitor('async_dumping'): | |
# TODO: Use lock and other techniques to achieve a better submission system. | |
# TODO: We need to design a better way to solve the synchronization problem. | |
if hasattr(self, 'dump_thread'): | |
self.dump_thread.result() # Wait for the previous dump to finish. | |
with ThreadPoolExecutor() as executor: | |
self.dump_thread = executor.submit(lambda: np.savez(self.better_pgt_fn, **self.better_pgt)) | |
# 5. Clean up the memory. | |
del skelify, skel_model | |
self._init_pd_dict() |