import torch import numpy as np import pickle as pkl import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) # from priors.pose_prior_35 import Prior # from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior from priors.shape_prior import ShapePrior from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa, geodesic_loss_R from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch from priors.shape_prior import ShapePrior from configs.SMAL_configs import SMAL_MODEL_CONFIG from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss class LossRef(torch.nn.Module): def __init__(self, smal_model_type, data_info, nf_version=None): super(LossRef, self).__init__() self.criterion_regr = torch.nn.MSELoss() # takes the mean self.criterion_class = torch.nn.CrossEntropyLoss() class_weights_isflat = torch.tensor([12, 2]) self.criterion_class_isflat = torch.nn.CrossEntropyLoss(weight=class_weights_isflat) self.criterion_l1 = torch.nn.L1Loss() self.geodesic_loss = geodesic_loss_R(reduction='mean') self.gc_loss_on_mesh = LossGConMesh() self.data_info = data_info self.smal_model_type = smal_model_type self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :]) # if nf_version is not None: # self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version) self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path'] self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl' with open(remeshing_path, 'rb') as fp: self.remeshing_dict = pkl.load(fp) self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long) self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32) # load 3d data for the unity dogs (an optional shape prior for 11 breeds) self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs'] if self.unity_smal_shape_prior_dogs is not None: self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type) else: self.dog_betas_unity = None def forward(self, output_ref, output_ref_comp, target_dict, weight_dict_ref): # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image'] # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight'] batch_size = output_ref['keyp_2d'].shape[0] loss_dict_temp = {} # loss on reprojected keypoints output_kp_resh = (output_ref['keyp_2d']).reshape((-1, 2)) target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2)) weights_resh = target_dict['tpts'][:, :, 2].reshape((-1)) keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1)) loss_dict_temp['keyp_ref'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \ max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) # loss on reprojected silhouette assert output_ref['silh'].shape == (target_dict['silh'][:, None, :, :]).shape silh_loss_type = 'default' if silh_loss_type == 'default': with torch.no_grad(): thr_silh = 20 diff = torch.norm(output_kp_resh - target_kp_resh, dim=1) diff_x = diff.reshape((batch_size, -1)) weights_resh_x = weights_resh.reshape((batch_size, -1)) unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6) loss_silh_bs = ((output_ref['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_ref['silh'].shape[2]*output_ref['silh'].shape[3]) loss_dict_temp['silh_ref'] = loss_silh_bs[unweighted_kp_mean_dist 0: if keep_smal_mesh: target_gc_class = target_dict['gc'][:, :, 0] gc_errors_plane = calculate_plane_errors_batch(output_ref['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching']) loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane) else: # use a uniformly sampled mesh target_gc_class = target_dict['gc'][:, :, 0] device = output_ref['vertices_smal'].device remeshing_relevant_faces = self.remeshing_relevant_faces.to(device) remeshing_relevant_barys = self.remeshing_relevant_barys.to(device) bs = output_ref['vertices_smal'].shape[0] # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_ref['vertices_smal'][:, self.remeshing_relevant_faces]) # sel_verts_comparison = output_ref['vertices_smal'][:, self.remeshing_relevant_faces] # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison) sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3)) verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts) target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32)) target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long) gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching']) loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane) loss_dict_temp['gc_blowplane'] = torch.mean(gc_errors_under_plane) # error on classification if the ground plane is flat if 'gc_isflat' in weight_dict_ref.keys(): # import pdb; pdb.set_trace() self.criterion_class_isflat.to(device) loss_dict_temp['gc_isflat'] = self.criterion_class(output_ref['isflat'], target_dict['isflat'].to(device)) # if we refine the shape WITHIN the refinement newtork (shaperef_type is not inexistent) # shape regularization # 'smal': loss on betas (pca coefficients), betas should be close to 0 # 'limbs...' loss on selected betas_limbs device = output_ref_comp['ref_trans_notnorm'].device loss_shape_weighted_list = [torch.zeros((1), device=device).mean()] if 'shape_options' in weight_dict_ref.keys(): for ind_sp, sp in enumerate(weight_dict_ref['shape_options']): weight_sp = weight_dict_ref['shape'][ind_sp] # self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] if sp == 'smal': loss_shape_tmp = self.shape_prior(output_ref['betas']) elif sp == 'limbs': loss_shape_tmp = torch.mean((output_ref['betas_limbs'])**2) elif sp == 'limbs7': limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2] limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device) loss_shape_tmp = torch.mean((output_ref['betas_limbs'] * limb_coeffs[None, :])**2) else: raise NotImplementedError loss_shape_weighted_list.append(weight_sp * loss_shape_tmp) loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum() # 3D loss for dogs for which we have a unity model or toy figure loss_dict_temp['models3d'] = torch.zeros((1), device=device).mean().to(output_ref['betas'].device) if 'models3d' in weight_dict_ref.keys(): if weight_dict_ref['models3d'] > 0: assert (self.dog_betas_unity is not None) if weight_dict_ref['models3d'] > 0: for ind_dog in range(target_dict['breed_index'].shape[0]): breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy()) if breed_index in self.dog_betas_unity.keys(): betas_target = self.dog_betas_unity[breed_index][:output_ref['betas'].shape[1]].to(output_ref['betas'].device) betas_output = output_ref['betas'][ind_dog, :] betas_limbs_output = output_ref['betas_limbs'][ind_dog, :] loss_dict_temp['models3d'] += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_ref['betas'].shape[1] + output_ref['betas_limbs'].shape[1]) else: weight_dict_ref['models3d'] = 0.0 else: weight_dict_ref['models3d'] = 0.0 # weight the losses loss = torch.zeros((1)).mean().to(device=output_ref['keyp_2d'].device, dtype=output_ref['keyp_2d'].dtype) loss_dict = {} for loss_name in weight_dict_ref.keys(): if not loss_name in ['shape', 'shape_options']: if weight_dict_ref[loss_name] > 0: loss_weighted = loss_dict_temp[loss_name] * weight_dict_ref[loss_name] loss_dict[loss_name] = loss_weighted.item() loss += loss_weighted loss += loss_shape_weighted loss_dict['loss'] = loss.item() return loss, loss_dict