kbrodt commited on
Commit
f73c7b1
1 Parent(s): 347ea73

Upload smpl.py

Browse files
Files changed (1) hide show
  1. src/spin/smpl.py +436 -0
src/spin/smpl.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from smplx import SMPL as _SMPL
4
+ from smplx import SMPLX as _SMPLX
5
+ from smplx.body_models import SMPLOutput, SMPLXOutput
6
+ from smplx.lbs import vertices2joints
7
+
8
+ from .constants import JOINT_MAP, JOINT_NAMES
9
+
10
+ # Hand joints
11
+ SMPLX_HAND_TO_PANOPTIC = [
12
+ 0,
13
+ 13,
14
+ 14,
15
+ 15,
16
+ 16,
17
+ 1,
18
+ 2,
19
+ 3,
20
+ 17,
21
+ 4,
22
+ 5,
23
+ 6,
24
+ 18,
25
+ 10,
26
+ 11,
27
+ 12,
28
+ 19,
29
+ 7,
30
+ 8,
31
+ 9,
32
+ 20,
33
+ ] # Wrist Thumb to Pinky
34
+
35
+
36
+
37
+ class SMPL(_SMPL):
38
+ """Extension of the official SMPL implementation to support more joints"""
39
+
40
+ JOINTS = (
41
+ "Hips",
42
+ "Left Upper Leg",
43
+ "Right Upper Leg",
44
+ "Spine",
45
+ "Left Leg",
46
+ "Right Leg",
47
+ "Spine1",
48
+ "Left Foot",
49
+ "Right Foot",
50
+ "Thorax",
51
+ "Left Toe",
52
+ "Right Toe",
53
+ "Neck",
54
+ "Left Shoulder",
55
+ "Right Shoulder",
56
+ "Head",
57
+ "Left ForeArm",
58
+ "Right ForeArm",
59
+ "Left Arm",
60
+ "Right Arm",
61
+ "Left Hand",
62
+ "Right Hand",
63
+ "Left Finger",
64
+ "Right Finger",
65
+ )
66
+
67
+ SKELETON = (
68
+ (0, 1),
69
+ (0, 2),
70
+ (0, 3),
71
+ (1, 4),
72
+ (2, 5),
73
+ (3, 6),
74
+ (4, 7),
75
+ (5, 8),
76
+ (6, 9),
77
+ (7, 10),
78
+ (8, 11),
79
+ (9, 12),
80
+ (12, 13),
81
+ (12, 14),
82
+ (12, 15),
83
+ (13, 16),
84
+ (14, 17),
85
+ (16, 18),
86
+ (17, 19),
87
+ (18, 20),
88
+ (19, 21),
89
+ (20, 22),
90
+ (21, 23),
91
+ )
92
+
93
+ def __init__(self, *args, **kwargs):
94
+ super(SMPL, self).__init__(*args, **kwargs)
95
+ joints = [JOINT_MAP[i] for i in JOINT_NAMES]
96
+ joint_regressor_extra = kwargs["joint_regressor_extra_path"]
97
+ J_regressor_extra = np.load(joint_regressor_extra)
98
+ self.register_buffer(
99
+ "J_regressor_extra", torch.tensor(J_regressor_extra, dtype=torch.float32)
100
+ )
101
+ self.joint_map = torch.tensor(joints, dtype=torch.long)
102
+
103
+ def forward(self, *args, **kwargs):
104
+ kwargs["get_skin"] = True
105
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
106
+ extra_joints = vertices2joints(
107
+ self.J_regressor_extra, smpl_output.vertices
108
+ ) # Additional 9 joints #Check doc/J_regressor_extra.png
109
+ joints = torch.cat(
110
+ [smpl_output.joints, extra_joints], dim=1
111
+ ) # [N, 24 + 21, 3] + [N, 9, 3]
112
+ joints = joints[:, self.joint_map, :]
113
+ output = SMPLOutput(
114
+ vertices=smpl_output.vertices,
115
+ global_orient=smpl_output.global_orient,
116
+ body_pose=smpl_output.body_pose,
117
+ joints=joints,
118
+ betas=smpl_output.betas,
119
+ full_pose=smpl_output.full_pose,
120
+ )
121
+ return output
122
+
123
+
124
+ class SMPLX(_SMPLX):
125
+ """Extension of the official SMPL implementation to support more joints"""
126
+
127
+ JOINTS = (
128
+ "Hips",
129
+ "Left Upper Leg",
130
+ "Right Upper Leg",
131
+ "Spine",
132
+ "Left Leg",
133
+ "Right Leg",
134
+ "Spine1",
135
+ "Left Foot",
136
+ "Right Foot",
137
+ "Thorax",
138
+ "Left Toe",
139
+ "Right Toe",
140
+ "Neck",
141
+ "Left Shoulder",
142
+ "Right Shoulder",
143
+ "Head",
144
+ "Left ForeArm",
145
+ "Right ForeArm",
146
+ "Left Arm",
147
+ "Right Arm",
148
+ "Left Hand",
149
+ "Right Hand",
150
+ )
151
+
152
+ SKELETON = (
153
+ (0, 1),
154
+ (0, 2),
155
+ (0, 3),
156
+ (1, 4),
157
+ (2, 5),
158
+ (3, 6),
159
+ (4, 7),
160
+ (5, 8),
161
+ (6, 9),
162
+ (7, 10),
163
+ (8, 11),
164
+ (9, 12),
165
+ (12, 13),
166
+ (12, 14),
167
+ (12, 15),
168
+ (13, 16),
169
+ (14, 17),
170
+ (16, 18),
171
+ (17, 19),
172
+ (18, 20),
173
+ (19, 21),
174
+ )
175
+
176
+ def __init__(self, *args, **kwargs):
177
+ kwargs["ext"] = "pkl" # We have pkl file
178
+ super(SMPLX, self).__init__(*args, **kwargs)
179
+ joints = [JOINT_MAP[i] for i in JOINT_NAMES]
180
+ self.joint_map = torch.tensor(joints, dtype=torch.long)
181
+
182
+ def forward(self, *args, **kwargs):
183
+ kwargs["get_skin"] = True
184
+
185
+ # if pose parameter is for SMPL with 21 joints (ignoring root)
186
+ try:
187
+ if kwargs["body_pose"].shape[1] == 69:
188
+ kwargs["body_pose"] = kwargs["body_pose"][
189
+ :, : -2 * 3
190
+ ] # Ignore the last two joints (which are on the palm. Not used)
191
+
192
+ if kwargs["body_pose"].shape[1] == 23:
193
+ kwargs["body_pose"] = kwargs["body_pose"][
194
+ :, :-2
195
+ ] # Ignore the last two joints (which are on the palm. Not used)
196
+ except:
197
+ pass
198
+
199
+ smpl_output = super(SMPLX, self).forward(*args, **kwargs)
200
+
201
+ # SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0
202
+ smplx_to_smpl = (
203
+ list(range(0, 22)) + [28, 43] + list(range(55, 76))
204
+ ) # 28 left middle finger , 43: right middle finger 1
205
+ smpl_joints = smpl_output.joints[
206
+ :, smplx_to_smpl, :
207
+ ] # Convert SMPL-X to SMPL 127 ->45
208
+ joints = smpl_joints
209
+ joints = joints[:, self.joint_map, :]
210
+
211
+ smplx_lhand = (
212
+ [20] + list(range(25, 40)) + list(range(66, 71))
213
+ ) # 20 for left wrist. 20 finger joints
214
+ lhand_joints = smpl_output.joints[:, smplx_lhand, :] # (N,21,3)
215
+ lhand_joints = lhand_joints[
216
+ :, SMPLX_HAND_TO_PANOPTIC, :
217
+ ] # Convert SMPL-X hand order to paonptic hand order
218
+
219
+ smplx_rhand = (
220
+ [21] + list(range(40, 55)) + list(range(71, 76))
221
+ ) # 21 for right wrist. 20 finger joints
222
+ rhand_joints = smpl_output.joints[:, smplx_rhand, :] # (N,21,3)
223
+ rhand_joints = rhand_joints[
224
+ :, SMPLX_HAND_TO_PANOPTIC, :
225
+ ] # Convert SMPL-X hand order to paonptic hand order
226
+
227
+ output = SMPLXOutput(
228
+ vertices=smpl_output.vertices,
229
+ global_orient=smpl_output.global_orient,
230
+ body_pose=smpl_output.body_pose,
231
+ joints=joints,
232
+ right_hand_pose=rhand_joints, # N,21,3
233
+ left_hand_pose=lhand_joints, # N,21,3
234
+ betas=smpl_output.betas,
235
+ full_pose=smpl_output.full_pose,
236
+ A=smpl_output.A,
237
+ )
238
+ return output
239
+
240
+
241
+ """
242
+ 0 pelvis',
243
+ 1 left_hip',
244
+ 2 right_hip',
245
+ 3 spine1',
246
+ 4 left_knee',
247
+ 5 right_knee',
248
+ 6 spine2',
249
+ 7 left_ankle',
250
+ 8 right_ankle',
251
+ 9 spine3',
252
+ 10 left_foot',
253
+ 11 right_foot',
254
+ 12 neck',
255
+ 13 left_collar',
256
+ 14 right_collar',
257
+ 15 head',
258
+ 16 left_shoulder',
259
+ 17 right_shoulder',
260
+ 18 left_elbow',
261
+ 19 right_elbow',
262
+ 20 left_wrist',
263
+ 21 right_wrist',
264
+ 22 jaw',
265
+ 23 left_eye_smplhf',
266
+ 24 right_eye_smplhf',
267
+ 25 left_index1',
268
+ 26 left_index2',
269
+ 27 left_index3',
270
+ 28 left_middle1',
271
+ 29 left_middle2',
272
+ 30 left_middle3',
273
+ 31 left_pinky1',
274
+ 32 left_pinky2',
275
+ 33 left_pinky3',
276
+ 34 left_ring1',
277
+ 35 left_ring2',
278
+ 36 left_ring3',
279
+ 37 left_thumb1',
280
+ 38 left_thumb2',
281
+ 39 left_thumb3',
282
+ 40 right_index1',
283
+ 41 right_index2',
284
+ 42 right_index3',
285
+ 43 right_middle1',
286
+ 44 right_middle2',
287
+ 45 right_middle3',
288
+ 46 right_pinky1',
289
+ 47 right_pinky2',
290
+ 48 right_pinky3',
291
+ 49 right_ring1',
292
+ 50 right_ring2',
293
+ 51 right_ring3',
294
+ 52 right_thumb1',
295
+ 53 right_thumb2',
296
+ 54 right_thumb3',
297
+ 55 nose',
298
+ 56 right_eye',
299
+ 57 left_eye',
300
+ 58 right_ear',
301
+ 59 left_ear',
302
+ 60 left_big_toe',
303
+ 61 left_small_toe',
304
+ 62 left_heel',
305
+ 63 right_big_toe',
306
+ 64 right_small_toe',
307
+ 65 right_heel',
308
+ 66 left_thumb',
309
+ 67 left_index',
310
+ 68 left_middle',
311
+ 69 left_ring',
312
+ 70 left_pinky',
313
+ 71 right_thumb',
314
+ 72 right_index',
315
+ 73 right_middle',
316
+ 74 right_ring',
317
+ 75 right_pinky',
318
+ 76 right_eye_brow1',
319
+ 77 right_eye_brow2',
320
+ 78 right_eye_brow3',
321
+ 79 right_eye_brow4',
322
+ 80 right_eye_brow5',
323
+ 81 left_eye_brow5',
324
+ 82 left_eye_brow4',
325
+ 83 left_eye_brow3',
326
+ 84 left_eye_brow2',
327
+ 85 left_eye_brow1',
328
+ 86 nose1',
329
+ 87 nose2',
330
+ 88 nose3',
331
+ 89 nose4',
332
+ 90 right_nose_2',
333
+ 91 right_nose_1',
334
+ 92 nose_middle',
335
+ 93 left_nose_1',
336
+ 94 left_nose_2',
337
+ 95 right_eye1',
338
+ 96 right_eye2',
339
+ 97 right_eye3',
340
+ 98 right_eye4',
341
+ 99 right_eye5',
342
+ 100 right_eye6',
343
+ 101 left_eye4',
344
+ 102 left_eye3',
345
+ 103 left_eye2',
346
+ 104 left_eye1',
347
+ 105 left_eye6',
348
+ 106 left_eye5',
349
+ 107 right_mouth_1',
350
+ 108 right_mouth_2',
351
+ 109 right_mouth_3',
352
+ 110 mouth_top',
353
+ 111 left_mouth_3',
354
+ 112 left_mouth_2',
355
+ 113 left_mouth_1',
356
+ 114 left_mouth_5', # 59 in OpenPose output
357
+ 115 left_mouth_4', # 58 in OpenPose output
358
+ 116 mouth_bottom',
359
+ 117 right_mouth_4',
360
+ 118 right_mouth_5',
361
+ 119 right_lip_1',
362
+ 120 right_lip_2',
363
+ 121 lip_top',
364
+ 122 left_lip_2',
365
+ 123 left_lip_1',
366
+ 124 left_lip_3',
367
+ 125 lip_bottom',
368
+ 126 right_lip_3',
369
+ 127 right_contour_1',
370
+ 128 right_contour_2',
371
+ 129 right_contour_3',
372
+ 130 right_contour_4',
373
+ 131 right_contour_5',
374
+ 132 right_contour_6',
375
+ 133 right_contour_7',
376
+ 134 right_contour_8',
377
+ 135 contour_middle',
378
+ 136 left_contour_8',
379
+ 137 left_contour_7',
380
+ 138 left_contour_6',
381
+ 139 left_contour_5',
382
+ 140 left_contour_4',
383
+ 141 left_contour_3',
384
+ 142 left_contour_2',
385
+ 143 left_contour_1'
386
+ """
387
+
388
+
389
+ # SMPL Joints:
390
+ """
391
+ 0 pelvis',
392
+ 1 left_hip',
393
+ 2 right_hip',
394
+ 3 spine1',
395
+ 4 left_knee',
396
+ 5 right_knee',
397
+ 6 spine2',
398
+ 7 left_ankle',
399
+ 8 right_ankle',
400
+ 9 spine3',
401
+ 10 left_foot',
402
+ 11 right_foot',
403
+ 12 neck',
404
+ 13 left_collar',
405
+ 14 right_collar',
406
+ 15 head',
407
+ 16 left_shoulder',
408
+ 17 right_shoulder',
409
+ 18 left_elbow',
410
+ 19 right_elbow',
411
+ 20 left_wrist',
412
+ 21 right_wrist',
413
+ 22
414
+ 23
415
+ 24 nose',
416
+ 25 right_eye',
417
+ 26 left_eye',
418
+ 27 right_ear',
419
+ 28 left_ear',
420
+ 29 left_big_toe',
421
+ 30 left_small_toe',
422
+ 31 left_heel',
423
+ 32 right_big_toe',
424
+ 33 right_small_toe',
425
+ 34 right_heel',
426
+ 35 left_thumb',
427
+ 36 left_index',
428
+ 37 left_middle',
429
+ 38 left_ring',
430
+ 39 left_pinky',
431
+ 40 right_thumb',
432
+ 41 right_index',
433
+ 42 right_middle',
434
+ 43 right_ring',
435
+ 44 right_pinky',
436
+ """