Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from smplx import SMPL as _SMPL | |
try: | |
from smplx.body_models import ModelOutput as SMPLOutput | |
except Exception as e: | |
from smplx.body_models import SMPLOutput | |
from smplx.lbs import vertices2joints | |
JOINT_NAMES = [ | |
# 25 OpenPose joints (in the order provided by OpenPose) | |
'OP Nose', | |
'OP Neck', | |
'OP RShoulder', | |
'OP RElbow', | |
'OP RWrist', | |
'OP LShoulder', | |
'OP LElbow', | |
'OP LWrist', | |
'OP MidHip', | |
'OP RHip', | |
'OP RKnee', | |
'OP RAnkle', | |
'OP LHip', | |
'OP LKnee', | |
'OP LAnkle', | |
'OP REye', | |
'OP LEye', | |
'OP REar', | |
'OP LEar', | |
'OP LBigToe', | |
'OP LSmallToe', | |
'OP LHeel', | |
'OP RBigToe', | |
'OP RSmallToe', | |
'OP RHeel', | |
# 24 Ground Truth joints (superset of joints from different datasets) | |
'Right Ankle', | |
'Right Knee', | |
'Right Hip', | |
'Left Hip', | |
'Left Knee', | |
'Left Ankle', | |
'Right Wrist', | |
'Right Elbow', | |
'Right Shoulder', | |
'Left Shoulder', | |
'Left Elbow', | |
'Left Wrist', | |
'Neck (LSP)', | |
'Top of Head (LSP)', | |
'Pelvis (MPII)', | |
'Thorax (MPII)', | |
'Spine (H36M)', | |
'Jaw (H36M)', | |
'Head (H36M)', | |
'Nose', | |
'Left Eye', | |
'Right Eye', | |
'Left Ear', | |
'Right Ear' | |
] | |
# Dict containing the joints in numerical order | |
JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))} | |
# Map joints to SMPL joints | |
JOINT_MAP = { | |
'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, | |
'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, | |
'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, | |
'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, | |
'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, | |
'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, | |
'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, | |
'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, | |
'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, | |
'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, | |
'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, | |
'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, | |
'Neck (LSP)': 47, 'Top of Head (LSP)': 48, | |
'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, | |
'Spine (H36M)': 51, 'Jaw (H36M)': 52, | |
'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, | |
'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 | |
} | |
class SMPL(_SMPL): | |
""" Extension of the official SMPL implementation to support more joints """ | |
def __init__(self, *args, **kwargs): | |
super(SMPL, self).__init__(*args, **kwargs) | |
joints = [JOINT_MAP[i] for i in JOINT_NAMES] | |
J_regressor_extra = np.load('data/J_regressor_extra.npy') | |
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) | |
self.joint_map = torch.tensor(joints, dtype=torch.long) | |
def forward(self, *args, **kwargs): | |
kwargs['get_skin'] = True | |
smpl_output = super(SMPL, self).forward(*args, **kwargs) | |
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) | |
joints = torch.cat([smpl_output.joints, extra_joints], dim=1) | |
joints = joints[:, self.joint_map, :] | |
output = SMPLOutput(vertices=smpl_output.vertices, | |
global_orient=smpl_output.global_orient, | |
body_pose=smpl_output.body_pose, | |
joints=joints, | |
betas=smpl_output.betas, | |
full_pose=smpl_output.full_pose) | |
return output | |