fffiloni's picture
Migrated from GitHub
2252f3d verified
"""
original from https://github.com/vchoutas/smplx
modified by Vassilis and Yao
"""
import torch
import torch.nn as nn
import numpy as np
import pickle
from .lbs import (
Struct,
to_tensor,
to_np,
lbs,
vertices2landmarks,
JointsFromVerticesSelector,
find_dynamic_lmk_idx_and_bcoords,
)
# SMPLX
J14_NAMES = [
"right_ankle",
"right_knee",
"right_hip",
"left_hip",
"left_knee",
"left_ankle",
"right_wrist",
"right_elbow",
"right_shoulder",
"left_shoulder",
"left_elbow",
"left_wrist",
"neck",
"head",
]
SMPLX_names = [
"pelvis",
"left_hip",
"right_hip",
"spine1",
"left_knee",
"right_knee",
"spine2",
"left_ankle",
"right_ankle",
"spine3",
"left_foot",
"right_foot",
"neck",
"left_collar",
"right_collar",
"head",
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
"jaw",
"left_eye_smplx",
"right_eye_smplx",
"left_index1",
"left_index2",
"left_index3",
"left_middle1",
"left_middle2",
"left_middle3",
"left_pinky1",
"left_pinky2",
"left_pinky3",
"left_ring1",
"left_ring2",
"left_ring3",
"left_thumb1",
"left_thumb2",
"left_thumb3",
"right_index1",
"right_index2",
"right_index3",
"right_middle1",
"right_middle2",
"right_middle3",
"right_pinky1",
"right_pinky2",
"right_pinky3",
"right_ring1",
"right_ring2",
"right_ring3",
"right_thumb1",
"right_thumb2",
"right_thumb3",
"right_eye_brow1",
"right_eye_brow2",
"right_eye_brow3",
"right_eye_brow4",
"right_eye_brow5",
"left_eye_brow5",
"left_eye_brow4",
"left_eye_brow3",
"left_eye_brow2",
"left_eye_brow1",
"nose1",
"nose2",
"nose3",
"nose4",
"right_nose_2",
"right_nose_1",
"nose_middle",
"left_nose_1",
"left_nose_2",
"right_eye1",
"right_eye2",
"right_eye3",
"right_eye4",
"right_eye5",
"right_eye6",
"left_eye4",
"left_eye3",
"left_eye2",
"left_eye1",
"left_eye6",
"left_eye5",
"right_mouth_1",
"right_mouth_2",
"right_mouth_3",
"mouth_top",
"left_mouth_3",
"left_mouth_2",
"left_mouth_1",
"left_mouth_5",
"left_mouth_4",
"mouth_bottom",
"right_mouth_4",
"right_mouth_5",
"right_lip_1",
"right_lip_2",
"lip_top",
"left_lip_2",
"left_lip_1",
"left_lip_3",
"lip_bottom",
"right_lip_3",
"right_contour_1",
"right_contour_2",
"right_contour_3",
"right_contour_4",
"right_contour_5",
"right_contour_6",
"right_contour_7",
"right_contour_8",
"contour_middle",
"left_contour_8",
"left_contour_7",
"left_contour_6",
"left_contour_5",
"left_contour_4",
"left_contour_3",
"left_contour_2",
"left_contour_1",
"head_top",
"left_big_toe",
"left_ear",
"left_eye",
"left_heel",
"left_index",
"left_middle",
"left_pinky",
"left_ring",
"left_small_toe",
"left_thumb",
"nose",
"right_big_toe",
"right_ear",
"right_eye",
"right_heel",
"right_index",
"right_middle",
"right_pinky",
"right_ring",
"right_small_toe",
"right_thumb",
]
extra_names = [
"head_top",
"left_big_toe",
"left_ear",
"left_eye",
"left_heel",
"left_index",
"left_middle",
"left_pinky",
"left_ring",
"left_small_toe",
"left_thumb",
"nose",
"right_big_toe",
"right_ear",
"right_eye",
"right_heel",
"right_index",
"right_middle",
"right_pinky",
"right_ring",
"right_small_toe",
"right_thumb",
]
SMPLX_names += extra_names
part_indices = {}
part_indices["body"] = np.array([
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
123,
124,
125,
126,
127,
132,
134,
135,
136,
137,
138,
143,
])
part_indices["torso"] = np.array([
0,
1,
2,
3,
6,
9,
12,
13,
14,
15,
16,
17,
18,
19,
22,
23,
24,
55,
56,
57,
58,
59,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
127,
128,
129,
130,
131,
132,
133,
134,
135,
136,
137,
138,
139,
140,
141,
142,
143,
144,
])
part_indices["head"] = np.array([
12,
15,
22,
23,
24,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
125,
126,
134,
136,
137,
])
part_indices["face"] = np.array([
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
])
part_indices["upper"] = np.array([
12,
13,
14,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
])
part_indices["hand"] = np.array([
20,
21,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
128,
129,
130,
131,
133,
139,
140,
141,
142,
144,
])
part_indices["left_hand"] = np.array([
20,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
128,
129,
130,
131,
133,
])
part_indices["right_hand"] = np.array([
21,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
139,
140,
141,
142,
144,
])
# kinematic tree
head_kin_chain = [15, 12, 9, 6, 3, 0]
# --smplx joints
# 00 - Global
# 01 - L_Thigh
# 02 - R_Thigh
# 03 - Spine
# 04 - L_Calf
# 05 - R_Calf
# 06 - Spine1
# 07 - L_Foot
# 08 - R_Foot
# 09 - Spine2
# 10 - L_Toes
# 11 - R_Toes
# 12 - Neck
# 13 - L_Shoulder
# 14 - R_Shoulder
# 15 - Head
# 16 - L_UpperArm
# 17 - R_UpperArm
# 18 - L_ForeArm
# 19 - R_ForeArm
# 20 - L_Hand
# 21 - R_Hand
# 22 - Jaw
# 23 - L_Eye
# 24 - R_Eye
class SMPLX(nn.Module):
"""
Given smplx parameters, this class generates a differentiable SMPLX function
which outputs a mesh and 3D joints
"""
def __init__(self, config):
super(SMPLX, self).__init__()
# print("creating the SMPLX Decoder")
ss = np.load(config.smplx_model_path, allow_pickle=True)
smplx_model = Struct(**ss)
self.dtype = torch.float32
self.register_buffer(
"faces_tensor",
to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long),
)
# The vertices of the template model
self.register_buffer(
"v_template",
to_tensor(to_np(smplx_model.v_template), dtype=self.dtype))
# The shape components and expression
# expression space is the same as FLAME
shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
shapedirs = torch.cat(
[
shapedirs[:, :, :config.n_shape],
shapedirs[:, :, 300:300 + config.n_exp],
],
2,
)
self.register_buffer("shapedirs", shapedirs)
# The pose components
num_pose_basis = smplx_model.posedirs.shape[-1]
posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
self.register_buffer("posedirs",
to_tensor(to_np(posedirs), dtype=self.dtype))
self.register_buffer(
"J_regressor",
to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype))
parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
parents[0] = -1
self.register_buffer("parents", parents)
self.register_buffer(
"lbs_weights",
to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
# for face keypoints
self.register_buffer(
"lmk_faces_idx",
torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long))
self.register_buffer(
"lmk_bary_coords",
torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
)
self.register_buffer(
"dynamic_lmk_faces_idx",
torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long),
)
self.register_buffer(
"dynamic_lmk_bary_coords",
torch.tensor(smplx_model.dynamic_lmk_bary_coords,
dtype=self.dtype),
)
# pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
self.register_buffer("head_kin_chain",
torch.tensor(head_kin_chain, dtype=torch.long))
# -- initialize parameters
# shape and expression
self.register_buffer(
"shape_params",
nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype),
requires_grad=False),
)
self.register_buffer(
"expression_params",
nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype),
requires_grad=False),
)
# pose: represented as rotation matrx [number of joints, 3, 3]
self.register_buffer(
"global_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
requires_grad=False,
),
)
self.register_buffer(
"head_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
requires_grad=False,
),
)
self.register_buffer(
"neck_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
requires_grad=False,
),
)
self.register_buffer(
"jaw_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
requires_grad=False,
),
)
self.register_buffer(
"eye_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1),
requires_grad=False,
),
)
self.register_buffer(
"body_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1),
requires_grad=False,
),
)
self.register_buffer(
"left_hand_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1),
requires_grad=False,
),
)
self.register_buffer(
"right_hand_pose",
nn.Parameter(
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1),
requires_grad=False,
),
)
if config.extra_joint_path:
self.extra_joint_selector = JointsFromVerticesSelector(
fname=config.extra_joint_path)
self.use_joint_regressor = True
self.keypoint_names = SMPLX_names
if self.use_joint_regressor:
with open(config.j14_regressor_path, "rb") as f:
j14_regressor = pickle.load(f, encoding="latin1")
source = []
target = []
for idx, name in enumerate(self.keypoint_names):
if name in J14_NAMES:
source.append(idx)
target.append(J14_NAMES.index(name))
source = np.asarray(source)
target = np.asarray(target)
self.register_buffer("source_idxs", torch.from_numpy(source))
self.register_buffer("target_idxs", torch.from_numpy(target))
joint_regressor = torch.from_numpy(j14_regressor).to(
dtype=torch.float32)
self.register_buffer("extra_joint_regressor", joint_regressor)
self.part_indices = part_indices
def forward(
self,
shape_params=None,
expression_params=None,
global_pose=None,
body_pose=None,
jaw_pose=None,
eye_pose=None,
left_hand_pose=None,
right_hand_pose=None,
):
"""
Args:
shape_params: [N, number of shape parameters]
expression_params: [N, number of expression parameters]
global_pose: pelvis pose, [N, 1, 3, 3]
body_pose: [N, 21, 3, 3]
jaw_pose: [N, 1, 3, 3]
eye_pose: [N, 2, 3, 3]
left_hand_pose: [N, 15, 3, 3]
right_hand_pose: [N, 15, 3, 3]
Returns:
vertices: [N, number of vertices, 3]
landmarks: [N, number of landmarks (68 face keypoints), 3]
joints: [N, number of smplx joints (145), 3]
"""
if shape_params is None:
batch_size = global_pose.shape[0]
shape_params = self.shape_params.expand(batch_size, -1)
else:
batch_size = shape_params.shape[0]
if expression_params is None:
expression_params = self.expression_params.expand(batch_size, -1)
if global_pose is None:
global_pose = self.global_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1)
if body_pose is None:
body_pose = self.body_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1)
if jaw_pose is None:
jaw_pose = self.jaw_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1)
if eye_pose is None:
eye_pose = self.eye_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1)
if left_hand_pose is None:
left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1)
if right_hand_pose is None:
right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1)
shape_components = torch.cat([shape_params, expression_params], dim=1)
full_pose = torch.cat(
[
global_pose,
body_pose,
jaw_pose,
eye_pose,
left_hand_pose,
right_hand_pose,
],
dim=1,
)
template_vertices = self.v_template.unsqueeze(0).expand(
batch_size, -1, -1)
# smplx
vertices, joints = lbs(
shape_components,
full_pose,
template_vertices,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
dtype=self.dtype,
pose2rot=False,
)
# face dynamic landmarks
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(
batch_size, -1)
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(
batch_size, -1, -1)
dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
vertices,
full_pose,
self.dynamic_lmk_faces_idx,
self.dynamic_lmk_bary_coords,
self.head_kin_chain,
)
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
landmarks = vertices2landmarks(vertices, self.faces_tensor,
lmk_faces_idx, lmk_bary_coords)
final_joint_set = [joints, landmarks]
if hasattr(self, "extra_joint_selector"):
# Add any extra joints that might be needed
extra_joints = self.extra_joint_selector(vertices,
self.faces_tensor)
final_joint_set.append(extra_joints)
# Create the final joint set
joints = torch.cat(final_joint_set, dim=1)
# if self.use_joint_regressor:
# reg_joints = torch.einsum("ji,bik->bjk",
# self.extra_joint_regressor, vertices)
# joints[:, self.source_idxs] = (
# joints[:, self.source_idxs].detach() * 0.0 +
# reg_joints[:, self.target_idxs] * 1.0)
return vertices, landmarks, joints
def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"):
"""change absolute pose to relative pose
Basic knowledge for SMPLX kinematic tree:
absolute pose = parent pose * relative pose
Here, pose must be represented as rotation matrix (batch_sizexnx3x3)
"""
if abs_joint == "head":
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head
kin_chain = [15, 12, 9, 6, 3, 0]
elif abs_joint == "neck":
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head
kin_chain = [12, 9, 6, 3, 0]
elif abs_joint == "right_wrist":
# Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder
# -> right elbow -> right wrist
kin_chain = [21, 19, 17, 14, 9, 6, 3, 0]
elif abs_joint == "left_wrist":
# Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder
# -> Left elbow -> Left wrist
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
else:
raise NotImplementedError(
f"pose_abs2rel does not support: {abs_joint}")
batch_size = global_pose.shape[0]
dtype = global_pose.dtype
device = global_pose.device
full_pose = torch.cat([global_pose, body_pose], dim=1)
rel_rot_mat = (torch.eye(3, device=device,
dtype=dtype).unsqueeze_(dim=0).repeat(
batch_size, 1, 1))
for idx in kin_chain[1:]:
rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)
# This contains the absolute pose of the parent
abs_parent_pose = rel_rot_mat.detach()
# Let's assume that in the input this specific joint is predicted as an absolute value
abs_joint_pose = body_pose[:, kin_chain[0] - 1]
# abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head
rel_joint_pose = torch.matmul(
abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2),
abs_joint_pose.reshape(-1, 3, 3),
)
# Replace the new relative pose
body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose
return body_pose
def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"):
"""change relative pose to absolute pose
Basic knowledge for SMPLX kinematic tree:
absolute pose = parent pose * relative pose
Here, pose must be represented as rotation matrix (batch_sizexnx3x3)
"""
full_pose = torch.cat([global_pose, body_pose], dim=1)
if abs_joint == "head":
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head
kin_chain = [15, 12, 9, 6, 3, 0]
elif abs_joint == "neck":
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head
kin_chain = [12, 9, 6, 3, 0]
elif abs_joint == "right_wrist":
# Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder
# -> right elbow -> right wrist
kin_chain = [21, 19, 17, 14, 9, 6, 3, 0]
elif abs_joint == "left_wrist":
# Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder
# -> Left elbow -> Left wrist
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
else:
raise NotImplementedError(
f"pose_rel2abs does not support: {abs_joint}")
rel_rot_mat = torch.eye(3,
device=full_pose.device,
dtype=full_pose.dtype).unsqueeze_(dim=0)
for idx in kin_chain:
rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
abs_pose = rel_rot_mat[:, None, :, :]
return abs_pose