ECON / apps /infer.py
Yuliang's picture
SMPL-X based Animatable Avatar
3577d3c
raw
history blame
24 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import warnings
import logging
warnings.filterwarnings("ignore")
logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("trimesh").setLevel(logging.ERROR)
import torch, torchvision
import trimesh
import numpy as np
import argparse
import os
from termcolor import colored
from tqdm.auto import tqdm
from apps.Normal import Normal
from apps.IFGeo import IFGeo
from pytorch3d.ops import SubdivideMeshes
from lib.common.config import cfg
from lib.common.train_util import init_loss, load_normal_networks, load_networks
from lib.common.BNI import BNI
from lib.common.BNI_utils import save_normal_tensor
from lib.dataset.TestDataset import TestDataset
from lib.common.local_affine import register
from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
from lib.dataset.mesh_util import *
from lib.common.voxelize import VoxelGrid
torch.backends.cudnn.benchmark = True
if __name__ == "__main__":
# loading cfg file
parser = argparse.ArgumentParser()
parser.add_argument("-gpu", "--gpu_device", type=int, default=0)
parser.add_argument("-loop_smpl", "--loop_smpl", type=int, default=50)
parser.add_argument("-patience", "--patience", type=int, default=5)
parser.add_argument("-vis_freq", "--vis_freq", type=int, default=1000)
parser.add_argument("-multi", action="store_false")
parser.add_argument("-in_dir", "--in_dir", type=str, default="./examples")
parser.add_argument("-out_dir", "--out_dir", type=str, default="./results")
parser.add_argument("-seg_dir", "--seg_dir", type=str, default=None)
parser.add_argument("-cfg", "--config", type=str, default="./configs/econ.yaml")
args = parser.parse_args()
# cfg read and merge
cfg.merge_from_file(args.config)
cfg.merge_from_file("./lib/pymafx/configs/pymafx_config.yaml")
device = torch.device(f"cuda:{args.gpu_device}")
# setting for testing on in-the-wild images
cfg_show_list = ["test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True, "batch_size", 1]
cfg.merge_from_list(cfg_show_list)
cfg.freeze()
# load model
normal_model = Normal(cfg).to(device)
load_normal_networks(normal_model, cfg.normal_path)
normal_model.netG.eval()
# load IFGeo model
ifnet_model = IFGeo(cfg).to(device)
load_networks(ifnet_model, mlp_path=cfg.ifnet_path)
ifnet_model.netG.eval()
# SMPLX object
SMPLX_object = SMPLX()
dataset_param = {
"image_dir": args.in_dir,
"seg_dir": args.seg_dir,
"use_seg": True, # w/ or w/o segmentation
"hps_type": cfg.bni.hps_type, # pymafx/pixie
"vol_res": cfg.vol_res,
"single": args.multi,
}
if cfg.bni.use_ifnet:
print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
else:
print(colored("Use SMPL-X (Explicit) for completion", "green"))
dataset = TestDataset(dataset_param, device)
print(colored(f"Dataset Size: {len(dataset)}", "green"))
pbar = tqdm(dataset)
for data in pbar:
losses = init_loss()
pbar.set_description(f"{data['name']}")
# final results rendered as image (PNG)
# 1. Render the final fitted SMPL (xxx_smpl.png)
# 2. Render the final reconstructed clothed human (xxx_cloth.png)
# 3. Blend the original image with predicted cloth normal (xxx_overlap.png)
# 4. Blend the cropped image with predicted cloth normal (xxx_crop.png)
os.makedirs(osp.join(args.out_dir, cfg.name, "png"), exist_ok=True)
# final reconstruction meshes (OBJ)
# 1. SMPL mesh (xxx_smpl_xx.obj)
# 2. SMPL params (xxx_smpl.npy)
# 3. d-BiNI surfaces (xxx_BNI.obj)
# 4. seperate face/hand mesh (xxx_hand/face.obj)
# 5. full shape impainted by IF-Nets+, and remeshed shape (xxx_IF_(remesh).obj)
# 6. sideded or occluded parts (xxx_side.obj)
# 7. final reconstructed clothed human (xxx_full.obj)
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
in_tensor = {
"smpl_faces": data["smpl_faces"],
"image": data["img_icon"].to(device),
"mask": data["img_mask"].to(device)
}
# The optimizer and variables
optimed_pose = data["body_pose"].requires_grad_(True)
optimed_trans = data["trans"].requires_grad_(True)
optimed_betas = data["betas"].requires_grad_(True)
optimed_orient = data["global_orient"].requires_grad_(True)
optimizer_smpl = torch.optim.Adam([optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True)
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_smpl,
mode="min",
factor=0.5,
verbose=0,
min_lr=1e-5,
patience=args.patience,
)
# [result_loop_1, result_loop_2, ...]
per_data_lst = []
N_body, N_pose = optimed_pose.shape[:2]
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
if osp.exists(smpl_path):
smpl_verts_lst = []
smpl_faces_lst = []
for idx in range(N_body):
smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
smpl_mesh = trimesh.load(smpl_obj)
smpl_verts = torch.tensor(smpl_mesh.vertices).to(device).float()
smpl_faces = torch.tensor(smpl_mesh.faces).to(device).long()
smpl_verts_lst.append(smpl_verts)
smpl_faces_lst.append(smpl_faces)
batch_smpl_verts = torch.stack(smpl_verts_lst)
batch_smpl_faces = torch.stack(smpl_faces_lst)
# render optimized mesh as normal [-1,1]
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(batch_smpl_verts, batch_smpl_faces)
with torch.no_grad():
in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor)
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
else:
# smpl optimization
loop_smpl = tqdm(range(args.loop_smpl))
for i in loop_smpl:
per_loop_lst = []
optimizer_smpl.zero_grad()
N_body, N_pose = optimed_pose.shape[:2]
# 6d_rot to rot_mat
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).view(N_body, 1, 3, 3)
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).view(N_body, N_pose, 3, 3)
smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model(
shape_params=optimed_betas,
expression_params=tensor2variable(data["exp"], device),
body_pose=optimed_pose_mat,
global_pose=optimed_orient_mat,
jaw_pose=tensor2variable(data["jaw_pose"], device),
left_hand_pose=tensor2variable(data["left_hand_pose"], device),
right_hand_pose=tensor2variable(data["right_hand_pose"], device),
)
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor([1.0, 1.0, -1.0]).to(device)
# landmark errors
smpl_joints_3d = (smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0) * 0.5
in_tensor["smpl_joint"] = smpl_joints[:, dataset.smpl_data.smpl_joint_ids_24_pixie, :]
ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device)
ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device)
smpl_lmks = smpl_joints_3d[:, SMPLX_object.ghum_smpl_pairs[:, 1], :2]
# render optimized mesh as normal [-1,1]
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
smpl_verts * torch.tensor([1.0, -1.0, -1.0]).to(device),
in_tensor["smpl_faces"],
)
T_mask_F, T_mask_B = dataset.render.get_image(type="mask")
with torch.no_grad():
in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor)
diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"])
diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"])
# silhouette loss
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)
gt_arr = in_tensor["mask"].repeat(1, 1, 2)
diff_S = torch.abs(smpl_arr - gt_arr)
losses["silhouette"]["value"] = diff_S.mean()
# large cloth_overlap --> big difference between body and cloth mask
# for loose clothing, reply more on landmarks instead of silhouette+normal loss
cloth_overlap = diff_S.sum(dim=[1, 2]) / gt_arr.sum(dim=[1, 2])
cloth_overlap_flag = cloth_overlap > cfg.cloth_overlap_thres
losses["joint"]["weight"] = [50.0 if flag else 5.0 for flag in cloth_overlap_flag]
# small body_overlap --> large occlusion or out-of-frame
# for highly occluded body, reply only on high-confidence landmarks, no silhouette+normal loss
# BUG: PyTorch3D silhouette renderer generates dilated mask
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
smpl_arr_fake = torch.cat(
[in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()],
dim=-1)
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
body_overlap_flag = body_overlap < cfg.body_overlap_thres
losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] +
diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
occluded_idx = torch.where(body_overlap_flag)[0]
ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95
losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) * ghum_conf).mean(dim=1)
# Weighted sum of the losses
smpl_loss = 0.0
pbar_desc = "Body Fitting --- "
for k in ["normal", "silhouette", "joint"]:
per_loop_loss = (losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device)).mean()
pbar_desc += f"{k}: {per_loop_loss:.3f} | "
smpl_loss += per_loop_loss
pbar_desc += f"Total: {smpl_loss:.3f}"
loose_str = ''.join([str(j) for j in cloth_overlap_flag.int().tolist()])
occlude_str = ''.join([str(j) for j in body_overlap_flag.int().tolist()])
pbar_desc += colored(f"| loose:{loose_str}, occluded:{occlude_str}", "yellow")
loop_smpl.set_description(pbar_desc)
# save intermediate results / vis_freq and final_step
if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1):
per_loop_lst.extend([
in_tensor["image"],
in_tensor["T_normal_F"],
in_tensor["normal_F"],
diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
])
per_loop_lst.extend([
in_tensor["image"],
in_tensor["T_normal_B"],
in_tensor["normal_B"],
diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
])
per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl"))
smpl_loss.backward()
optimizer_smpl.step()
scheduler_smpl.step(smpl_loss)
in_tensor["smpl_verts"] = smpl_verts * torch.tensor([1.0, 1.0, -1.0]).to(device)
in_tensor["smpl_faces"] = in_tensor["smpl_faces"][:, :, [0, 2, 1]]
per_data_lst[-1].save(osp.join(args.out_dir, cfg.name, f"png/{data['name']}_smpl.png"))
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
torchvision.utils.save_image(
torch.cat([
data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
(in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
],
dim=3), img_crop_path)
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
torchvision.utils.save_image(
torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
smpl_obj_lst = []
for idx in range(N_body):
smpl_obj = trimesh.Trimesh(
in_tensor["smpl_verts"].detach().cpu()[idx] * torch.tensor([1.0, -1.0, 1.0]),
in_tensor["smpl_faces"].detach().cpu()[0][:, [0, 2, 1]],
process=False,
maintains_order=True,
)
smpl_obj_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
if not osp.exists(smpl_obj_path):
smpl_obj.export(smpl_obj_path)
smpl_info = {
"betas": optimed_betas[idx].detach().cpu().unsqueeze(0),
"body_pose": rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()).cpu().unsqueeze(0),
"global_orient": rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()).cpu().unsqueeze(0),
"transl": optimed_trans[idx].detach().cpu(),
"expression": data["exp"][idx].cpu().unsqueeze(0),
"jaw_pose": rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
"left_hand_pose": rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]).cpu().unsqueeze(0),
"right_hand_pose": rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]).cpu().unsqueeze(0),
"scale": data["scale"][idx].cpu(),
}
np.save(
smpl_obj_path.replace(".obj", ".npy"),
smpl_info,
allow_pickle=True,
)
smpl_obj_lst.append(smpl_obj)
del optimizer_smpl
del optimed_betas
del optimed_orient
del optimed_pose
del optimed_trans
torch.cuda.empty_cache()
# ------------------------------------------------------------------------------------------------------------------
# clothing refinement
per_data_lst = []
batch_smpl_verts = in_tensor["smpl_verts"].detach() * torch.tensor([1.0, -1.0, 1.0], device=device)
batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(batch_smpl_verts, batch_smpl_faces)
per_loop_lst = []
in_tensor["BNI_verts"] = []
in_tensor["BNI_faces"] = []
in_tensor["body_verts"] = []
in_tensor["body_faces"] = []
for idx in range(N_body):
final_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_full.obj"
side_mesh = smpl_obj_lst[idx].copy()
face_mesh = smpl_obj_lst[idx].copy()
hand_mesh = smpl_obj_lst[idx].copy()
smplx_mesh = smpl_obj_lst[idx].copy()
# save normals, depths and masks
BNI_dict = save_normal_tensor(
in_tensor,
idx,
osp.join(args.out_dir, cfg.name, f"BNI/{data['name']}_{idx}"),
cfg.bni.thickness,
)
# BNI process
BNI_object = BNI(
dir_path=osp.join(args.out_dir, cfg.name, "BNI"),
name=data["name"],
BNI_dict=BNI_dict,
cfg=cfg.bni,
device=device)
BNI_object.extract_surface(False)
in_tensor["body_verts"].append(torch.tensor(smpl_obj_lst[idx].vertices).float())
in_tensor["body_faces"].append(torch.tensor(smpl_obj_lst[idx].faces).long())
# requires shape completion when low overlap
# replace SMPL by completed mesh as side_mesh
if cfg.bni.use_ifnet:
side_mesh_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_IF.obj"
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
# mesh completion via IF-net
in_tensor.update(
dataset.depth_to_voxel({
"depth_F": BNI_object.F_depth.unsqueeze(0),
"depth_B": BNI_object.B_depth.unsqueeze(0)
}))
occupancies = VoxelGrid.from_mesh(
side_mesh, cfg.vol_res, loc=[
0,
] * 3, scale=2.0).data.transpose(2, 1, 0)
occupancies = np.flip(occupancies, axis=1)
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
with torch.no_grad():
sdf = ifnet_model.reconEngine(netG=ifnet_model.netG, batch=in_tensor)
verts_IF, faces_IF = ifnet_model.reconEngine.export_mesh(sdf)
if ifnet_model.clean_mesh_flag:
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
side_mesh = trimesh.Trimesh(verts_IF, faces_IF)
side_mesh = remesh(side_mesh, side_mesh_path)
else:
side_mesh = apply_vertex_mask(
side_mesh,
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
SMPLX_object.eyeball_vertex_mask).eq(0).float(),
)
#register side_mesh to BNI surfaces
side_mesh = Meshes(
verts=[torch.tensor(side_mesh.vertices).float()],
faces=[torch.tensor(side_mesh.faces).long()],
).to(device)
sm = SubdivideMeshes(side_mesh)
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
side_faces = torch.tensor(side_mesh.faces).long().to(device)
# Possion Fusion between SMPLX and BNI
# 1. keep the faces invisible to front+back cameras
# 2. keep the front-FLAME+MANO faces
# 3. remove eyeball faces
# export intermediate meshes
BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
full_lst = []
if "face" in cfg.bni.use_smpl:
# only face
face_mesh = apply_vertex_mask(face_mesh, SMPLX_object.front_flame_vertex_mask)
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
# remove face neighbor triangles
BNI_object.F_B_trimesh = part_removal(
BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
full_lst += [face_mesh]
if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]):
hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0],)
if data['hands_visibility'][idx][0]:
hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0)
if data['hands_visibility'][idx][1]:
hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0)
# only hands
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
# remove hand neighbor triangles
BNI_object.F_B_trimesh = part_removal(
BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
full_lst += [hand_mesh]
full_lst += [BNI_object.F_B_trimesh]
# initial side_mesh could be SMPLX or IF-net
side_mesh = part_removal(side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False)
full_lst += [side_mesh]
# # export intermediate meshes
BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj")
if cfg.bni.use_poisson:
final_mesh = poisson(
sum(full_lst),
final_path,
cfg.bni.poisson_depth,
)
else:
final_mesh = sum(full_lst)
final_mesh.export(final_path)
dataset.render.load_meshes(final_mesh.vertices, final_mesh.faces)
rotate_recon_lst = dataset.render.get_image(cam_type="four")
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
# for video rendering
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())
if len(per_loop_lst) > 0:
per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=5, type="cloth"))
per_data_lst[-1].save(osp.join(args.out_dir, cfg.name, f"png/{data['name']}_cloth.png"))
os.makedirs(osp.join(args.out_dir, cfg.name, "vid"), exist_ok=True)
in_tensor["uncrop_param"] = data["uncrop_param"]
in_tensor["img_raw"] = data["img_raw"]
torch.save(in_tensor, osp.join(args.out_dir, cfg.name, f"vid/{data['name']}_in_tensor.pt"))