Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: ps-license@tuebingen.mpg.de | |
| import contextlib | |
| from typing import Optional | |
| import torch | |
| from torch import Tensor | |
| from mGPT.utils.joints import smplh_to_mmm_scaling_factor, smplh2mmm_indexes, get_root_idx | |
| from mGPT.utils.easyconvert import rep_to_rep | |
| from .base import Rots2Joints | |
| def slice_or_none(data, cslice): | |
| if data is None: | |
| return data | |
| else: | |
| return data[cslice] | |
| class SMPLX(Rots2Joints): | |
| def __init__(self, | |
| path: str, | |
| jointstype: str = "mmm", | |
| input_pose_rep: str = "matrix", | |
| batch_size: int = 512, | |
| gender="neutral", | |
| **kwargs) -> None: | |
| super().__init__(path=None, normalization=False) | |
| self.batch_size = batch_size | |
| self.input_pose_rep = input_pose_rep | |
| self.jointstype = jointstype | |
| self.training = False | |
| from smplx.body_models import SMPLXLayer | |
| import os | |
| # rel_p = path.split('/') | |
| # rel_p = rel_p[rel_p.index('data'):] | |
| # rel_p = '/'.join(rel_p) | |
| # Remove annoying print | |
| with contextlib.redirect_stdout(None): | |
| self.smplx = SMPLXLayer(path, | |
| ext="npz", | |
| gender=gender, | |
| batch_size=batch_size).eval() | |
| self.faces = self.smplx.faces | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| def train(self, *args, **kwargs): | |
| return self | |
| def forward(self, | |
| smpl_data: dict, | |
| jointstype: Optional[str] = None, | |
| input_pose_rep: Optional[str] = None, | |
| batch_size: Optional[int] = None) -> Tensor: | |
| # Take values from init if not specified there | |
| jointstype = self.jointstype if jointstype is None else jointstype | |
| batch_size = self.batch_size if batch_size is None else batch_size | |
| input_pose_rep = self.input_pose_rep if input_pose_rep is None else input_pose_rep | |
| poses = smpl_data.rots | |
| trans = smpl_data.trans | |
| from functools import reduce | |
| import operator | |
| save_shape_bs_len = poses.shape[:-3] | |
| nposes = reduce(operator.mul, save_shape_bs_len, 1) | |
| matrix_poses = rep_to_rep(self.input_pose_rep, input_pose_rep, poses) | |
| # Reshaping | |
| matrix_poses = matrix_poses.reshape((nposes, *matrix_poses.shape[-3:])) | |
| global_orient = matrix_poses[:, 0] | |
| if trans is None: | |
| trans = torch.zeros((*save_shape_bs_len, 3), | |
| dtype=poses.dtype, | |
| device=poses.device) | |
| trans_all = trans.reshape((nposes, *trans.shape[-1:])) | |
| body_pose = matrix_poses[:, 1:22] | |
| if poses.shape[-3] == 55: | |
| nohands = False | |
| nofaces = False | |
| elif poses.shape[-3] == 52: | |
| nohands = False | |
| nofaces = True | |
| elif poses.shape[-3] == 22: | |
| nohands = True | |
| nofaces = True | |
| else: | |
| raise NotImplementedError("Could not parse the poses.") | |
| if nohands: | |
| left_hand_pose = None | |
| right_hand_pose = None | |
| else: | |
| left_hand_pose = matrix_poses[:, 25:40] | |
| right_hand_pose = matrix_poses[:, 40:55] | |
| if nofaces: | |
| jaw_pose = None | |
| leye_pose = None | |
| reye_pose = None | |
| else: | |
| jaw_pose = matrix_poses[:, 22:23] | |
| leye_pose = matrix_poses[:, 23:24] | |
| reye_pose = matrix_poses[:, 24:25] | |
| n = len(body_pose) | |
| outputs = [] | |
| for chunk in range(int((n - 1) / batch_size) + 1): | |
| chunk_slice = slice(chunk * batch_size, (chunk + 1) * batch_size) | |
| smpl_output = self.smplx( | |
| global_orient=slice_or_none(global_orient, chunk_slice), | |
| body_pose=slice_or_none(body_pose, chunk_slice), | |
| left_hand_pose=slice_or_none(left_hand_pose, chunk_slice), | |
| right_hand_pose=slice_or_none(right_hand_pose, chunk_slice), | |
| jaw_pose=slice_or_none(jaw_pose, chunk_slice), | |
| leye_pose=slice_or_none(leye_pose, chunk_slice), | |
| reye_pose=slice_or_none(reye_pose, chunk_slice), | |
| transl=slice_or_none(trans_all, chunk_slice)) | |
| if jointstype == "vertices": | |
| output_chunk = smpl_output.vertices | |
| else: | |
| joints = smpl_output.joints | |
| output_chunk = joints | |
| outputs.append(output_chunk) | |
| outputs = torch.cat(outputs) | |
| outputs = outputs.reshape((*save_shape_bs_len, *outputs.shape[1:])) | |
| # Change topology if needed | |
| outputs = smplx_to(jointstype, outputs, trans) | |
| return outputs | |
| def inverse(self, joints: Tensor) -> Tensor: | |
| raise NotImplementedError("Cannot inverse SMPLX layer.") | |
| def smplx_to(jointstype, data, trans): | |
| if "mmm" in jointstype: | |
| indexes = smplh2mmm_indexes | |
| data = data[..., indexes, :] | |
| # make it compatible with mmm | |
| if jointstype == "mmm": | |
| data *= smplh_to_mmm_scaling_factor | |
| if jointstype == "smplmmm": | |
| pass | |
| elif jointstype in ["mmm", "mmmns"]: | |
| # swap axis | |
| data = data[..., [1, 2, 0]] | |
| # revert left and right | |
| data[..., 2] = -data[..., 2] | |
| elif jointstype == "smplnh": | |
| from mGPT.utils.joints import smplh2smplnh_indexes | |
| indexes = smplh2smplnh_indexes | |
| data = data[..., indexes, :] | |
| elif jointstype == "smplh": | |
| pass | |
| elif jointstype == "vertices": | |
| pass | |
| else: | |
| raise NotImplementedError(f"SMPLX to {jointstype} is not implemented.") | |
| if jointstype != "vertices": | |
| # shift the output in each batch | |
| # such that it is centered on the pelvis/root on the first frame | |
| root_joint_idx = get_root_idx(jointstype) | |
| shift = trans[..., 0, :] - data[..., 0, root_joint_idx, :] | |
| data += shift[..., None, None, :] | |
| return data | |