# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch import torch.nn as nn from ..builder import MESH_MODELS try: from smplx import SMPL as SMPL_ has_smpl = True except (ImportError, ModuleNotFoundError): has_smpl = False @MESH_MODELS.register_module() class SMPL(nn.Module): """SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned multi-person linear model''. This module is based on the smplx project (https://github.com/vchoutas/smplx). Args: smpl_path (str): The path to the folder where the model weights are stored. joints_regressor (str): The path to the file where the joints regressor weight are stored. """ def __init__(self, smpl_path, joints_regressor): super().__init__() assert has_smpl, 'Please install smplx to use SMPL.' self.smpl_neutral = SMPL_( model_path=smpl_path, create_global_orient=False, create_body_pose=False, create_transl=False, gender='neutral') self.smpl_male = SMPL_( model_path=smpl_path, create_betas=False, create_global_orient=False, create_body_pose=False, create_transl=False, gender='male') self.smpl_female = SMPL_( model_path=smpl_path, create_betas=False, create_global_orient=False, create_body_pose=False, create_transl=False, gender='female') joints_regressor = torch.tensor( np.load(joints_regressor), dtype=torch.float)[None, ...] self.register_buffer('joints_regressor', joints_regressor) self.num_verts = self.smpl_neutral.get_num_verts() self.num_joints = self.joints_regressor.shape[1] def smpl_forward(self, model, **kwargs): """Apply a specific SMPL model with given model parameters. Note: B: batch size V: number of vertices K: number of joints Returns: outputs (dict): Dict with mesh vertices and joints. - vertices: Tensor([B, V, 3]), mesh vertices - joints: Tensor([B, K, 3]), 3d joints regressed from mesh vertices. """ betas = kwargs['betas'] batch_size = betas.shape[0] device = betas.device output = {} if batch_size == 0: output['vertices'] = betas.new_zeros([0, self.num_verts, 3]) output['joints'] = betas.new_zeros([0, self.num_joints, 3]) else: smpl_out = model(**kwargs) output['vertices'] = smpl_out.vertices output['joints'] = torch.matmul( self.joints_regressor.to(device), output['vertices']) return output def get_faces(self): """Return mesh faces. Note: F: number of faces Returns: faces: np.ndarray([F, 3]), mesh faces """ return self.smpl_neutral.faces def forward(self, betas, body_pose, global_orient, transl=None, gender=None): """Forward function. Note: B: batch size J: number of controllable joints of model, for smpl model J=23 K: number of joints Args: betas: Tensor([B, 10]), human body shape parameters of SMPL model. body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose parameters of SMPL model. It should be axis-angle vector ([B, J*3]) or rotation matrix ([B, J, 3, 3)]. global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation of human body. It should be axis-angle vector ([B, 3]) or rotation matrix ([B, 1, 3, 3)]. transl: Tensor([B, 3]), global translation of human body. gender: Tensor([B]), gender parameters of human body. -1 for neutral, 0 for male , 1 for female. Returns: outputs (dict): Dict with mesh vertices and joints. - vertices: Tensor([B, V, 3]), mesh vertices - joints: Tensor([B, K, 3]), 3d joints regressed from mesh vertices. """ batch_size = betas.shape[0] pose2rot = True if body_pose.dim() == 2 else False if batch_size > 0 and gender is not None: output = { 'vertices': betas.new_zeros([batch_size, self.num_verts, 3]), 'joints': betas.new_zeros([batch_size, self.num_joints, 3]) } mask = gender < 0 _out = self.smpl_forward( self.smpl_neutral, betas=betas[mask], body_pose=body_pose[mask], global_orient=global_orient[mask], transl=transl[mask] if transl is not None else None, pose2rot=pose2rot) output['vertices'][mask] = _out['vertices'] output['joints'][mask] = _out['joints'] mask = gender == 0 _out = self.smpl_forward( self.smpl_male, betas=betas[mask], body_pose=body_pose[mask], global_orient=global_orient[mask], transl=transl[mask] if transl is not None else None, pose2rot=pose2rot) output['vertices'][mask] = _out['vertices'] output['joints'][mask] = _out['joints'] mask = gender == 1 _out = self.smpl_forward( self.smpl_male, betas=betas[mask], body_pose=body_pose[mask], global_orient=global_orient[mask], transl=transl[mask] if transl is not None else None, pose2rot=pose2rot) output['vertices'][mask] = _out['vertices'] output['joints'][mask] = _out['joints'] else: return self.smpl_forward( self.smpl_neutral, betas=betas, body_pose=body_pose, global_orient=global_orient, transl=transl, pose2rot=pose2rot) return output