|
import os |
|
import sys |
|
import pickle |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
DEFAULT_DTYPE = torch.float32 |
|
|
|
|
|
def create_prior(prior_type, **kwargs): |
|
if prior_type == 'gmm': |
|
prior = MaxMixturePrior(**kwargs) |
|
elif prior_type == 'l2': |
|
return L2Prior(**kwargs) |
|
elif prior_type == 'angle': |
|
return SMPLifyAnglePrior(**kwargs) |
|
elif prior_type == 'none' or prior_type is None: |
|
|
|
def no_prior(*args, **kwargs): |
|
return 0.0 |
|
prior = no_prior |
|
else: |
|
raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') |
|
return prior |
|
|
|
|
|
class SMPLifyAnglePrior(nn.Module): |
|
def __init__(self, dtype=DEFAULT_DTYPE, **kwargs): |
|
super(SMPLifyAnglePrior, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) |
|
angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) |
|
self.register_buffer('angle_prior_idxs', angle_prior_idxs) |
|
|
|
angle_prior_signs = np.array([1, -1, -1, -1], |
|
dtype=np.float32 if dtype == torch.float32 |
|
else np.float64) |
|
angle_prior_signs = torch.tensor(angle_prior_signs, |
|
dtype=dtype) |
|
self.register_buffer('angle_prior_signs', angle_prior_signs) |
|
|
|
def forward(self, pose, with_global_pose=False): |
|
''' Returns the angle prior loss for the given pose |
|
|
|
Args: |
|
pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle |
|
representation of the rotations of the joints of the SMPL model. |
|
Kwargs: |
|
with_global_pose: Whether the pose vector also contains the global |
|
orientation of the SMPL model. If not then the indices must be |
|
corrected. |
|
Returns: |
|
A sze (B) tensor containing the angle prior loss for each element |
|
in the batch. |
|
''' |
|
angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 |
|
return torch.exp(pose[:, angle_prior_idxs] * |
|
self.angle_prior_signs).pow(2) |
|
|
|
|
|
class L2Prior(nn.Module): |
|
def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): |
|
super(L2Prior, self).__init__() |
|
|
|
def forward(self, module_input, *args): |
|
return torch.sum(module_input.pow(2)) |
|
|
|
|
|
class MaxMixturePrior(nn.Module): |
|
|
|
def __init__(self, prior_folder='prior', |
|
num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, |
|
use_merged=True, |
|
**kwargs): |
|
super(MaxMixturePrior, self).__init__() |
|
|
|
if dtype == DEFAULT_DTYPE: |
|
np_dtype = np.float32 |
|
elif dtype == torch.float64: |
|
np_dtype = np.float64 |
|
else: |
|
print('Unknown float type {}, exiting!'.format(dtype)) |
|
sys.exit(-1) |
|
|
|
self.num_gaussians = num_gaussians |
|
self.epsilon = epsilon |
|
self.use_merged = use_merged |
|
gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) |
|
|
|
full_gmm_fn = os.path.join(prior_folder, gmm_fn) |
|
if not os.path.exists(full_gmm_fn): |
|
print('The path to the mixture prior "{}"'.format(full_gmm_fn) + |
|
' does not exist, exiting!') |
|
sys.exit(-1) |
|
|
|
with open(full_gmm_fn, 'rb') as f: |
|
gmm = pickle.load(f, encoding='latin1') |
|
|
|
if type(gmm) == dict: |
|
means = gmm['means'].astype(np_dtype) |
|
covs = gmm['covars'].astype(np_dtype) |
|
weights = gmm['weights'].astype(np_dtype) |
|
elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): |
|
means = gmm.means_.astype(np_dtype) |
|
covs = gmm.covars_.astype(np_dtype) |
|
weights = gmm.weights_.astype(np_dtype) |
|
else: |
|
print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) |
|
sys.exit(-1) |
|
|
|
self.register_buffer('means', torch.tensor(means, dtype=dtype)) |
|
|
|
self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) |
|
|
|
precisions = [np.linalg.inv(cov) for cov in covs] |
|
precisions = np.stack(precisions).astype(np_dtype) |
|
|
|
self.register_buffer('precisions', |
|
torch.tensor(precisions, dtype=dtype)) |
|
|
|
|
|
sqrdets = np.array([(np.sqrt(np.linalg.det(c))) |
|
for c in gmm['covars']]) |
|
const = (2 * np.pi)**(69 / 2.) |
|
|
|
nll_weights = np.asarray(gmm['weights'] / (const * |
|
(sqrdets / sqrdets.min()))) |
|
nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) |
|
self.register_buffer('nll_weights', nll_weights) |
|
|
|
weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) |
|
self.register_buffer('weights', weights) |
|
|
|
self.register_buffer('pi_term', |
|
torch.log(torch.tensor(2 * np.pi, dtype=dtype))) |
|
|
|
cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) |
|
for cov in covs] |
|
self.register_buffer('cov_dets', |
|
torch.tensor(cov_dets, dtype=dtype)) |
|
|
|
|
|
self.random_var_dim = self.means.shape[1] |
|
|
|
def get_mean(self): |
|
''' Returns the mean of the mixture ''' |
|
mean_pose = torch.matmul(self.weights, self.means) |
|
return mean_pose |
|
|
|
def merged_log_likelihood(self, pose, betas): |
|
diff_from_mean = pose.unsqueeze(dim=1) - self.means |
|
|
|
prec_diff_prod = torch.einsum('mij,bmj->bmi', |
|
[self.precisions, diff_from_mean]) |
|
diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) |
|
|
|
curr_loglikelihood = 0.5 * diff_prec_quadratic - \ |
|
torch.log(self.nll_weights) |
|
|
|
|
|
|
|
|
|
|
|
min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) |
|
return min_likelihood |
|
|
|
def log_likelihood(self, pose, betas, *args, **kwargs): |
|
''' Create graph operation for negative log-likelihood calculation |
|
''' |
|
likelihoods = [] |
|
|
|
for idx in range(self.num_gaussians): |
|
mean = self.means[idx] |
|
prec = self.precisions[idx] |
|
cov = self.covs[idx] |
|
diff_from_mean = pose - mean |
|
|
|
curr_loglikelihood = torch.einsum('bj,ji->bi', |
|
[diff_from_mean, prec]) |
|
curr_loglikelihood = torch.einsum('bi,bi->b', |
|
[curr_loglikelihood, |
|
diff_from_mean]) |
|
cov_term = torch.log(torch.det(cov) + self.epsilon) |
|
curr_loglikelihood += 0.5 * (cov_term + |
|
self.random_var_dim * |
|
self.pi_term) |
|
likelihoods.append(curr_loglikelihood) |
|
|
|
log_likelihoods = torch.stack(likelihoods, dim=1) |
|
min_idx = torch.argmin(log_likelihoods, dim=1) |
|
weight_component = self.nll_weights[:, min_idx] |
|
weight_component = -torch.log(weight_component) |
|
|
|
return weight_component + log_likelihoods[:, min_idx] |
|
|
|
def forward(self, pose, betas): |
|
if self.use_merged: |
|
return self.merged_log_likelihood(pose, betas) |
|
else: |
|
return self.log_likelihood(pose, betas) |
|
|