bill-jiang's picture
Init
4409449
raw
history blame
6.81 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import contextlib
from typing import Optional
import torch
from torch import Tensor
from mGPT.utils.joints import smplh_to_mmm_scaling_factor, smplh2mmm_indexes, get_root_idx
from mGPT.utils.easyconvert import rep_to_rep
from .base import Rots2Joints
def slice_or_none(data, cslice):
if data is None:
return data
else:
return data[cslice]
class SMPLX(Rots2Joints):
def __init__(self,
path: str,
jointstype: str = "mmm",
input_pose_rep: str = "matrix",
batch_size: int = 512,
gender="neutral",
**kwargs) -> None:
super().__init__(path=None, normalization=False)
self.batch_size = batch_size
self.input_pose_rep = input_pose_rep
self.jointstype = jointstype
self.training = False
from smplx.body_models import SMPLXLayer
import os
# rel_p = path.split('/')
# rel_p = rel_p[rel_p.index('data'):]
# rel_p = '/'.join(rel_p)
# Remove annoying print
with contextlib.redirect_stdout(None):
self.smplx = SMPLXLayer(path,
ext="npz",
gender=gender,
batch_size=batch_size).eval()
self.faces = self.smplx.faces
for p in self.parameters():
p.requires_grad = False
def train(self, *args, **kwargs):
return self
def forward(self,
smpl_data: dict,
jointstype: Optional[str] = None,
input_pose_rep: Optional[str] = None,
batch_size: Optional[int] = None) -> Tensor:
# Take values from init if not specified there
jointstype = self.jointstype if jointstype is None else jointstype
batch_size = self.batch_size if batch_size is None else batch_size
input_pose_rep = self.input_pose_rep if input_pose_rep is None else input_pose_rep
poses = smpl_data.rots
trans = smpl_data.trans
from functools import reduce
import operator
save_shape_bs_len = poses.shape[:-3]
nposes = reduce(operator.mul, save_shape_bs_len, 1)
matrix_poses = rep_to_rep(self.input_pose_rep, input_pose_rep, poses)
# Reshaping
matrix_poses = matrix_poses.reshape((nposes, *matrix_poses.shape[-3:]))
global_orient = matrix_poses[:, 0]
if trans is None:
trans = torch.zeros((*save_shape_bs_len, 3),
dtype=poses.dtype,
device=poses.device)
trans_all = trans.reshape((nposes, *trans.shape[-1:]))
body_pose = matrix_poses[:, 1:22]
if poses.shape[-3] == 55:
nohands = False
nofaces = False
elif poses.shape[-3] == 52:
nohands = False
nofaces = True
elif poses.shape[-3] == 22:
nohands = True
nofaces = True
else:
raise NotImplementedError("Could not parse the poses.")
if nohands:
left_hand_pose = None
right_hand_pose = None
else:
left_hand_pose = matrix_poses[:, 25:40]
right_hand_pose = matrix_poses[:, 40:55]
if nofaces:
jaw_pose = None
leye_pose = None
reye_pose = None
else:
jaw_pose = matrix_poses[:, 22:23]
leye_pose = matrix_poses[:, 23:24]
reye_pose = matrix_poses[:, 24:25]
n = len(body_pose)
outputs = []
for chunk in range(int((n - 1) / batch_size) + 1):
chunk_slice = slice(chunk * batch_size, (chunk + 1) * batch_size)
smpl_output = self.smplx(
global_orient=slice_or_none(global_orient, chunk_slice),
body_pose=slice_or_none(body_pose, chunk_slice),
left_hand_pose=slice_or_none(left_hand_pose, chunk_slice),
right_hand_pose=slice_or_none(right_hand_pose, chunk_slice),
jaw_pose=slice_or_none(jaw_pose, chunk_slice),
leye_pose=slice_or_none(leye_pose, chunk_slice),
reye_pose=slice_or_none(reye_pose, chunk_slice),
transl=slice_or_none(trans_all, chunk_slice))
if jointstype == "vertices":
output_chunk = smpl_output.vertices
else:
joints = smpl_output.joints
output_chunk = joints
outputs.append(output_chunk)
outputs = torch.cat(outputs)
outputs = outputs.reshape((*save_shape_bs_len, *outputs.shape[1:]))
# Change topology if needed
outputs = smplx_to(jointstype, outputs, trans)
return outputs
def inverse(self, joints: Tensor) -> Tensor:
raise NotImplementedError("Cannot inverse SMPLX layer.")
def smplx_to(jointstype, data, trans):
if "mmm" in jointstype:
indexes = smplh2mmm_indexes
data = data[..., indexes, :]
# make it compatible with mmm
if jointstype == "mmm":
data *= smplh_to_mmm_scaling_factor
if jointstype == "smplmmm":
pass
elif jointstype in ["mmm", "mmmns"]:
# swap axis
data = data[..., [1, 2, 0]]
# revert left and right
data[..., 2] = -data[..., 2]
elif jointstype == "smplnh":
from mGPT.utils.joints import smplh2smplnh_indexes
indexes = smplh2smplnh_indexes
data = data[..., indexes, :]
elif jointstype == "smplh":
pass
elif jointstype == "vertices":
pass
else:
raise NotImplementedError(f"SMPLX to {jointstype} is not implemented.")
if jointstype != "vertices":
# shift the output in each batch
# such that it is centered on the pelvis/root on the first frame
root_joint_idx = get_root_idx(jointstype)
shift = trans[..., 0, :] - data[..., 0, root_joint_idx, :]
data += shift[..., None, None, :]
return data