Spaces:
Running
Running
import argparse | |
import math | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import PIL.Image as Image | |
import selfcontact | |
import selfcontact.losses | |
import shapely.geometry | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchgeometry | |
import tqdm | |
import trimesh | |
from skimage import measure | |
import fist_pose | |
import hist_cub | |
import losses | |
import pose_estimation | |
import spin | |
PE_KSP_TO_SPIN = { | |
"Head": "Head", | |
"Neck": "Neck", | |
"Right Shoulder": "Right ForeArm", | |
"Right Arm": "Right Arm", | |
"Right Hand": "Right Hand", | |
"Left Shoulder": "Left ForeArm", | |
"Left Arm": "Left Arm", | |
"Left Hand": "Left Hand", | |
"Spine": "Spine1", | |
"Hips": "Hips", | |
"Right Upper Leg": "Right Upper Leg", | |
"Right Leg": "Right Leg", | |
"Right Foot": "Right Foot", | |
"Left Upper Leg": "Left Upper Leg", | |
"Left Leg": "Left Leg", | |
"Left Foot": "Left Foot", | |
"Left Toe": "Left Toe", | |
"Right Toe": "Right Toe", | |
} | |
MODELS_DIR = "models" | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--pose-estimation-model-path", | |
type=str, | |
default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx", | |
help="Pose Estimation model", | |
) | |
parser.add_argument( | |
"--contact-model-path", | |
type=str, | |
default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx", | |
help="Contact model", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda", | |
choices=["cpu", "cuda"], | |
help="Torch device", | |
) | |
parser.add_argument( | |
"--spin-model-path", | |
type=str, | |
default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt", | |
help="SPIN model path", | |
) | |
parser.add_argument( | |
"--smpl-type", | |
type=str, | |
default="smplx", | |
choices=["smplx"], | |
help="SMPL model type", | |
) | |
parser.add_argument( | |
"--smpl-model-dir", | |
type=str, | |
default=f"./{MODELS_DIR}/models/smplx", | |
help="SMPL model dir", | |
) | |
parser.add_argument( | |
"--smpl-mean-params-path", | |
type=str, | |
default=f"./{MODELS_DIR}/data/smpl_mean_params.npz", | |
help="SMPL mean params", | |
) | |
parser.add_argument( | |
"--essentials-dir", | |
type=str, | |
default=f"./{MODELS_DIR}/smplify-xmc-essentials", | |
help="SMPL Essentials folder for contacts", | |
) | |
parser.add_argument( | |
"--parametrization-path", | |
type=str, | |
default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy", | |
help="Parametrization path", | |
) | |
parser.add_argument( | |
"--bone-parametrization-path", | |
type=str, | |
default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy", | |
help="Bone parametrization path", | |
) | |
parser.add_argument( | |
"--foot-inds-path", | |
type=str, | |
default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy", | |
help="Foot indinces", | |
) | |
parser.add_argument( | |
"--save-path", | |
type=str, | |
required=True, | |
help="Path to save the results", | |
) | |
parser.add_argument( | |
"--img-path", | |
type=str, | |
required=True, | |
help="Path to img to test", | |
) | |
parser.add_argument( | |
"--use-contacts", | |
action="store_true", | |
help="Use contact model", | |
) | |
parser.add_argument( | |
"--use-msc", | |
action="store_true", | |
help="Use MSC loss", | |
) | |
parser.add_argument( | |
"--use-natural", | |
action="store_true", | |
help="Use regularity", | |
) | |
parser.add_argument( | |
"--use-cos", | |
action="store_true", | |
help="Use cos model", | |
) | |
parser.add_argument( | |
"--use-angle-transf", | |
action="store_true", | |
help="Use cube foreshortening transformation", | |
) | |
parser.add_argument( | |
"--c-mse", | |
type=float, | |
default=0, | |
help="MSE weight", | |
) | |
parser.add_argument( | |
"--c-par", | |
type=float, | |
default=10, | |
help="Parallel weight", | |
) | |
parser.add_argument( | |
"--c-f", | |
type=float, | |
default=1000, | |
help="Cos coef", | |
) | |
parser.add_argument( | |
"--c-parallel", | |
type=float, | |
default=100, | |
help="Parallel weight", | |
) | |
parser.add_argument( | |
"--c-reg", | |
type=float, | |
default=1000, | |
help="Regularity weight", | |
) | |
parser.add_argument( | |
"--c-cont2d", | |
type=float, | |
default=1, | |
help="Contact 2D weight", | |
) | |
parser.add_argument( | |
"--c-msc", | |
type=float, | |
default=17_500, | |
help="MSC weight", | |
) | |
parser.add_argument( | |
"--fist", | |
nargs="+", | |
type=str, | |
choices=list(fist_pose.INT_TO_FIST), | |
) | |
args = parser.parse_args() | |
return args | |
def freeze_layers(model): | |
for module in model.modules(): | |
if type(module) is False: | |
continue | |
if isinstance(module, nn.modules.batchnorm._BatchNorm): | |
module.eval() | |
for m in module.parameters(): | |
m.requires_grad = False | |
if isinstance(module, nn.Dropout): | |
module.eval() | |
for m in module.parameters(): | |
m.requires_grad = False | |
def project_and_normalize_to_spin(vertices_3d, camera): | |
vertices_2d = vertices_3d # [:, :2] | |
scale, translate = camera[0], camera[1:] | |
translate = scale.new_zeros(3) | |
translate[:2] = camera[1:] | |
vertices_2d = vertices_2d + translate | |
vertices_2d = scale * vertices_2d + 1 | |
vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d | |
return vertices_2d | |
def project_and_normalize_to_spin_legs(vertices_3d, A, camera): | |
A, J = A | |
A = A[0] | |
J = J[0] | |
L = vertices_3d.new_tensor( | |
[ | |
[0.98619063, 0.16560926, 0.00127302], | |
[-0.16560601, 0.98603675, 0.01749799], | |
[0.00164258, -0.01746717, 0.99984609], | |
] | |
) | |
R = vertices_3d.new_tensor( | |
[ | |
[0.9910211, -0.13368178, -0.0025208], | |
[0.13367888, 0.99027076, 0.03864949], | |
[-0.00267045, -0.03863944, 0.99924965], | |
] | |
) | |
scale = camera[0] | |
R = A[2, :3, :3] @ R # 2 - right | |
L = A[1, :3, :3] @ L # 1 - left | |
r = J[5] - J[2] | |
l = J[4] - J[1] | |
rleg = scale * spin.constants.IMG_RES / 2 * R @ r | |
lleg = scale * spin.constants.IMG_RES / 2 * L @ l | |
rleg = rleg[:2] | |
lleg = lleg[:2] | |
return rleg, lleg | |
def rotation_matrix_to_angle_axis(rotmat): | |
bs, n_joints, *_ = rotmat.size() | |
rotmat = torch.cat( | |
[ | |
rotmat.view(-1, 3, 3), | |
rotmat.new_tensor([0, 0, 1], dtype=torch.float32) | |
.view(bs, 3, 1) | |
.expand(n_joints, -1, -1), | |
], | |
dim=-1, | |
) | |
aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat) | |
aa = aa.reshape(bs, 3 * n_joints) | |
return aa | |
def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False): | |
if smpl.name() == "SMPL": | |
smpl_output = smpl( | |
betas=betas if use_betas else None, | |
body_pose=rotmat[:, 1:], | |
global_orient=rotmat[:, 0].unsqueeze(1), | |
pose2rot=False, | |
) | |
elif smpl.name() == "SMPL-X": | |
rotmat = rotation_matrix_to_angle_axis(rotmat) | |
if zero_hands: | |
for i in [20, 21]: | |
rotmat[:, 3 * i : 3 * (i + 1)] = 0 | |
for i in [12, 15]: # neck, head | |
rotmat[:, 3 * i + 1] = 0 # y | |
smpl_output = smpl( | |
betas=betas if use_betas else None, | |
body_pose=rotmat[:, 3:], | |
global_orient=rotmat[:, :3], | |
pose2rot=True, | |
) | |
else: | |
raise NotImplementedError | |
return smpl_output, rotmat | |
def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False): | |
input_img = input_img.unsqueeze(0) | |
rotmat, betas, camera = model_hmr(input_img) | |
smpl_output, rotmat = get_smpl_output( | |
smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands | |
) | |
rotmat = rotmat.squeeze(0) | |
betas = betas.squeeze(0) | |
camera = camera.squeeze(0) | |
z = smpl_output.joints | |
z = z.squeeze(0) | |
return rotmat, betas, camera, smpl_output, z | |
def get_pred_and_data( | |
model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False | |
): | |
rotmat, betas, camera, smpl_output, zz = get_predictions( | |
model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands | |
) | |
joints = smpl_output.joints.squeeze(0) | |
joints_2d = project_and_normalize_to_spin(joints, camera) | |
rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera) | |
joints_2d_orig = joints_2d | |
joints_2d = joints_2d[selector] | |
vertices = smpl_output.vertices.squeeze(0) | |
vertices_2d = project_and_normalize_to_spin(vertices, camera) | |
zz = zz[selector] | |
return ( | |
rotmat, | |
betas, | |
camera, | |
joints_2d, | |
zz, | |
vertices_2d, | |
smpl_output, | |
(rleg, lleg), | |
joints_2d_orig, | |
) | |
def normalize_keypoints_to_spin(keypoints_2d, img_size): | |
h, w = img_size | |
if h > w: # vertically | |
ax1 = 1 | |
ax2 = 0 | |
else: # horizontal | |
ax1 = 0 | |
ax2 = 1 | |
shift = (img_size[ax1] - img_size[ax2]) / 2 | |
scale = spin.constants.IMG_RES / img_size[ax2] | |
keypoints_2d_normalized = np.copy(keypoints_2d) | |
keypoints_2d_normalized[:, ax2] -= shift | |
keypoints_2d_normalized *= scale | |
return keypoints_2d_normalized, shift, scale, ax2 | |
def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2): | |
keypoints_2d_normalized = np.copy(keypoints_2d) | |
keypoints_2d_normalized /= scale | |
keypoints_2d_normalized[:, ax2] += shift | |
return keypoints_2d_normalized | |
def get_vertices_in_heatmap(contact_heatmap): | |
contact_heatmap_size = contact_heatmap.shape[:2] | |
label = measure.label(contact_heatmap) | |
y_data_conts = [] | |
for i in range(1, label.max() + 1): | |
predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype( | |
"float" | |
) | |
predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin( | |
predicted_kps_contact, contact_heatmap_size | |
) | |
y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist() | |
y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull | |
y_data_conts.append(y_data_cont) | |
return y_data_conts | |
def get_contact_heatmap(model_contact, img_path, thresh=0.5): | |
contact_heatmap = pose_estimation.infer_single_image( | |
model_contact, | |
img_path, | |
input_img_size=(192, 256), | |
return_kps=False, | |
) | |
contact_heatmap = contact_heatmap.squeeze(0) | |
contact_heatmap_orig = contact_heatmap.copy() | |
mi = contact_heatmap.min() | |
ma = contact_heatmap.max() | |
contact_heatmap = (contact_heatmap - mi) / (ma - mi) | |
contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8") | |
contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1) | |
contact_heatmap = (contact_heatmap * 255).astype("uint8") | |
return contact_heatmap_, contact_heatmap, contact_heatmap_orig | |
def discretize(parametrization, n_bins=100): | |
bins = np.linspace(0, 1, n_bins + 1) | |
inds = np.digitize(parametrization, bins) | |
disc_parametrization = bins[inds - 1] | |
return disc_parametrization | |
def get_mapping_from_params_to_verts(verts, params): | |
mapping = {} | |
for v, t in zip(verts, params): | |
mapping.setdefault(t, []).append(v) | |
return mapping | |
def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375): | |
n_bins = int(math.ceil(1 / step)) - 1 # mean face's circumradius | |
contact = [] | |
contact_2d = [] | |
for_mask = [] | |
for y_data_cont in y_data_conts: | |
contact_loc = [] | |
contact_2d_loc = [] | |
buffer = y_data_cont.buffer(thresh) | |
mask_add = False | |
for i, j in pose_estimation.SKELETON: | |
verts, t3d = bone_to_params[(i, j)] | |
if len(verts) == 0: | |
continue | |
t3d = discretize(t3d, n_bins=n_bins) | |
t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d) | |
t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0]) | |
t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted]) | |
line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]]) | |
lint = buffer.intersection(line) | |
if len(lint.boundary.geoms) < 2: | |
continue | |
t2d_start = line.project(lint.boundary.geoms[0], normalized=True) | |
t2d_end = line.project(lint.boundary.geoms[1], normalized=True) | |
assert t2d_start <= t2d_end | |
t2ds = discretize( | |
np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins | |
) | |
to_add = False | |
for t2d in t2ds: | |
if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]: | |
continue | |
t2d_ind = np.searchsorted(t3d_sorted_np, t2d) | |
c = t3d_to_verts_sorted[t2d_ind][1] | |
contact_loc.extend(c) | |
to_add = True | |
mask_add = True | |
if t2d_ind + 1 < len(t3d_to_verts_sorted): | |
c = t3d_to_verts_sorted[t2d_ind + 1][1] | |
contact_loc.extend(c) | |
if t2d_ind > 0: | |
c = t3d_to_verts_sorted[t2d_ind - 1][1] | |
contact_loc.extend(c) | |
if to_add: | |
contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start))) | |
if mask_add: | |
for_mask.append(buffer.exterior.coords.xy) | |
contact_loc = sorted(set(contact_loc)) | |
contact_loc = np.array(contact_loc, dtype="int") | |
contact.append(contact_loc) | |
contact_2d.append(contact_2d_loc) | |
for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask] | |
return contact, contact_2d, for_mask | |
def optimize( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
keypoints_2d, | |
optimizer, | |
args, | |
loss_mse=None, | |
loss_parallel=None, | |
c_mse=0.0, | |
c_new_mse=1.0, | |
c_beta=1e-3, | |
sc_crit=None, | |
msc_crit=None, | |
contact=None, | |
n_steps=60, | |
i_ini=0, | |
): | |
mean_zfoot_val = {} | |
with tqdm.trange(n_steps) as pbar: | |
for i in pbar: | |
global_step = i + i_ini | |
optimizer.zero_grad() | |
( | |
rotmat_pred, | |
betas_pred, | |
camera_pred, | |
keypoints_3d_pred, | |
z, | |
vertices_2d_pred, | |
smpl_output, | |
(rleg, lleg), | |
joints_2d_orig, | |
) = get_pred_and_data( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
) | |
keypoints_2d_pred = keypoints_3d_pred[:, :2] | |
loss = l2 = 0.0 | |
if c_mse > 0 and loss_mse is not None: | |
l2 = loss_mse(keypoints_2d_pred, keypoints_2d) | |
loss = loss + c_mse * l2 | |
vertices_pred = smpl_output.vertices | |
lpar = z_loss = loss_sh = 0.0 | |
if c_new_mse > 0 and loss_parallel is not None: | |
Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel( | |
keypoints_3d_pred, | |
keypoints_2d, | |
z, | |
(rleg, lleg), | |
global_step=global_step, | |
) | |
lpar = ( | |
Ltan | |
+ c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar) | |
+ Lspine | |
+ args.c_reg * Lgr | |
+ args.c_reg * Lstraight3d | |
+ args.c_cont2d * Lcon2d | |
) | |
loss = loss + 300 * lpar | |
for side in ["left", "right"]: | |
attr = f"{side}_foot_inds" | |
if hasattr(loss_parallel, attr): | |
foot_inds = getattr(loss_parallel, attr) | |
zind = 1 | |
if attr not in mean_zfoot_val: | |
with torch.no_grad(): | |
mean_zfoot_val[attr] = torch.median( | |
vertices_pred[0, foot_inds, zind], dim=0 | |
).values | |
loss_foot = ( | |
(vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr]) | |
** 2 | |
).sum() | |
loss = loss + args.c_reg * loss_foot | |
if hasattr(loss_parallel, "silhuette_vertices_inds"): | |
inds = loss_parallel.silhuette_vertices_inds | |
loss_sh = ( | |
(vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2 | |
).sum() | |
loss = loss + args.c_reg * loss_sh | |
lbeta = (betas_pred**2).mean() | |
lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean() | |
loss = loss + c_beta * lbeta + lcam | |
lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0 | |
if sc_crit is not None: | |
gsc_contact_loss, faces_angle_loss = sc_crit( | |
vertices_pred, | |
) | |
lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss | |
loss = loss + lgsc_a | |
msc_loss = 0.0 | |
if contact is not None and len(contact) > 0 and msc_crit is not None: | |
if not isinstance(contact, list): | |
contact = [contact] | |
for cntct in contact: | |
msc_loss = msc_crit( | |
cntct, | |
vertices_pred, | |
) | |
loss = loss + args.c_msc * msc_loss | |
loss.backward() | |
optimizer.step() | |
epoch_loss = loss.item() | |
pbar.set_postfix( | |
**{ | |
"l": f"{epoch_loss:.3}", | |
"l2": f"{l2:.3}", | |
"par": f"{lpar:.3}", | |
"beta": f"{lbeta:.3}", | |
"cam": f"{lcam:.3}", | |
"z": f"{z_loss:.3}", | |
"gsc_contact": f"{float(gsc_contact_loss):.3}", | |
"faces_angle": f"{float(faces_angle_loss):.3}", | |
"msc": f"{float(msc_loss):.3}", | |
} | |
) | |
with torch.no_grad(): | |
( | |
rotmat_pred, | |
betas_pred, | |
camera_pred, | |
keypoints_3d_pred, | |
z, | |
vertices_2d_pred, | |
smpl_output, | |
(rleg, lleg), | |
joints_2d_orig, | |
) = get_pred_and_data( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
zero_hands=True, | |
) | |
return ( | |
rotmat_pred, | |
betas_pred, | |
camera_pred, | |
keypoints_3d_pred, | |
vertices_2d_pred, | |
smpl_output, | |
z, | |
joints_2d_orig, | |
) | |
def optimize_ft( | |
theta, | |
camera, | |
smpl, | |
selector, | |
keypoints_2d, | |
args, | |
loss_mse=None, | |
loss_parallel=None, | |
c_mse=0.0, | |
c_new_mse=1.0, | |
sc_crit=None, | |
msc_crit=None, | |
contact=None, | |
n_steps=60, | |
i_ini=0, | |
zero_hands=False, | |
fist=None, | |
): | |
mean_zfoot_val = {} | |
theta = theta.detach().clone() | |
camera = camera.detach().clone() | |
rotmat_pred = nn.Parameter(theta) | |
camera_pred = nn.Parameter(camera) | |
optimizer = torch.optim.Adam( | |
[ | |
rotmat_pred, | |
camera_pred, | |
], | |
lr=1e-3, | |
) | |
global_step = i_ini | |
with tqdm.trange(n_steps) as pbar: | |
for i in pbar: | |
global_step = i + i_ini | |
optimizer.zero_grad() | |
global_orient = rotmat_pred[:3] | |
body_pose = rotmat_pred[3:] | |
smpl_output = smpl( | |
global_orient=global_orient.unsqueeze(0), | |
body_pose=body_pose.unsqueeze(0), | |
pose2rot=True, | |
) | |
z = smpl_output.joints | |
z = z.squeeze(0) | |
joints = smpl_output.joints.squeeze(0) | |
joints_2d = project_and_normalize_to_spin(joints, camera_pred) | |
rleg, lleg = project_and_normalize_to_spin_legs( | |
joints, smpl_output.A, camera_pred | |
) | |
joints_2d = joints_2d[selector] | |
z = z[selector] | |
keypoints_3d_pred = joints_2d | |
keypoints_2d_pred = keypoints_3d_pred[:, :2] | |
lprior = ((rotmat_pred - theta) ** 2).sum() + ( | |
(camera_pred - camera) ** 2 | |
).sum() | |
loss = lprior | |
l2 = 0.0 | |
if c_mse > 0 and loss_mse is not None: | |
l2 = loss_mse(keypoints_2d_pred, keypoints_2d) | |
loss = loss + c_mse * l2 | |
vertices_pred = smpl_output.vertices | |
lpar = z_loss = loss_sh = 0.0 | |
if c_new_mse > 0 and loss_parallel is not None: | |
Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel( | |
keypoints_3d_pred, | |
keypoints_2d, | |
z, | |
(rleg, lleg), | |
global_step=global_step, | |
) | |
lpar = ( | |
Ltan | |
+ c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar) | |
+ Lspine | |
+ args.c_reg * Lgr | |
+ args.c_reg * Lstraight3d | |
+ args.c_cont2d * Lcon2d | |
) | |
loss = loss + 300 * lpar | |
for side in ["left", "right"]: | |
attr = f"{side}_foot_inds" | |
if hasattr(loss_parallel, attr): | |
foot_inds = getattr(loss_parallel, attr) | |
zind = 1 | |
if attr not in mean_zfoot_val: | |
with torch.no_grad(): | |
mean_zfoot_val[attr] = torch.median( | |
vertices_pred[0, foot_inds, zind], dim=0 | |
).values | |
loss_foot = ( | |
(vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr]) | |
** 2 | |
).sum() | |
loss = loss + args.c_reg * loss_foot | |
if hasattr(loss_parallel, "silhuette_vertices_inds"): | |
inds = loss_parallel.silhuette_vertices_inds | |
loss_sh = ( | |
(vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2 | |
).sum() | |
loss = loss + args.c_reg * loss_sh | |
lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0 | |
if sc_crit is not None: | |
gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred) | |
lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss | |
loss = loss + lgsc_a | |
msc_loss = 0.0 | |
if contact is not None and len(contact) > 0 and msc_crit is not None: | |
if not isinstance(contact, list): | |
contact = [contact] | |
for cntct in contact: | |
msc_loss = msc_crit( | |
cntct, | |
vertices_pred, | |
) | |
loss = loss + args.c_msc * msc_loss | |
loss.backward() | |
optimizer.step() | |
epoch_loss = loss.item() | |
pbar.set_postfix( | |
**{ | |
"l": f"{epoch_loss:.3}", | |
"l2": f"{l2:.3}", | |
"par": f"{lpar:.3}", | |
"z": f"{z_loss:.3}", | |
"gsc_contact": f"{float(gsc_contact_loss):.3}", | |
"faces_angle": f"{float(faces_angle_loss):.3}", | |
"msc": f"{float(msc_loss):.3}", | |
} | |
) | |
rotmat_pred = rotmat_pred.detach() | |
if zero_hands: | |
for i in [20, 21]: | |
rotmat_pred[3 * i : 3 * (i + 1)] = 0 | |
for i in [12, 15]: # neck, head | |
rotmat_pred[3 * i + 1] = 0 # y | |
global_orient = rotmat_pred[:3] | |
body_pose = rotmat_pred[3:] | |
left_hand_pose = None | |
right_hand_pose = None | |
if fist is not None: | |
left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0) | |
right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0) | |
for f in fist: | |
pp = fist_pose.INT_TO_FIST[f] | |
if pp is not None: | |
pp = rotmat_pred.new_tensor(pp).unsqueeze(0) | |
if f.startswith("lf"): | |
left_hand_pose = pp | |
elif f.startswith("rf"): | |
right_hand_pose = pp | |
elif f.startswith("l"): | |
body_pose[19 * 3 : 19 * 3 + 3] = pp | |
left_hand_pose = None | |
elif f.startswith("r"): | |
body_pose[20 * 3 : 20 * 3 + 3] = pp | |
right_hand_pose = None | |
else: | |
raise RuntimeError(f"No such hand pose: {f}") | |
with torch.no_grad(): | |
smpl_output = smpl( | |
global_orient=global_orient.unsqueeze(0), | |
body_pose=body_pose.unsqueeze(0), | |
left_hand_pose=left_hand_pose, | |
right_hand_pose=right_hand_pose, | |
pose2rot=True, | |
) | |
return rotmat_pred, smpl_output | |
def create_bone(i, j, keypoints_2d): | |
a = keypoints_2d[i] | |
b = keypoints_2d[j] | |
ab = b - a | |
ab = torch.nn.functional.normalize(ab, dim=0) | |
return ab | |
def is_parallel_to_plane(bone, thresh=21): | |
return abs(bone[0]) > math.cos(math.radians(thresh)) | |
def is_close_to_plane(bone, plane, thresh): | |
dist = abs(bone[0] - plane) | |
return dist < thresh | |
def get_selector(): | |
selector = [] | |
for kp in pose_estimation.KPS: | |
tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp]) | |
selector.append(tmp) | |
return selector | |
def calc_cos(joints_2d, joints_3d): | |
cos = [] | |
for i, j in pose_estimation.SKELETON: | |
a = joints_2d[i] - joints_2d[j] | |
a = nn.functional.normalize(a, dim=0) | |
b = joints_3d[i] - joints_3d[j] | |
b = nn.functional.normalize(b, dim=0)[:2] | |
c = (a * b).sum() | |
cos.append(c) | |
cos = torch.stack(cos, dim=0) | |
return cos | |
def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl): | |
height_2d = ( | |
keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0] | |
).item() | |
plane_2d = keypoints_2d.max(dim=0).values[0].item() | |
ground_parallel = [] | |
parallel_in_3d = [] | |
parallel3d_bones = set() | |
# parallel chains | |
for i, j, k in [ | |
("Right Upper Leg", "Right Leg", "Right Foot"), | |
("Right Leg", "Right Foot", "Right Toe"), # to remove? | |
("Left Upper Leg", "Left Leg", "Left Foot"), | |
("Left Leg", "Left Foot", "Left Toe"), # to remove? | |
("Right Shoulder", "Right Arm", "Right Hand"), | |
("Left Shoulder", "Left Arm", "Left Hand"), | |
# ("Hips", "Spine", "Neck"), | |
# ("Spine", "Neck", "Head"), | |
]: | |
i = pose_estimation.KPS.index(i) | |
j = pose_estimation.KPS.index(j) | |
k = pose_estimation.KPS.index(k) | |
upleg_leg = create_bone(i, j, keypoints_2d) | |
leg_foot = create_bone(j, k, keypoints_2d) | |
if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot): | |
if is_close_to_plane( | |
upleg_leg, plane_2d, thresh=0.1 * height_2d | |
) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d): | |
ground_parallel.append(((i, j), 1)) | |
ground_parallel.append(((j, k), 1)) | |
if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)): | |
parallel_in_3d.append(((i, j), (j, k))) | |
parallel3d_bones.add((i, j)) | |
parallel3d_bones.add((j, k)) | |
# parallel feets | |
for i, j in [ | |
("Right Foot", "Right Toe"), | |
("Left Foot", "Left Toe"), | |
]: | |
i = pose_estimation.KPS.index(i) | |
j = pose_estimation.KPS.index(j) | |
if (i, j) in parallel3d_bones: | |
continue | |
foot_toe = create_bone(i, j, keypoints_2d) | |
if is_parallel_to_plane(foot_toe, thresh=25): | |
if "Right" in pose_estimation.KPS[i]: | |
loss_parallel.right_foot_inds = right_foot_inds | |
else: | |
loss_parallel.left_foot_inds = left_foot_inds | |
loss_parallel.ground_parallel = ground_parallel | |
loss_parallel.parallel_in_3d = parallel_in_3d | |
vertices_np = vertices[0].cpu().numpy() | |
if len(ground_parallel) > 0: | |
# Silhuette veritices | |
mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False) | |
silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1 | |
height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min() | |
plane_3d = vertices_np[:, 1].max() | |
silhuette_vertices_mask_2 = ( | |
np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d | |
) | |
silhuette_vertices_mask = np.logical_and( | |
silhuette_vertices_mask_1, silhuette_vertices_mask_2 | |
) | |
(silhuette_vertices_inds,) = np.where(silhuette_vertices_mask) | |
if len(silhuette_vertices_inds) > 0: | |
loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds | |
loss_parallel.ground = plane_3d | |
def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel): | |
keypoints_2d_pred = keypoints_3d_pred[:, :2] | |
with torch.no_grad(): | |
cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred) | |
alpha = torch.acos(cos_r) | |
if use_angle_transf: | |
leg_inds = [ | |
5, | |
6, # right leg | |
7, | |
8, # left leg | |
] | |
foot_inds = [15, 16] | |
nleg_inds = sorted( | |
set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds) | |
) | |
alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min() | |
amli = alpha[leg_inds].min() | |
leg_inds.extend(foot_inds) | |
alpha[leg_inds] = alpha[leg_inds] - amli | |
angles = alpha.detach().cpu().numpy() | |
angles = hist_cub.cub( | |
angles / (math.pi / 2), | |
a=1.2121212121212122, | |
b=-1.105527638190953, | |
c=0.787878787878789, | |
) * (math.pi / 2) | |
alpha = alpha.new_tensor(angles) | |
loss_parallel.cos = torch.cos(alpha) | |
return cos_r | |
def get_contacts( | |
args, | |
sc_module, | |
y_data_conts, | |
keypoints_2d, | |
vertices, | |
bone_to_params, | |
loss_parallel, | |
): | |
use_contacts = args.use_contacts | |
use_msc = args.use_msc | |
c_mse = args.c_mse | |
if use_contacts: | |
assert c_mse == 0 | |
contact, contact_2d, _ = find_contacts( | |
y_data_conts, keypoints_2d, bone_to_params | |
) | |
if len(contact_2d) > 0: | |
loss_parallel.contact_2d = contact_2d | |
if len(contact) == 0: | |
_, contact = sc_module.verts_in_contact(vertices, return_idx=True) | |
contact = contact.cpu().numpy().ravel() | |
elif use_msc: | |
_, contact = sc_module.verts_in_contact(vertices, return_idx=True) | |
contact = contact.cpu().numpy().ravel() | |
else: | |
contact = np.array([]) | |
return contact | |
def save_mesh( | |
smpl, | |
smpl_output, | |
save_path, | |
fname, | |
): | |
mesh = trimesh.Trimesh( | |
vertices=smpl_output.vertices[0].cpu().numpy(), | |
faces=smpl.faces, | |
process=False, | |
) | |
rot = trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]) | |
mesh.apply_transform(rot) | |
mesh.export(save_path / f"{fname}.obj") | |
mesh.export(save_path / f"{fname}.glb") | |
mesh.export(save_path / f"{fname}.gltf") | |
def eft_step( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
keypoints_2d, | |
optimizer, | |
args, | |
loss_mse, | |
loss_parallel, | |
c_beta, | |
sc_module, | |
y_data_conts, | |
bone_to_params, | |
): | |
( | |
_, | |
_, | |
_, | |
keypoints_3d_pred, | |
_, | |
smpl_output, | |
_, | |
_, | |
) = optimize( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
keypoints_2d, | |
optimizer, | |
args, | |
loss_mse=loss_mse, | |
loss_parallel=loss_parallel, | |
c_mse=1, | |
c_new_mse=0, | |
c_beta=c_beta, | |
sc_crit=None, | |
msc_crit=None, | |
contact=None, | |
n_steps=60 + 90, | |
) | |
# find contacts | |
vertices = smpl_output.vertices.detach() | |
contact = get_contacts( | |
args, | |
sc_module, | |
y_data_conts, | |
keypoints_2d, | |
vertices, | |
bone_to_params, | |
loss_parallel, | |
) | |
return vertices, keypoints_3d_pred, contact | |
def dc_step( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
keypoints_2d, | |
optimizer, | |
args, | |
loss_mse, | |
loss_parallel, | |
c_mse, | |
c_new_mse, | |
c_beta, | |
sc_crit, | |
msc_crit, | |
contact, | |
use_contacts, | |
use_msc, | |
): | |
rotmat_pred, *_ = optimize( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
keypoints_2d, | |
optimizer, | |
args, | |
loss_mse=loss_mse, | |
loss_parallel=loss_parallel, | |
c_mse=c_mse, | |
c_new_mse=c_new_mse, | |
c_beta=c_beta, | |
sc_crit=sc_crit, | |
msc_crit=msc_crit if use_contacts or use_msc else None, | |
contact=contact if use_contacts or use_msc else None, | |
n_steps=60 if c_new_mse > 0 or use_contacts or use_msc else 0, # + 60,, | |
i_ini=60 + 90, | |
) | |
return rotmat_pred | |
def us_step( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
rotmat_pred, | |
keypoints_2d, | |
args, | |
loss_mse, | |
loss_parallel, | |
c_mse, | |
c_new_mse, | |
sc_crit, | |
msc_crit, | |
contact, | |
use_contacts, | |
use_msc, | |
save_path, | |
): | |
(_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
use_betas=False, | |
zero_hands=True, | |
) | |
_, smpl_output_us = optimize_ft( | |
rotmat_pred, | |
camera_pred_us, | |
smpl, | |
selector, | |
keypoints_2d, | |
args, | |
loss_mse=loss_mse, | |
loss_parallel=loss_parallel, | |
c_mse=c_mse, | |
c_new_mse=c_new_mse, | |
sc_crit=sc_crit, | |
msc_crit=msc_crit if use_contacts or use_msc else None, | |
contact=contact if use_contacts or use_msc else None, | |
n_steps=60 if use_contacts or use_msc else 0, # + 60, | |
i_ini=60 + 90 + 60, | |
zero_hands=True, | |
fist=args.fist, | |
) | |
save_mesh( | |
smpl, | |
smpl_output_us, | |
save_path, | |
"us", | |
) | |
def main(): | |
args = parse_args() | |
print(args) | |
# models | |
model_pose = cv2.dnn.readNetFromONNX( | |
args.pose_estimation_model_path | |
) # "hrn_w48_384x288.onnx" | |
model_contact = cv2.dnn.readNetFromONNX( | |
args.contact_model_path | |
) # "contact_hrn_w32_256x192.onnx" | |
device = ( | |
torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu") | |
) | |
model_hmr = spin.hmr(args.smpl_mean_params_path) # "smpl_mean_params.npz" | |
model_hmr.to(device) | |
checkpoint = torch.load( | |
args.spin_model_path, # "spin_model_smplx_eft_18.pt" | |
map_location="cpu" | |
) | |
smpl = spin.SMPLX( | |
args.smpl_model_dir, # "models/smplx" | |
batch_size=1, | |
create_transl=False, | |
use_pca=False, | |
flat_hand_mean=args.fist is not None, | |
) | |
smpl.to(device) | |
selector = get_selector() | |
use_contacts = args.use_contacts | |
use_msc = args.use_msc | |
bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item() | |
foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item() | |
left_foot_inds = foot_inds["left_foot_inds"] | |
right_foot_inds = foot_inds["right_foot_inds"] | |
if use_contacts: | |
model_type = args.smpl_type | |
sc_module = selfcontact.SelfContact( | |
essentials_folder=args.essentials_dir, # "smplify-xmc-essentials" | |
geothres=0.3, | |
euclthres=0.02, | |
test_segments=True, | |
compute_hd=True, | |
model_type=model_type, | |
device=device, | |
) | |
sc_module.to(device) | |
sc_crit = selfcontact.losses.SelfContactLoss( | |
contact_module=sc_module, | |
inside_loss_weight=0.5, | |
outside_loss_weight=0.0, | |
contact_loss_weight=0.5, | |
align_faces=True, | |
use_hd=True, | |
test_segments=True, | |
device=device, | |
model_type=model_type, | |
) | |
sc_crit.to(device) | |
msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask) | |
msc_crit.to(device) | |
else: | |
sc_module = None | |
sc_crit = None | |
msc_crit = None | |
loss_mse = losses.MSE([1, 10, 13]) # Neck + Right Upper Leg + Left Upper Leg | |
ignore = ( | |
(1, 2), # Neck + Right Shoulder | |
(1, 5), # Neck + Left Shoulder | |
(9, 10), # Hips + Right Upper Leg | |
(9, 13), # Hips + Left Upper Leg | |
) | |
loss_parallel = losses.Parallel( | |
skeleton=pose_estimation.SKELETON, | |
ignore=ignore, | |
) | |
c_mse = args.c_mse | |
c_new_mse = args.c_par | |
c_beta = 1e-3 | |
if c_mse > 0: | |
assert c_new_mse == 0 | |
elif c_mse == 0: | |
assert c_new_mse > 0 | |
root_path = Path(args.save_path) | |
root_path.mkdir(exist_ok=True, parents=True) | |
path_to_imgs = Path(args.img_path) | |
if path_to_imgs.is_dir(): | |
path_to_imgs = path_to_imgs.iterdir() | |
else: | |
path_to_imgs = [path_to_imgs] | |
for img_path in path_to_imgs: | |
if not any( | |
img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"] | |
): | |
continue | |
img_name = img_path.stem | |
# use 2d keypoints detection | |
( | |
img_original, | |
predicted_keypoints_2d, | |
_, | |
_, | |
) = pose_estimation.infer_single_image( | |
model_pose, | |
img_path, | |
input_img_size=pose_estimation.IMG_SIZE, | |
return_kps=True, | |
) | |
save_path = root_path / img_name | |
save_path.mkdir(exist_ok=True, parents=True) | |
img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB) | |
img_size_original = img_original.shape[:2] | |
keypoints_2d, *_ = normalize_keypoints_to_spin( | |
predicted_keypoints_2d, img_size_original | |
) | |
keypoints_2d = torch.from_numpy(keypoints_2d) | |
keypoints_2d = keypoints_2d.to(device) | |
( | |
predicted_contact_heatmap, | |
predicted_contact_heatmap_raw, | |
very_hm_raw, | |
) = get_contact_heatmap(model_contact, img_path) | |
predicted_contact_heatmap_raw = Image.fromarray( | |
predicted_contact_heatmap_raw | |
).resize(img_size_original[::-1]) | |
predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1]) | |
if c_new_mse == 0: | |
predicted_contact_heatmap_raw = None | |
y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap) | |
model_hmr.load_state_dict(checkpoint["model"], strict=True) | |
model_hmr.train() | |
freeze_layers(model_hmr) | |
_, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES) | |
input_img = input_img.to(device) | |
optimizer = optim.Adam( | |
filter(lambda p: p.requires_grad, model_hmr.parameters()), | |
lr=1e-6, | |
) | |
vertices, keypoints_3d_pred, contact = eft_step( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
keypoints_2d, | |
optimizer, | |
args, | |
loss_mse, | |
loss_parallel, | |
c_beta, | |
sc_module, | |
y_data_conts, | |
bone_to_params, | |
) | |
if args.use_natural: | |
get_natural( | |
keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl, | |
) | |
if args.use_cos: | |
get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel) | |
rotmat_pred = dc_step( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
keypoints_2d, | |
optimizer, | |
args, | |
loss_mse, | |
loss_parallel, | |
c_mse, | |
c_new_mse, | |
c_beta, | |
sc_crit, | |
msc_crit, | |
contact, | |
use_contacts, | |
use_msc, | |
) | |
us_step( | |
model_hmr, | |
smpl, | |
selector, | |
input_img, | |
rotmat_pred, | |
keypoints_2d, | |
args, | |
loss_mse, | |
loss_parallel, | |
c_mse, | |
c_new_mse, | |
sc_crit, | |
msc_crit, | |
contact, | |
use_contacts, | |
use_msc, | |
save_path, | |
) | |
if __name__ == "__main__": | |
main() | |