import torch import numpy as np import pickle as pkl import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) # from priors.pose_prior_35 import Prior # from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior from priors.shape_prior import ShapePrior from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa from configs.SMAL_configs import UNITY_SMAL_SHAPE_PRIOR_DOGS class Loss(torch.nn.Module): def __init__(self, data_info, nf_version=None): super(Loss, self).__init__() self.criterion_regr = torch.nn.MSELoss() # takes the mean self.criterion_class = torch.nn.CrossEntropyLoss() self.data_info = data_info self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :]) self.l_anchor = None self.l_pos = None self.l_neg = None if nf_version is not None: self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version) self.shape_prior = ShapePrior(UNITY_SMAL_SHAPE_PRIOR_DOGS) self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1) # load 3d data for the unity dogs (an optional shape prior for 11 breeds) with open(UNITY_SMAL_SHAPE_PRIOR_DOGS, 'rb') as f: data = pkl.load(f) dog_betas_unity = data['dogs_betas'] self.dog_betas_unity = {29: torch.tensor(dog_betas_unity[0, :]).float(), 91: torch.tensor(dog_betas_unity[1, :]).float(), 84: torch.tensor(0.5*dog_betas_unity[3, :] + 0.5*dog_betas_unity[14, :]).float(), 85: torch.tensor(dog_betas_unity[5, :]).float(), 28: torch.tensor(dog_betas_unity[6, :]).float(), 94: torch.tensor(dog_betas_unity[7, :]).float(), 92: torch.tensor(dog_betas_unity[8, :]).float(), 95: torch.tensor(dog_betas_unity[10, :]).float(), 20: torch.tensor(dog_betas_unity[11, :]).float(), 83: torch.tensor(dog_betas_unity[12, :]).float(), 99: torch.tensor(dog_betas_unity[16, :]).float()} def prepare_anchor_pos_neg(self, batch_size, device): l0 = np.arange(0, batch_size, 2) l_anchor = [] l_pos = [] l_neg = [] for ind in l0: xx = set(np.arange(0, batch_size)) xx.discard(ind) xx.discard(ind+1) for ind2 in xx: if ind2 % 2 == 0: l_anchor.append(ind) l_pos.append(ind + 1) else: l_anchor.append(ind + 1) l_pos.append(ind) l_neg.append(ind2) self.l_anchor = torch.Tensor(l_anchor).to(torch.int64).to(device) self.l_pos = torch.Tensor(l_pos).to(torch.int64).to(device) self.l_neg = torch.Tensor(l_neg).to(torch.int64).to(device) return def forward(self, output_reproj, target_dict, weight_dict=None): # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image'] # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight'] batch_size = output_reproj['keyp_2d'].shape[0] # loss on reprojected keypoints output_kp_resh = (output_reproj['keyp_2d']).reshape((-1, 2)) target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2)) weights_resh = target_dict['tpts'][:, :, 2].reshape((-1)) keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1)) loss_keyp = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \ max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) # loss on reprojected silhouette assert output_reproj['silh'].shape == (target_dict['silh'][:, None, :, :]).shape silh_loss_type = 'default' if silh_loss_type == 'default': with torch.no_grad(): thr_silh = 20 diff = torch.norm(output_kp_resh - target_kp_resh, dim=1) diff_x = diff.reshape((batch_size, -1)) weights_resh_x = weights_resh.reshape((batch_size, -1)) unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6) loss_silh_bs = ((output_reproj['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_reproj['silh'].shape[2]*output_reproj['silh'].shape[3]) loss_silh = loss_silh_bs[unweighted_kp_mean_dist 0: for ind_dog in range(target_dict['breed_index'].shape[0]): breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy()) if breed_index in self.dog_betas_unity.keys(): betas_target = self.dog_betas_unity[breed_index][:output_reproj['betas'].shape[1]].to(output_reproj['betas'].device) betas_output = output_reproj['betas'][ind_dog, :] betas_limbs_output = output_reproj['betas_limbs'][ind_dog, :] loss_models3d += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_reproj['betas'].shape[1] + output_reproj['betas_limbs'].shape[1]) else: weight_dict['models3d'] = 0 # shape resularization loss on shapedirs # -> in the current version shapedirs are kept fixed, so we don't need those losses if weight_dict['shapedirs'] > 0: raise NotImplementedError else: loss_shapedirs = torch.zeros((1)).mean().to(output_reproj['betas'].device) # prior on back joints (not used in cvpr 2022 paper) # -> elementwise MSE loss on all 6 coefficients of 6d rotation representation if 'pose_0' in weight_dict.keys(): if weight_dict['pose_0'] > 0: pred_pose_rot6d = output_reproj['pose_rot6d'] w_rj_np = np.zeros((pred_pose_rot6d.shape[1])) w_rj_np[[2, 3, 4, 5]] = 1.0 # back w_rj = torch.tensor(w_rj_np).to(torch.float32).to(pred_pose_rot6d.device) zero_rot = torch.tensor([1, 0, 0, 1, 0, 0]).to(pred_pose_rot6d.device).to(torch.float32)[None, None, :].repeat((batch_size, pred_pose_rot6d.shape[1], 1)) loss_pose = self.criterion_regr(pred_pose_rot6d*w_rj[None, :, None], zero_rot*w_rj[None, :, None]) else: loss_pose = torch.zeros((1)).mean() # pose prior # -> we did experiment with different pose priors, for example: # * similart to SMALify (https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py, # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/priors/pose_prior_35.py) # * vae # * normalizing flow pose prior # -> our cvpr 2022 paper uses the normalizing flow pose prior as implemented below if 'poseprior' in weight_dict.keys(): if weight_dict['poseprior'] > 0: pred_pose_rot6d = output_reproj['pose_rot6d'] pred_pose = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) if 'normalizing_flow_tiger' in weight_dict['poseprior_options']: if output_reproj['normflow_z'] is not None: loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='square') else: loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='square') elif 'normalizing_flow_tiger_logprob' in weight_dict['poseprior_options']: if output_reproj['normflow_z'] is not None: loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='neg_log_prob') else: loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='neg_log_prob') else: raise NotImplementedError else: loss_poseprior = torch.zeros((1)).mean() else: weight_dict['poseprior'] = 0 loss_poseprior = torch.zeros((1)).mean() # add a prior which penalizes side-movement angles for legs if 'poselegssidemovement' in weight_dict.keys(): use_pose_legs_side_loss = True else: use_pose_legs_side_loss = False if use_pose_legs_side_loss: leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back vec = torch.zeros((3, 1)).to(device=pred_pose.device, dtype=pred_pose.dtype) vec[2] = -1 x0_rotmat = pred_pose x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :] x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :] x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec eps=0 # 1e-7 # use the component of the vector which points to the side loss_poselegssidemovement = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean() else: loss_poselegssidemovement = torch.zeros((1)).mean() weight_dict['poselegssidemovement'] = 0 # dog breed classification loss dog_breed_gt = target_dict['breed_index'] dog_breed_pred = output_reproj['dog_breed'] loss_class = self.criterion_class(dog_breed_pred, dog_breed_gt) # dog breed relationship loss # -> we did experiment with many other options, but none was significantly better if '4' in weight_dict['breed_options']: # we have pairs of dogs of the same breed assert weight_dict['breed'] > 0 z = output_reproj['z'] # go through all pairs and compare them to each other sample if self.l_anchor is None: self.prepare_anchor_pos_neg(batch_size, z.device) anchor = torch.index_select(z, 0, self.l_anchor) positive = torch.index_select(z, 0, self.l_pos) negative = torch.index_select(z, 0, self.l_neg) loss_breed = self.criterion_triplet(anchor, positive, negative) else: loss_breed = torch.zeros((1)).mean() # regularizarion for focal length loss_flength_near_mean = torch.mean(output_reproj['flength']**2) loss_flength = loss_flength_near_mean # bodypart segmentation loss if 'partseg' in weight_dict.keys(): if weight_dict['partseg'] > 0: raise NotImplementedError else: loss_partseg = torch.zeros((1)).mean() else: weight_dict['partseg'] = 0 loss_partseg = torch.zeros((1)).mean() # weight and combine losses loss_keyp_weighted = loss_keyp * weight_dict['keyp'] loss_silh_weighted = loss_silh * weight_dict['silh'] loss_shapedirs_weighted = loss_shapedirs * weight_dict['shapedirs'] loss_pose_weighted = loss_pose * weight_dict['pose_0'] loss_class_weighted = loss_class * weight_dict['class'] loss_breed_weighted = loss_breed * weight_dict['breed'] loss_flength_weighted = loss_flength * weight_dict['flength'] loss_poseprior_weighted = loss_poseprior * weight_dict['poseprior'] loss_partseg_weighted = loss_partseg * weight_dict['partseg'] loss_models3d_weighted = loss_models3d * weight_dict['models3d'] loss_poselegssidemovement_weighted = loss_poselegssidemovement * weight_dict['poselegssidemovement'] #################################################################################################### loss = loss_keyp_weighted + loss_silh_weighted + loss_shape_weighted + loss_pose_weighted + loss_class_weighted + \ loss_shapedirs_weighted + loss_breed_weighted + loss_flength_weighted + loss_poseprior_weighted + \ loss_partseg_weighted + loss_models3d_weighted + loss_poselegssidemovement_weighted #################################################################################################### loss_dict = {'loss': loss.item(), 'loss_keyp_weighted': loss_keyp_weighted.item(), \ 'loss_silh_weighted': loss_silh_weighted.item(), \ 'loss_shape_weighted': loss_shape_weighted.item(), \ 'loss_shapedirs_weighted': loss_shapedirs_weighted.item(), \ 'loss_pose0_weighted': loss_pose_weighted.item(), \ 'loss_class_weighted': loss_class_weighted.item(), \ 'loss_breed_weighted': loss_breed_weighted.item(), \ 'loss_flength_weighted': loss_flength_weighted.item(), \ 'loss_poseprior_weighted': loss_poseprior_weighted.item(), \ 'loss_partseg_weighted': loss_partseg_weighted.item(), \ 'loss_models3d_weighted': loss_models3d_weighted.item(), \ 'loss_poselegssidemovement_weighted': loss_poselegssidemovement_weighted.item()} return loss, loss_dict