Nadine Rueegg
initial commit for barc
7629b39
raw
history blame
No virus
13 kB
"""
PyTorch implementation of the SMAL/SMPL model
see:
1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py
2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py
main changes compared to SMALST and WLDO:
* new model
(/ps/scratch/nrueegg/new_projects/side_packages/SMALify/new_smal_pca/results/my_tposeref_results_3/)
dogs are part of the pca to create the model
al meshes are centered around their root joint
the animals are all scaled such that their body length (butt to breast) is 1
X_init = np.concatenate((vertices_dogs, vertices_smal), axis=0) # vertices_dogs
X = []
for ind in range(0, X_init.shape[0]):
X_tmp, _, _, _ = align_smal_template_to_symmetry_axis(X_init[ind, :, :], subtract_mean=True) # not sure if this is necessary
X.append(X_tmp)
X = np.asarray(X)
# define points which will be used for normalization
idxs_front = [6, 16, 8, 964] # [1172, 6, 16, 8, 964]
idxs_back = [174, 2148, 175, 2149] # not in the middle, but pairs
reg_j = np.asarray(dd['J_regressor'].todense())
# normalize the meshes such that X_frontback_dist is 1 and the root joint is in the center (0, 0, 0)
X_front = X[:, idxs_front, :].mean(axis=1)
X_back = X[:, idxs_back, :].mean(axis=1)
X_frontback_dist = np.sqrt(((X_front - X_back)**2).sum(axis=1))
X = X / X_frontback_dist[:, None, None]
X_j0 = np.sum(X[:, reg_j[0, :]>0, :] * reg_j[0, (reg_j[0, :]>0)][None, :, None], axis=1)
X = X - X_j0[:, None, :]
* add limb length changes the same way as in WLDO
* overall scale factor is added
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import chumpy as ch
import os.path
from torch import nn
from torch.autograd import Variable
import pickle as pkl
from .batch_lbs import batch_rodrigues, batch_global_rigid_transformation, batch_global_rigid_transformation_biggs, get_bone_length_scales, get_beta_scale_mask
from .smal_basics import align_smal_template_to_symmetry_axis, get_symmetry_indices
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from configs.SMAL_configs import KEY_VIDS, CANONICAL_MODEL_JOINTS, IDXS_BONES_NO_REDUNDANCY, SMAL_MODEL_PATH
from smal_pytorch.utils import load_vertex_colors
# There are chumpy variables so convert them to numpy.
def undo_chumpy(x):
return x if isinstance(x, np.ndarray) else x.r
# class SMAL(object):
class SMAL(nn.Module):
def __init__(self, pkl_path=SMAL_MODEL_PATH, n_betas=None, template_name='neutral', use_smal_betas=True, logscale_part_list=None):
super(SMAL, self).__init__()
if logscale_part_list is None:
self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
self.betas_scale_mask = get_beta_scale_mask(part_list=self.logscale_part_list)
self.num_betas_logscale = len(self.logscale_part_list)
self.use_smal_betas = use_smal_betas
# -- Load SMPL params --
try:
with open(pkl_path, 'r') as f:
dd = pkl.load(f)
except (UnicodeDecodeError, TypeError) as e:
with open(pkl_path, 'rb') as file:
u = pkl._Unpickler(file)
u.encoding = 'latin1'
dd = u.load()
self.f = dd['f']
self.register_buffer('faces', torch.from_numpy(self.f.astype(int)))
# get the correct template (mean shape)
if template_name=='neutral':
v_template = dd['v_template']
v = v_template
else:
raise NotImplementedError
# Mean template vertices
self.register_buffer('v_template', torch.Tensor(v))
# Size of mesh [Number of vertices, 3]
self.size = [self.v_template.shape[0], 3]
self.num_betas = dd['shapedirs'].shape[-1]
# symmetry indices
self.sym_ids_dict = get_symmetry_indices()
# Shape blend shape basis
shapedir = np.reshape(undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T
shapedir.flags['WRITEABLE'] = True # not sure why this is necessary
self.register_buffer('shapedirs', torch.Tensor(shapedir))
# Regressor for joint locations given shape
self.register_buffer('J_regressor', torch.Tensor(dd['J_regressor'].T.todense()))
# Pose blend shape basis
num_pose_basis = dd['posedirs'].shape[-1]
posedirs = np.reshape(undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T
self.register_buffer('posedirs', torch.Tensor(posedirs))
# indices of parents for each joints
self.parents = dd['kintree_table'][0].astype(np.int32)
# LBS weights
self.register_buffer('weights', torch.Tensor(undo_chumpy(dd['weights'])))
def _caclulate_bone_lengths_from_J(self, J, betas_logscale):
# NEW: calculate bone lengths:
all_bone_lengths_list = []
for i in range(1, self.parents.shape[0]):
bone_vec = J[:, i] - J[:, self.parents[i]]
bone_length = torch.sqrt(torch.sum(bone_vec ** 2, axis=1))
all_bone_lengths_list.append(bone_length)
all_bone_lengths = torch.stack(all_bone_lengths_list)
# some bones are pairs, it is enough to take one of the two bones
all_bone_length_scales = get_bone_length_scales(self.logscale_part_list, betas_logscale)
all_bone_lengths = all_bone_lengths.permute((1,0)) * all_bone_length_scales
return all_bone_lengths #.permute((1,0))
def caclulate_bone_lengths(self, beta, betas_logscale, shapedirs_sel=None, short=True):
nBetas = beta.shape[1]
# 1. Add shape blend shapes
# do we use the original shapedirs or a new set of selected shapedirs?
if shapedirs_sel is None:
shapedirs_sel = self.shapedirs[:nBetas,:]
else:
assert shapedirs_sel.shape[0] == nBetas
v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]])
# 2. Infer shape-dependent joint locations.
Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor)
Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor)
Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor)
J = torch.stack([Jx, Jy, Jz], dim=2)
# calculate bone lengths
all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale)
selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY]
if short:
return selected_bone_lengths
else:
return all_bone_lengths
def __call__(self, beta, betas_limbs, theta=None, pose=None, trans=None, del_v=None, get_skin=True, keyp_conf='red', get_all_info=False, shapedirs_sel=None):
device = beta.device
betas_logscale = betas_limbs
# NEW: allow that rotation is given as rotation matrices instead of axis angle rotation
# theta: BSxNJointsx3 or BSx(NJoints*3)
# pose: NxNJointsx3x3
if (theta is None) and (pose is None):
raise ValueError("Either pose (rotation matrices NxNJointsx3x3) or theta (axis angle BSxNJointsx3) must be given")
elif (theta is not None) and (pose is not None):
raise ValueError("Not both pose (rotation matrices NxNJointsx3x3) and theta (axis angle BSxNJointsx3) can be given")
if True: # self.use_smal_betas:
nBetas = beta.shape[1]
else:
nBetas = 0
# 1. Add shape blend shapes
# do we use the original shapedirs or a new set of selected shapedirs?
if shapedirs_sel is None:
shapedirs_sel = self.shapedirs[:nBetas,:]
else:
assert shapedirs_sel.shape[0] == nBetas
if nBetas > 0:
if del_v is None:
v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]])
else:
v_shaped = self.v_template + del_v + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]])
else:
if del_v is None:
v_shaped = self.v_template.unsqueeze(0)
else:
v_shaped = self.v_template + del_v
# 2. Infer shape-dependent joint locations.
Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor)
Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor)
Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor)
J = torch.stack([Jx, Jy, Jz], dim=2)
# 3. Add pose blend shapes
# N x 24 x 3 x 3
if pose is None:
Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3])), [-1, 35, 3, 3])
else:
Rs = pose
# Ignore global rotation.
pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).to(device=device), [-1, 306])
v_posed = torch.reshape(
torch.matmul(pose_feature, self.posedirs),
[-1, self.size[0], self.size[1]]) + v_shaped
#-------------------------
# new: add corrections of bone lengths to the template (before hypothetical pose blend shapes!)
# see biggs batch_lbs.py
betas_scale = torch.exp(betas_logscale @ self.betas_scale_mask.to(betas_logscale.device))
scaling_factors = betas_scale.reshape(-1, 35, 3)
scale_factors_3x3 = torch.diag_embed(scaling_factors, dim1=-2, dim2=-1)
# 4. Get the global joint location
# self.J_transformed, A = batch_global_rigid_transformation(Rs, J, self.parents)
self.J_transformed, A = batch_global_rigid_transformation_biggs(Rs, J, self.parents, scale_factors_3x3, betas_logscale=betas_logscale)
# 2-BONES. Calculate bone lengths
all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale)
# selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY]
#-------------------------
# 5. Do skinning:
num_batch = Rs.shape[0]
weights_t = self.weights.repeat([num_batch, 1])
W = torch.reshape(weights_t, [num_batch, -1, 35])
T = torch.reshape(
torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])),
[num_batch, -1, 4, 4])
v_posed_homo = torch.cat(
[v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).to(device=device)], 2)
v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1))
verts = v_homo[:, :, :3, 0]
if trans is None:
trans = torch.zeros((num_batch,3)).to(device=device)
verts = verts + trans[:,None,:]
# Get joints:
joint_x = torch.matmul(verts[:, :, 0], self.J_regressor)
joint_y = torch.matmul(verts[:, :, 1], self.J_regressor)
joint_z = torch.matmul(verts[:, :, 2], self.J_regressor)
joints = torch.stack([joint_x, joint_y, joint_z], dim=2)
# New... (see https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py)
joints = torch.cat([
joints,
verts[:, None, 1863], # end_of_nose
verts[:, None, 26], # chin
verts[:, None, 2124], # right ear tip
verts[:, None, 150], # left ear tip
verts[:, None, 3055], # left eye
verts[:, None, 1097], # right eye
], dim = 1)
if keyp_conf == 'blue' or keyp_conf == 'dict':
# Generate keypoints
nLandmarks = KEY_VIDS.shape[0] # 24
j3d = torch.zeros((num_batch, nLandmarks, 3)).to(device=device)
for j in range(nLandmarks):
j3d[:, j,:] = torch.mean(verts[:, KEY_VIDS[j],:], dim=1) # translation is already added to the vertices
joints_blue = j3d
joints_red = joints[:, :-6, :]
joints_green = joints[:, CANONICAL_MODEL_JOINTS, :]
if keyp_conf == 'red':
relevant_joints = joints_red
elif keyp_conf == 'green':
relevant_joints = joints_green
elif keyp_conf == 'blue':
relevant_joints = joints_blue
elif keyp_conf == 'dict':
relevant_joints = {'red': joints_red,
'green': joints_green,
'blue': joints_blue}
else:
raise NotImplementedError
if get_all_info:
return verts, relevant_joints, Rs, all_bone_lengths
else:
if get_skin:
return verts, relevant_joints, Rs # , v_shaped
else:
return relevant_joints