Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import pickle as pkl | |
import os | |
import sys | |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) | |
# from priors.pose_prior_35 import Prior | |
# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior | |
from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior | |
from priors.shape_prior import ShapePrior | |
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa | |
from configs.SMAL_configs import UNITY_SMAL_SHAPE_PRIOR_DOGS | |
class Loss(torch.nn.Module): | |
def __init__(self, data_info, nf_version=None): | |
super(Loss, self).__init__() | |
self.criterion_regr = torch.nn.MSELoss() # takes the mean | |
self.criterion_class = torch.nn.CrossEntropyLoss() | |
self.data_info = data_info | |
self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :]) | |
self.l_anchor = None | |
self.l_pos = None | |
self.l_neg = None | |
if nf_version is not None: | |
self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version) | |
self.shape_prior = ShapePrior(UNITY_SMAL_SHAPE_PRIOR_DOGS) | |
self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1) | |
# load 3d data for the unity dogs (an optional shape prior for 11 breeds) | |
with open(UNITY_SMAL_SHAPE_PRIOR_DOGS, 'rb') as f: | |
data = pkl.load(f) | |
dog_betas_unity = data['dogs_betas'] | |
self.dog_betas_unity = {29: torch.tensor(dog_betas_unity[0, :]).float(), | |
91: torch.tensor(dog_betas_unity[1, :]).float(), | |
84: torch.tensor(0.5*dog_betas_unity[3, :] + 0.5*dog_betas_unity[14, :]).float(), | |
85: torch.tensor(dog_betas_unity[5, :]).float(), | |
28: torch.tensor(dog_betas_unity[6, :]).float(), | |
94: torch.tensor(dog_betas_unity[7, :]).float(), | |
92: torch.tensor(dog_betas_unity[8, :]).float(), | |
95: torch.tensor(dog_betas_unity[10, :]).float(), | |
20: torch.tensor(dog_betas_unity[11, :]).float(), | |
83: torch.tensor(dog_betas_unity[12, :]).float(), | |
99: torch.tensor(dog_betas_unity[16, :]).float()} | |
def prepare_anchor_pos_neg(self, batch_size, device): | |
l0 = np.arange(0, batch_size, 2) | |
l_anchor = [] | |
l_pos = [] | |
l_neg = [] | |
for ind in l0: | |
xx = set(np.arange(0, batch_size)) | |
xx.discard(ind) | |
xx.discard(ind+1) | |
for ind2 in xx: | |
if ind2 % 2 == 0: | |
l_anchor.append(ind) | |
l_pos.append(ind + 1) | |
else: | |
l_anchor.append(ind + 1) | |
l_pos.append(ind) | |
l_neg.append(ind2) | |
self.l_anchor = torch.Tensor(l_anchor).to(torch.int64).to(device) | |
self.l_pos = torch.Tensor(l_pos).to(torch.int64).to(device) | |
self.l_neg = torch.Tensor(l_neg).to(torch.int64).to(device) | |
return | |
def forward(self, output_reproj, target_dict, weight_dict=None): | |
# output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image'] | |
# target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight'] | |
batch_size = output_reproj['keyp_2d'].shape[0] | |
# loss on reprojected keypoints | |
output_kp_resh = (output_reproj['keyp_2d']).reshape((-1, 2)) | |
target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2)) | |
weights_resh = target_dict['tpts'][:, :, 2].reshape((-1)) | |
keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1)) | |
loss_keyp = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \ | |
max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5) | |
# loss on reprojected silhouette | |
assert output_reproj['silh'].shape == (target_dict['silh'][:, None, :, :]).shape | |
silh_loss_type = 'default' | |
if silh_loss_type == 'default': | |
with torch.no_grad(): | |
thr_silh = 20 | |
diff = torch.norm(output_kp_resh - target_kp_resh, dim=1) | |
diff_x = diff.reshape((batch_size, -1)) | |
weights_resh_x = weights_resh.reshape((batch_size, -1)) | |
unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6) | |
loss_silh_bs = ((output_reproj['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_reproj['silh'].shape[2]*output_reproj['silh'].shape[3]) | |
loss_silh = loss_silh_bs[unweighted_kp_mean_dist<thr_silh].sum() / batch_size | |
else: | |
print('silh_loss_type: ' + silh_loss_type) | |
raise ValueError | |
# shape regularization | |
# 'smal': loss on betas (pca coefficients), betas should be close to 0 | |
# 'limbs...' loss on selected betas_limbs | |
loss_shape_weighted_list = [torch.zeros((1)).mean().to(output_reproj['keyp_2d'].device)] | |
for ind_sp, sp in enumerate(weight_dict['shape_options']): | |
weight_sp = weight_dict['shape'][ind_sp] | |
# self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'] | |
if sp == 'smal': | |
loss_shape_tmp = self.shape_prior(output_reproj['betas']) | |
elif sp == 'limbs': | |
loss_shape_tmp = torch.mean((output_reproj['betas_limbs'])**2) | |
elif sp == 'limbs7': | |
limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2] | |
limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device) | |
loss_shape_tmp = torch.mean((output_reproj['betas_limbs'] * limb_coeffs[None, :])**2) | |
else: | |
raise NotImplementedError | |
loss_shape_weighted_list.append(weight_sp * loss_shape_tmp) | |
loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum() | |
# 3D loss for dogs for which we have a unity model or toy figure | |
loss_models3d = torch.zeros((1)).mean().to(output_reproj['betas'].device) | |
if 'models3d' in weight_dict.keys(): | |
if weight_dict['models3d'] > 0: | |
for ind_dog in range(target_dict['breed_index'].shape[0]): | |
breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy()) | |
if breed_index in self.dog_betas_unity.keys(): | |
betas_target = self.dog_betas_unity[breed_index][:output_reproj['betas'].shape[1]].to(output_reproj['betas'].device) | |
betas_output = output_reproj['betas'][ind_dog, :] | |
betas_limbs_output = output_reproj['betas_limbs'][ind_dog, :] | |
loss_models3d += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_reproj['betas'].shape[1] + output_reproj['betas_limbs'].shape[1]) | |
else: | |
weight_dict['models3d'] = 0 | |
# shape resularization loss on shapedirs | |
# -> in the current version shapedirs are kept fixed, so we don't need those losses | |
if weight_dict['shapedirs'] > 0: | |
raise NotImplementedError | |
else: | |
loss_shapedirs = torch.zeros((1)).mean().to(output_reproj['betas'].device) | |
# prior on back joints (not used in cvpr 2022 paper) | |
# -> elementwise MSE loss on all 6 coefficients of 6d rotation representation | |
if 'pose_0' in weight_dict.keys(): | |
if weight_dict['pose_0'] > 0: | |
pred_pose_rot6d = output_reproj['pose_rot6d'] | |
w_rj_np = np.zeros((pred_pose_rot6d.shape[1])) | |
w_rj_np[[2, 3, 4, 5]] = 1.0 # back | |
w_rj = torch.tensor(w_rj_np).to(torch.float32).to(pred_pose_rot6d.device) | |
zero_rot = torch.tensor([1, 0, 0, 1, 0, 0]).to(pred_pose_rot6d.device).to(torch.float32)[None, None, :].repeat((batch_size, pred_pose_rot6d.shape[1], 1)) | |
loss_pose = self.criterion_regr(pred_pose_rot6d*w_rj[None, :, None], zero_rot*w_rj[None, :, None]) | |
else: | |
loss_pose = torch.zeros((1)).mean() | |
# pose prior | |
# -> we did experiment with different pose priors, for example: | |
# * similart to SMALify (https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py, | |
# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/priors/pose_prior_35.py) | |
# * vae | |
# * normalizing flow pose prior | |
# -> our cvpr 2022 paper uses the normalizing flow pose prior as implemented below | |
if 'poseprior' in weight_dict.keys(): | |
if weight_dict['poseprior'] > 0: | |
pred_pose_rot6d = output_reproj['pose_rot6d'] | |
pred_pose = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) | |
if 'normalizing_flow_tiger' in weight_dict['poseprior_options']: | |
if output_reproj['normflow_z'] is not None: | |
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='square') | |
else: | |
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='square') | |
elif 'normalizing_flow_tiger_logprob' in weight_dict['poseprior_options']: | |
if output_reproj['normflow_z'] is not None: | |
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='neg_log_prob') | |
else: | |
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='neg_log_prob') | |
else: | |
raise NotImplementedError | |
else: | |
loss_poseprior = torch.zeros((1)).mean() | |
else: | |
weight_dict['poseprior'] = 0 | |
loss_poseprior = torch.zeros((1)).mean() | |
# add a prior which penalizes side-movement angles for legs | |
if 'poselegssidemovement' in weight_dict.keys(): | |
use_pose_legs_side_loss = True | |
else: | |
use_pose_legs_side_loss = False | |
if use_pose_legs_side_loss: | |
leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back | |
leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back | |
vec = torch.zeros((3, 1)).to(device=pred_pose.device, dtype=pred_pose.dtype) | |
vec[2] = -1 | |
x0_rotmat = pred_pose | |
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :] | |
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :] | |
x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec | |
x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec | |
eps=0 # 1e-7 | |
# use the component of the vector which points to the side | |
loss_poselegssidemovement = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean() | |
else: | |
loss_poselegssidemovement = torch.zeros((1)).mean() | |
weight_dict['poselegssidemovement'] = 0 | |
# dog breed classification loss | |
dog_breed_gt = target_dict['breed_index'] | |
dog_breed_pred = output_reproj['dog_breed'] | |
loss_class = self.criterion_class(dog_breed_pred, dog_breed_gt) | |
# dog breed relationship loss | |
# -> we did experiment with many other options, but none was significantly better | |
if '4' in weight_dict['breed_options']: # we have pairs of dogs of the same breed | |
assert weight_dict['breed'] > 0 | |
z = output_reproj['z'] | |
# go through all pairs and compare them to each other sample | |
if self.l_anchor is None: | |
self.prepare_anchor_pos_neg(batch_size, z.device) | |
anchor = torch.index_select(z, 0, self.l_anchor) | |
positive = torch.index_select(z, 0, self.l_pos) | |
negative = torch.index_select(z, 0, self.l_neg) | |
loss_breed = self.criterion_triplet(anchor, positive, negative) | |
else: | |
loss_breed = torch.zeros((1)).mean() | |
# regularizarion for focal length | |
loss_flength_near_mean = torch.mean(output_reproj['flength']**2) | |
loss_flength = loss_flength_near_mean | |
# bodypart segmentation loss | |
if 'partseg' in weight_dict.keys(): | |
if weight_dict['partseg'] > 0: | |
raise NotImplementedError | |
else: | |
loss_partseg = torch.zeros((1)).mean() | |
else: | |
weight_dict['partseg'] = 0 | |
loss_partseg = torch.zeros((1)).mean() | |
# weight and combine losses | |
loss_keyp_weighted = loss_keyp * weight_dict['keyp'] | |
loss_silh_weighted = loss_silh * weight_dict['silh'] | |
loss_shapedirs_weighted = loss_shapedirs * weight_dict['shapedirs'] | |
loss_pose_weighted = loss_pose * weight_dict['pose_0'] | |
loss_class_weighted = loss_class * weight_dict['class'] | |
loss_breed_weighted = loss_breed * weight_dict['breed'] | |
loss_flength_weighted = loss_flength * weight_dict['flength'] | |
loss_poseprior_weighted = loss_poseprior * weight_dict['poseprior'] | |
loss_partseg_weighted = loss_partseg * weight_dict['partseg'] | |
loss_models3d_weighted = loss_models3d * weight_dict['models3d'] | |
loss_poselegssidemovement_weighted = loss_poselegssidemovement * weight_dict['poselegssidemovement'] | |
#################################################################################################### | |
loss = loss_keyp_weighted + loss_silh_weighted + loss_shape_weighted + loss_pose_weighted + loss_class_weighted + \ | |
loss_shapedirs_weighted + loss_breed_weighted + loss_flength_weighted + loss_poseprior_weighted + \ | |
loss_partseg_weighted + loss_models3d_weighted + loss_poselegssidemovement_weighted | |
#################################################################################################### | |
loss_dict = {'loss': loss.item(), | |
'loss_keyp_weighted': loss_keyp_weighted.item(), \ | |
'loss_silh_weighted': loss_silh_weighted.item(), \ | |
'loss_shape_weighted': loss_shape_weighted.item(), \ | |
'loss_shapedirs_weighted': loss_shapedirs_weighted.item(), \ | |
'loss_pose0_weighted': loss_pose_weighted.item(), \ | |
'loss_class_weighted': loss_class_weighted.item(), \ | |
'loss_breed_weighted': loss_breed_weighted.item(), \ | |
'loss_flength_weighted': loss_flength_weighted.item(), \ | |
'loss_poseprior_weighted': loss_poseprior_weighted.item(), \ | |
'loss_partseg_weighted': loss_partseg_weighted.item(), \ | |
'loss_models3d_weighted': loss_models3d_weighted.item(), \ | |
'loss_poselegssidemovement_weighted': loss_poselegssidemovement_weighted.item()} | |
return loss, loss_dict | |