''' Adjusted version of other PyTorch implementation of the SMAL/SMPL model see: 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py ''' from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import numpy as np def batch_skew(vec, batch_size=None): """ vec is N x 3, batch_size is int returns N x 3 x 3. Skew_sym version of each matrix. """ device = vec.device if batch_size is None: batch_size = vec.shape.as_list()[0] col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7]) indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1]) updates = torch.reshape( torch.stack( [ -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1], vec[:, 0] ], dim=1), [-1]) out_shape = [batch_size * 9] res = torch.Tensor(np.zeros(out_shape[0])).to(device=device) res[np.array(indices.flatten())] = updates res = torch.reshape(res, [batch_size, 3, 3]) return res def batch_rodrigues(theta): """ Theta is Nx3 """ device = theta.device batch_size = theta.shape[0] angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1) r = (torch.div(theta, angle)).unsqueeze(-1) angle = angle.unsqueeze(-1) cos = torch.cos(angle) sin = torch.sin(angle) outer = torch.matmul(r, r.transpose(1,2)) eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).to(device=device) H = batch_skew(r, batch_size=batch_size) R = cos * eyes + (1 - cos) * outer + sin * H return R def batch_lrotmin(theta): """ Output of this is used to compute joint-to-pose blend shape mapping. Equation 9 in SMPL paper. Args: pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints. This includes the global rotation so K=24 Returns diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted., """ # Ignore global rotation theta = theta[:,3:] Rs = batch_rodrigues(torch.reshape(theta, [-1,3])) lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207]) return lrotmin def batch_global_rigid_transformation(Rs, Js, parent, rotate_base=False): """ Computes absolute joint locations given pose. rotate_base: if True, rotates the global rotation by 90 deg in x axis. if False, this is the original SMPL coordinate. Args: Rs: N x 24 x 3 x 3 rotation vector of K joints Js: N x 24 x 3, joint locations before posing parent: 24 holding the parent id for each index Returns new_J : `Tensor`: N x 24 x 3 location of absolute joints A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. """ device = Rs.device if rotate_base: print('Flipping the SMPL coordinate frame!!!!') rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) else: root_rotation = Rs[:, 0, :, :] # Now Js is N x 24 x 3 x 1 Js = Js.unsqueeze(-1) N = Rs.shape[0] def make_A(R, t): # Rs is N x 3 x 3, ts is N x 3 x 1 R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(device=device)], 1) return torch.cat([R_homo, t_homo], 2) A0 = make_A(root_rotation, Js[:, 0]) results = [A0] for i in range(1, parent.shape[0]): j_here = Js[:, i] - Js[:, parent[i]] A_here = make_A(Rs[:, i], j_here) res_here = torch.matmul( results[parent[i]], A_here) results.append(res_here) # 10 x 24 x 4 x 4 results = torch.stack(results, dim=1) new_J = results[:, :, :3, 3] # --- Compute relative A: Skinning is based on # how much the bone moved (not the final location of the bone) # but (final_bone - init_bone) # --- Js_w0 = torch.cat([Js, torch.zeros([N, 35, 1, 1]).to(device=device)], 2) init_bone = torch.matmul(results, Js_w0) # Append empty 4 x 3: init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) A = results - init_bone return new_J, A ######################################################################################### def get_bone_length_scales(part_list, betas_logscale): leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) tail_joints = list(range(25, 32)) ear_joints = [33, 34] neck_joints = [15, 6] # ? core_joints = [4, 5] # ? mouth_joints = [16, 32] log_scales = torch.zeros(betas_logscale.shape[0], 35).to(betas_logscale.device) for ind, part in enumerate(part_list): if part == 'legs_l': log_scales[:, leg_joints] = betas_logscale[:, ind][:, None] elif part == 'tail_l': log_scales[:, tail_joints] = betas_logscale[:, ind][:, None] elif part == 'ears_l': log_scales[:, ear_joints] = betas_logscale[:, ind][:, None] elif part == 'neck_l': log_scales[:, neck_joints] = betas_logscale[:, ind][:, None] elif part == 'core_l': log_scales[:, core_joints] = betas_logscale[:, ind][:, None] elif part == 'head_l': log_scales[:, mouth_joints] = betas_logscale[:, ind][:, None] else: pass all_scales = torch.exp(log_scales) return all_scales[:, 1:] # don't count root def get_beta_scale_mask(part_list): # which joints belong to which bodypart leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) tail_joints = list(range(25, 32)) ear_joints = [33, 34] neck_joints = [15, 6] # ? core_joints = [4, 5] # ? mouth_joints = [16, 32] n_b_log = len(part_list) #betas_logscale.shape[1] # 8 # 6 beta_scale_mask = torch.zeros(35, 3, n_b_log) # .to(betas_logscale.device) for ind, part in enumerate(part_list): if part == 'legs_l': beta_scale_mask[leg_joints, [2], [ind]] = 1.0 # Leg lengthening elif part == 'legs_f': beta_scale_mask[leg_joints, [0], [ind]] = 1.0 # Leg fatness beta_scale_mask[leg_joints, [1], [ind]] = 1.0 # Leg fatness elif part == 'tail_l': beta_scale_mask[tail_joints, [0], [ind]] = 1.0 # Tail lengthening elif part == 'tail_f': beta_scale_mask[tail_joints, [1], [ind]] = 1.0 # Tail fatness beta_scale_mask[tail_joints, [2], [ind]] = 1.0 # Tail fatness elif part == 'ears_y': beta_scale_mask[ear_joints, [1], [ind]] = 1.0 # Ear y elif part == 'ears_l': beta_scale_mask[ear_joints, [2], [ind]] = 1.0 # Ear z elif part == 'neck_l': beta_scale_mask[neck_joints, [0], [ind]] = 1.0 # Neck lengthening elif part == 'neck_f': beta_scale_mask[neck_joints, [1], [ind]] = 1.0 # Neck fatness beta_scale_mask[neck_joints, [2], [ind]] = 1.0 # Neck fatness elif part == 'core_l': beta_scale_mask[core_joints, [0], [ind]] = 1.0 # Core lengthening # beta_scale_mask[core_joints, [1], [ind]] = 1.0 # Core fatness (height) elif part == 'core_fs': beta_scale_mask[core_joints, [2], [ind]] = 1.0 # Core fatness (side) elif part == 'head_l': beta_scale_mask[mouth_joints, [0], [ind]] = 1.0 # Head lengthening elif part == 'head_f': beta_scale_mask[mouth_joints, [1], [ind]] = 1.0 # Head fatness 0 beta_scale_mask[mouth_joints, [2], [ind]] = 1.0 # Head fatness 1 else: print(part + ' not available') raise ValueError beta_scale_mask = torch.transpose( beta_scale_mask.reshape(35*3, n_b_log), 0, 1) return beta_scale_mask def batch_global_rigid_transformation_biggs(Rs, Js, parent, scale_factors_3x3, rotate_base = False, betas_logscale=None, opts=None): """ Computes absolute joint locations given pose. rotate_base: if True, rotates the global rotation by 90 deg in x axis. if False, this is the original SMPL coordinate. Args: Rs: N x 24 x 3 x 3 rotation vector of K joints Js: N x 24 x 3, joint locations before posing parent: 24 holding the parent id for each index Returns new_J : `Tensor`: N x 24 x 3 location of absolute joints A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. """ if rotate_base: print('Flipping the SMPL coordinate frame!!!!') rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) else: root_rotation = Rs[:, 0, :, :] # Now Js is N x 24 x 3 x 1 Js = Js.unsqueeze(-1) N = Rs.shape[0] Js_orig = Js.clone() def make_A(R, t): # Rs is N x 3 x 3, ts is N x 3 x 1 R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(Rs.device)], 1) return torch.cat([R_homo, t_homo], 2) A0 = make_A(root_rotation, Js[:, 0]) results = [A0] for i in range(1, parent.shape[0]): j_here = Js[:, i] - Js[:, parent[i]] try: s_par_inv = torch.inverse(scale_factors_3x3[:, parent[i]]) except: # import pdb; pdb.set_trace() s_par_inv = torch.max(scale_factors_3x3[:, parent[i]], 0.01*torch.eye((3))[None, :, :].to(scale_factors_3x3.device)) rot = Rs[:, i] s = scale_factors_3x3[:, i] rot_new = s_par_inv @ rot @ s A_here = make_A(rot_new, j_here) res_here = torch.matmul( results[parent[i]], A_here) results.append(res_here) # 10 x 24 x 4 x 4 results = torch.stack(results, dim=1) # scale updates new_J = results[:, :, :3, 3] # --- Compute relative A: Skinning is based on # how much the bone moved (not the final location of the bone) # but (final_bone - init_bone) # --- Js_w0 = torch.cat([Js_orig, torch.zeros([N, 35, 1, 1]).to(Rs.device)], 2) init_bone = torch.matmul(results, Js_w0) # Append empty 4 x 3: init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) A = results - init_bone return new_J, A