kbrodt commited on
Commit
9da5f30
1 Parent(s): 38e46a7

Upload pose.py

Browse files
Files changed (1) hide show
  1. src/pose.py +1482 -0
src/pose.py ADDED
@@ -0,0 +1,1482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import PIL.Image as Image
8
+ import selfcontact
9
+ import selfcontact.losses
10
+ import shapely.geometry
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+ import torchgeometry
15
+ import tqdm
16
+ import trimesh
17
+ from skimage import measure
18
+
19
+ import fist_pose
20
+ import hist_cub
21
+ import losses
22
+ import pose_estimation
23
+ import spin
24
+ import utils
25
+
26
+ PE_KSP_TO_SPIN = {
27
+ "Head": "Head",
28
+ "Neck": "Neck",
29
+ "Right Shoulder": "Right ForeArm",
30
+ "Right Arm": "Right Arm",
31
+ "Right Hand": "Right Hand",
32
+ "Left Shoulder": "Left ForeArm",
33
+ "Left Arm": "Left Arm",
34
+ "Left Hand": "Left Hand",
35
+ "Spine": "Spine1",
36
+ "Hips": "Hips",
37
+ "Right Upper Leg": "Right Upper Leg",
38
+ "Right Leg": "Right Leg",
39
+ "Right Foot": "Right Foot",
40
+ "Left Upper Leg": "Left Upper Leg",
41
+ "Left Leg": "Left Leg",
42
+ "Left Foot": "Left Foot",
43
+ "Left Toe": "Left Toe",
44
+ "Right Toe": "Right Toe",
45
+ }
46
+ MODELS_DIR = "models"
47
+
48
+
49
+ def parse_args():
50
+ parser = argparse.ArgumentParser()
51
+
52
+ parser.add_argument(
53
+ "--pose-estimation-model-path",
54
+ type=str,
55
+ default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx",
56
+ help="Pose Estimation model",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--contact-model-path",
61
+ type=str,
62
+ default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx",
63
+ help="Contact model",
64
+ )
65
+
66
+ parser.add_argument(
67
+ "--device",
68
+ type=str,
69
+ default="cuda",
70
+ choices=["cpu", "cuda"],
71
+ help="Torch device",
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--spin-model-path",
76
+ type=str,
77
+ default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt",
78
+ help="SPIN model path",
79
+ )
80
+
81
+ parser.add_argument(
82
+ "--smpl-type",
83
+ type=str,
84
+ default="smplx",
85
+ choices=["smplx"],
86
+ help="SMPL model type",
87
+ )
88
+ parser.add_argument(
89
+ "--smpl-model-dir",
90
+ type=str,
91
+ default=f"./{MODELS_DIR}/models/smplx",
92
+ help="SMPL model dir",
93
+ )
94
+ parser.add_argument(
95
+ "--smpl-mean-params-path",
96
+ type=str,
97
+ default=f"./{MODELS_DIR}/data/smpl_mean_params.npz",
98
+ help="SMPL mean params",
99
+ )
100
+ parser.add_argument(
101
+ "--essentials-dir",
102
+ type=str,
103
+ default=f"./{MODELS_DIR}/smplify-xmc-essentials",
104
+ help="SMPL Essentials folder for contacts",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--parametrization-path",
109
+ type=str,
110
+ default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy",
111
+ help="Parametrization path",
112
+ )
113
+ parser.add_argument(
114
+ "--bone-parametrization-path",
115
+ type=str,
116
+ default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy",
117
+ help="Bone parametrization path",
118
+ )
119
+ parser.add_argument(
120
+ "--foot-inds-path",
121
+ type=str,
122
+ default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy",
123
+ help="Foot indinces",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--save-path",
128
+ type=str,
129
+ required=True,
130
+ help="Path to save the results",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--img-path",
135
+ type=str,
136
+ required=True,
137
+ help="Path to img to test",
138
+ )
139
+
140
+ parser.add_argument(
141
+ "--use-contacts",
142
+ action="store_true",
143
+ help="Use contact model",
144
+ )
145
+ parser.add_argument(
146
+ "--use-msc",
147
+ action="store_true",
148
+ help="Use MSC loss",
149
+ )
150
+ parser.add_argument(
151
+ "--use-natural",
152
+ action="store_true",
153
+ help="Use regularity",
154
+ )
155
+ parser.add_argument(
156
+ "--use-cos",
157
+ action="store_true",
158
+ help="Use cos model",
159
+ )
160
+ parser.add_argument(
161
+ "--use-angle-transf",
162
+ action="store_true",
163
+ help="Use cube foreshortening transformation",
164
+ )
165
+
166
+ parser.add_argument(
167
+ "--c-mse",
168
+ type=float,
169
+ default=0,
170
+ help="MSE weight",
171
+ )
172
+ parser.add_argument(
173
+ "--c-par",
174
+ type=float,
175
+ default=10,
176
+ help="Parallel weight",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--c-f",
181
+ type=float,
182
+ default=1000,
183
+ help="Cos coef",
184
+ )
185
+ parser.add_argument(
186
+ "--c-parallel",
187
+ type=float,
188
+ default=100,
189
+ help="Parallel weight",
190
+ )
191
+ parser.add_argument(
192
+ "--c-reg",
193
+ type=float,
194
+ default=1000,
195
+ help="Regularity weight",
196
+ )
197
+ parser.add_argument(
198
+ "--c-cont2d",
199
+ type=float,
200
+ default=1,
201
+ help="Contact 2D weight",
202
+ )
203
+ parser.add_argument(
204
+ "--c-msc",
205
+ type=float,
206
+ default=17_500,
207
+ help="MSC weight",
208
+ )
209
+
210
+ parser.add_argument(
211
+ "--fist",
212
+ nargs="+",
213
+ type=str,
214
+ choices=list(fist_pose.INT_TO_FIST),
215
+ )
216
+
217
+ args = parser.parse_args()
218
+
219
+ return args
220
+
221
+
222
+ def freeze_layers(model):
223
+ for module in model.modules():
224
+ if type(module) is False:
225
+ continue
226
+
227
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
228
+ module.eval()
229
+ for m in module.parameters():
230
+ m.requires_grad = False
231
+
232
+ if isinstance(module, nn.Dropout):
233
+ module.eval()
234
+ for m in module.parameters():
235
+ m.requires_grad = False
236
+
237
+
238
+ def project_and_normalize_to_spin(vertices_3d, camera):
239
+ vertices_2d = vertices_3d # [:, :2]
240
+
241
+ scale, translate = camera[0], camera[1:]
242
+ translate = scale.new_zeros(3)
243
+ translate[:2] = camera[1:]
244
+
245
+ vertices_2d = vertices_2d + translate
246
+ vertices_2d = scale * vertices_2d + 1
247
+ vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d
248
+
249
+ return vertices_2d
250
+
251
+
252
+ def project_and_normalize_to_spin_legs(vertices_3d, A, camera):
253
+ A, J = A
254
+ A = A[0]
255
+ J = J[0]
256
+ L = vertices_3d.new_tensor(
257
+ [
258
+ [0.98619063, 0.16560926, 0.00127302],
259
+ [-0.16560601, 0.98603675, 0.01749799],
260
+ [0.00164258, -0.01746717, 0.99984609],
261
+ ]
262
+ )
263
+ R = vertices_3d.new_tensor(
264
+ [
265
+ [0.9910211, -0.13368178, -0.0025208],
266
+ [0.13367888, 0.99027076, 0.03864949],
267
+ [-0.00267045, -0.03863944, 0.99924965],
268
+ ]
269
+ )
270
+ scale = camera[0]
271
+ R = A[2, :3, :3] @ R # 2 - right
272
+ L = A[1, :3, :3] @ L # 1 - left
273
+ r = J[5] - J[2]
274
+ l = J[4] - J[1]
275
+
276
+ rleg = scale * spin.constants.IMG_RES / 2 * R @ r
277
+ lleg = scale * spin.constants.IMG_RES / 2 * L @ l
278
+
279
+ rleg = rleg[:2]
280
+ lleg = lleg[:2]
281
+
282
+ return rleg, lleg
283
+
284
+
285
+ def rotation_matrix_to_angle_axis(rotmat):
286
+ bs, n_joints, *_ = rotmat.size()
287
+ rotmat = torch.cat(
288
+ [
289
+ rotmat.view(-1, 3, 3),
290
+ rotmat.new_tensor([0, 0, 1], dtype=torch.float32)
291
+ .view(bs, 3, 1)
292
+ .expand(n_joints, -1, -1),
293
+ ],
294
+ dim=-1,
295
+ )
296
+ aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat)
297
+ aa = aa.reshape(bs, 3 * n_joints)
298
+
299
+ return aa
300
+
301
+
302
+ def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):
303
+ if smpl.name() == "SMPL":
304
+ smpl_output = smpl(
305
+ betas=betas if use_betas else None,
306
+ body_pose=rotmat[:, 1:],
307
+ global_orient=rotmat[:, 0].unsqueeze(1),
308
+ pose2rot=False,
309
+ )
310
+ elif smpl.name() == "SMPL-X":
311
+ rotmat = rotation_matrix_to_angle_axis(rotmat)
312
+ if zero_hands:
313
+ for i in [20, 21]:
314
+ rotmat[:, 3 * i : 3 * (i + 1)] = 0
315
+
316
+ for i in [12, 15]: # neck, head
317
+ rotmat[:, 3 * i + 1] = 0 # y
318
+ smpl_output = smpl(
319
+ betas=betas if use_betas else None,
320
+ body_pose=rotmat[:, 3:],
321
+ global_orient=rotmat[:, :3],
322
+ pose2rot=True,
323
+ )
324
+ else:
325
+ raise NotImplementedError
326
+
327
+ return smpl_output, rotmat
328
+
329
+
330
+ def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False):
331
+ input_img = input_img.unsqueeze(0)
332
+ rotmat, betas, camera = model_hmr(input_img)
333
+
334
+ smpl_output, rotmat = get_smpl_output(
335
+ smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands
336
+ )
337
+
338
+ rotmat = rotmat.squeeze(0)
339
+ betas = betas.squeeze(0)
340
+ camera = camera.squeeze(0)
341
+ z = smpl_output.joints
342
+ z = z.squeeze(0)
343
+
344
+ return rotmat, betas, camera, smpl_output, z
345
+
346
+
347
+ def get_pred_and_data(
348
+ model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False
349
+ ):
350
+ rotmat, betas, camera, smpl_output, zz = get_predictions(
351
+ model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands
352
+ )
353
+
354
+ joints = smpl_output.joints.squeeze(0)
355
+ joints_2d = project_and_normalize_to_spin(joints, camera)
356
+ rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera)
357
+ joints_2d_orig = joints_2d
358
+ joints_2d = joints_2d[selector]
359
+
360
+ vertices = smpl_output.vertices.squeeze(0)
361
+ vertices_2d = project_and_normalize_to_spin(vertices, camera)
362
+
363
+ zz = zz[selector]
364
+
365
+ return (
366
+ rotmat,
367
+ betas,
368
+ camera,
369
+ joints_2d,
370
+ zz,
371
+ vertices_2d,
372
+ smpl_output,
373
+ (rleg, lleg),
374
+ joints_2d_orig,
375
+ )
376
+
377
+
378
+ def normalize_keypoints_to_spin(keypoints_2d, img_size):
379
+ h, w = img_size
380
+ if h > w: # vertically
381
+ ax1 = 1
382
+ ax2 = 0
383
+ else: # horizontal
384
+ ax1 = 0
385
+ ax2 = 1
386
+
387
+ shift = (img_size[ax1] - img_size[ax2]) / 2
388
+ scale = spin.constants.IMG_RES / img_size[ax2]
389
+ keypoints_2d_normalized = np.copy(keypoints_2d)
390
+ keypoints_2d_normalized[:, ax2] -= shift
391
+ keypoints_2d_normalized *= scale
392
+
393
+ return keypoints_2d_normalized, shift, scale, ax2
394
+
395
+
396
+ def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):
397
+ keypoints_2d_normalized = np.copy(keypoints_2d)
398
+ keypoints_2d_normalized /= scale
399
+ keypoints_2d_normalized[:, ax2] += shift
400
+
401
+ return keypoints_2d_normalized
402
+
403
+
404
+ def get_vertices_in_heatmap(contact_heatmap):
405
+ contact_heatmap_size = contact_heatmap.shape[:2]
406
+ label = measure.label(contact_heatmap)
407
+
408
+ y_data_conts = []
409
+ for i in range(1, label.max() + 1):
410
+ predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype(
411
+ "float"
412
+ )
413
+ predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin(
414
+ predicted_kps_contact, contact_heatmap_size
415
+ )
416
+ y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist()
417
+ y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull
418
+ y_data_conts.append(y_data_cont)
419
+
420
+ return y_data_conts
421
+
422
+
423
+ def get_contact_heatmap(model_contact, img_path, thresh=0.5):
424
+ contact_heatmap = pose_estimation.infer_single_image(
425
+ model_contact,
426
+ img_path,
427
+ input_img_size=(192, 256),
428
+ return_kps=False,
429
+ )
430
+ contact_heatmap = contact_heatmap.squeeze(0)
431
+ contact_heatmap_orig = contact_heatmap.copy()
432
+
433
+ mi = contact_heatmap.min()
434
+ ma = contact_heatmap.max()
435
+ contact_heatmap = (contact_heatmap - mi) / (ma - mi)
436
+ contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8")
437
+
438
+ contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1)
439
+ contact_heatmap = (contact_heatmap * 255).astype("uint8")
440
+
441
+ return contact_heatmap_, contact_heatmap, contact_heatmap_orig
442
+
443
+
444
+ def discretize(parametrization, n_bins=100):
445
+ bins = np.linspace(0, 1, n_bins + 1)
446
+ inds = np.digitize(parametrization, bins)
447
+ disc_parametrization = bins[inds - 1]
448
+
449
+ return disc_parametrization
450
+
451
+
452
+ def get_mapping_from_params_to_verts(verts, params):
453
+ mapping = {}
454
+ for v, t in zip(verts, params):
455
+ mapping.setdefault(t, []).append(v)
456
+
457
+ return mapping
458
+
459
+
460
+ def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375):
461
+ n_bins = int(math.ceil(1 / step)) - 1 # mean face's circumradius
462
+ contact = []
463
+ contact_2d = []
464
+ for_mask = []
465
+ for y_data_cont in y_data_conts:
466
+ contact_loc = []
467
+ contact_2d_loc = []
468
+ buffer = y_data_cont.buffer(thresh)
469
+ mask_add = False
470
+ for i, j in pose_estimation.SKELETON:
471
+ verts, t3d = bone_to_params[(i, j)]
472
+ if len(verts) == 0:
473
+ continue
474
+
475
+ t3d = discretize(t3d, n_bins=n_bins)
476
+ t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d)
477
+ t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0])
478
+ t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted])
479
+
480
+ line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]])
481
+ lint = buffer.intersection(line)
482
+ if len(lint.boundary.geoms) < 2:
483
+ continue
484
+
485
+ t2d_start = line.project(lint.boundary.geoms[0], normalized=True)
486
+ t2d_end = line.project(lint.boundary.geoms[1], normalized=True)
487
+ assert t2d_start <= t2d_end
488
+
489
+ t2ds = discretize(
490
+ np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins
491
+ )
492
+ to_add = False
493
+ for t2d in t2ds:
494
+ if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]:
495
+ continue
496
+
497
+ t2d_ind = np.searchsorted(t3d_sorted_np, t2d)
498
+ c = t3d_to_verts_sorted[t2d_ind][1]
499
+
500
+ contact_loc.extend(c)
501
+ to_add = True
502
+ mask_add = True
503
+
504
+ if t2d_ind + 1 < len(t3d_to_verts_sorted):
505
+ c = t3d_to_verts_sorted[t2d_ind + 1][1]
506
+ contact_loc.extend(c)
507
+
508
+ if t2d_ind > 0:
509
+ c = t3d_to_verts_sorted[t2d_ind - 1][1]
510
+ contact_loc.extend(c)
511
+
512
+ if to_add:
513
+ contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start)))
514
+
515
+ if mask_add:
516
+ for_mask.append(buffer.exterior.coords.xy)
517
+
518
+ contact_loc = sorted(set(contact_loc))
519
+ contact_loc = np.array(contact_loc, dtype="int")
520
+ contact.append(contact_loc)
521
+ contact_2d.append(contact_2d_loc)
522
+
523
+ for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask]
524
+
525
+ return contact, contact_2d, for_mask
526
+
527
+
528
+ def optimize(
529
+ model_hmr,
530
+ smpl,
531
+ selector,
532
+ input_img,
533
+ keypoints_2d,
534
+ optimizer,
535
+ args,
536
+ loss_mse=None,
537
+ loss_parallel=None,
538
+ c_mse=0.0,
539
+ c_new_mse=1.0,
540
+ c_beta=1e-3,
541
+ sc_crit=None,
542
+ msc_crit=None,
543
+ contact=None,
544
+ n_steps=60,
545
+ i_ini=0,
546
+ ):
547
+ mean_zfoot_val = {}
548
+ with tqdm.trange(n_steps) as pbar:
549
+ for i in pbar:
550
+ global_step = i + i_ini
551
+ optimizer.zero_grad()
552
+
553
+ (
554
+ rotmat_pred,
555
+ betas_pred,
556
+ camera_pred,
557
+ keypoints_3d_pred,
558
+ z,
559
+ vertices_2d_pred,
560
+ smpl_output,
561
+ (rleg, lleg),
562
+ joints_2d_orig,
563
+ ) = get_pred_and_data(
564
+ model_hmr,
565
+ smpl,
566
+ selector,
567
+ input_img,
568
+ )
569
+ keypoints_2d_pred = keypoints_3d_pred[:, :2]
570
+
571
+ loss = l2 = 0.0
572
+ if c_mse > 0 and loss_mse is not None:
573
+ l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
574
+ loss = loss + c_mse * l2
575
+
576
+ vertices_pred = smpl_output.vertices
577
+
578
+ lpar = z_loss = loss_sh = 0.0
579
+ if c_new_mse > 0 and loss_parallel is not None:
580
+ Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
581
+ keypoints_3d_pred,
582
+ keypoints_2d,
583
+ z,
584
+ (rleg, lleg),
585
+ global_step=global_step,
586
+ )
587
+ lpar = (
588
+ Ltan
589
+ + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
590
+ + Lspine
591
+ + args.c_reg * Lgr
592
+ + args.c_reg * Lstraight3d
593
+ + args.c_cont2d * Lcon2d
594
+ )
595
+ loss = loss + 300 * lpar
596
+
597
+ for side in ["left", "right"]:
598
+ attr = f"{side}_foot_inds"
599
+ if hasattr(loss_parallel, attr):
600
+ foot_inds = getattr(loss_parallel, attr)
601
+ zind = 1
602
+ if attr not in mean_zfoot_val:
603
+ with torch.no_grad():
604
+ mean_zfoot_val[attr] = torch.median(
605
+ vertices_pred[0, foot_inds, zind], dim=0
606
+ ).values
607
+
608
+ loss_foot = (
609
+ (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
610
+ ** 2
611
+ ).sum()
612
+ loss = loss + args.c_reg * loss_foot
613
+
614
+ if hasattr(loss_parallel, "silhuette_vertices_inds"):
615
+ inds = loss_parallel.silhuette_vertices_inds
616
+ loss_sh = (
617
+ (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
618
+ ).sum()
619
+ loss = loss + args.c_reg * loss_sh
620
+
621
+ lbeta = (betas_pred**2).mean()
622
+ lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean()
623
+ loss = loss + c_beta * lbeta + lcam
624
+
625
+ lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
626
+ if sc_crit is not None:
627
+ gsc_contact_loss, faces_angle_loss = sc_crit(
628
+ vertices_pred,
629
+ )
630
+ lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
631
+ loss = loss + lgsc_a
632
+
633
+ msc_loss = 0.0
634
+ if contact is not None and len(contact) > 0 and msc_crit is not None:
635
+ if not isinstance(contact, list):
636
+ contact = [contact]
637
+
638
+ for cntct in contact:
639
+ msc_loss = msc_crit(
640
+ cntct,
641
+ vertices_pred,
642
+ )
643
+ loss = loss + args.c_msc * msc_loss
644
+
645
+ loss.backward()
646
+ optimizer.step()
647
+
648
+ epoch_loss = loss.item()
649
+ pbar.set_postfix(
650
+ **{
651
+ "l": f"{epoch_loss:.3}",
652
+ "l2": f"{l2:.3}",
653
+ "par": f"{lpar:.3}",
654
+ "beta": f"{lbeta:.3}",
655
+ "cam": f"{lcam:.3}",
656
+ "z": f"{z_loss:.3}",
657
+ "gsc_contact": f"{float(gsc_contact_loss):.3}",
658
+ "faces_angle": f"{float(faces_angle_loss):.3}",
659
+ "msc": f"{float(msc_loss):.3}",
660
+ }
661
+ )
662
+
663
+ with torch.no_grad():
664
+ (
665
+ rotmat_pred,
666
+ betas_pred,
667
+ camera_pred,
668
+ keypoints_3d_pred,
669
+ z,
670
+ vertices_2d_pred,
671
+ smpl_output,
672
+ (rleg, lleg),
673
+ joints_2d_orig,
674
+ ) = get_pred_and_data(
675
+ model_hmr,
676
+ smpl,
677
+ selector,
678
+ input_img,
679
+ zero_hands=True,
680
+ )
681
+
682
+ return (
683
+ rotmat_pred,
684
+ betas_pred,
685
+ camera_pred,
686
+ keypoints_3d_pred,
687
+ vertices_2d_pred,
688
+ smpl_output,
689
+ z,
690
+ joints_2d_orig,
691
+ )
692
+
693
+
694
+ def optimize_ft(
695
+ theta,
696
+ camera,
697
+ smpl,
698
+ selector,
699
+ keypoints_2d,
700
+ args,
701
+ loss_mse=None,
702
+ loss_parallel=None,
703
+ c_mse=0.0,
704
+ c_new_mse=1.0,
705
+ sc_crit=None,
706
+ msc_crit=None,
707
+ contact=None,
708
+ n_steps=60,
709
+ i_ini=0,
710
+ zero_hands=False,
711
+ fist=None,
712
+ ):
713
+ mean_zfoot_val = {}
714
+
715
+ theta = theta.detach().clone()
716
+ camera = camera.detach().clone()
717
+ rotmat_pred = nn.Parameter(theta)
718
+ camera_pred = nn.Parameter(camera)
719
+ optimizer = torch.optim.Adam(
720
+ [
721
+ rotmat_pred,
722
+ camera_pred,
723
+ ],
724
+ lr=1e-3,
725
+ )
726
+ global_step = i_ini
727
+
728
+ with tqdm.trange(n_steps) as pbar:
729
+ for i in pbar:
730
+ global_step = i + i_ini
731
+ optimizer.zero_grad()
732
+
733
+ global_orient = rotmat_pred[:3]
734
+ body_pose = rotmat_pred[3:]
735
+ smpl_output = smpl(
736
+ global_orient=global_orient.unsqueeze(0),
737
+ body_pose=body_pose.unsqueeze(0),
738
+ pose2rot=True,
739
+ )
740
+
741
+ z = smpl_output.joints
742
+ z = z.squeeze(0)
743
+
744
+ joints = smpl_output.joints.squeeze(0)
745
+ joints_2d = project_and_normalize_to_spin(joints, camera_pred)
746
+ rleg, lleg = project_and_normalize_to_spin_legs(
747
+ joints, smpl_output.A, camera_pred
748
+ )
749
+ joints_2d = joints_2d[selector]
750
+ z = z[selector]
751
+ keypoints_3d_pred = joints_2d
752
+
753
+ keypoints_2d_pred = keypoints_3d_pred[:, :2]
754
+
755
+ lprior = ((rotmat_pred - theta) ** 2).sum() + (
756
+ (camera_pred - camera) ** 2
757
+ ).sum()
758
+ loss = lprior
759
+
760
+ l2 = 0.0
761
+ if c_mse > 0 and loss_mse is not None:
762
+ l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
763
+ loss = loss + c_mse * l2
764
+
765
+ vertices_pred = smpl_output.vertices
766
+
767
+ lpar = z_loss = loss_sh = 0.0
768
+ if c_new_mse > 0 and loss_parallel is not None:
769
+ Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
770
+ keypoints_3d_pred,
771
+ keypoints_2d,
772
+ z,
773
+ (rleg, lleg),
774
+ global_step=global_step,
775
+ )
776
+ lpar = (
777
+ Ltan
778
+ + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
779
+ + Lspine
780
+ + args.c_reg * Lgr
781
+ + args.c_reg * Lstraight3d
782
+ + args.c_cont2d * Lcon2d
783
+ )
784
+ loss = loss + 300 * lpar
785
+
786
+ for side in ["left", "right"]:
787
+ attr = f"{side}_foot_inds"
788
+ if hasattr(loss_parallel, attr):
789
+ foot_inds = getattr(loss_parallel, attr)
790
+ zind = 1
791
+ if attr not in mean_zfoot_val:
792
+ with torch.no_grad():
793
+ mean_zfoot_val[attr] = torch.median(
794
+ vertices_pred[0, foot_inds, zind], dim=0
795
+ ).values
796
+
797
+ loss_foot = (
798
+ (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
799
+ ** 2
800
+ ).sum()
801
+ loss = loss + args.c_reg * loss_foot
802
+
803
+ if hasattr(loss_parallel, "silhuette_vertices_inds"):
804
+ inds = loss_parallel.silhuette_vertices_inds
805
+ loss_sh = (
806
+ (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
807
+ ).sum()
808
+ loss = loss + args.c_reg * loss_sh
809
+
810
+ lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
811
+ if sc_crit is not None:
812
+ gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred)
813
+ lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
814
+ loss = loss + lgsc_a
815
+
816
+ msc_loss = 0.0
817
+ if contact is not None and len(contact) > 0 and msc_crit is not None:
818
+ if not isinstance(contact, list):
819
+ contact = [contact]
820
+
821
+ for cntct in contact:
822
+ msc_loss = msc_crit(
823
+ cntct,
824
+ vertices_pred,
825
+ )
826
+ loss = loss + args.c_msc * msc_loss
827
+
828
+ loss.backward()
829
+ optimizer.step()
830
+
831
+ epoch_loss = loss.item()
832
+ pbar.set_postfix(
833
+ **{
834
+ "l": f"{epoch_loss:.3}",
835
+ "l2": f"{l2:.3}",
836
+ "par": f"{lpar:.3}",
837
+ "z": f"{z_loss:.3}",
838
+ "gsc_contact": f"{float(gsc_contact_loss):.3}",
839
+ "faces_angle": f"{float(faces_angle_loss):.3}",
840
+ "msc": f"{float(msc_loss):.3}",
841
+ }
842
+ )
843
+
844
+ rotmat_pred = rotmat_pred.detach()
845
+
846
+ if zero_hands:
847
+ for i in [20, 21]:
848
+ rotmat_pred[3 * i : 3 * (i + 1)] = 0
849
+
850
+ for i in [12, 15]: # neck, head
851
+ rotmat_pred[3 * i + 1] = 0 # y
852
+
853
+ global_orient = rotmat_pred[:3]
854
+ body_pose = rotmat_pred[3:]
855
+ left_hand_pose = None
856
+ right_hand_pose = None
857
+ if fist is not None:
858
+ left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0)
859
+ right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0)
860
+ for f in fist:
861
+ pp = fist_pose.INT_TO_FIST[f]
862
+ if pp is not None:
863
+ pp = rotmat_pred.new_tensor(pp).unsqueeze(0)
864
+
865
+ if f.startswith("lf"):
866
+ left_hand_pose = pp
867
+ elif f.startswith("rf"):
868
+ right_hand_pose = pp
869
+ elif f.startswith("l"):
870
+ body_pose[19 * 3 : 19 * 3 + 3] = pp
871
+ left_hand_pose = None
872
+ elif f.startswith("r"):
873
+ body_pose[20 * 3 : 20 * 3 + 3] = pp
874
+ right_hand_pose = None
875
+ else:
876
+ raise RuntimeError(f"No such hand pose: {f}")
877
+
878
+ with torch.no_grad():
879
+ smpl_output = smpl(
880
+ global_orient=global_orient.unsqueeze(0),
881
+ body_pose=body_pose.unsqueeze(0),
882
+ left_hand_pose=left_hand_pose,
883
+ right_hand_pose=right_hand_pose,
884
+ pose2rot=True,
885
+ )
886
+
887
+ return rotmat_pred, smpl_output
888
+
889
+
890
+ def create_bone(i, j, keypoints_2d):
891
+ a = keypoints_2d[i]
892
+ b = keypoints_2d[j]
893
+ ab = b - a
894
+ ab = torch.nn.functional.normalize(ab, dim=0)
895
+
896
+ return ab
897
+
898
+
899
+ def is_parallel_to_plane(bone, thresh=21):
900
+ return abs(bone[0]) > math.cos(math.radians(thresh))
901
+
902
+
903
+ def is_close_to_plane(bone, plane, thresh):
904
+ dist = abs(bone[0] - plane)
905
+
906
+ return dist < thresh
907
+
908
+
909
+ def get_selector():
910
+ selector = []
911
+ for kp in pose_estimation.KPS:
912
+ tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp])
913
+ selector.append(tmp)
914
+
915
+ return selector
916
+
917
+
918
+ def calc_cos(joints_2d, joints_3d):
919
+ cos = []
920
+ for i, j in pose_estimation.SKELETON:
921
+ a = joints_2d[i] - joints_2d[j]
922
+ a = nn.functional.normalize(a, dim=0)
923
+
924
+ b = joints_3d[i] - joints_3d[j]
925
+ b = nn.functional.normalize(b, dim=0)[:2]
926
+
927
+ c = (a * b).sum()
928
+ cos.append(c)
929
+
930
+ cos = torch.stack(cos, dim=0)
931
+
932
+ return cos
933
+
934
+
935
+ def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl):
936
+ height_2d = (
937
+ keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0]
938
+ ).item()
939
+ plane_2d = keypoints_2d.max(dim=0).values[0].item()
940
+
941
+ ground_parallel = []
942
+ parallel_in_3d = []
943
+ parallel3d_bones = set()
944
+
945
+ # parallel chains
946
+ for i, j, k in [
947
+ ("Right Upper Leg", "Right Leg", "Right Foot"),
948
+ ("Right Leg", "Right Foot", "Right Toe"), # to remove?
949
+ ("Left Upper Leg", "Left Leg", "Left Foot"),
950
+ ("Left Leg", "Left Foot", "Left Toe"), # to remove?
951
+ ("Right Shoulder", "Right Arm", "Right Hand"),
952
+ ("Left Shoulder", "Left Arm", "Left Hand"),
953
+ # ("Hips", "Spine", "Neck"),
954
+ # ("Spine", "Neck", "Head"),
955
+ ]:
956
+ i = pose_estimation.KPS.index(i)
957
+ j = pose_estimation.KPS.index(j)
958
+ k = pose_estimation.KPS.index(k)
959
+ upleg_leg = create_bone(i, j, keypoints_2d)
960
+ leg_foot = create_bone(j, k, keypoints_2d)
961
+
962
+ if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot):
963
+ if is_close_to_plane(
964
+ upleg_leg, plane_2d, thresh=0.1 * height_2d
965
+ ) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d):
966
+ ground_parallel.append(((i, j), 1))
967
+ ground_parallel.append(((j, k), 1))
968
+
969
+ if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)):
970
+ parallel_in_3d.append(((i, j), (j, k)))
971
+ parallel3d_bones.add((i, j))
972
+ parallel3d_bones.add((j, k))
973
+
974
+ # parallel feets
975
+ for i, j in [
976
+ ("Right Foot", "Right Toe"),
977
+ ("Left Foot", "Left Toe"),
978
+ ]:
979
+ i = pose_estimation.KPS.index(i)
980
+ j = pose_estimation.KPS.index(j)
981
+ if (i, j) in parallel3d_bones:
982
+ continue
983
+
984
+ foot_toe = create_bone(i, j, keypoints_2d)
985
+ if is_parallel_to_plane(foot_toe, thresh=25):
986
+ if "Right" in pose_estimation.KPS[i]:
987
+ loss_parallel.right_foot_inds = right_foot_inds
988
+ else:
989
+ loss_parallel.left_foot_inds = left_foot_inds
990
+
991
+ loss_parallel.ground_parallel = ground_parallel
992
+ loss_parallel.parallel_in_3d = parallel_in_3d
993
+
994
+ vertices_np = vertices[0].cpu().numpy()
995
+ if len(ground_parallel) > 0:
996
+ # Silhuette veritices
997
+ mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False)
998
+ silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1
999
+ height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min()
1000
+ plane_3d = vertices_np[:, 1].max()
1001
+ silhuette_vertices_mask_2 = (
1002
+ np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d
1003
+ )
1004
+ silhuette_vertices_mask = np.logical_and(
1005
+ silhuette_vertices_mask_1, silhuette_vertices_mask_2
1006
+ )
1007
+ (silhuette_vertices_inds,) = np.where(silhuette_vertices_mask)
1008
+ if len(silhuette_vertices_inds) > 0:
1009
+ loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds
1010
+ loss_parallel.ground = plane_3d
1011
+
1012
+
1013
+ def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):
1014
+ keypoints_2d_pred = keypoints_3d_pred[:, :2]
1015
+ with torch.no_grad():
1016
+ cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred)
1017
+
1018
+ alpha = torch.acos(cos_r)
1019
+ if use_angle_transf:
1020
+ leg_inds = [
1021
+ 5,
1022
+ 6, # right leg
1023
+ 7,
1024
+ 8, # left leg
1025
+ ]
1026
+ foot_inds = [15, 16]
1027
+ nleg_inds = sorted(
1028
+ set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds)
1029
+ )
1030
+ alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min()
1031
+
1032
+ amli = alpha[leg_inds].min()
1033
+ leg_inds.extend(foot_inds)
1034
+ alpha[leg_inds] = alpha[leg_inds] - amli
1035
+
1036
+ angles = alpha.detach().cpu().numpy()
1037
+ angles = hist_cub.cub(
1038
+ angles / (math.pi / 2),
1039
+ a=1.2121212121212122,
1040
+ b=-1.105527638190953,
1041
+ c=0.787878787878789,
1042
+ ) * (math.pi / 2)
1043
+ alpha = alpha.new_tensor(angles)
1044
+
1045
+ loss_parallel.cos = torch.cos(alpha)
1046
+
1047
+ return cos_r
1048
+
1049
+
1050
+ def get_contacts(
1051
+ args,
1052
+ sc_module,
1053
+ y_data_conts,
1054
+ keypoints_2d,
1055
+ vertices,
1056
+ bone_to_params,
1057
+ loss_parallel,
1058
+ ):
1059
+ use_contacts = args.use_contacts
1060
+ use_msc = args.use_msc
1061
+ c_mse = args.c_mse
1062
+
1063
+ if use_contacts:
1064
+ assert c_mse == 0
1065
+ contact, contact_2d, _ = find_contacts(
1066
+ y_data_conts, keypoints_2d, bone_to_params
1067
+ )
1068
+ if len(contact_2d) > 0:
1069
+ loss_parallel.contact_2d = contact_2d
1070
+
1071
+ if len(contact) == 0:
1072
+ _, contact = sc_module.verts_in_contact(vertices, return_idx=True)
1073
+ contact = contact.cpu().numpy().ravel()
1074
+ elif use_msc:
1075
+ _, contact = sc_module.verts_in_contact(vertices, return_idx=True)
1076
+ contact = contact.cpu().numpy().ravel()
1077
+ else:
1078
+ contact = np.array([])
1079
+
1080
+ return contact
1081
+
1082
+
1083
+ def save_all(
1084
+ smpl,
1085
+ smpl_output,
1086
+ save_path,
1087
+ fname,
1088
+ ):
1089
+ utils.save_mesh_with_colors(
1090
+ smpl_output.vertices[0].cpu().numpy(),
1091
+ smpl.faces,
1092
+ save_path / f"{fname}.ply",
1093
+ )
1094
+
1095
+
1096
+ def eft_step(
1097
+ model_hmr,
1098
+ smpl,
1099
+ selector,
1100
+ input_img,
1101
+ keypoints_2d,
1102
+ optimizer,
1103
+ args,
1104
+ loss_mse,
1105
+ loss_parallel,
1106
+ c_beta,
1107
+ sc_module,
1108
+ y_data_conts,
1109
+ bone_to_params,
1110
+ ):
1111
+ (
1112
+ _,
1113
+ _,
1114
+ _,
1115
+ keypoints_3d_pred,
1116
+ _,
1117
+ smpl_output,
1118
+ _,
1119
+ _,
1120
+ ) = optimize(
1121
+ model_hmr,
1122
+ smpl,
1123
+ selector,
1124
+ input_img,
1125
+ keypoints_2d,
1126
+ optimizer,
1127
+ args,
1128
+ loss_mse=loss_mse,
1129
+ loss_parallel=loss_parallel,
1130
+ c_mse=1,
1131
+ c_new_mse=0,
1132
+ c_beta=c_beta,
1133
+ sc_crit=None,
1134
+ msc_crit=None,
1135
+ contact=None,
1136
+ n_steps=60 + 90,
1137
+ )
1138
+
1139
+ # find contacts
1140
+ vertices = smpl_output.vertices.detach()
1141
+ contact = get_contacts(
1142
+ args,
1143
+ sc_module,
1144
+ y_data_conts,
1145
+ keypoints_2d,
1146
+ vertices,
1147
+ bone_to_params,
1148
+ loss_parallel,
1149
+ )
1150
+
1151
+ return vertices, keypoints_3d_pred, contact
1152
+
1153
+
1154
+ def dc_step(
1155
+ model_hmr,
1156
+ smpl,
1157
+ selector,
1158
+ input_img,
1159
+ keypoints_2d,
1160
+ optimizer,
1161
+ args,
1162
+ loss_mse,
1163
+ loss_parallel,
1164
+ c_mse,
1165
+ c_new_mse,
1166
+ c_beta,
1167
+ sc_crit,
1168
+ msc_crit,
1169
+ contact,
1170
+ use_contacts,
1171
+ use_msc,
1172
+ ):
1173
+ rotmat_pred, *_ = optimize(
1174
+ model_hmr,
1175
+ smpl,
1176
+ selector,
1177
+ input_img,
1178
+ keypoints_2d,
1179
+ optimizer,
1180
+ args,
1181
+ loss_mse=loss_mse,
1182
+ loss_parallel=loss_parallel,
1183
+ c_mse=c_mse,
1184
+ c_new_mse=c_new_mse,
1185
+ c_beta=c_beta,
1186
+ sc_crit=sc_crit,
1187
+ msc_crit=msc_crit if use_contacts or use_msc else None,
1188
+ contact=contact if use_contacts or use_msc else None,
1189
+ n_steps=60 if c_new_mse > 0 or use_contacts or use_msc else 0, # + 60,,
1190
+ i_ini=60 + 90,
1191
+ )
1192
+
1193
+ return rotmat_pred
1194
+
1195
+
1196
+ def us_step(
1197
+ model_hmr,
1198
+ smpl,
1199
+ selector,
1200
+ input_img,
1201
+ rotmat_pred,
1202
+ keypoints_2d,
1203
+ args,
1204
+ loss_mse,
1205
+ loss_parallel,
1206
+ c_mse,
1207
+ c_new_mse,
1208
+ sc_crit,
1209
+ msc_crit,
1210
+ contact,
1211
+ use_contacts,
1212
+ use_msc,
1213
+ save_path,
1214
+ ):
1215
+ (_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data(
1216
+ model_hmr,
1217
+ smpl,
1218
+ selector,
1219
+ input_img,
1220
+ use_betas=False,
1221
+ zero_hands=True,
1222
+ )
1223
+
1224
+ _, smpl_output_us = optimize_ft(
1225
+ rotmat_pred,
1226
+ camera_pred_us,
1227
+ smpl,
1228
+ selector,
1229
+ keypoints_2d,
1230
+ args,
1231
+ loss_mse=loss_mse,
1232
+ loss_parallel=loss_parallel,
1233
+ c_mse=c_mse,
1234
+ c_new_mse=c_new_mse,
1235
+ sc_crit=sc_crit,
1236
+ msc_crit=msc_crit if use_contacts or use_msc else None,
1237
+ contact=contact if use_contacts or use_msc else None,
1238
+ n_steps=60 if use_contacts or use_msc else 0, # + 60,
1239
+ i_ini=60 + 90 + 60,
1240
+ zero_hands=True,
1241
+ fist=args.fist,
1242
+ )
1243
+
1244
+ save_all(
1245
+ smpl,
1246
+ smpl_output_us,
1247
+ save_path,
1248
+ "us",
1249
+ )
1250
+
1251
+
1252
+ def main():
1253
+ args = parse_args()
1254
+ print(args)
1255
+
1256
+ # models
1257
+ model_pose = cv2.dnn.readNetFromONNX(
1258
+ args.pose_estimation_model_path
1259
+ ) # "hrn_w48_384x288.onnx"
1260
+ model_contact = cv2.dnn.readNetFromONNX(
1261
+ args.contact_model_path
1262
+ ) # "contact_hrn_w32_256x192.onnx"
1263
+
1264
+ device = (
1265
+ torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu")
1266
+ )
1267
+ model_hmr = spin.hmr(args.smpl_mean_params_path) # "smpl_mean_params.npz"
1268
+ model_hmr.to(device)
1269
+ checkpoint = torch.load(
1270
+ args.spin_model_path, # "spin_model_smplx_eft_18.pt"
1271
+ map_location="cpu"
1272
+ )
1273
+
1274
+ smpl = spin.SMPLX(
1275
+ args.smpl_model_dir, # "models/smplx"
1276
+ batch_size=1,
1277
+ create_transl=False,
1278
+ use_pca=False,
1279
+ flat_hand_mean=args.fist is not None,
1280
+ )
1281
+ smpl.to(device)
1282
+
1283
+ selector = get_selector()
1284
+
1285
+ use_contacts = args.use_contacts
1286
+ use_msc = args.use_msc
1287
+
1288
+ bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item()
1289
+ foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item()
1290
+ left_foot_inds = foot_inds["left_foot_inds"]
1291
+ right_foot_inds = foot_inds["right_foot_inds"]
1292
+
1293
+ if use_contacts:
1294
+ model_type = args.smpl_type
1295
+ sc_module = selfcontact.SelfContact(
1296
+ essentials_folder=args.essentials_dir, # "smplify-xmc-essentials"
1297
+ geothres=0.3,
1298
+ euclthres=0.02,
1299
+ test_segments=True,
1300
+ compute_hd=True,
1301
+ model_type=model_type,
1302
+ device=device,
1303
+ )
1304
+ sc_module.to(device)
1305
+
1306
+ sc_crit = selfcontact.losses.SelfContactLoss(
1307
+ contact_module=sc_module,
1308
+ inside_loss_weight=0.5,
1309
+ outside_loss_weight=0.0,
1310
+ contact_loss_weight=0.5,
1311
+ align_faces=True,
1312
+ use_hd=True,
1313
+ test_segments=True,
1314
+ device=device,
1315
+ model_type=model_type,
1316
+ )
1317
+ sc_crit.to(device)
1318
+
1319
+ msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask)
1320
+ msc_crit.to(device)
1321
+ else:
1322
+ sc_module = None
1323
+ sc_crit = None
1324
+ msc_crit = None
1325
+
1326
+ loss_mse = losses.MSE([1, 10, 13]) # Neck + Right Upper Leg + Left Upper Leg
1327
+
1328
+ ignore = (
1329
+ (1, 2), # Neck + Right Shoulder
1330
+ (1, 5), # Neck + Left Shoulder
1331
+ (9, 10), # Hips + Right Upper Leg
1332
+ (9, 13), # Hips + Left Upper Leg
1333
+ )
1334
+ loss_parallel = losses.Parallel(
1335
+ skeleton=pose_estimation.SKELETON,
1336
+ ignore=ignore,
1337
+ )
1338
+
1339
+ c_mse = args.c_mse
1340
+ c_new_mse = args.c_par
1341
+ c_beta = 1e-3
1342
+
1343
+ if c_mse > 0:
1344
+ assert c_new_mse == 0
1345
+ elif c_mse == 0:
1346
+ assert c_new_mse > 0
1347
+
1348
+ root_path = Path(args.save_path)
1349
+ root_path.mkdir(exist_ok=True, parents=True)
1350
+
1351
+ path_to_imgs = Path(args.img_path)
1352
+ if path_to_imgs.is_dir():
1353
+ path_to_imgs = path_to_imgs.iterdir()
1354
+ else:
1355
+ path_to_imgs = [path_to_imgs]
1356
+
1357
+ for img_path in path_to_imgs:
1358
+ if not any(
1359
+ img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"]
1360
+ ):
1361
+ continue
1362
+
1363
+ img_name = img_path.stem
1364
+
1365
+ # use 2d keypoints detection
1366
+ (
1367
+ img_original,
1368
+ predicted_keypoints_2d,
1369
+ _,
1370
+ _,
1371
+ ) = pose_estimation.infer_single_image(
1372
+ model_pose,
1373
+ img_path,
1374
+ input_img_size=pose_estimation.IMG_SIZE,
1375
+ return_kps=True,
1376
+ )
1377
+
1378
+ save_path = root_path / img_name
1379
+ save_path.mkdir(exist_ok=True, parents=True)
1380
+
1381
+ img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
1382
+ img_size_original = img_original.shape[:2]
1383
+ keypoints_2d, *_ = normalize_keypoints_to_spin(
1384
+ predicted_keypoints_2d, img_size_original
1385
+ )
1386
+ keypoints_2d = torch.from_numpy(keypoints_2d)
1387
+ keypoints_2d = keypoints_2d.to(device)
1388
+
1389
+ (
1390
+ predicted_contact_heatmap,
1391
+ predicted_contact_heatmap_raw,
1392
+ very_hm_raw,
1393
+ ) = get_contact_heatmap(model_contact, img_path)
1394
+ predicted_contact_heatmap_raw = Image.fromarray(
1395
+ predicted_contact_heatmap_raw
1396
+ ).resize(img_size_original[::-1])
1397
+ predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1])
1398
+
1399
+ if c_new_mse == 0:
1400
+ predicted_contact_heatmap_raw = None
1401
+
1402
+ y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap)
1403
+
1404
+ model_hmr.load_state_dict(checkpoint["model"], strict=True)
1405
+ model_hmr.train()
1406
+ freeze_layers(model_hmr)
1407
+
1408
+ _, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES)
1409
+ input_img = input_img.to(device)
1410
+
1411
+ optimizer = optim.Adam(
1412
+ filter(lambda p: p.requires_grad, model_hmr.parameters()),
1413
+ lr=1e-6,
1414
+ )
1415
+
1416
+ vertices, keypoints_3d_pred, contact = eft_step(
1417
+ model_hmr,
1418
+ smpl,
1419
+ selector,
1420
+ input_img,
1421
+ keypoints_2d,
1422
+ optimizer,
1423
+ args,
1424
+ loss_mse,
1425
+ loss_parallel,
1426
+ c_beta,
1427
+ sc_module,
1428
+ y_data_conts,
1429
+ bone_to_params,
1430
+ )
1431
+
1432
+ if args.use_natural:
1433
+ get_natural(
1434
+ keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl,
1435
+ )
1436
+
1437
+ if args.use_cos:
1438
+ get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel)
1439
+
1440
+ rotmat_pred = dc_step(
1441
+ model_hmr,
1442
+ smpl,
1443
+ selector,
1444
+ input_img,
1445
+ keypoints_2d,
1446
+ optimizer,
1447
+ args,
1448
+ loss_mse,
1449
+ loss_parallel,
1450
+ c_mse,
1451
+ c_new_mse,
1452
+ c_beta,
1453
+ sc_crit,
1454
+ msc_crit,
1455
+ contact,
1456
+ use_contacts,
1457
+ use_msc,
1458
+ )
1459
+
1460
+ us_step(
1461
+ model_hmr,
1462
+ smpl,
1463
+ selector,
1464
+ input_img,
1465
+ rotmat_pred,
1466
+ keypoints_2d,
1467
+ args,
1468
+ loss_mse,
1469
+ loss_parallel,
1470
+ c_mse,
1471
+ c_new_mse,
1472
+ sc_crit,
1473
+ msc_crit,
1474
+ contact,
1475
+ use_contacts,
1476
+ use_msc,
1477
+ save_path,
1478
+ )
1479
+
1480
+
1481
+ if __name__ == "__main__":
1482
+ main()