File size: 3,674 Bytes
6325697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
import hydra
import numpy as np
from ..smpl.body_models import SMPL
class SMPLServer(torch.nn.Module):
def __init__(self, gender='neutral', betas=None, v_template=None):
super().__init__()
self.smpl = SMPL(model_path=hydra.utils.to_absolute_path('lib/smpl/smpl_model'),
gender=gender,
batch_size=1,
use_hands=False,
use_feet_keypoints=False,
dtype=torch.float32).cuda()
self.bone_parents = self.smpl.bone_parents.astype(int)
self.bone_parents[0] = -1
self.bone_ids = []
self.faces = self.smpl.faces
for i in range(24): self.bone_ids.append([self.bone_parents[i], i])
if v_template is not None:
self.v_template = torch.tensor(v_template).float().cuda()
else:
self.v_template = None
if betas is not None:
self.betas = torch.tensor(betas).float().cuda()
else:
self.betas = None
# define the canonical pose
param_canonical = torch.zeros((1, 86),dtype=torch.float32).cuda()
param_canonical[0, 0] = 1
param_canonical[0, 9] = np.pi / 6
param_canonical[0, 12] = -np.pi / 6
if self.betas is not None and self.v_template is None:
param_canonical[0,-10:] = self.betas
self.param_canonical = param_canonical
output = self.forward(*torch.split(self.param_canonical, [1, 3, 72, 10], dim=1), absolute=True)
self.verts_c = output['smpl_verts']
self.joints_c = output['smpl_jnts']
self.tfs_c_inv = output['smpl_tfs'].squeeze(0).inverse()
def forward(self, scale, transl, thetas, betas, absolute=False):
"""return SMPL output from params
Args:
scale : scale factor. shape: [B, 1]
transl: translation. shape: [B, 3]
thetas: pose. shape: [B, 72]
betas: shape. shape: [B, 10]
absolute (bool): if true return smpl_tfs wrt thetas=0. else wrt thetas=thetas_canonical.
Returns:
smpl_verts: vertices. shape: [B, 6893. 3]
smpl_tfs: bone transformations. shape: [B, 24, 4, 4]
smpl_jnts: joint positions. shape: [B, 25, 3]
"""
output = {}
# ignore betas if v_template is provided
if self.v_template is not None:
betas = torch.zeros_like(betas)
smpl_output = self.smpl.forward(betas=betas,
transl=torch.zeros_like(transl),
body_pose=thetas[:, 3:],
global_orient=thetas[:, :3],
return_verts=True,
return_full_pose=True,
v_template=self.v_template)
verts = smpl_output.vertices.clone()
output['smpl_verts'] = verts * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1)
joints = smpl_output.joints.clone()
output['smpl_jnts'] = joints * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1)
tf_mats = smpl_output.T.clone()
tf_mats[:, :, :3, :] = tf_mats[:, :, :3, :] * scale.unsqueeze(1).unsqueeze(1)
tf_mats[:, :, :3, 3] = tf_mats[:, :, :3, 3] + transl.unsqueeze(1) * scale.unsqueeze(1)
if not absolute:
tf_mats = torch.einsum('bnij,njk->bnik', tf_mats, self.tfs_c_inv)
output['smpl_tfs'] = tf_mats
output['smpl_weights'] = smpl_output.weights
return output |