HSMR / lib /modeling /callbacks /skelify_spin.py
IsshikiHugh's picture
feat: CPU demo
5ac1897
from lib.kits.basic import *
from concurrent.futures import ThreadPoolExecutor
from lightning_fabric.utilities.rank_zero import _get_rank
from lib.data.augmentation.skel import rot_skel_on_plane
from lib.utils.data import to_tensor, to_numpy
from lib.utils.camera import perspective_projection, estimate_camera_trans
from lib.utils.vis import Wis3D
from lib.modeling.optim.skelify.skelify import SKELify
from lib.body_models.common import make_SKEL
DEBUG = False
DEBUG_ROUND = False
class SKELifySPIN(pl.Callback):
'''
Call SKELify to optimize the prediction results.
Here we have several concepts of data: gt, opgt, pd, (res), bpgt.
1. `gt`: Static Ground Truth: they are loaded from static training datasets,
they might be real ground truth (like 2D keypoints), or pseudo ground truth (like SKEL
parameters). They will be gradually replaced by the better pseudo ground truth through
iterations in anticipation.
2. `opgt`: Old Pseudo-Ground Truth: they are the better ground truth among static datasets
(those called gt), and the dynamic datasets (maintained in the callbacks), and will serve
as the labels for training the network.
3. `pd`: Predicted Results: they are from the network outputs and will be optimized later.
After being optimized, they will be called as `res`(Results from optimization).
4. `bpgt`: Better Pseudo Ground Truth: they are the optimized results stored in extra file
and in the memory. These data are the highest quality data picked among the static
ground truth, or cached better pseudo ground truth, or the predicted & optimized data.
'''
# TODO: Now I only consider to use kp2d to evaluate the performance. (Because not all data provide kp3d.)
# TODO: But we need to consider the kp3d in the future, which is if we have, than use it.
def __init__(
self,
cfg : DictConfig,
skelify : DictConfig,
**kwargs,
):
super().__init__()
self.interval = cfg.interval
self.B = cfg.batch_size
self.kb_pr = cfg.get('max_batches_per_round', None) # latest k batches per round are SPINed
self.better_pgt_fn = Path(cfg.better_pgt_fn) # load it before training
self.skip_warm_up_steps = cfg.skip_warm_up_steps
self.update_better_pgt = cfg.update_better_pgt
self.skelify_cfg = skelify
# The threshold to determine if the result is valid. (In case some data
# don't have parameters at first but was updated to a bad parameters.)
self.valid_betas_threshold = cfg.valid_betas_threshold
self._init_pd_dict()
self.better_pgt = None
def on_train_batch_start(self, trainer, pl_module, raw_batch, batch_idx):
# Lazy initialization for better pgt.
if self.better_pgt is None:
self._init_better_pgt()
# GPU_monitor.snapshot('GPU-Mem-Before-Train-Before-SPIN-Update')
device = pl_module.device
batch = raw_batch['img_ds']
if not self.update_better_pgt:
return
# 1. Compose the data from batches.
seq_key_list = batch['__key__']
batch_do_flip_list = batch['augm_args']['do_flip']
sample_uid_list = [
f'{seq_key}_flip' if do_flip else f'{seq_key}_orig'
for seq_key, do_flip in zip(seq_key_list, batch_do_flip_list)
]
# 2. Update the labels from better_pgt.
for i, sample_uid in enumerate(sample_uid_list):
if sample_uid in self.better_pgt['poses'].keys():
batch['raw_skel_params']['poses'][i] = to_tensor(self.better_pgt['poses'][sample_uid], device=device) # (46,)
batch['raw_skel_params']['betas'][i] = to_tensor(self.better_pgt['betas'][sample_uid], device=device) # (10,)
batch['has_skel_params']['poses'][i] = self.better_pgt['has_poses'][sample_uid] # 0 or 1
batch['has_skel_params']['betas'][i] = self.better_pgt['has_betas'][sample_uid] # 0 or 1
batch['updated_by_spin'][i] = True # add information for inspection
# get_logger().trace(f'Update the pseudo-gt for {sample_uid}.')
# GPU_monitor.snapshot('GPU-Mem-Before-Train-After-SPIN-Update')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# GPU_monitor.snapshot('GPU-Mem-After-Train-Before-SPIN-Update')
# Since the prediction from network might be far from the ground truth before well-trained,
# we can skip some steps to avoid meaningless optimization.
if trainer.global_step > self.skip_warm_up_steps or DEBUG_ROUND:
# Collect the prediction results.
self._save_pd(batch['img_ds'], outputs)
if self.interval > 0 and trainer.global_step % self.interval == 0 or DEBUG_ROUND:
torch.cuda.empty_cache()
with PM.time_monitor('SPIN'):
self._spin(trainer.logger, pl_module.device)
torch.cuda.empty_cache()
# GPU_monitor.snapshot('GPU-Mem-After-Train-After-SPIN-Update')
# GPU_monitor.report_latest(k=4)
def _init_pd_dict(self):
''' Memory clean up for each SPIN. '''
''' Use numpy to store the value to save GPU memory. '''
self.cache = {
# Things to identify one sample.
'seq_key_list' : [],
# Things for comparison.
'opgt_poses_list' : [],
'opgt_betas_list' : [],
'opgt_has_poses_list' : [],
'opgt_has_betas_list' : [],
# Things for optimization and self-iteration.
'gt_kp2d_list' : [],
'pd_poses_list' : [],
'pd_betas_list' : [],
'pd_cam_t_list' : [],
'do_flip_list' : [],
'rot_deg_list' : [],
'do_extreme_crop_list': [], # if the extreme crop is applied, we don't update the pseudo-gt
# Things for visualization.
'img_patch': [],
# No gt_cam_t_list.
}
def _format_pd(self):
''' Format the cache to numpy. '''
if self.kb_pr is None:
last_k = len(self.cache['seq_key_list'])
else:
last_k = self.kb_pr * self.B # the latest k samples to be optimized.
self.cache['seq_key_list'] = to_numpy(self.cache['seq_key_list'])[-last_k:]
self.cache['opgt_poses_list'] = to_numpy(self.cache['opgt_poses_list'])[-last_k:]
self.cache['opgt_betas_list'] = to_numpy(self.cache['opgt_betas_list'])[-last_k:]
self.cache['opgt_has_poses_list'] = to_numpy(self.cache['opgt_has_poses_list'])[-last_k:]
self.cache['opgt_has_betas_list'] = to_numpy(self.cache['opgt_has_betas_list'])[-last_k:]
self.cache['gt_kp2d_list'] = to_numpy(self.cache['gt_kp2d_list'])[-last_k:]
self.cache['pd_poses_list'] = to_numpy(self.cache['pd_poses_list'])[-last_k:]
self.cache['pd_betas_list'] = to_numpy(self.cache['pd_betas_list'])[-last_k:]
self.cache['pd_cam_t_list'] = to_numpy(self.cache['pd_cam_t_list'])[-last_k:]
self.cache['do_flip_list'] = to_numpy(self.cache['do_flip_list'])[-last_k:]
self.cache['rot_deg_list'] = to_numpy(self.cache['rot_deg_list'])[-last_k:]
self.cache['do_extreme_crop_list'] = to_numpy(self.cache['do_extreme_crop_list'])[-last_k:]
if DEBUG:
self.cache['img_patch'] = to_numpy(self.cache['img_patch'])[-last_k:]
def _save_pd(self, batch, outputs):
''' Save all the prediction results and related labels from the outputs. '''
B = len(batch['__key__'])
self.cache['seq_key_list'].extend(batch['__key__']) # (NS,)
self.cache['opgt_poses_list'].extend(to_numpy(batch['raw_skel_params']['poses'])) # (NS, 46)
self.cache['opgt_betas_list'].extend(to_numpy(batch['raw_skel_params']['betas'])) # (NS, 10)
self.cache['opgt_has_poses_list'].extend(to_numpy(batch['has_skel_params']['poses'])) # (NS,) 0 or 1
self.cache['opgt_has_betas_list'].extend(to_numpy(batch['has_skel_params']['betas'])) # (NS,) 0 or 1
self.cache['gt_kp2d_list'].extend(to_numpy(batch['kp2d'])) # (NS, 44, 3)
self.cache['pd_poses_list'].extend(to_numpy(outputs['pd_params']['poses']))
self.cache['pd_betas_list'].extend(to_numpy(outputs['pd_params']['betas']))
self.cache['pd_cam_t_list'].extend(to_numpy(outputs['pd_cam_t']))
self.cache['do_flip_list'].extend(to_numpy(batch['augm_args']['do_flip']))
self.cache['rot_deg_list'].extend(to_numpy(batch['augm_args']['rot_deg']))
self.cache['do_extreme_crop_list'].extend(to_numpy(batch['augm_args']['do_extreme_crop']))
if DEBUG:
img_patch = batch['img_patch'].clone().permute(0, 2, 3, 1) # (NS, 256, 256, 3)
mean = torch.tensor([0.485, 0.456, 0.406], device=img_patch.device).reshape(1, 1, 1, 3)
std = torch.tensor([0.229, 0.224, 0.225], device=img_patch.device).reshape(1, 1, 1, 3)
img_patch = 255 * (img_patch * std + mean)
self.cache['img_patch'].extend(to_numpy(img_patch).astype(np.uint8)) # (NS, 256, 256, 3)
def _init_better_pgt(self):
''' DDP adaptable initialization. '''
self.rank = _get_rank()
get_logger().info(f'Initializing better pgt cache @ rank {self.rank}')
if self.rank is not None:
self.better_pgt_fn = Path(f'{self.better_pgt_fn}_r{self.rank}')
get_logger().info(f'Redirecting better pgt cache to {self.better_pgt_fn}')
if self.better_pgt_fn.exists():
better_pgt_z = np.load(self.better_pgt_fn, allow_pickle=True)
self.better_pgt = {k: better_pgt_z[k].item() for k in better_pgt_z.files}
else:
self.better_pgt = {'poses': {}, 'betas': {}, 'has_poses': {}, 'has_betas': {}}
def _spin(self, tb_logger, device):
skelify : SKELify = instantiate(self.skelify_cfg, tb_logger=tb_logger, device=device, _recursive_=False)
skel_model = skelify.skel_model
self._format_pd()
# 1. Make up the cache to run SKELify.
with PM.time_monitor('preparation'):
sample_uid_list = [
f'{seq_key}_flip' if do_flip else f'{seq_key}_orig'
for seq_key, do_flip in zip(self.cache['seq_key_list'], self.cache['do_flip_list'])
]
all_gt_kp2d = self.cache['gt_kp2d_list'] # (NS, 44, 2)
all_init_poses = self.cache['pd_poses_list'] # (NS, 46)
all_init_betas = self.cache['pd_betas_list'] # (NS, 10)
all_init_cam_t = self.cache['pd_cam_t_list'] # (NS, 3)
all_do_extreme_crop = self.cache['do_extreme_crop_list'] # (NS,)
all_res_poses = []
all_res_betas = []
all_res_cam_t = []
all_res_kp2d_err = [] # the evaluation of the keypoints 2D error
# 2. Run SKELify optimization here to get better results.
with PM.time_monitor('SKELify') as tm:
get_logger().info(f'Start to run SKELify optimization. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.')
n_samples = len(self.cache['seq_key_list'])
n_round = (n_samples - 1) // self.B + 1
get_logger().info(f'Running SKELify optimization for {n_samples} samples in {n_round} rounds.')
for rid in range(n_round):
sid = rid * self.B
eid = min(sid + self.B, n_samples)
gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device)
init_poses = to_tensor(all_init_poses[sid:eid], device=device)
init_betas = to_tensor(all_init_betas[sid:eid], device=device)
init_cam_t = to_tensor(all_init_cam_t[sid:eid], device=device)
# Run the SKELify optimization.
outputs = skelify(
gt_kp2d = gt_kp2d_with_conf,
init_poses = init_poses,
init_betas = init_betas,
init_cam_t = init_cam_t,
img_patch = self.cache['img_patch'][sid:eid] if DEBUG else None,
)
# Store the results.
all_res_poses.extend(to_numpy(outputs['poses'])) # (~NS, 46)
all_res_betas.extend(to_numpy(outputs['betas'])) # (~NS, 10)
all_res_cam_t.extend(to_numpy(outputs['cam_t'])) # (~NS, 3)
all_res_kp2d_err.extend(to_numpy(outputs['kp2d_err'])) # (~NS,)
tm.tick(f'SKELify round {rid} finished.')
# 3. Initialize the uninitialized better pseudo-gt with old ground truth.
with PM.time_monitor('init_bpgt'):
get_logger().info(f'Initializing bgbt. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.')
for i in range(n_samples):
sample_uid = sample_uid_list[i]
if sample_uid not in self.better_pgt.keys():
self.better_pgt['poses'][sample_uid] = self.cache['opgt_poses_list'][i]
self.better_pgt['betas'][sample_uid] = self.cache['opgt_betas_list'][i]
self.better_pgt['has_poses'][sample_uid] = self.cache['opgt_has_poses_list'][i]
self.better_pgt['has_betas'][sample_uid] = self.cache['opgt_has_betas_list'][i]
# 4. Update the results.
with PM.time_monitor('upd_bpgt'):
upd_cnt = 0 # Count the number of updated samples.
get_logger().info(f'Update the results. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.')
for rid in range(n_round):
torch.cuda.empty_cache()
sid = rid * self.B
eid = min(sid + self.B, n_samples)
focal_length = np.ones(2) * 5000 / 256 # TODO: These data should be loaded from configuration files.
focal_length = focal_length.reshape(1, 2).repeat(eid - sid, 1) # (B, 2)
gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) # (B, 44, 3)
rot_deg = to_tensor(self.cache['rot_deg_list'][sid:eid], device=device)
# 4.1. Prepare the better pseudo-gt and the results.
res_betas = to_tensor(all_res_betas[sid:eid], device=device) # (B, 10)
res_poses_after_augm = to_tensor(all_res_poses[sid:eid], device=device) # (B, 46)
res_poses_before_augm = rot_skel_on_plane(res_poses_after_augm, -rot_deg) # recover the augmentation rotation
res_kp2d_err = to_tensor(all_res_kp2d_err[sid:eid], device=device) # (B,)
cur_do_extreme_crop = all_do_extreme_crop[sid:eid]
# 4.2. Evaluate the quality of the existing better pseudo-gt.
uids = sample_uid_list[sid:eid] # [sid ~ eid] -> sample_uids
bpgt_betas = to_tensor([self.better_pgt['betas'][uid] for uid in uids], device=device)
bpgt_poses_before_augm = to_tensor([self.better_pgt['poses'][uid] for uid in uids], device=device)
bpgt_poses_after_augm = rot_skel_on_plane(bpgt_poses_before_augm.clone(), rot_deg) # recover the augmentation rotation
skel_outputs = skel_model(poses=bpgt_poses_after_augm, betas=bpgt_betas, skelmesh=False)
bpgt_kp3d = skel_outputs.joints.detach() # (B, 44, 3)
bpgt_est_cam_t = estimate_camera_trans(
S = bpgt_kp3d,
joints_2d = gt_kp2d_with_conf.clone(),
focal_length = 5000,
img_size = 256,
) # estimate camera translation from inference 3D keypoints and GT 2D keypoints
bpgt_reproj_kp2d = perspective_projection(
points = to_tensor(bpgt_kp3d, device=device),
translation = to_tensor(bpgt_est_cam_t, device=device),
focal_length = to_tensor(focal_length, device=device),
)
bpgt_kp2d_err = SKELify.eval_kp2d_err(gt_kp2d_with_conf, bpgt_reproj_kp2d) # (B, 44)
valid_betas_mask = res_betas.abs().max(dim=-1)[0] < self.valid_betas_threshold # (B,)
better_mask = res_kp2d_err < bpgt_kp2d_err # (B,)
upd_mask = torch.logical_and(valid_betas_mask, better_mask) # (B,)
upd_ids = torch.arange(eid-sid, device=device)[upd_mask] # uids -> ids
# Update one by one.
for upd_id in upd_ids:
# `uid` for dynamic dataset unique id, `id` for in-round batch data.
# Notes: id starts from zeros, it should be applied to [sid ~ eid] directly.
# Either `all_res_poses[upd_id]` or `res_poses[upd_id - sid]` is wrong.
if cur_do_extreme_crop[upd_id]:
# Skip the extreme crop data.
continue
sample_uid = uids[upd_id]
self.better_pgt['poses'][sample_uid] = to_numpy(res_poses_before_augm[upd_id])
self.better_pgt['betas'][sample_uid] = to_numpy(res_betas[upd_id])
self.better_pgt['has_poses'][sample_uid] = 1. # If updated, then must have.
self.better_pgt['has_betas'][sample_uid] = 1. # If updated, then must have.
upd_cnt += 1
get_logger().info(f'Update {upd_cnt} samples among all {n_samples} samples.')
# 5. [Async] Save the results.
with PM.time_monitor('async_dumping'):
# TODO: Use lock and other techniques to achieve a better submission system.
# TODO: We need to design a better way to solve the synchronization problem.
if hasattr(self, 'dump_thread'):
self.dump_thread.result() # Wait for the previous dump to finish.
with ThreadPoolExecutor() as executor:
self.dump_thread = executor.submit(lambda: np.savez(self.better_pgt_fn, **self.better_pgt))
# 5. Clean up the memory.
del skelify, skel_model
self._init_pd_dict()