Spaces:
Runtime error
Runtime error
Upload smpl.py
Browse files- src/spin/smpl.py +436 -0
src/spin/smpl.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from smplx import SMPL as _SMPL
|
4 |
+
from smplx import SMPLX as _SMPLX
|
5 |
+
from smplx.body_models import SMPLOutput, SMPLXOutput
|
6 |
+
from smplx.lbs import vertices2joints
|
7 |
+
|
8 |
+
from .constants import JOINT_MAP, JOINT_NAMES
|
9 |
+
|
10 |
+
# Hand joints
|
11 |
+
SMPLX_HAND_TO_PANOPTIC = [
|
12 |
+
0,
|
13 |
+
13,
|
14 |
+
14,
|
15 |
+
15,
|
16 |
+
16,
|
17 |
+
1,
|
18 |
+
2,
|
19 |
+
3,
|
20 |
+
17,
|
21 |
+
4,
|
22 |
+
5,
|
23 |
+
6,
|
24 |
+
18,
|
25 |
+
10,
|
26 |
+
11,
|
27 |
+
12,
|
28 |
+
19,
|
29 |
+
7,
|
30 |
+
8,
|
31 |
+
9,
|
32 |
+
20,
|
33 |
+
] # Wrist Thumb to Pinky
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class SMPL(_SMPL):
|
38 |
+
"""Extension of the official SMPL implementation to support more joints"""
|
39 |
+
|
40 |
+
JOINTS = (
|
41 |
+
"Hips",
|
42 |
+
"Left Upper Leg",
|
43 |
+
"Right Upper Leg",
|
44 |
+
"Spine",
|
45 |
+
"Left Leg",
|
46 |
+
"Right Leg",
|
47 |
+
"Spine1",
|
48 |
+
"Left Foot",
|
49 |
+
"Right Foot",
|
50 |
+
"Thorax",
|
51 |
+
"Left Toe",
|
52 |
+
"Right Toe",
|
53 |
+
"Neck",
|
54 |
+
"Left Shoulder",
|
55 |
+
"Right Shoulder",
|
56 |
+
"Head",
|
57 |
+
"Left ForeArm",
|
58 |
+
"Right ForeArm",
|
59 |
+
"Left Arm",
|
60 |
+
"Right Arm",
|
61 |
+
"Left Hand",
|
62 |
+
"Right Hand",
|
63 |
+
"Left Finger",
|
64 |
+
"Right Finger",
|
65 |
+
)
|
66 |
+
|
67 |
+
SKELETON = (
|
68 |
+
(0, 1),
|
69 |
+
(0, 2),
|
70 |
+
(0, 3),
|
71 |
+
(1, 4),
|
72 |
+
(2, 5),
|
73 |
+
(3, 6),
|
74 |
+
(4, 7),
|
75 |
+
(5, 8),
|
76 |
+
(6, 9),
|
77 |
+
(7, 10),
|
78 |
+
(8, 11),
|
79 |
+
(9, 12),
|
80 |
+
(12, 13),
|
81 |
+
(12, 14),
|
82 |
+
(12, 15),
|
83 |
+
(13, 16),
|
84 |
+
(14, 17),
|
85 |
+
(16, 18),
|
86 |
+
(17, 19),
|
87 |
+
(18, 20),
|
88 |
+
(19, 21),
|
89 |
+
(20, 22),
|
90 |
+
(21, 23),
|
91 |
+
)
|
92 |
+
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(SMPL, self).__init__(*args, **kwargs)
|
95 |
+
joints = [JOINT_MAP[i] for i in JOINT_NAMES]
|
96 |
+
joint_regressor_extra = kwargs["joint_regressor_extra_path"]
|
97 |
+
J_regressor_extra = np.load(joint_regressor_extra)
|
98 |
+
self.register_buffer(
|
99 |
+
"J_regressor_extra", torch.tensor(J_regressor_extra, dtype=torch.float32)
|
100 |
+
)
|
101 |
+
self.joint_map = torch.tensor(joints, dtype=torch.long)
|
102 |
+
|
103 |
+
def forward(self, *args, **kwargs):
|
104 |
+
kwargs["get_skin"] = True
|
105 |
+
smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
106 |
+
extra_joints = vertices2joints(
|
107 |
+
self.J_regressor_extra, smpl_output.vertices
|
108 |
+
) # Additional 9 joints #Check doc/J_regressor_extra.png
|
109 |
+
joints = torch.cat(
|
110 |
+
[smpl_output.joints, extra_joints], dim=1
|
111 |
+
) # [N, 24 + 21, 3] + [N, 9, 3]
|
112 |
+
joints = joints[:, self.joint_map, :]
|
113 |
+
output = SMPLOutput(
|
114 |
+
vertices=smpl_output.vertices,
|
115 |
+
global_orient=smpl_output.global_orient,
|
116 |
+
body_pose=smpl_output.body_pose,
|
117 |
+
joints=joints,
|
118 |
+
betas=smpl_output.betas,
|
119 |
+
full_pose=smpl_output.full_pose,
|
120 |
+
)
|
121 |
+
return output
|
122 |
+
|
123 |
+
|
124 |
+
class SMPLX(_SMPLX):
|
125 |
+
"""Extension of the official SMPL implementation to support more joints"""
|
126 |
+
|
127 |
+
JOINTS = (
|
128 |
+
"Hips",
|
129 |
+
"Left Upper Leg",
|
130 |
+
"Right Upper Leg",
|
131 |
+
"Spine",
|
132 |
+
"Left Leg",
|
133 |
+
"Right Leg",
|
134 |
+
"Spine1",
|
135 |
+
"Left Foot",
|
136 |
+
"Right Foot",
|
137 |
+
"Thorax",
|
138 |
+
"Left Toe",
|
139 |
+
"Right Toe",
|
140 |
+
"Neck",
|
141 |
+
"Left Shoulder",
|
142 |
+
"Right Shoulder",
|
143 |
+
"Head",
|
144 |
+
"Left ForeArm",
|
145 |
+
"Right ForeArm",
|
146 |
+
"Left Arm",
|
147 |
+
"Right Arm",
|
148 |
+
"Left Hand",
|
149 |
+
"Right Hand",
|
150 |
+
)
|
151 |
+
|
152 |
+
SKELETON = (
|
153 |
+
(0, 1),
|
154 |
+
(0, 2),
|
155 |
+
(0, 3),
|
156 |
+
(1, 4),
|
157 |
+
(2, 5),
|
158 |
+
(3, 6),
|
159 |
+
(4, 7),
|
160 |
+
(5, 8),
|
161 |
+
(6, 9),
|
162 |
+
(7, 10),
|
163 |
+
(8, 11),
|
164 |
+
(9, 12),
|
165 |
+
(12, 13),
|
166 |
+
(12, 14),
|
167 |
+
(12, 15),
|
168 |
+
(13, 16),
|
169 |
+
(14, 17),
|
170 |
+
(16, 18),
|
171 |
+
(17, 19),
|
172 |
+
(18, 20),
|
173 |
+
(19, 21),
|
174 |
+
)
|
175 |
+
|
176 |
+
def __init__(self, *args, **kwargs):
|
177 |
+
kwargs["ext"] = "pkl" # We have pkl file
|
178 |
+
super(SMPLX, self).__init__(*args, **kwargs)
|
179 |
+
joints = [JOINT_MAP[i] for i in JOINT_NAMES]
|
180 |
+
self.joint_map = torch.tensor(joints, dtype=torch.long)
|
181 |
+
|
182 |
+
def forward(self, *args, **kwargs):
|
183 |
+
kwargs["get_skin"] = True
|
184 |
+
|
185 |
+
# if pose parameter is for SMPL with 21 joints (ignoring root)
|
186 |
+
try:
|
187 |
+
if kwargs["body_pose"].shape[1] == 69:
|
188 |
+
kwargs["body_pose"] = kwargs["body_pose"][
|
189 |
+
:, : -2 * 3
|
190 |
+
] # Ignore the last two joints (which are on the palm. Not used)
|
191 |
+
|
192 |
+
if kwargs["body_pose"].shape[1] == 23:
|
193 |
+
kwargs["body_pose"] = kwargs["body_pose"][
|
194 |
+
:, :-2
|
195 |
+
] # Ignore the last two joints (which are on the palm. Not used)
|
196 |
+
except:
|
197 |
+
pass
|
198 |
+
|
199 |
+
smpl_output = super(SMPLX, self).forward(*args, **kwargs)
|
200 |
+
|
201 |
+
# SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0
|
202 |
+
smplx_to_smpl = (
|
203 |
+
list(range(0, 22)) + [28, 43] + list(range(55, 76))
|
204 |
+
) # 28 left middle finger , 43: right middle finger 1
|
205 |
+
smpl_joints = smpl_output.joints[
|
206 |
+
:, smplx_to_smpl, :
|
207 |
+
] # Convert SMPL-X to SMPL 127 ->45
|
208 |
+
joints = smpl_joints
|
209 |
+
joints = joints[:, self.joint_map, :]
|
210 |
+
|
211 |
+
smplx_lhand = (
|
212 |
+
[20] + list(range(25, 40)) + list(range(66, 71))
|
213 |
+
) # 20 for left wrist. 20 finger joints
|
214 |
+
lhand_joints = smpl_output.joints[:, smplx_lhand, :] # (N,21,3)
|
215 |
+
lhand_joints = lhand_joints[
|
216 |
+
:, SMPLX_HAND_TO_PANOPTIC, :
|
217 |
+
] # Convert SMPL-X hand order to paonptic hand order
|
218 |
+
|
219 |
+
smplx_rhand = (
|
220 |
+
[21] + list(range(40, 55)) + list(range(71, 76))
|
221 |
+
) # 21 for right wrist. 20 finger joints
|
222 |
+
rhand_joints = smpl_output.joints[:, smplx_rhand, :] # (N,21,3)
|
223 |
+
rhand_joints = rhand_joints[
|
224 |
+
:, SMPLX_HAND_TO_PANOPTIC, :
|
225 |
+
] # Convert SMPL-X hand order to paonptic hand order
|
226 |
+
|
227 |
+
output = SMPLXOutput(
|
228 |
+
vertices=smpl_output.vertices,
|
229 |
+
global_orient=smpl_output.global_orient,
|
230 |
+
body_pose=smpl_output.body_pose,
|
231 |
+
joints=joints,
|
232 |
+
right_hand_pose=rhand_joints, # N,21,3
|
233 |
+
left_hand_pose=lhand_joints, # N,21,3
|
234 |
+
betas=smpl_output.betas,
|
235 |
+
full_pose=smpl_output.full_pose,
|
236 |
+
A=smpl_output.A,
|
237 |
+
)
|
238 |
+
return output
|
239 |
+
|
240 |
+
|
241 |
+
"""
|
242 |
+
0 pelvis',
|
243 |
+
1 left_hip',
|
244 |
+
2 right_hip',
|
245 |
+
3 spine1',
|
246 |
+
4 left_knee',
|
247 |
+
5 right_knee',
|
248 |
+
6 spine2',
|
249 |
+
7 left_ankle',
|
250 |
+
8 right_ankle',
|
251 |
+
9 spine3',
|
252 |
+
10 left_foot',
|
253 |
+
11 right_foot',
|
254 |
+
12 neck',
|
255 |
+
13 left_collar',
|
256 |
+
14 right_collar',
|
257 |
+
15 head',
|
258 |
+
16 left_shoulder',
|
259 |
+
17 right_shoulder',
|
260 |
+
18 left_elbow',
|
261 |
+
19 right_elbow',
|
262 |
+
20 left_wrist',
|
263 |
+
21 right_wrist',
|
264 |
+
22 jaw',
|
265 |
+
23 left_eye_smplhf',
|
266 |
+
24 right_eye_smplhf',
|
267 |
+
25 left_index1',
|
268 |
+
26 left_index2',
|
269 |
+
27 left_index3',
|
270 |
+
28 left_middle1',
|
271 |
+
29 left_middle2',
|
272 |
+
30 left_middle3',
|
273 |
+
31 left_pinky1',
|
274 |
+
32 left_pinky2',
|
275 |
+
33 left_pinky3',
|
276 |
+
34 left_ring1',
|
277 |
+
35 left_ring2',
|
278 |
+
36 left_ring3',
|
279 |
+
37 left_thumb1',
|
280 |
+
38 left_thumb2',
|
281 |
+
39 left_thumb3',
|
282 |
+
40 right_index1',
|
283 |
+
41 right_index2',
|
284 |
+
42 right_index3',
|
285 |
+
43 right_middle1',
|
286 |
+
44 right_middle2',
|
287 |
+
45 right_middle3',
|
288 |
+
46 right_pinky1',
|
289 |
+
47 right_pinky2',
|
290 |
+
48 right_pinky3',
|
291 |
+
49 right_ring1',
|
292 |
+
50 right_ring2',
|
293 |
+
51 right_ring3',
|
294 |
+
52 right_thumb1',
|
295 |
+
53 right_thumb2',
|
296 |
+
54 right_thumb3',
|
297 |
+
55 nose',
|
298 |
+
56 right_eye',
|
299 |
+
57 left_eye',
|
300 |
+
58 right_ear',
|
301 |
+
59 left_ear',
|
302 |
+
60 left_big_toe',
|
303 |
+
61 left_small_toe',
|
304 |
+
62 left_heel',
|
305 |
+
63 right_big_toe',
|
306 |
+
64 right_small_toe',
|
307 |
+
65 right_heel',
|
308 |
+
66 left_thumb',
|
309 |
+
67 left_index',
|
310 |
+
68 left_middle',
|
311 |
+
69 left_ring',
|
312 |
+
70 left_pinky',
|
313 |
+
71 right_thumb',
|
314 |
+
72 right_index',
|
315 |
+
73 right_middle',
|
316 |
+
74 right_ring',
|
317 |
+
75 right_pinky',
|
318 |
+
76 right_eye_brow1',
|
319 |
+
77 right_eye_brow2',
|
320 |
+
78 right_eye_brow3',
|
321 |
+
79 right_eye_brow4',
|
322 |
+
80 right_eye_brow5',
|
323 |
+
81 left_eye_brow5',
|
324 |
+
82 left_eye_brow4',
|
325 |
+
83 left_eye_brow3',
|
326 |
+
84 left_eye_brow2',
|
327 |
+
85 left_eye_brow1',
|
328 |
+
86 nose1',
|
329 |
+
87 nose2',
|
330 |
+
88 nose3',
|
331 |
+
89 nose4',
|
332 |
+
90 right_nose_2',
|
333 |
+
91 right_nose_1',
|
334 |
+
92 nose_middle',
|
335 |
+
93 left_nose_1',
|
336 |
+
94 left_nose_2',
|
337 |
+
95 right_eye1',
|
338 |
+
96 right_eye2',
|
339 |
+
97 right_eye3',
|
340 |
+
98 right_eye4',
|
341 |
+
99 right_eye5',
|
342 |
+
100 right_eye6',
|
343 |
+
101 left_eye4',
|
344 |
+
102 left_eye3',
|
345 |
+
103 left_eye2',
|
346 |
+
104 left_eye1',
|
347 |
+
105 left_eye6',
|
348 |
+
106 left_eye5',
|
349 |
+
107 right_mouth_1',
|
350 |
+
108 right_mouth_2',
|
351 |
+
109 right_mouth_3',
|
352 |
+
110 mouth_top',
|
353 |
+
111 left_mouth_3',
|
354 |
+
112 left_mouth_2',
|
355 |
+
113 left_mouth_1',
|
356 |
+
114 left_mouth_5', # 59 in OpenPose output
|
357 |
+
115 left_mouth_4', # 58 in OpenPose output
|
358 |
+
116 mouth_bottom',
|
359 |
+
117 right_mouth_4',
|
360 |
+
118 right_mouth_5',
|
361 |
+
119 right_lip_1',
|
362 |
+
120 right_lip_2',
|
363 |
+
121 lip_top',
|
364 |
+
122 left_lip_2',
|
365 |
+
123 left_lip_1',
|
366 |
+
124 left_lip_3',
|
367 |
+
125 lip_bottom',
|
368 |
+
126 right_lip_3',
|
369 |
+
127 right_contour_1',
|
370 |
+
128 right_contour_2',
|
371 |
+
129 right_contour_3',
|
372 |
+
130 right_contour_4',
|
373 |
+
131 right_contour_5',
|
374 |
+
132 right_contour_6',
|
375 |
+
133 right_contour_7',
|
376 |
+
134 right_contour_8',
|
377 |
+
135 contour_middle',
|
378 |
+
136 left_contour_8',
|
379 |
+
137 left_contour_7',
|
380 |
+
138 left_contour_6',
|
381 |
+
139 left_contour_5',
|
382 |
+
140 left_contour_4',
|
383 |
+
141 left_contour_3',
|
384 |
+
142 left_contour_2',
|
385 |
+
143 left_contour_1'
|
386 |
+
"""
|
387 |
+
|
388 |
+
|
389 |
+
# SMPL Joints:
|
390 |
+
"""
|
391 |
+
0 pelvis',
|
392 |
+
1 left_hip',
|
393 |
+
2 right_hip',
|
394 |
+
3 spine1',
|
395 |
+
4 left_knee',
|
396 |
+
5 right_knee',
|
397 |
+
6 spine2',
|
398 |
+
7 left_ankle',
|
399 |
+
8 right_ankle',
|
400 |
+
9 spine3',
|
401 |
+
10 left_foot',
|
402 |
+
11 right_foot',
|
403 |
+
12 neck',
|
404 |
+
13 left_collar',
|
405 |
+
14 right_collar',
|
406 |
+
15 head',
|
407 |
+
16 left_shoulder',
|
408 |
+
17 right_shoulder',
|
409 |
+
18 left_elbow',
|
410 |
+
19 right_elbow',
|
411 |
+
20 left_wrist',
|
412 |
+
21 right_wrist',
|
413 |
+
22
|
414 |
+
23
|
415 |
+
24 nose',
|
416 |
+
25 right_eye',
|
417 |
+
26 left_eye',
|
418 |
+
27 right_ear',
|
419 |
+
28 left_ear',
|
420 |
+
29 left_big_toe',
|
421 |
+
30 left_small_toe',
|
422 |
+
31 left_heel',
|
423 |
+
32 right_big_toe',
|
424 |
+
33 right_small_toe',
|
425 |
+
34 right_heel',
|
426 |
+
35 left_thumb',
|
427 |
+
36 left_index',
|
428 |
+
37 left_middle',
|
429 |
+
38 left_ring',
|
430 |
+
39 left_pinky',
|
431 |
+
40 right_thumb',
|
432 |
+
41 right_index',
|
433 |
+
42 right_middle',
|
434 |
+
43 right_ring',
|
435 |
+
44 right_pinky',
|
436 |
+
"""
|