sketch2pose / src /pose.py
kbrodt's picture
Update src/pose.py
5c5f6da
import argparse
import math
from pathlib import Path
import cv2
import numpy as np
import PIL.Image as Image
import selfcontact
import selfcontact.losses
import shapely.geometry
import torch
import torch.nn as nn
import torch.optim as optim
import torchgeometry
import tqdm
import trimesh
from skimage import measure
import fist_pose
import hist_cub
import losses
import pose_estimation
import spin
PE_KSP_TO_SPIN = {
"Head": "Head",
"Neck": "Neck",
"Right Shoulder": "Right ForeArm",
"Right Arm": "Right Arm",
"Right Hand": "Right Hand",
"Left Shoulder": "Left ForeArm",
"Left Arm": "Left Arm",
"Left Hand": "Left Hand",
"Spine": "Spine1",
"Hips": "Hips",
"Right Upper Leg": "Right Upper Leg",
"Right Leg": "Right Leg",
"Right Foot": "Right Foot",
"Left Upper Leg": "Left Upper Leg",
"Left Leg": "Left Leg",
"Left Foot": "Left Foot",
"Left Toe": "Left Toe",
"Right Toe": "Right Toe",
}
MODELS_DIR = "models"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pose-estimation-model-path",
type=str,
default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx",
help="Pose Estimation model",
)
parser.add_argument(
"--contact-model-path",
type=str,
default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx",
help="Contact model",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cpu", "cuda"],
help="Torch device",
)
parser.add_argument(
"--spin-model-path",
type=str,
default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt",
help="SPIN model path",
)
parser.add_argument(
"--smpl-type",
type=str,
default="smplx",
choices=["smplx"],
help="SMPL model type",
)
parser.add_argument(
"--smpl-model-dir",
type=str,
default=f"./{MODELS_DIR}/models/smplx",
help="SMPL model dir",
)
parser.add_argument(
"--smpl-mean-params-path",
type=str,
default=f"./{MODELS_DIR}/data/smpl_mean_params.npz",
help="SMPL mean params",
)
parser.add_argument(
"--essentials-dir",
type=str,
default=f"./{MODELS_DIR}/smplify-xmc-essentials",
help="SMPL Essentials folder for contacts",
)
parser.add_argument(
"--parametrization-path",
type=str,
default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy",
help="Parametrization path",
)
parser.add_argument(
"--bone-parametrization-path",
type=str,
default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy",
help="Bone parametrization path",
)
parser.add_argument(
"--foot-inds-path",
type=str,
default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy",
help="Foot indinces",
)
parser.add_argument(
"--save-path",
type=str,
required=True,
help="Path to save the results",
)
parser.add_argument(
"--img-path",
type=str,
required=True,
help="Path to img to test",
)
parser.add_argument(
"--use-contacts",
action="store_true",
help="Use contact model",
)
parser.add_argument(
"--use-msc",
action="store_true",
help="Use MSC loss",
)
parser.add_argument(
"--use-natural",
action="store_true",
help="Use regularity",
)
parser.add_argument(
"--use-cos",
action="store_true",
help="Use cos model",
)
parser.add_argument(
"--use-angle-transf",
action="store_true",
help="Use cube foreshortening transformation",
)
parser.add_argument(
"--c-mse",
type=float,
default=0,
help="MSE weight",
)
parser.add_argument(
"--c-par",
type=float,
default=10,
help="Parallel weight",
)
parser.add_argument(
"--c-f",
type=float,
default=1000,
help="Cos coef",
)
parser.add_argument(
"--c-parallel",
type=float,
default=100,
help="Parallel weight",
)
parser.add_argument(
"--c-reg",
type=float,
default=1000,
help="Regularity weight",
)
parser.add_argument(
"--c-cont2d",
type=float,
default=1,
help="Contact 2D weight",
)
parser.add_argument(
"--c-msc",
type=float,
default=17_500,
help="MSC weight",
)
parser.add_argument(
"--fist",
nargs="+",
type=str,
choices=list(fist_pose.INT_TO_FIST),
)
args = parser.parse_args()
return args
def freeze_layers(model):
for module in model.modules():
if type(module) is False:
continue
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module.eval()
for m in module.parameters():
m.requires_grad = False
if isinstance(module, nn.Dropout):
module.eval()
for m in module.parameters():
m.requires_grad = False
def project_and_normalize_to_spin(vertices_3d, camera):
vertices_2d = vertices_3d # [:, :2]
scale, translate = camera[0], camera[1:]
translate = scale.new_zeros(3)
translate[:2] = camera[1:]
vertices_2d = vertices_2d + translate
vertices_2d = scale * vertices_2d + 1
vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d
return vertices_2d
def project_and_normalize_to_spin_legs(vertices_3d, A, camera):
A, J = A
A = A[0]
J = J[0]
L = vertices_3d.new_tensor(
[
[0.98619063, 0.16560926, 0.00127302],
[-0.16560601, 0.98603675, 0.01749799],
[0.00164258, -0.01746717, 0.99984609],
]
)
R = vertices_3d.new_tensor(
[
[0.9910211, -0.13368178, -0.0025208],
[0.13367888, 0.99027076, 0.03864949],
[-0.00267045, -0.03863944, 0.99924965],
]
)
scale = camera[0]
R = A[2, :3, :3] @ R # 2 - right
L = A[1, :3, :3] @ L # 1 - left
r = J[5] - J[2]
l = J[4] - J[1]
rleg = scale * spin.constants.IMG_RES / 2 * R @ r
lleg = scale * spin.constants.IMG_RES / 2 * L @ l
rleg = rleg[:2]
lleg = lleg[:2]
return rleg, lleg
def rotation_matrix_to_angle_axis(rotmat):
bs, n_joints, *_ = rotmat.size()
rotmat = torch.cat(
[
rotmat.view(-1, 3, 3),
rotmat.new_tensor([0, 0, 1], dtype=torch.float32)
.view(bs, 3, 1)
.expand(n_joints, -1, -1),
],
dim=-1,
)
aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat)
aa = aa.reshape(bs, 3 * n_joints)
return aa
def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):
if smpl.name() == "SMPL":
smpl_output = smpl(
betas=betas if use_betas else None,
body_pose=rotmat[:, 1:],
global_orient=rotmat[:, 0].unsqueeze(1),
pose2rot=False,
)
elif smpl.name() == "SMPL-X":
rotmat = rotation_matrix_to_angle_axis(rotmat)
if zero_hands:
for i in [20, 21]:
rotmat[:, 3 * i : 3 * (i + 1)] = 0
for i in [12, 15]: # neck, head
rotmat[:, 3 * i + 1] = 0 # y
smpl_output = smpl(
betas=betas if use_betas else None,
body_pose=rotmat[:, 3:],
global_orient=rotmat[:, :3],
pose2rot=True,
)
else:
raise NotImplementedError
return smpl_output, rotmat
def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False):
input_img = input_img.unsqueeze(0)
rotmat, betas, camera = model_hmr(input_img)
smpl_output, rotmat = get_smpl_output(
smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands
)
rotmat = rotmat.squeeze(0)
betas = betas.squeeze(0)
camera = camera.squeeze(0)
z = smpl_output.joints
z = z.squeeze(0)
return rotmat, betas, camera, smpl_output, z
def get_pred_and_data(
model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False
):
rotmat, betas, camera, smpl_output, zz = get_predictions(
model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands
)
joints = smpl_output.joints.squeeze(0)
joints_2d = project_and_normalize_to_spin(joints, camera)
rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera)
joints_2d_orig = joints_2d
joints_2d = joints_2d[selector]
vertices = smpl_output.vertices.squeeze(0)
vertices_2d = project_and_normalize_to_spin(vertices, camera)
zz = zz[selector]
return (
rotmat,
betas,
camera,
joints_2d,
zz,
vertices_2d,
smpl_output,
(rleg, lleg),
joints_2d_orig,
)
def normalize_keypoints_to_spin(keypoints_2d, img_size):
h, w = img_size
if h > w: # vertically
ax1 = 1
ax2 = 0
else: # horizontal
ax1 = 0
ax2 = 1
shift = (img_size[ax1] - img_size[ax2]) / 2
scale = spin.constants.IMG_RES / img_size[ax2]
keypoints_2d_normalized = np.copy(keypoints_2d)
keypoints_2d_normalized[:, ax2] -= shift
keypoints_2d_normalized *= scale
return keypoints_2d_normalized, shift, scale, ax2
def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):
keypoints_2d_normalized = np.copy(keypoints_2d)
keypoints_2d_normalized /= scale
keypoints_2d_normalized[:, ax2] += shift
return keypoints_2d_normalized
def get_vertices_in_heatmap(contact_heatmap):
contact_heatmap_size = contact_heatmap.shape[:2]
label = measure.label(contact_heatmap)
y_data_conts = []
for i in range(1, label.max() + 1):
predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype(
"float"
)
predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin(
predicted_kps_contact, contact_heatmap_size
)
y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist()
y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull
y_data_conts.append(y_data_cont)
return y_data_conts
def get_contact_heatmap(model_contact, img_path, thresh=0.5):
contact_heatmap = pose_estimation.infer_single_image(
model_contact,
img_path,
input_img_size=(192, 256),
return_kps=False,
)
contact_heatmap = contact_heatmap.squeeze(0)
contact_heatmap_orig = contact_heatmap.copy()
mi = contact_heatmap.min()
ma = contact_heatmap.max()
contact_heatmap = (contact_heatmap - mi) / (ma - mi)
contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8")
contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1)
contact_heatmap = (contact_heatmap * 255).astype("uint8")
return contact_heatmap_, contact_heatmap, contact_heatmap_orig
def discretize(parametrization, n_bins=100):
bins = np.linspace(0, 1, n_bins + 1)
inds = np.digitize(parametrization, bins)
disc_parametrization = bins[inds - 1]
return disc_parametrization
def get_mapping_from_params_to_verts(verts, params):
mapping = {}
for v, t in zip(verts, params):
mapping.setdefault(t, []).append(v)
return mapping
def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375):
n_bins = int(math.ceil(1 / step)) - 1 # mean face's circumradius
contact = []
contact_2d = []
for_mask = []
for y_data_cont in y_data_conts:
contact_loc = []
contact_2d_loc = []
buffer = y_data_cont.buffer(thresh)
mask_add = False
for i, j in pose_estimation.SKELETON:
verts, t3d = bone_to_params[(i, j)]
if len(verts) == 0:
continue
t3d = discretize(t3d, n_bins=n_bins)
t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d)
t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0])
t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted])
line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]])
lint = buffer.intersection(line)
if len(lint.boundary.geoms) < 2:
continue
t2d_start = line.project(lint.boundary.geoms[0], normalized=True)
t2d_end = line.project(lint.boundary.geoms[1], normalized=True)
assert t2d_start <= t2d_end
t2ds = discretize(
np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins
)
to_add = False
for t2d in t2ds:
if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]:
continue
t2d_ind = np.searchsorted(t3d_sorted_np, t2d)
c = t3d_to_verts_sorted[t2d_ind][1]
contact_loc.extend(c)
to_add = True
mask_add = True
if t2d_ind + 1 < len(t3d_to_verts_sorted):
c = t3d_to_verts_sorted[t2d_ind + 1][1]
contact_loc.extend(c)
if t2d_ind > 0:
c = t3d_to_verts_sorted[t2d_ind - 1][1]
contact_loc.extend(c)
if to_add:
contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start)))
if mask_add:
for_mask.append(buffer.exterior.coords.xy)
contact_loc = sorted(set(contact_loc))
contact_loc = np.array(contact_loc, dtype="int")
contact.append(contact_loc)
contact_2d.append(contact_2d_loc)
for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask]
return contact, contact_2d, for_mask
def optimize(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse=None,
loss_parallel=None,
c_mse=0.0,
c_new_mse=1.0,
c_beta=1e-3,
sc_crit=None,
msc_crit=None,
contact=None,
n_steps=60,
i_ini=0,
):
mean_zfoot_val = {}
with tqdm.trange(n_steps) as pbar:
for i in pbar:
global_step = i + i_ini
optimizer.zero_grad()
(
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
z,
vertices_2d_pred,
smpl_output,
(rleg, lleg),
joints_2d_orig,
) = get_pred_and_data(
model_hmr,
smpl,
selector,
input_img,
)
keypoints_2d_pred = keypoints_3d_pred[:, :2]
loss = l2 = 0.0
if c_mse > 0 and loss_mse is not None:
l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
loss = loss + c_mse * l2
vertices_pred = smpl_output.vertices
lpar = z_loss = loss_sh = 0.0
if c_new_mse > 0 and loss_parallel is not None:
Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
keypoints_3d_pred,
keypoints_2d,
z,
(rleg, lleg),
global_step=global_step,
)
lpar = (
Ltan
+ c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
+ Lspine
+ args.c_reg * Lgr
+ args.c_reg * Lstraight3d
+ args.c_cont2d * Lcon2d
)
loss = loss + 300 * lpar
for side in ["left", "right"]:
attr = f"{side}_foot_inds"
if hasattr(loss_parallel, attr):
foot_inds = getattr(loss_parallel, attr)
zind = 1
if attr not in mean_zfoot_val:
with torch.no_grad():
mean_zfoot_val[attr] = torch.median(
vertices_pred[0, foot_inds, zind], dim=0
).values
loss_foot = (
(vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
** 2
).sum()
loss = loss + args.c_reg * loss_foot
if hasattr(loss_parallel, "silhuette_vertices_inds"):
inds = loss_parallel.silhuette_vertices_inds
loss_sh = (
(vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
).sum()
loss = loss + args.c_reg * loss_sh
lbeta = (betas_pred**2).mean()
lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean()
loss = loss + c_beta * lbeta + lcam
lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
if sc_crit is not None:
gsc_contact_loss, faces_angle_loss = sc_crit(
vertices_pred,
)
lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
loss = loss + lgsc_a
msc_loss = 0.0
if contact is not None and len(contact) > 0 and msc_crit is not None:
if not isinstance(contact, list):
contact = [contact]
for cntct in contact:
msc_loss = msc_crit(
cntct,
vertices_pred,
)
loss = loss + args.c_msc * msc_loss
loss.backward()
optimizer.step()
epoch_loss = loss.item()
pbar.set_postfix(
**{
"l": f"{epoch_loss:.3}",
"l2": f"{l2:.3}",
"par": f"{lpar:.3}",
"beta": f"{lbeta:.3}",
"cam": f"{lcam:.3}",
"z": f"{z_loss:.3}",
"gsc_contact": f"{float(gsc_contact_loss):.3}",
"faces_angle": f"{float(faces_angle_loss):.3}",
"msc": f"{float(msc_loss):.3}",
}
)
with torch.no_grad():
(
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
z,
vertices_2d_pred,
smpl_output,
(rleg, lleg),
joints_2d_orig,
) = get_pred_and_data(
model_hmr,
smpl,
selector,
input_img,
zero_hands=True,
)
return (
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
vertices_2d_pred,
smpl_output,
z,
joints_2d_orig,
)
def optimize_ft(
theta,
camera,
smpl,
selector,
keypoints_2d,
args,
loss_mse=None,
loss_parallel=None,
c_mse=0.0,
c_new_mse=1.0,
sc_crit=None,
msc_crit=None,
contact=None,
n_steps=60,
i_ini=0,
zero_hands=False,
fist=None,
):
mean_zfoot_val = {}
theta = theta.detach().clone()
camera = camera.detach().clone()
rotmat_pred = nn.Parameter(theta)
camera_pred = nn.Parameter(camera)
optimizer = torch.optim.Adam(
[
rotmat_pred,
camera_pred,
],
lr=1e-3,
)
global_step = i_ini
with tqdm.trange(n_steps) as pbar:
for i in pbar:
global_step = i + i_ini
optimizer.zero_grad()
global_orient = rotmat_pred[:3]
body_pose = rotmat_pred[3:]
smpl_output = smpl(
global_orient=global_orient.unsqueeze(0),
body_pose=body_pose.unsqueeze(0),
pose2rot=True,
)
z = smpl_output.joints
z = z.squeeze(0)
joints = smpl_output.joints.squeeze(0)
joints_2d = project_and_normalize_to_spin(joints, camera_pred)
rleg, lleg = project_and_normalize_to_spin_legs(
joints, smpl_output.A, camera_pred
)
joints_2d = joints_2d[selector]
z = z[selector]
keypoints_3d_pred = joints_2d
keypoints_2d_pred = keypoints_3d_pred[:, :2]
lprior = ((rotmat_pred - theta) ** 2).sum() + (
(camera_pred - camera) ** 2
).sum()
loss = lprior
l2 = 0.0
if c_mse > 0 and loss_mse is not None:
l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
loss = loss + c_mse * l2
vertices_pred = smpl_output.vertices
lpar = z_loss = loss_sh = 0.0
if c_new_mse > 0 and loss_parallel is not None:
Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
keypoints_3d_pred,
keypoints_2d,
z,
(rleg, lleg),
global_step=global_step,
)
lpar = (
Ltan
+ c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
+ Lspine
+ args.c_reg * Lgr
+ args.c_reg * Lstraight3d
+ args.c_cont2d * Lcon2d
)
loss = loss + 300 * lpar
for side in ["left", "right"]:
attr = f"{side}_foot_inds"
if hasattr(loss_parallel, attr):
foot_inds = getattr(loss_parallel, attr)
zind = 1
if attr not in mean_zfoot_val:
with torch.no_grad():
mean_zfoot_val[attr] = torch.median(
vertices_pred[0, foot_inds, zind], dim=0
).values
loss_foot = (
(vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
** 2
).sum()
loss = loss + args.c_reg * loss_foot
if hasattr(loss_parallel, "silhuette_vertices_inds"):
inds = loss_parallel.silhuette_vertices_inds
loss_sh = (
(vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
).sum()
loss = loss + args.c_reg * loss_sh
lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
if sc_crit is not None:
gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred)
lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
loss = loss + lgsc_a
msc_loss = 0.0
if contact is not None and len(contact) > 0 and msc_crit is not None:
if not isinstance(contact, list):
contact = [contact]
for cntct in contact:
msc_loss = msc_crit(
cntct,
vertices_pred,
)
loss = loss + args.c_msc * msc_loss
loss.backward()
optimizer.step()
epoch_loss = loss.item()
pbar.set_postfix(
**{
"l": f"{epoch_loss:.3}",
"l2": f"{l2:.3}",
"par": f"{lpar:.3}",
"z": f"{z_loss:.3}",
"gsc_contact": f"{float(gsc_contact_loss):.3}",
"faces_angle": f"{float(faces_angle_loss):.3}",
"msc": f"{float(msc_loss):.3}",
}
)
rotmat_pred = rotmat_pred.detach()
if zero_hands:
for i in [20, 21]:
rotmat_pred[3 * i : 3 * (i + 1)] = 0
for i in [12, 15]: # neck, head
rotmat_pred[3 * i + 1] = 0 # y
global_orient = rotmat_pred[:3]
body_pose = rotmat_pred[3:]
left_hand_pose = None
right_hand_pose = None
if fist is not None:
left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0)
right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0)
for f in fist:
pp = fist_pose.INT_TO_FIST[f]
if pp is not None:
pp = rotmat_pred.new_tensor(pp).unsqueeze(0)
if f.startswith("lf"):
left_hand_pose = pp
elif f.startswith("rf"):
right_hand_pose = pp
elif f.startswith("l"):
body_pose[19 * 3 : 19 * 3 + 3] = pp
left_hand_pose = None
elif f.startswith("r"):
body_pose[20 * 3 : 20 * 3 + 3] = pp
right_hand_pose = None
else:
raise RuntimeError(f"No such hand pose: {f}")
with torch.no_grad():
smpl_output = smpl(
global_orient=global_orient.unsqueeze(0),
body_pose=body_pose.unsqueeze(0),
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
pose2rot=True,
)
return rotmat_pred, smpl_output
def create_bone(i, j, keypoints_2d):
a = keypoints_2d[i]
b = keypoints_2d[j]
ab = b - a
ab = torch.nn.functional.normalize(ab, dim=0)
return ab
def is_parallel_to_plane(bone, thresh=21):
return abs(bone[0]) > math.cos(math.radians(thresh))
def is_close_to_plane(bone, plane, thresh):
dist = abs(bone[0] - plane)
return dist < thresh
def get_selector():
selector = []
for kp in pose_estimation.KPS:
tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp])
selector.append(tmp)
return selector
def calc_cos(joints_2d, joints_3d):
cos = []
for i, j in pose_estimation.SKELETON:
a = joints_2d[i] - joints_2d[j]
a = nn.functional.normalize(a, dim=0)
b = joints_3d[i] - joints_3d[j]
b = nn.functional.normalize(b, dim=0)[:2]
c = (a * b).sum()
cos.append(c)
cos = torch.stack(cos, dim=0)
return cos
def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl):
height_2d = (
keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0]
).item()
plane_2d = keypoints_2d.max(dim=0).values[0].item()
ground_parallel = []
parallel_in_3d = []
parallel3d_bones = set()
# parallel chains
for i, j, k in [
("Right Upper Leg", "Right Leg", "Right Foot"),
("Right Leg", "Right Foot", "Right Toe"), # to remove?
("Left Upper Leg", "Left Leg", "Left Foot"),
("Left Leg", "Left Foot", "Left Toe"), # to remove?
("Right Shoulder", "Right Arm", "Right Hand"),
("Left Shoulder", "Left Arm", "Left Hand"),
# ("Hips", "Spine", "Neck"),
# ("Spine", "Neck", "Head"),
]:
i = pose_estimation.KPS.index(i)
j = pose_estimation.KPS.index(j)
k = pose_estimation.KPS.index(k)
upleg_leg = create_bone(i, j, keypoints_2d)
leg_foot = create_bone(j, k, keypoints_2d)
if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot):
if is_close_to_plane(
upleg_leg, plane_2d, thresh=0.1 * height_2d
) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d):
ground_parallel.append(((i, j), 1))
ground_parallel.append(((j, k), 1))
if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)):
parallel_in_3d.append(((i, j), (j, k)))
parallel3d_bones.add((i, j))
parallel3d_bones.add((j, k))
# parallel feets
for i, j in [
("Right Foot", "Right Toe"),
("Left Foot", "Left Toe"),
]:
i = pose_estimation.KPS.index(i)
j = pose_estimation.KPS.index(j)
if (i, j) in parallel3d_bones:
continue
foot_toe = create_bone(i, j, keypoints_2d)
if is_parallel_to_plane(foot_toe, thresh=25):
if "Right" in pose_estimation.KPS[i]:
loss_parallel.right_foot_inds = right_foot_inds
else:
loss_parallel.left_foot_inds = left_foot_inds
loss_parallel.ground_parallel = ground_parallel
loss_parallel.parallel_in_3d = parallel_in_3d
vertices_np = vertices[0].cpu().numpy()
if len(ground_parallel) > 0:
# Silhuette veritices
mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False)
silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1
height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min()
plane_3d = vertices_np[:, 1].max()
silhuette_vertices_mask_2 = (
np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d
)
silhuette_vertices_mask = np.logical_and(
silhuette_vertices_mask_1, silhuette_vertices_mask_2
)
(silhuette_vertices_inds,) = np.where(silhuette_vertices_mask)
if len(silhuette_vertices_inds) > 0:
loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds
loss_parallel.ground = plane_3d
def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):
keypoints_2d_pred = keypoints_3d_pred[:, :2]
with torch.no_grad():
cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred)
alpha = torch.acos(cos_r)
if use_angle_transf:
leg_inds = [
5,
6, # right leg
7,
8, # left leg
]
foot_inds = [15, 16]
nleg_inds = sorted(
set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds)
)
alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min()
amli = alpha[leg_inds].min()
leg_inds.extend(foot_inds)
alpha[leg_inds] = alpha[leg_inds] - amli
angles = alpha.detach().cpu().numpy()
angles = hist_cub.cub(
angles / (math.pi / 2),
a=1.2121212121212122,
b=-1.105527638190953,
c=0.787878787878789,
) * (math.pi / 2)
alpha = alpha.new_tensor(angles)
loss_parallel.cos = torch.cos(alpha)
return cos_r
def get_contacts(
args,
sc_module,
y_data_conts,
keypoints_2d,
vertices,
bone_to_params,
loss_parallel,
):
use_contacts = args.use_contacts
use_msc = args.use_msc
c_mse = args.c_mse
if use_contacts:
assert c_mse == 0
contact, contact_2d, _ = find_contacts(
y_data_conts, keypoints_2d, bone_to_params
)
if len(contact_2d) > 0:
loss_parallel.contact_2d = contact_2d
if len(contact) == 0:
_, contact = sc_module.verts_in_contact(vertices, return_idx=True)
contact = contact.cpu().numpy().ravel()
elif use_msc:
_, contact = sc_module.verts_in_contact(vertices, return_idx=True)
contact = contact.cpu().numpy().ravel()
else:
contact = np.array([])
return contact
def save_mesh(
smpl,
smpl_output,
save_path,
fname,
):
mesh = trimesh.Trimesh(
vertices=smpl_output.vertices[0].cpu().numpy(),
faces=smpl.faces,
process=False,
)
rot = trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])
mesh.apply_transform(rot)
rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
mesh.apply_transform(rot)
mesh.export(save_path / f"{fname}.glb")
def eft_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_beta,
sc_module,
y_data_conts,
bone_to_params,
):
(
_,
_,
_,
keypoints_3d_pred,
_,
smpl_output,
_,
_,
) = optimize(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse=loss_mse,
loss_parallel=loss_parallel,
c_mse=1,
c_new_mse=0,
c_beta=c_beta,
sc_crit=None,
msc_crit=None,
contact=None,
n_steps=60 + 90,
)
# find contacts
vertices = smpl_output.vertices.detach()
contact = get_contacts(
args,
sc_module,
y_data_conts,
keypoints_2d,
vertices,
bone_to_params,
loss_parallel,
)
return vertices, keypoints_3d_pred, contact
def dc_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
c_beta,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
):
rotmat_pred, *_ = optimize(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse=loss_mse,
loss_parallel=loss_parallel,
c_mse=c_mse,
c_new_mse=c_new_mse,
c_beta=c_beta,
sc_crit=sc_crit,
msc_crit=msc_crit if use_contacts or use_msc else None,
contact=contact if use_contacts or use_msc else None,
n_steps=60 if c_new_mse > 0 or use_contacts or use_msc else 0, # + 60,,
i_ini=60 + 90,
)
return rotmat_pred
def us_step(
model_hmr,
smpl,
selector,
input_img,
rotmat_pred,
keypoints_2d,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
save_path,
):
(_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data(
model_hmr,
smpl,
selector,
input_img,
use_betas=False,
zero_hands=True,
)
_, smpl_output_us = optimize_ft(
rotmat_pred,
camera_pred_us,
smpl,
selector,
keypoints_2d,
args,
loss_mse=loss_mse,
loss_parallel=loss_parallel,
c_mse=c_mse,
c_new_mse=c_new_mse,
sc_crit=sc_crit,
msc_crit=msc_crit if use_contacts or use_msc else None,
contact=contact if use_contacts or use_msc else None,
n_steps=60 if use_contacts or use_msc else 0, # + 60,
i_ini=60 + 90 + 60,
zero_hands=True,
fist=args.fist,
)
save_mesh(
smpl,
smpl_output_us,
save_path,
"us",
)
def main():
args = parse_args()
print(args)
# models
model_pose = cv2.dnn.readNetFromONNX(
args.pose_estimation_model_path
) # "hrn_w48_384x288.onnx"
model_contact = cv2.dnn.readNetFromONNX(
args.contact_model_path
) # "contact_hrn_w32_256x192.onnx"
device = (
torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu")
)
model_hmr = spin.hmr(args.smpl_mean_params_path) # "smpl_mean_params.npz"
model_hmr.to(device)
checkpoint = torch.load(
args.spin_model_path, # "spin_model_smplx_eft_18.pt"
map_location="cpu"
)
smpl = spin.SMPLX(
args.smpl_model_dir, # "models/smplx"
batch_size=1,
create_transl=False,
use_pca=False,
flat_hand_mean=args.fist is not None,
)
smpl.to(device)
selector = get_selector()
use_contacts = args.use_contacts
use_msc = args.use_msc
bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item()
foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item()
left_foot_inds = foot_inds["left_foot_inds"]
right_foot_inds = foot_inds["right_foot_inds"]
if use_contacts:
model_type = args.smpl_type
sc_module = selfcontact.SelfContact(
essentials_folder=args.essentials_dir, # "smplify-xmc-essentials"
geothres=0.3,
euclthres=0.02,
test_segments=True,
compute_hd=True,
model_type=model_type,
device=device,
)
sc_module.to(device)
sc_crit = selfcontact.losses.SelfContactLoss(
contact_module=sc_module,
inside_loss_weight=0.5,
outside_loss_weight=0.0,
contact_loss_weight=0.5,
align_faces=True,
use_hd=True,
test_segments=True,
device=device,
model_type=model_type,
)
sc_crit.to(device)
msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask)
msc_crit.to(device)
else:
sc_module = None
sc_crit = None
msc_crit = None
loss_mse = losses.MSE([1, 10, 13]) # Neck + Right Upper Leg + Left Upper Leg
ignore = (
(1, 2), # Neck + Right Shoulder
(1, 5), # Neck + Left Shoulder
(9, 10), # Hips + Right Upper Leg
(9, 13), # Hips + Left Upper Leg
)
loss_parallel = losses.Parallel(
skeleton=pose_estimation.SKELETON,
ignore=ignore,
)
c_mse = args.c_mse
c_new_mse = args.c_par
c_beta = 1e-3
if c_mse > 0:
assert c_new_mse == 0
elif c_mse == 0:
assert c_new_mse > 0
root_path = Path(args.save_path)
root_path.mkdir(exist_ok=True, parents=True)
path_to_imgs = Path(args.img_path)
if path_to_imgs.is_dir():
path_to_imgs = path_to_imgs.iterdir()
else:
path_to_imgs = [path_to_imgs]
for img_path in path_to_imgs:
if not any(
img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"]
):
continue
img_name = img_path.stem
# use 2d keypoints detection
(
img_original,
predicted_keypoints_2d,
_,
_,
) = pose_estimation.infer_single_image(
model_pose,
img_path,
input_img_size=pose_estimation.IMG_SIZE,
return_kps=True,
)
save_path = root_path / img_name
save_path.mkdir(exist_ok=True, parents=True)
img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
img_size_original = img_original.shape[:2]
keypoints_2d, *_ = normalize_keypoints_to_spin(
predicted_keypoints_2d, img_size_original
)
keypoints_2d = torch.from_numpy(keypoints_2d)
keypoints_2d = keypoints_2d.to(device)
(
predicted_contact_heatmap,
predicted_contact_heatmap_raw,
very_hm_raw,
) = get_contact_heatmap(model_contact, img_path)
predicted_contact_heatmap_raw = Image.fromarray(
predicted_contact_heatmap_raw
).resize(img_size_original[::-1])
predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1])
if c_new_mse == 0:
predicted_contact_heatmap_raw = None
y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap)
model_hmr.load_state_dict(checkpoint["model"], strict=True)
model_hmr.train()
freeze_layers(model_hmr)
_, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES)
input_img = input_img.to(device)
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model_hmr.parameters()),
lr=1e-6,
)
vertices, keypoints_3d_pred, contact = eft_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_beta,
sc_module,
y_data_conts,
bone_to_params,
)
if args.use_natural:
get_natural(
keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl,
)
if args.use_cos:
get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel)
rotmat_pred = dc_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
c_beta,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
)
us_step(
model_hmr,
smpl,
selector,
input_img,
rotmat_pred,
keypoints_2d,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
save_path,
)
if __name__ == "__main__":
main()