Spaces:
Build error
Build error
File size: 12,641 Bytes
c5f8b57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import torch
import os, sys
import pickle
import smplx
import numpy as np
sys.path.append(os.path.dirname(__file__))
from customloss import (camera_fitting_loss,
body_fitting_loss,
camera_fitting_loss_3d,
body_fitting_loss_3d,
)
from prior import MaxMixturePrior
from visualize.joints2smpl.src import config
@torch.no_grad()
def guess_init_3d(model_joints,
j3d,
joints_category="orig"):
"""Initialize the camera translation via triangle similarity, by using the torso joints .
:param model_joints: SMPL model with pre joints
:param j3d: 25x3 array of Kinect Joints
:returns: 3D vector corresponding to the estimated camera translation
"""
# get the indexed four
gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder']
gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints]
if joints_category=="orig":
joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints]
elif joints_category=="AMASS":
joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints]
else:
print("NO SUCH JOINTS CATEGORY!")
sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1)
init_t = sum_init_t / 4.0
return init_t
# SMPLIfy 3D
class SMPLify3D():
"""Implementation of SMPLify, use 3D joints."""
def __init__(self,
smplxmodel,
step_size=1e-2,
batch_size=1,
num_iters=100,
use_collision=False,
use_lbfgs=True,
joints_category="orig",
device=torch.device('cuda:0'),
):
# Store options
self.batch_size = batch_size
self.device = device
self.step_size = step_size
self.num_iters = num_iters
# --- choose optimizer
self.use_lbfgs = use_lbfgs
# GMM pose prior
self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR,
num_gaussians=8,
dtype=torch.float32).to(device)
# collision part
self.use_collision = use_collision
if self.use_collision:
self.part_segm_fn = config.Part_Seg_DIR
# reLoad SMPL-X model
self.smpl = smplxmodel
self.model_faces = smplxmodel.faces_tensor.view(-1)
# select joint joint_category
self.joints_category = joints_category
if joints_category=="orig":
self.smpl_index = config.full_smpl_idx
self.corr_index = config.full_smpl_idx
elif joints_category=="AMASS":
self.smpl_index = config.amass_smpl_idx
self.corr_index = config.amass_idx
else:
self.smpl_index = None
self.corr_index = None
print("NO SUCH JOINTS CATEGORY!")
# ---- get the man function here ------
def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0):
"""Perform body fitting.
Input:
init_pose: SMPL pose estimate
init_betas: SMPL betas estimate
init_cam_t: Camera translation estimate
j3d: joints 3d aka keypoints
conf_3d: confidence for 3d joints
seq_ind: index of the sequence
Returns:
vertices: Vertices of optimized shape
joints: 3D joints of optimized shape
pose: SMPL pose parameters of optimized shape
betas: SMPL beta parameters of optimized shape
camera_translation: Camera translation
"""
# # # add the mesh inter-section to avoid
search_tree = None
pen_distance = None
filter_faces = None
if self.use_collision:
from mesh_intersection.bvh_search_tree import BVH
import mesh_intersection.loss as collisions_loss
from mesh_intersection.filter_faces import FilterFaces
search_tree = BVH(max_collisions=8)
pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True)
if self.part_segm_fn:
# Read the part segmentation
part_segm_fn = os.path.expandvars(self.part_segm_fn)
with open(part_segm_fn, 'rb') as faces_parents_file:
face_segm_data = pickle.load(faces_parents_file, encoding='latin1')
faces_segm = face_segm_data['segm']
faces_parents = face_segm_data['parents']
# Create the module used to filter invalid collision pairs
filter_faces = FilterFaces(
faces_segm=faces_segm, faces_parents=faces_parents,
ign_part_pairs=None).to(device=self.device)
# Split SMPL pose to body pose and global orientation
body_pose = init_pose[:, 3:].detach().clone()
global_orient = init_pose[:, :3].detach().clone()
betas = init_betas.detach().clone()
# use guess 3d to get the initial
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach()
camera_translation = init_cam_t.clone()
preserve_pose = init_pose[:, 3:].detach().clone()
# -------------Step 1: Optimize camera translation and body orientation--------
# Optimize only camera translation and body orientation
body_pose.requires_grad = False
betas.requires_grad = False
global_orient.requires_grad = True
camera_translation.requires_grad = True
camera_opt_params = [global_orient, camera_translation]
if self.use_lbfgs:
camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters,
lr=self.step_size, line_search_fn='strong_wolfe')
for i in range(10):
def closure():
camera_optimizer.zero_grad()
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
# print('model_joints', model_joints.shape)
# print('camera_translation', camera_translation.shape)
# print('init_cam_t', init_cam_t.shape)
# print('j3d', j3d.shape)
loss = camera_fitting_loss_3d(model_joints, camera_translation,
init_cam_t, j3d, self.joints_category)
loss.backward()
return loss
camera_optimizer.step(closure)
else:
camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999))
for i in range(20):
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation,
init_cam_t, j3d[:, self.corr_index], self.joints_category)
camera_optimizer.zero_grad()
loss.backward()
camera_optimizer.step()
# Fix camera translation after optimizing camera
# --------Step 2: Optimize body joints --------------------------
# Optimize only the body pose and global orientation of the body
body_pose.requires_grad = True
global_orient.requires_grad = True
camera_translation.requires_grad = True
# --- if we use the sequence, fix the shape
if seq_ind == 0:
betas.requires_grad = True
body_opt_params = [body_pose, betas, global_orient, camera_translation]
else:
betas.requires_grad = False
body_opt_params = [body_pose, global_orient, camera_translation]
if self.use_lbfgs:
body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters,
lr=self.step_size, line_search_fn='strong_wolfe')
for i in range(self.num_iters):
def closure():
body_optimizer.zero_grad()
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
model_vertices = smpl_output.vertices
loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
j3d[:, self.corr_index], self.pose_prior,
joints3d_conf=conf_3d,
joint_loss_weight=600.0,
pose_preserve_weight=5.0,
use_collision=self.use_collision,
model_vertices=model_vertices, model_faces=self.model_faces,
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
loss.backward()
return loss
body_optimizer.step(closure)
else:
body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999))
for i in range(self.num_iters):
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
model_vertices = smpl_output.vertices
loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
j3d[:, self.corr_index], self.pose_prior,
joints3d_conf=conf_3d,
joint_loss_weight=600.0,
use_collision=self.use_collision,
model_vertices=model_vertices, model_faces=self.model_faces,
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
body_optimizer.zero_grad()
loss.backward()
body_optimizer.step()
# Get final loss value
with torch.no_grad():
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas, return_full_pose=True)
model_joints = smpl_output.joints
model_vertices = smpl_output.vertices
final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
j3d[:, self.corr_index], self.pose_prior,
joints3d_conf=conf_3d,
joint_loss_weight=600.0,
use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces,
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
vertices = smpl_output.vertices.detach()
joints = smpl_output.joints.detach()
pose = torch.cat([global_orient, body_pose], dim=-1).detach()
betas = betas.detach()
return vertices, joints, pose, betas, camera_translation, final_loss
|