# from model.mdm import MDM # from model.mdm_ours import MDM as MDM_Ours # from model.mdm_ours import MDMV3 as MDM_Ours_V3 # from model.mdm_ours import MDMV4 as MDM_Ours_V4 # from model.mdm_ours import MDMV5 as MDM_Ours_V5 # from model.mdm_ours import MDMV6 as MDM_Ours_V6 # from model.mdm_ours import MDMV7 as MDM_Ours_V7 # from model.mdm_ours import MDMV8 as MDM_Ours_V8 # from model.mdm_ours import MDMV9 as MDM_Ours_V9 from model.mdm_ours import MDMV10 as MDM_Ours_V10 # from model.mdm_ours import MDMV11 as MDM_Ours_V11 # MDM_Ours_V12 from model.mdm_ours import MDMV12 as MDM_Ours_V12 # MDM_Ours_V13 # from model.mdm_ours import MDMV13 as MDM_Ours_V13 # # MDM_Ours_V14 # from model.mdm_ours import MDMV14 as MDM_Ours_V14 from diffusion import gaussian_diffusion as gd from diffusion.respace import SpacedDiffusion, space_timesteps from utils.parser_util import get_cond_mode import torch from torch import optim, nn import torch.nn.functional as F from manopth.manolayer import ManoLayer import numpy as np import trimesh import os from diffusion.respace_ours import SpacedDiffusion as SpacedDiffusion_Ours # SpacedDiffusionV2 from diffusion.respace_ours import SpacedDiffusionV2 as SpacedDiffusion_OursV2 from diffusion.respace_ours import SpacedDiffusionV3 as SpacedDiffusion_OursV3 # SpacedDiffusionV4 from diffusion.respace_ours import SpacedDiffusionV4 as SpacedDiffusion_OursV4 # SpacedDiffusion_OursV5 from diffusion.respace_ours import SpacedDiffusionV5 as SpacedDiffusion_OursV5 # SpacedDiffusion_OursV6 from diffusion.respace_ours import SpacedDiffusionV6 as SpacedDiffusion_OursV6 # SpacedDiffusion_OursV7 from diffusion.respace_ours import SpacedDiffusionV7 as SpacedDiffusion_OursV7 from diffusion.respace_ours import SpacedDiffusionV9 as SpacedDiffusion_OursV9 def batched_index_select_ours(values, indices, dim = 1): value_dims = values.shape[(dim + 1):] values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) indices = indices[(..., *((None,) * len(value_dims)))] indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) value_expand_len = len(indices_shape) - (dim + 1) values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] value_expand_shape = [-1] * len(values.shape) expand_slice = slice(dim, (dim + value_expand_len)) value_expand_shape[expand_slice] = indices.shape[expand_slice] values = values.expand(*value_expand_shape) dim += value_expand_len return values.gather(dim, indices) def gaussian_entropy(logvar): # gaussian entropy ## const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2)) ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const return ent def standard_normal_logprob(z): # feature dim dim = z.size(-1) # dim size -1 log_z = -0.5 * dim * np.log(2 * np.pi) return log_z - z.pow(2) / 2 def load_multiple_models_fr_path(model_path, model): model_paths = model_path.split(";") print(f"Loading multiple models with split model_path: {model_paths}") setting_to_model_path = {} for cur_path in model_paths: cur_setting_nm, cur_model_path = cur_path.split(':') setting_to_model_path[cur_setting_nm] = cur_model_path loaded_dict = {} for cur_setting in setting_to_model_path: cur_model_path = setting_to_model_path[cur_setting] cur_model_state_dict = torch.load(cur_model_path, map_location='cpu') if cur_setting == 'diff_realbasejtsrel': interested_keys = [ 'real_basejtsrel_input_process', 'real_basejtsrel_sequence_pos_encoder', 'real_basejtsrel_seqTransEncoder', 'real_basejtsrel_embed_timestep', 'real_basejtsrel_sequence_pos_denoising_encoder', 'real_basejtsrel_denoising_seqTransEncoder', 'real_basejtsrel_output_process' ] elif cur_setting == 'diff_basejtsrel': interested_keys = [ 'avg_joints_sequence_input_process', 'joints_offset_input_process', 'sequence_pos_encoder', 'seqTransEncoder', 'logvar_seqTransEncoder', 'embed_timestep', 'basejtsrel_denoising_embed_timestep', 'sequence_pos_denoising_encoder', 'basejtsrel_denoising_seqTransEncoder', 'basejtsrel_glb_denoising_latents_trans_layer', 'avg_joint_sequence_output_process', 'joint_offset_output_process', 'output_process' ] elif cur_setting == 'diff_realbasejtsrel_to_joints': interested_keys = [ 'real_basejtsrel_to_joints_input_process', 'real_basejtsrel_to_joints_sequence_pos_encoder', 'real_basejtsrel_to_joints_seqTransEncoder', 'real_basejtsrel_to_joints_embed_timestep', 'real_basejtsrel_to_joints_sequence_pos_denoising_encoder', 'real_basejtsrel_to_joints_denoising_seqTransEncoder', 'real_basejtsrel_to_joints_output_process', ] else: raise ValueError(f"cur_setting:{cur_setting} Not implemented yet") for k in cur_model_state_dict: for cur_inter_key in interested_keys: if cur_inter_key in k: loaded_dict[k] = cur_model_state_dict[k] model_dict = model.state_dict() model_dict.update(loaded_dict) model.load_state_dict(model_dict) def load_model_wo_clip(model, state_dict): # missing_keys: in the current model but not found in the state_dict? # unexpected_keys: not in the current model but found inthe state_dict? missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) # print(unexpected_keys) assert len(unexpected_keys) == 0 assert all([k.startswith('clip_model.') for k in missing_keys]) ### create model and diffusion ## # def create_model_and_diffusion(args, data): if args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist"]: model = MDM_Ours(**get_model_args(args, data)) elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we"]: model = MDM_Ours_V3(**get_model_args(args, data)) # MDM_Ours_V4 elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj"]: model = MDM_Ours_V4(**get_model_args(args, data)) # obj_base_rel_dist_we_wj_latents elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj_latents"]: if args.diff_spatial: if args.pred_joints_offset: if args.diff_joint_quants: model = MDM_Ours_V13(**get_model_args(args, data)) elif args.diff_hand_params: model = MDM_Ours_V14(**get_model_args(args, data)) else: if args.finetune_with_cond: print(f"Using MDM ours V12!!!!") model = MDM_Ours_V12(**get_model_args(args, data)) else: print(f"Using MDM ours V10!!!!") model = MDM_Ours_V10(**get_model_args(args, data)) else: print(f"Using MDM ours V9!!!!") model = MDM_Ours_V9(**get_model_args(args, data)) elif args.diff_latents: print(f"Using MDM ours V11!!!!") model = MDM_Ours_V11(**get_model_args(args, data)) elif args.use_sep_models: if args.use_vae: if args.pred_basejtsrel_avgjts: print(f"Using MDM ours V8!!!!") model = MDM_Ours_V8(**get_model_args(args, data)) else: model = MDM_Ours_V7(**get_model_args(args, data)) else: model = MDM_Ours_V6(**get_model_args(args, data)) else: model = MDM_Ours_V5(**get_model_args(args, data)) else: model = MDM(**get_model_args(args, data)) diffusion = create_gaussian_diffusion(args) return model, diffusion # give utils to models # def get_model_args(args, data): # default_args clip_version = 'ViT-B/32' action_emb = 'tensor' ## get model arguments ## cond_mode = get_cond_mode(args) if hasattr(data.dataset, 'num_actions'): num_actions = data.dataset.num_actions else: num_actions = 1 # SMPL defaults data_rep = 'rot6d' njoints = 25 nfeats = 6 if args.dataset in ['humanml']: ## from data_rep = 'hml_vec' njoints = 263 # joints nfeats = 1 elif args.dataset in ['motion_ours']: data_rep = 'xyz' njoints = 21 nfeats = 3 elif args.dataset == 'kit': data_rep = 'hml_vec' njoints = 251 nfeats = 1 ## modeltype; return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions, 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode, 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch, 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset, 'args': args} def optimize_sampled_hand_joints(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals): # sampled_joints: bsz x ws x nnj x 3 # signed distances # smoothness bsz, ws, nnj = sampled_joints.shape[:3] device = sampled_joints.device coarse_lr = 0.1 num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations # mano_path = "/data1/sim/mano_models/mano/models" base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() signed_dist_e_coeff = 1.0 signed_dist_e_coeff = 0.0 ### start optimization ### # setup MANO layer mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_path for the mano model # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).to(device) ## random init variables ## beta_var = torch.randn([bsz, 10]).to(device) rot_var = torch.randn([bsz * ws, 3]).to(device) theta_var = torch.randn([bsz * ws, 24]).to(device) transl_var = torch.randn([bsz * ws, 3]).to(device) beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() opt = optim.Adam([rot_var, transl_var], lr=coarse_lr) for i_iter in range(num_iters): opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() if dists_base_pts_to_joints is not None: dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean() ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## ''' strategy 2: use all base pts, rel, dists for resolving ''' # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 # signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0. l2_dist_rel_joints_to_base_pts_mask = torch.sqrt( torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1) ) < 0.05 signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5 dot_rel_with_normals = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) signed_dist_mask = signed_dist_mask.detach() # detach the mask # # dot_rel_with_normals: bsz x ws x nnj x nnb avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints signed_dist_e = torch.sum( signed_dist_e[signed_dist_mask] ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() ###### ====== get loss for signed distances ==== ### ''' strategy 2: use all base pts, rel, dists for resolving ''' ''' strategy 1: use nearest base pts, rel, dists for resolving ''' # dist_rhand_joints_to_base_pts = torch.sum( # (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 # ) # # minn_dists_idxes: bsz x ws x nnj --> # minn_dists_to_base_pts, minn_dists_idxes = torch.min( # dist_rhand_joints_to_base_pts, dim=-1 # ) # # base_pts: bsz x nn_base_pts x 3 # # # base_pts: bsz x ws x nn_base_pts x 3 # # # bsz x ws x nnj # # object verts and object faces # # ## other than the sampling process; not # # bsz x ws x nnj x 3 ## # nearest_base_pts = batched_index_select_ours( # base_pts_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # base normalse # # nearest_base_normals = batched_index_select_ours( # base_normals_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # the nearest distance points may be of some ambiguous # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # # bsz x ws x nnj # # signed_dist_joints_to_base_pts = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # # should not be negative # signed_dist_mask = signed_dist_joints_to_base_pts < 0. # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # ) < 0.05 # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 # ### ==== mean of signed distances ==== ### # signed_dist_e = torch.sum( # penetration # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] # ) / torch.clamp( # torch.sum(signed_dist_mask.float()), min=1e-5 # ).item() ''' strategy 1: use nearest base pts, rel, dists for resolving ''' ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() ### signed distance coeff -> the distance coeff # loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) fine_lr = 0.1 num_iters = 1000 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr) for i_iter in range(num_iters): opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() # dists_base_pts_to_joints ## dists_base_pts_to_joints ## if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ## dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).mean() ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## dist_rhand_joints_to_base_pts = torch.sum( (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 ) # minn_dists_idxes: bsz x ws x nnj --> minn_dists_to_base_pts, minn_dists_idxes = torch.min( dist_rhand_joints_to_base_pts, dim=-1 ) # base_pts: bsz x nn_base_pts x 3 # # base_pts: bsz x ws x nn_base_pts x 3 # # bsz x ws x nnj # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() # bsz x ws x nnj x 3 ## nearest_base_pts = batched_index_select_ours( base_pts_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # nearest_base_normals = batched_index_select_ours( base_normals_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # bsz x ws x nnj # signed_dist_joints_to_base_pts = torch.sum( rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 ) # should not be negative signed_dist_mask = signed_dist_joints_to_base_pts < 0. l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) ) < 0.05 signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 ### ==== mean of signed distances ==== ### signed_dist_e = torch.sum( -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] ) / torch.clamp( torch.sum(signed_dist_mask.float()), min=1e-5 ).item() ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) ### refine the optimization with signed energy ## signed_dist_e_coeff = 1.0 fine_lr = 0.1 num_iters = 1000 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr) for i_iter in range(num_iters): opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() # dists_base_pts_to_joints ## dists_base_pts_to_joints ## if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ## dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).mean() ''' strategy 2: use all base pts, rel, dists for resolving ''' # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 # signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0. l2_dist_rel_joints_to_base_pts_mask = torch.sqrt( torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1) ) < 0.05 signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5 ## === dot rel with normals === ## # dot_rel_with_normals = torch.sum( # rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # ) ## === dot rel with normals === ## ## === dot rel with rel, strategy 3 === ## dot_rel_with_normals = torch.sum( -1.0 * rel_base_pts_to_hand_joints * rel_base_pts_to_hand_joints, dim=-1 ) ## === dot rel with rel, strategy 3 === ## signed_dist_mask = signed_dist_mask.detach() # detach the mask # # dot_rel_with_normals: bsz x ws x nnj x nnb avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints signed_dist_e = torch.sum( signed_dist_e[signed_dist_mask] ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() ###### ====== get loss for signed distances ==== ### ''' strategy 2: use all base pts, rel, dists for resolving ''' # hard projections for ''' strategy 1: use nearest base pts, rel, dists for resolving ''' # ### === e2 the signed distances to nearest points should not be negative to the neareste === ### # ## base_pts: bsz x nn_base_pts x 3 # ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## # ## bsz x ws x nnj x nnb ## # dist_rhand_joints_to_base_pts = torch.sum( # (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 # ) # # minn_dists_idxes: bsz x ws x nnj --> # minn_dists_to_base_pts, minn_dists_idxes = torch.min( # dist_rhand_joints_to_base_pts, dim=-1 # ) # # # base_pts: bsz x nn_base_pts x 3 # # # base_pts: bsz x ws x nn_base_pts x 3 # # # bsz x ws x nnj # # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() # # bsz x ws x nnj x 3 ## # nearest_base_pts = batched_index_select_ours( # base_pts_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # nearest_base_normals = batched_index_select_ours( # base_normals_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # # bsz x ws x nnj # # signed_dist_joints_to_base_pts = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # # should not be negative # signed_dist_mask = signed_dist_joints_to_base_pts < 0. # ## === luojisiwei and others === ## # # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # # ) < 0.05 # ## === luojisiwei and others === ## # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # ) < 0.1 # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 # ### ==== mean of signed distances ==== ### # # signed_dist_e = torch.sum( # # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] # # ) / torch.clamp( # # torch.sum(signed_dist_mask.float()), min=1e-5 # # ).item() # # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances # signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach() # # ## penetraition resolving --- strategy # dot_rel_with_normals = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # signed_dist_mask = signed_dist_mask.detach() # detach the mask # # # bsz x ws x nnj --> the loss term # ## signed distances 3 #### isgned distance 3 ### # ## dotrelwithnormals, ## # # # signed_dist_mask -> the distances # # dot_rel_with_normals: bsz x ws x nnj x nnb # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts # signed_dist_e = torch.sum( # signed_dist_e[signed_dist_mask] # ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() # ###### ====== get loss for signed distances ==== ### ''' strategy 1: use nearest base pts, rel, dists for resolving ''' ## judeg whether inside the object and only project those one inside of the object ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() # shoudl take a # how to proejct the jvertex # hwo to project the veretex # weighted sum of the projectiondirection # weights of each base point # atraction field -> should be able to learn the penetration resolving strategy # stochestic penetration resolving strategy # loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) # avg_masks print('\tAvg masks: {}'.format(avg_masks.item())) ''' returning sampled_joints ''' sampled_joints = hand_joints np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy()) print(f"Optimized verts saved to optimized_verts.npy") return sampled_joints.detach() def get_obj_trimesh_list(obj_verts, obj_faces): tot_trimeshes = [] tot_n = len(obj_verts) for i_obj in range(tot_n): cur_obj_verts, cur_obj_faces = obj_verts[i_obj], obj_faces[i_obj] if isinstance(cur_obj_verts, torch.Tensor): cur_obj_verts = cur_obj_verts.detach().cpu().numpy() if isinstance(cur_obj_faces, torch.Tensor): cur_obj_faces = cur_obj_faces.detach().cpu().numpy() cur_obj_mesh = trimesh.Trimesh(vertices=cur_obj_verts, faces=cur_obj_faces, process=False, use_embree=True) tot_trimeshes.append(cur_obj_mesh) return tot_trimeshes def judge_penetrated_points(obj_mesh, subj_pts): # bsz tot_pts_inside_objmesh_labels = [] nn_bsz = len(obj_mesh) for i_bsz in range(nn_bsz): cur_obj_mesh = obj_mesh[i_bsz] cur_subj_pts = subj_pts[i_bsz].detach().cpu().numpy() ori_subj_pts_shape = cur_subj_pts.shape if len(cur_subj_pts.shape) > 2: cur_subj_pts = cur_subj_pts.reshape(cur_subj_pts.shape[0] * cur_subj_pts.shape[1], 3) # pts_inside_objmesh = cur_obj_mesh.contains(cur_subj_pts) pts_inside_objmesh = pts_inside_objmesh.astype(np.float32) ### reshape inside_objmesh labels ### pts_inside_objmesh = pts_inside_objmesh.reshape(*ori_subj_pts_shape[:-1]) tot_pts_inside_objmesh_labels.append(pts_inside_objmesh) tot_pts_inside_objmesh_labels = np.stack(tot_pts_inside_objmesh_labels, axis=0) # nn_bsz x nn_subj_pts tot_pts_inside_objmesh_labels = torch.from_numpy(tot_pts_inside_objmesh_labels).float() return tot_pts_inside_objmesh_labels.to(subj_pts.device) # gt inside objmesh labels and to the pts device # # TODO: other optimization strategies? e.g. sequential optimziation> # def optimize_sampled_hand_joints_wobj(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces): # sampled_joints: bsz x ws x nnj x 3 # signed distances # smoothness # tot_n_objs # tot_obj_trimeshes = get_obj_trimesh_list(obj_verts, obj_faces) ## TODO: write the collect function for object verts, normals, faces ## ### A simple penetration resolving strategy is as follows: #### 1) get vertices in the object; 2) get nearest base points (for simplicity); 3) project the vertex to the base point #### ## 1) for joints only; ## 2) for vertices; ## 3) for vertices ## ## TODO: optimzie the resolvign strategy stated above ## bsz, ws, nnj = sampled_joints.shape[:3] device = sampled_joints.device coarse_lr = 0.1 num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations # mano_path = "/data1/sim/mano_models/mano/models" # obj_verts: bsz x nnobjverts x base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() signed_dist_e_coeff = 1.0 signed_dist_e_coeff = 0.0 ### start optimization ### # setup MANO layer mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_path for the mano model # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).to(device) ## random init variables ## beta_var = torch.randn([bsz, 10]).to(device) rot_var = torch.randn([bsz * ws, 3]).to(device) theta_var = torch.randn([bsz * ws, 24]).to(device) transl_var = torch.randn([bsz * ws, 3]).to(device) beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() opt = optim.Adam([rot_var, transl_var], lr=coarse_lr) for i_iter in range(num_iters): opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() if dists_base_pts_to_joints is not None: dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean() ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## ''' strategy 2: use all base pts, rel, dists for resolving ''' # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 # signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0. l2_dist_rel_joints_to_base_pts_mask = torch.sqrt( torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1) ) < 0.05 signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5 dot_rel_with_normals = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) signed_dist_mask = signed_dist_mask.detach() # detach the mask # # dot_rel_with_normals: bsz x ws x nnj x nnb avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints signed_dist_e = torch.sum( signed_dist_e[signed_dist_mask] ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() ###### ====== get loss for signed distances ==== ### ''' strategy 2: use all base pts, rel, dists for resolving ''' ''' strategy 1: use nearest base pts, rel, dists for resolving ''' # dist_rhand_joints_to_base_pts = torch.sum( # (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 # ) # # minn_dists_idxes: bsz x ws x nnj --> # minn_dists_to_base_pts, minn_dists_idxes = torch.min( # dist_rhand_joints_to_base_pts, dim=-1 # ) # # base_pts: bsz x nn_base_pts x 3 # # # base_pts: bsz x ws x nn_base_pts x 3 # # # bsz x ws x nnj # # object verts and object faces # # ## other than the sampling process; not # # bsz x ws x nnj x 3 ## # nearest_base_pts = batched_index_select_ours( # base_pts_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # base normalse # # nearest_base_normals = batched_index_select_ours( # base_normals_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # the nearest distance points may be of some ambiguous # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # # bsz x ws x nnj # # signed_dist_joints_to_base_pts = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # # should not be negative # signed_dist_mask = signed_dist_joints_to_base_pts < 0. # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # ) < 0.05 # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 # ### ==== mean of signed distances ==== ### # signed_dist_e = torch.sum( # penetration # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] # ) / torch.clamp( # torch.sum(signed_dist_mask.float()), min=1e-5 # ).item() ''' strategy 1: use nearest base pts, rel, dists for resolving ''' ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() ### signed distance coeff -> the distance coeff # loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) fine_lr = 0.1 num_iters = 1000 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr) for i_iter in range(num_iters): opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() # dists_base_pts_to_joints ## dists_base_pts_to_joints ## if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ## dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).mean() ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## dist_rhand_joints_to_base_pts = torch.sum( (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 ) # minn_dists_idxes: bsz x ws x nnj --> minn_dists_to_base_pts, minn_dists_idxes = torch.min( dist_rhand_joints_to_base_pts, dim=-1 ) # base_pts: bsz x nn_base_pts x 3 # # base_pts: bsz x ws x nn_base_pts x 3 # # bsz x ws x nnj # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() # bsz x ws x nnj x 3 ## nearest_base_pts = batched_index_select_ours( base_pts_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # nearest_base_normals = batched_index_select_ours( base_normals_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # bsz x ws x nnj # signed_dist_joints_to_base_pts = torch.sum( rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 ) # should not be negative signed_dist_mask = signed_dist_joints_to_base_pts < 0. l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) ) < 0.05 signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 ### ==== mean of signed distances ==== ### signed_dist_e = torch.sum( -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] ) / torch.clamp( torch.sum(signed_dist_mask.float()), min=1e-5 ).item() ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) # tot_obj_trimeshes ### refine the optimization with signed energy ## signed_dist_e_coeff = 1.0 # fine_lr = 0.1 # num_iters = 1000 # num_iters = 100 # reinement # opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr) for i_iter in range(num_iters): # opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() # dists_base_pts_to_joints ## dists_base_pts_to_joints ## if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ## dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).mean() ''' strategy 2: use all base pts, rel, dists for resolving ''' # # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 # # signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0. # l2_dist_rel_joints_to_base_pts_mask = torch.sqrt( # torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1) # ) < 0.05 # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5 # ## === dot rel with normals === ## # # dot_rel_with_normals = torch.sum( # # rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # # ) # ## === dot rel with normals === ## # ## === dot rel with rel, strategy 3 === ## # dot_rel_with_normals = torch.sum( # -1.0 * rel_base_pts_to_hand_joints * rel_base_pts_to_hand_joints, dim=-1 # ) # ## === dot rel with rel, strategy 3 === ## # signed_dist_mask = signed_dist_mask.detach() # detach the mask # # # dot_rel_with_normals: bsz x ws x nnj x nnb # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() # signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints # signed_dist_e = torch.sum( # signed_dist_e[signed_dist_mask] # ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() # ###### ====== get loss for signed distances ==== ### ''' strategy 2: use all base pts, rel, dists for resolving ''' ## use all base pts ## { # hard projections for ''' strategy 1: use nearest base pts, rel, dists for resolving ''' # ### === e2 the signed distances to nearest points should not be negative to the neareste === ### # ## base_pts: bsz x nn_base_pts x 3 # ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## # ## bsz x ws x nnj x nnb ## # dist_rhand_joints_to_base_pts = torch.sum( # (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 # ) # # minn_dists_idxes: bsz x ws x nnj --> # minn_dists_to_base_pts, minn_dists_idxes = torch.min( # dist_rhand_joints_to_base_pts, dim=-1 # ) # # # base_pts: bsz x nn_base_pts x 3 # # # base_pts: bsz x ws x nn_base_pts x 3 # # # bsz x ws x nnj # # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() # # bsz x ws x nnj x 3 ## # nearest_base_pts = batched_index_select_ours( # base_pts_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # nearest_base_normals = batched_index_select_ours( # base_normals_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # # bsz x ws x nnj # # signed_dist_joints_to_base_pts = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # # should not be negative # signed_dist_mask = signed_dist_joints_to_base_pts < 0. # ## === luojisiwei and others === ## # # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # # ) < 0.05 # ## === luojisiwei and others === ## # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # ) < 0.1 # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 # ### ==== mean of signed distances ==== ### # # signed_dist_e = torch.sum( # # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] # # ) / torch.clamp( # # torch.sum(signed_dist_mask.float()), min=1e-5 # # ).item() # # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances # signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach() # # ## penetraition resolving --- strategy # dot_rel_with_normals = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # signed_dist_mask = signed_dist_mask.detach() # detach the mask # # # bsz x ws x nnj --> the loss term # ## signed distances 3 #### isgned distance 3 ### # ## dotrelwithnormals, ## # # # signed_dist_mask -> the distances # # dot_rel_with_normals: bsz x ws x nnj x nnb # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts # signed_dist_e = torch.sum( # signed_dist_e[signed_dist_mask] # ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() # ###### ====== get loss for signed distances ==== ### ''' strategy 1: use nearest base pts, rel, dists for resolving ''' } # bsz x ws x nnj # --> objmesh insides pts labels pts_inside_objmesh_labels = judge_penetrated_points(tot_obj_trimeshes, hand_joints) pts_inside_objmesh_labels_mask = pts_inside_objmesh_labels.bool() # { # hard projections for ''' strategy 1: use nearest base pts, rel, dists for resolving ''' ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## dist_rhand_joints_to_base_pts = torch.sum( (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 ) # minn_dists_idxes: bsz x ws x nnj --> # base_pts minn_dists_to_base_pts, minn_dists_idxes = torch.min( dist_rhand_joints_to_base_pts, dim=-1 ) # base_pts: bsz x nn_base_pts x 3 # # base_pts: bsz x ws x nn_base_pts x 3 # # bsz x ws x nnj # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() # bsz x ws x nnj x 3 ## # simple penetration ## nearest_base_pts = batched_index_select_ours( base_pts_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # nearest_base_normals = batched_index_select_ours( base_normals_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # bsz x ws x nnj # # signed_dist_joints_to_base_pts = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # # should not be negative # signed_dist_mask = signed_dist_joints_to_base_pts < 0. ## === luojisiwei and others === ## # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # ) < 0.05 ## === luojisiwei and others === ## ##### ===== GET l2_distance mask ===== ##### # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # ) < 0.1 # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 ##### ===== GET l2_distance mask ===== ##### ### ==== mean of signed distances ==== ### # signed_dist_e = torch.sum( # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] # ) / torch.clamp( # torch.sum(signed_dist_mask.float()), min=1e-5 # ).item() # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach() # # dot rel # penetraition resolving --- strategy # dot_rel_with_normals = torch.sum( # dot rhand joints with normals # # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # dot_rel_with_normals = torch.sum( # dot rhand joints with normals # -rel_joints_to_nearest_base_pts * rel_joints_to_nearest_base_pts, dim=-1 ) #### Get masks for penetrated joint points #### # signed_dist_mask = (signed_dist_mask.float() + pts_inside_objmesh_labels_mask.float()) > 1.5 signed_dist_mask = pts_inside_objmesh_labels_mask # bsz x ws x nnj signed_dist_mask = signed_dist_mask.detach() # detach the mask # # bsz x ws x nnj --> the loss term ## signed distances 3 #### isgned distance 3 ### ## dotrelwithnormals, ## # # signed_dist_mask -> the distances # dot_rel_with_normals: bsz x ws x nnj x nnb # avg over windows and batches # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() ## get singed distance energies ### ## projection ## # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts signed_dist_e = -1.0 * dot_rel_with_normals signed_dist_e = torch.sum( signed_dist_e[signed_dist_mask] ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() ###### ====== get loss for signed distances ==== ### ''' strategy 1: use nearest base pts, rel, dists for resolving ''' # cannot mask in some caes # change of isgned distances # # intersection spline ## judeg whether inside the object and only project those one inside of the object ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## #### ==== sv_dict ==== #### sv_dict = { 'pts_inside_objmesh_labels_mask': pts_inside_objmesh_labels_mask.detach().cpu().numpy(), 'hand_joints': hand_joints.detach().cpu().numpy(), 'obj_verts': [cur_verts.detach().cpu().numpy() for cur_verts in obj_verts], 'obj_faces': [cur_faces.detach().cpu().numpy() for cur_faces in obj_faces], 'base_pts': base_pts.detach().cpu().numpy(), 'base_normals': base_normals.detach().cpu().numpy(), # bsz x nnb x 3 -> bsz x nnb x 3 -> base normals # 'nearest_base_pts': nearest_base_pts.detach().cpu().numpy(), # bsz x ws x nnj x 3 # 'nearest_base_normals': nearest_base_normals.detach().cpu().numpy(), # bsz x ws x nnj x 3 --> base normals and pts } # sv_dict_folder = "/data1/sim/mdm/tmp_saving" os.makedirs(sv_dict_folder, exist_ok=True) sv_dict_fn = os.path.join(sv_dict_folder, f"optim_iter_{i_iter}.npy") np.save(sv_dict_fn, sv_dict) print(f"Obj and subj saved to {sv_dict_fn}") #### ==== sv_dict ==== #### ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() # shoudl take a # how to proejct the jvertex # hwo to project the veretex # weighted sum of the projectiondirection # weights of each base point # atraction field -> should be able to learn the penetration resolving strategy # stochestic penetration resolving strategy # loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) # avg_masks print('\tAvg masks: {}'.format(avg_masks.item())) ''' returning sampled_joints ''' sampled_joints = hand_joints np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy()) print(f"Optimized verts saved to optimized_verts.npy") return sampled_joints.detach() # TODO: other optimization strategies? e.g. sequential optimziation> # def optimize_sampled_hand_joints_wobj_v2(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces): # sampled_joints: bsz x ws x nnj x 3 # # sampled_joints: bsz x ws x nnj x 3 # obj trimeshes # tot_obj_trimeshes = get_obj_trimesh_list(obj_verts, obj_faces) ## TODO: write the collect function for object verts, normals, faces ## ### A simple penetration resolving strategy is as follows: #### 1) get vertices in the object; 2) get nearest base points (for simplicity); 3) project the vertex to the base point #### ## 1) for joints only; ## 2) for vertices; ## 3) for vertices; ## TODO: optimzie the resolvign strategy stated above ## bsz, ws, nnj = sampled_joints.shape[:3] device = sampled_joints.device coarse_lr = 0.1 num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations # mano_path = "/data1/sim/mano_models/mano/models" # obj_verts: bsz x nnobjverts x base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() signed_dist_e_coeff = 1.0 signed_dist_e_coeff = 0.0 ### start optimization ### # setup MANO layer mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=mano_path, # mano_path for the mano model # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).to(device) ## random init variables ## beta_var = torch.randn([bsz, 10]).to(device) rot_var = torch.randn([bsz * ws, 3]).to(device) theta_var = torch.randn([bsz * ws, 24]).to(device) transl_var = torch.randn([bsz * ws, 3]).to(device) beta_var.requires_grad_() rot_var.requires_grad_() theta_var.requires_grad_() transl_var.requires_grad_() opt = optim.Adam([rot_var, transl_var], lr=coarse_lr) for i_iter in range(num_iters): opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() if dists_base_pts_to_joints is not None: dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean() ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## ''' strategy 2: use all base pts, rel, dists for resolving ''' # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 # signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0. l2_dist_rel_joints_to_base_pts_mask = torch.sqrt( torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1) ) < 0.05 signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5 dot_rel_with_normals = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) signed_dist_mask = signed_dist_mask.detach() # detach the mask # # dot_rel_with_normals: bsz x ws x nnj x nnb avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints signed_dist_e = torch.sum( signed_dist_e[signed_dist_mask] ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() ###### ====== get loss for signed distances ==== ### ''' strategy 2: use all base pts, rel, dists for resolving ''' ''' strategy 1: use nearest base pts, rel, dists for resolving ''' # dist_rhand_joints_to_base_pts = torch.sum( # (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 # ) # # minn_dists_idxes: bsz x ws x nnj --> # minn_dists_to_base_pts, minn_dists_idxes = torch.min( # dist_rhand_joints_to_base_pts, dim=-1 # ) # # base_pts: bsz x nn_base_pts x 3 # # # base_pts: bsz x ws x nn_base_pts x 3 # # # bsz x ws x nnj # # object verts and object faces # # ## other than the sampling process; not # # bsz x ws x nnj x 3 ## # nearest_base_pts = batched_index_select_ours( # base_pts_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # base normalse # # nearest_base_normals = batched_index_select_ours( # base_normals_exp, indices=minn_dists_idxes, dim=2 # ) # # bsz x ws x nnj x 3 # # the nearest distance points may be of some ambiguous # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # # bsz x ws x nnj # # signed_dist_joints_to_base_pts = torch.sum( # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # # should not be negative # signed_dist_mask = signed_dist_joints_to_base_pts < 0. # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( # torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) # ) < 0.05 # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 # ### ==== mean of signed distances ==== ### # signed_dist_e = torch.sum( # penetration # -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] # ) / torch.clamp( # torch.sum(signed_dist_mask.float()), min=1e-5 # ).item() ''' strategy 1: use nearest base pts, rel, dists for resolving ''' ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() ### signed distance coeff -> the distance coeff # loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) fine_lr = 0.1 num_iters = 1000 opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr) for i_iter in range(num_iters): opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() # dists_base_pts_to_joints ## dists_base_pts_to_joints ## if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ## dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).mean() ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## dist_rhand_joints_to_base_pts = torch.sum( (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 ) # minn_dists_idxes: bsz x ws x nnj --> minn_dists_to_base_pts, minn_dists_idxes = torch.min( dist_rhand_joints_to_base_pts, dim=-1 ) # base_pts: bsz x nn_base_pts x 3 # # base_pts: bsz x ws x nn_base_pts x 3 # # bsz x ws x nnj # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() # bsz x ws x nnj x 3 ## nearest_base_pts = batched_index_select_ours( base_pts_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # nearest_base_normals = batched_index_select_ours( base_normals_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # bsz x ws x nnj # signed_dist_joints_to_base_pts = torch.sum( rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 ) # should not be negative signed_dist_mask = signed_dist_joints_to_base_pts < 0. l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt( torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1) ) < 0.05 signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5 ### ==== mean of signed distances ==== ### signed_dist_e = torch.sum( -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask] ) / torch.clamp( torch.sum(signed_dist_mask.float()), min=1e-5 ).item() ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) # tot_obj_trimeshes ### refine the optimization with signed energy ## # # signed_dist_jts_to_nearest_base_pts = [] # tot_nearest_base_pts = [] # tot_nearest_base_normals = [] signed_dist_e_coeff = 1.0 # fine_lr = 0.1 # num_iters = 1000 # num_iters = 100 # reinement # opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr) for i_iter in range(num_iters): # opt.zero_grad() # mano_layer # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1), beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var) hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001 ### === e1 should be close to predicted values === ### # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bs zx ws x nnj x nnb # signed_dist_base_pts_to_hand_joints = torch.sum( rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) rel_e = torch.sum( (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1 ).mean() # dists_base_pts_to_joints ## dists_base_pts_to_joints ## if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ## dist_e = torch.sum( (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1 ).mean() else: dist_e = torch.zeros((1,), dtype=torch.float32).mean() ### ==== inside the objemesh labels ==== ### # bsz x ws x nnj # --> objmesh insides pts labels # pts_inside_objmesh_labels = judge_penetrated_points(tot_obj_trimeshes, hand_joints) pts_inside_objmesh_labels_mask = pts_inside_objmesh_labels.bool() # { # hard projections for ''' strategy 1: use nearest base pts, rel, dists for resolving ''' ### === e2 the signed distances to nearest points should not be negative to the neareste === ### ## base_pts: bsz x nn_base_pts x 3 ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ## ## bsz x ws x nnj x nnb ## dist_rhand_joints_to_base_pts = torch.sum( (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1 ) # minn_dists_idxes: bsz x ws x nnj # # base_pts minn_dists_to_base_pts, minn_dists_idxes = torch.min( dist_rhand_joints_to_base_pts, dim=-1 ) # base_pts: bsz x nn_base_pts x 3 # # base_pts: bsz x ws x nn_base_pts x 3 # # bsz x ws x nnj # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous() # bsz x ws x nnj x 3 ## # simple penetration ## nearest_base_pts = batched_index_select_ours( base_pts_exp, indices=minn_dists_idxes, dim=2 ) # bsz x ws x nnj x 3 # # nearest_base_normals = batched_index_select_ours( base_normals_exp, indices=minn_dists_idxes, dim=2 ) tot_masks = [] tot_base_pts = [] tot_base_normals = [] tot_base_signed_dists = [] ## === nearest base pts === ## for i_bsz in range(nearest_base_pts.size(0)): # masks, base_pts, base_normals for each frame here # cur_bsz_ cur_bsz_masks = [pts_inside_objmesh_labels_mask[i_bsz][0]] cur_bsz_base_pts = [nearest_base_pts[i_bsz][0]] cur_bsz_base_normals = [nearest_base_normals[i_bsz][0]] # nnjts # ## st frame signed dist ## cur_bsz_st_frame_signed_dist = torch.sum( (hand_joints[i_bsz][0] - cur_bsz_base_pts[0]) * cur_bsz_base_normals[0], dim=-1 ) cur_bsz_signed_dist = [cur_bsz_st_frame_signed_dist] for i_fr in range(1, nearest_base_pts.size(1)): cur_bsz_cur_fr_jts = hand_joints[i_bsz][i_fr] # cur_bsz_cur_fr_base_pts = nearest_base_pts # cur_fr_jts - cur_bsz_cur_fr_prev_fr_signed_dist = torch.sum( (cur_bsz_cur_fr_jts - cur_bsz_base_pts[-1]) * cur_bsz_base_normals[-1], dim=-1 ) # nnjts # cur cur_bsz_cur_fr_mask = ((cur_bsz_signed_dist[-1] >= 0.).float() + (cur_bsz_cur_fr_prev_fr_signed_dist < 0.).float()) > 1.5 cur_bsz_cur_fr_base_pts = nearest_base_pts[i_bsz][i_fr].clone() cur_bsz_cur_fr_base_pts[cur_bsz_cur_fr_mask] = cur_bsz_base_pts[-1][cur_bsz_cur_fr_mask] cur_bsz_cur_fr_base_normals = nearest_base_normals[i_bsz][i_fr].clone() # ### curbsz curfr base normals; ### # cur_bsz_cur_fr_base_normals[cur_bsz_cur_fr_mask] = cur_bsz_base_normals[-1][cur_bsz_cur_fr_mask] cur_bsz_cur_fr_signed_dist = torch.sum( (cur_bsz_cur_fr_jts - cur_bsz_cur_fr_base_pts) * cur_bsz_cur_fr_base_normals, dim=-1 ) cur_bsz_cur_fr_signed_dist[cur_bsz_cur_fr_mask] = 0. # ot the bes points ### for masks ### cur_bsz_masks.append(cur_bsz_cur_fr_mask) cur_bsz_base_pts.append(cur_bsz_cur_fr_base_pts) cur_bsz_base_normals.append(cur_bsz_cur_fr_base_normals) # cur_bsz_masks = torch.stack(cur_bsz_masks, dim=0) cur_bsz_base_pts = torch.stack(cur_bsz_base_pts, dim=0) cur_bsz_base_normals = torch.stack(cur_bsz_base_normals, dim=0) cur_bsz_signed_dist = torch.stack(cur_bsz_signed_dist, dim=0) tot_masks.append(cur_bsz_masks) tot_base_pts.append(cur_bsz_base_pts) tot_base_normals.append(cur_bsz_base_normals) tot_base_signed_dists.append(cur_bsz_signed_dist) # masks; tot_masks = torch.stack(tot_masks, dim=0) tot_base_pts = torch.stack(tot_base_pts, dim=0) tot_base_normals = torch.stack(tot_base_normals, dim=0) tot_base_signed_dists = torch.stack(tot_base_signed_dists, dim=0) # nearest_base_pts = tot_base_pts.clone() # tot base pts nearest_base_normals = tot_base_normals.clone() pts_inside_objmesh_labels_mask = tot_masks.clone() # if len() # bsz x ws x nnj x 3 # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances # signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach() # # dot rel # penetraition resolving --- strategy # dot_rel_with_normals = torch.sum( # dot rhand joints with normals # # rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1 # ) # dot_rel_with_normals = torch.sum( # dot rhand joints with normals # -rel_joints_to_nearest_base_pts * rel_joints_to_nearest_base_pts, dim=-1 ) #### Get masks for penetrated joint points #### # signed_dist_mask = (signed_dist_mask.float() + pts_inside_objmesh_labels_mask.float()) > 1.5 signed_dist_mask = pts_inside_objmesh_labels_mask # bsz x ws x nnj signed_dist_mask = signed_dist_mask.detach() # detach the mask # # bsz x ws x nnj --> the loss term ## signed distances 3 #### isgned distance 3 ### ## dotrelwithnormals, ## # # signed_dist_mask -> the distances # dot_rel_with_normals: bsz x ws x nnj x nnb # avg over windows and batches # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean() ## get singed distance energies ### ## projection ## # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts ### dot_rel_with_normals --> signed_dist_e = -1.0 * dot_rel_with_normals signed_dist_e = torch.sum( signed_dist_e[signed_dist_mask] ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item() ###### ====== get loss for signed distances ==== ### ''' strategy 1: use nearest base pts, rel, dists for resolving ''' # cannot mask in some caes # change of isgned distances # # intersection spline ## judeg whether inside the object and only project those one inside of the object ## === e3 smoothness and prior losses === ## pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1]) shape_prior_loss = torch.mean(beta_var**2) pose_prior_loss = torch.mean(theta_var**2) ## === e3 smoothness and prior losses === ## # points to object vertices #### ==== sv_dict ==== #### sv_dict = { 'pts_inside_objmesh_labels_mask': pts_inside_objmesh_labels_mask.detach().cpu().numpy(), 'hand_joints': hand_joints.detach().cpu().numpy(), 'obj_verts': [cur_verts.detach().cpu().numpy() for cur_verts in obj_verts], 'obj_faces': [cur_faces.detach().cpu().numpy() for cur_faces in obj_faces], 'base_pts': base_pts.detach().cpu().numpy(), 'base_normals': base_normals.detach().cpu().numpy(), # bsz x nnb x 3 -> bsz x nnb x 3 -> base normals # 'nearest_base_pts': nearest_base_pts.detach().cpu().numpy(), # bsz x ws x nnj x 3 # 'nearest_base_normals': nearest_base_normals.detach().cpu().numpy(), # bsz x ws x nnj x 3 --> base normals and pts } # sv_dict_folder = "/data1/sim/mdm/tmp_saving" os.makedirs(sv_dict_folder, exist_ok=True) sv_dict_fn = os.path.join(sv_dict_folder, f"optim_iter_{i_iter}.npy") np.save(sv_dict_fn, sv_dict) print(f"Obj and subj saved to {sv_dict_fn}") #### ==== sv_dict ==== #### ## === e4 hand joints should be close to sampled hand joints === ## dist_dec_jts_to_sampled_pts = torch.sum( (hand_joints - sampled_joints) ** 2, dim=-1 ).mean() # shoudl take a # how to proejct the jvertex # hwo to project the veretex # weighted sum of the projectiondirection # weights of each base point # atraction field -> should be able to learn the penetration resolving strategy # stochestic penetration resolving strategy # loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tShape Prior Loss: {}'.format(shape_prior_loss.item())) print('\tPose Prior Loss: {}'.format(pose_prior_loss.item())) print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item())) print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item())) print('\trel_e Loss: {}'.format(rel_e.item())) print('\tdist_e Loss: {}'.format(dist_e.item())) print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item())) # avg_masks print('\tAvg masks: {}'.format(avg_masks.item())) ''' returning sampled_joints ''' sampled_joints = hand_joints np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy()) print(f"Optimized verts saved to optimized_verts.npy") return sampled_joints.detach() ## def create_gaussian_diffusion(args): ## create guassian diffusion ## # default params predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! steps = 1000 # scale_beta = 1. # no scaling timestep_respacing = '' # can be used for ddim sampling, we don't use it. learn_sigma = False # learn sigma # rescale_timesteps = False ## noose schedule; steps; scale_beta ## ## MSE ## betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) loss_type = gd.LossType.MSE if not timestep_respacing: timestep_respacing = [steps] print(f"dataset: {args.dataset}, rep_type: {args.rep_type}") if args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist"]: print(f"here! dataset: {args.dataset}, rep_type: {args.rep_type}") cur_spaced_diffusion_model = SpacedDiffusion_Ours # SpacedDiffusion_OursV2 elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we"]: cur_spaced_diffusion_model = SpacedDiffusion_OursV2 elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj"]: cur_spaced_diffusion_model = SpacedDiffusion_OursV3 # SpacedDiffusion_OursV4 elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj_latents"]: if args.diff_joint_quants: cur_spaced_diffusion_model = SpacedDiffusion_OursV7 elif args.diff_hand_params: cur_spaced_diffusion_model = SpacedDiffusion_OursV9 else: if args.diff_spatial: cur_spaced_diffusion_model = SpacedDiffusion_OursV5 elif args.diff_latents: cur_spaced_diffusion_model = SpacedDiffusion_OursV6 else: cur_spaced_diffusion_model = SpacedDiffusion_OursV4 else: cur_spaced_diffusion_model = SpacedDiffusion ### ==== predict xstart other than the noise in the model === ### return cur_spaced_diffusion_model( use_timesteps=space_timesteps(steps, timestep_respacing), betas=betas, model_mean_type=( gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X ), model_var_type=( ## use fixed sigmas / variances ## ( gd.ModelVarType.FIXED_LARGE if not args.sigma_small else gd.ModelVarType.FIXED_SMALL # fixed small # ) if not learn_sigma ## use learned sigmas ## else gd.ModelVarType.LEARNED_RANGE ), ## modelvartype ## loss_type=loss_type, ## loss_type ## rescale_timesteps=rescale_timesteps, lambda_vel=args.lambda_vel, lambda_rcxyz=args.lambda_rcxyz, ## lambda lambda_fc=args.lambda_fc, # motion_to_rep denoising_stra=args.denoising_stra, inter_optim=args.inter_optim, args=args, ) ### from decoded energies to optimized joints ### ## latent variables ## ## encoded energies ## from energies calculated from perturbed energies ## ## decoded energies should also match the clean energy term ## ## and those values should be all denormed ## def optimize_joints_according_to_e(dec_joints, base_pts, base_normals, dec_e): # dec_e_along_normals: bsz x (ws - 1) x nnj x nnb dec_e_along_normals = dec_e['dec_e_along_normals'] # dec_e_vt_normals: bsz x (ws - 1) x nnj x nnb dec_e_vt_normals = dec_e['dec_e_vt_normals'] nn_iters = 10 coarse_lr = 0.001 dec_joints.requires_grad_() opt = optim.Adam([dec_joints], lr=coarse_lr) for i_iter in range(nn_iters): # dec_joints: bsz x ws x nnj x 3 # base_pts: bsz x nnb x 3 k_f = 1. # bsz x ws x nnj x nnb x 3 # denormed_rel_base_pts_to_rhand_joints = dec_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) k_f = 1. ## l2 rel base pts to pert rhand joints ## # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # bsz x (ws - 1) x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# denormed_rhand_joints_disp = dec_joints[:, 1:, :, :] - dec_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. ### bsz x (ws - 1) x nnj x nnb ### e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # ## e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # loss_cur_e_pred_e_along_normals = ((e_disp_rel_to_base_along_normals - dec_e_along_normals) ** 2).mean() loss_cur_e_pred_e_vt_normals = ((e_disp_rel_to_baes_vt_normals - dec_e_vt_normals) ** 2).mean() loss = loss_cur_e_pred_e_along_normals + loss_cur_e_pred_e_vt_normals opt.zero_grad() loss.backward() opt.step() print('Iter {}: {}'.format(i_iter, loss.item()), flush=True) print('\tloss_cur_e_pred_e_along_normals: {}'.format(loss_cur_e_pred_e_along_normals.item())) print('\tloss_cur_e_pred_e_vt_normals: {}'.format(loss_cur_e_pred_e_vt_normals.item())) return dec_joints.detach()