# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de import warnings import logging warnings.filterwarnings("ignore") logging.getLogger("lightning").setLevel(logging.ERROR) logging.getLogger("trimesh").setLevel(logging.ERROR) import torch, torchvision import trimesh import numpy as np import argparse import os from termcolor import colored from tqdm.auto import tqdm from apps.Normal import Normal from apps.IFGeo import IFGeo from pytorch3d.ops import SubdivideMeshes from lib.common.config import cfg from lib.common.render import query_color from lib.common.train_util import init_loss, Format from lib.common.imutils import blend_rgb_norm from lib.common.BNI import BNI from lib.common.BNI_utils import save_normal_tensor from lib.dataset.TestDataset import TestDataset from lib.common.local_affine import register from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis from lib.dataset.mesh_util import * from lib.common.voxelize import VoxelGrid torch.backends.cudnn.benchmark = True if __name__ == "__main__": # loading cfg file parser = argparse.ArgumentParser() parser.add_argument("-gpu", "--gpu_device", type=int, default=0) parser.add_argument("-loop_smpl", "--loop_smpl", type=int, default=50) parser.add_argument("-patience", "--patience", type=int, default=5) parser.add_argument("-vis_freq", "--vis_freq", type=int, default=1000) parser.add_argument("-multi", action="store_false") parser.add_argument("-in_dir", "--in_dir", type=str, default="./examples") parser.add_argument("-out_dir", "--out_dir", type=str, default="./results") parser.add_argument("-seg_dir", "--seg_dir", type=str, default=None) parser.add_argument("-cfg", "--config", type=str, default="./configs/econ.yaml") args = parser.parse_args() # cfg read and merge cfg.merge_from_file(args.config) cfg.merge_from_file("./lib/pymafx/configs/pymafx_config.yaml") device = torch.device(f"cuda:{args.gpu_device}") # setting for testing on in-the-wild images cfg_show_list = [ "test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True, "batch_size", 1 ] cfg.merge_from_list(cfg_show_list) cfg.freeze() # load normal model normal_net = Normal.load_from_checkpoint( cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False ) normal_net = normal_net.to(device) normal_net.netG.eval() print( colored( f"Resume Normal Estimator from {Format.start} {cfg.normal_path} {Format.end}", "green" ) ) # SMPLX object SMPLX_object = SMPLX() dataset_param = { "image_dir": args.in_dir, "seg_dir": args.seg_dir, "use_seg": True, # w/ or w/o segmentation "hps_type": cfg.bni.hps_type, # pymafx/pixie "vol_res": cfg.vol_res, "single": args.multi, } if cfg.bni.use_ifnet: # load IFGeo model ifnet = IFGeo.load_from_checkpoint( cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False ) ifnet = ifnet.to(device) ifnet.netG.eval() print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green")) print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green")) else: print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green")) dataset = TestDataset(dataset_param, device) print(colored(f"Dataset Size: {len(dataset)}", "green")) pbar = tqdm(dataset) for data in pbar: losses = init_loss() pbar.set_description(f"{data['name']}") # final results rendered as image (PNG) # 1. Render the final fitted SMPL (xxx_smpl.png) # 2. Render the final reconstructed clothed human (xxx_cloth.png) # 3. Blend the original image with predicted cloth normal (xxx_overlap.png) # 4. Blend the cropped image with predicted cloth normal (xxx_crop.png) os.makedirs(osp.join(args.out_dir, cfg.name, "png"), exist_ok=True) # final reconstruction meshes (OBJ) # 1. SMPL mesh (xxx_smpl_xx.obj) # 2. SMPL params (xxx_smpl.npy) # 3. d-BiNI surfaces (xxx_BNI.obj) # 4. seperate face/hand mesh (xxx_hand/face.obj) # 5. full shape impainted by IF-Nets+ after remeshing (xxx_IF.obj) # 6. sideded or occluded parts (xxx_side.obj) # 7. final reconstructed clothed human (xxx_full.obj) os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True) in_tensor = { "smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask": data["img_mask"].to(device) } # The optimizer and variables optimed_pose = data["body_pose"].requires_grad_(True) optimed_trans = data["trans"].requires_grad_(True) optimed_betas = data["betas"].requires_grad_(True) optimed_orient = data["global_orient"].requires_grad_(True) optimizer_smpl = torch.optim.Adam( [optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True ) scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_smpl, mode="min", factor=0.5, verbose=0, min_lr=1e-5, patience=args.patience, ) # [result_loop_1, result_loop_2, ...] per_data_lst = [] N_body, N_pose = optimed_pose.shape[:2] smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj" # remove this line if you change the loop_smpl and obtain different SMPL-X fits if osp.exists(smpl_path): smpl_verts_lst = [] smpl_faces_lst = [] for idx in range(N_body): smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj" smpl_mesh = trimesh.load(smpl_obj) smpl_verts = torch.tensor(smpl_mesh.vertices).to(device).float() smpl_faces = torch.tensor(smpl_mesh.faces).to(device).long() smpl_verts_lst.append(smpl_verts) smpl_faces_lst.append(smpl_faces) batch_smpl_verts = torch.stack(smpl_verts_lst) batch_smpl_faces = torch.stack(smpl_faces_lst) # render optimized mesh as normal [-1,1] in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( batch_smpl_verts, batch_smpl_faces ) with torch.no_grad(): in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor) in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device) in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]] else: # smpl optimization loop_smpl = tqdm(range(args.loop_smpl)) for i in loop_smpl: per_loop_lst = [] optimizer_smpl.zero_grad() N_body, N_pose = optimed_pose.shape[:2] # 6d_rot to rot_mat optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).view(N_body, 1, 3, 3) optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).view(N_body, N_pose, 3, 3) smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model( shape_params=optimed_betas, expression_params=tensor2variable(data["exp"], device), body_pose=optimed_pose_mat, global_pose=optimed_orient_mat, jaw_pose=tensor2variable(data["jaw_pose"], device), left_hand_pose=tensor2variable(data["left_hand_pose"], device), right_hand_pose=tensor2variable(data["right_hand_pose"], device), ) smpl_verts = (smpl_verts + optimed_trans) * data["scale"] smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor( [1.0, 1.0, -1.0] ).to(device) # landmark errors smpl_joints_3d = ( smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0 ) * 0.5 in_tensor["smpl_joint"] = smpl_joints[:, dataset.smpl_data.smpl_joint_ids_24_pixie, :] ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device) ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device) smpl_lmks = smpl_joints_3d[:, SMPLX_object.ghum_smpl_pairs[:, 1], :2] # render optimized mesh as normal [-1,1] in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( smpl_verts * torch.tensor([1.0, -1.0, -1.0]).to(device), in_tensor["smpl_faces"], ) T_mask_F, T_mask_B = dataset.render.get_image(type="mask") with torch.no_grad(): in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor) diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"]) diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"]) # silhouette loss smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1) gt_arr = in_tensor["mask"].repeat(1, 1, 2) diff_S = torch.abs(smpl_arr - gt_arr) losses["silhouette"]["value"] = diff_S.mean() # large cloth_overlap --> big difference between body and cloth mask # for loose clothing, reply more on landmarks instead of silhouette+normal loss cloth_overlap = diff_S.sum(dim=[1, 2]) / gt_arr.sum(dim=[1, 2]) cloth_overlap_flag = cloth_overlap > cfg.cloth_overlap_thres losses["joint"]["weight"] = [50.0 if flag else 5.0 for flag in cloth_overlap_flag] # small body_overlap --> large occlusion or out-of-frame # for highly occluded body, reply only on high-confidence landmarks, no silhouette+normal loss # BUG: PyTorch3D silhouette renderer generates dilated mask bg_value = in_tensor["T_normal_F"][0, 0, 0, 0] smpl_arr_fake = torch.cat( [ in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float() ], dim=-1 ) body_overlap = (gt_arr * smpl_arr_fake.gt(0.0) ).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2]) body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1) body_overlap_flag = body_overlap < cfg.body_overlap_thres losses["normal"]["value"] = ( diff_F_smpl * body_overlap_mask[..., :512] + diff_B_smpl * body_overlap_mask[..., 512:] ).mean() / 2.0 losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag] occluded_idx = torch.where(body_overlap_flag)[0] ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95 losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) * ghum_conf).mean(dim=1) # Weighted sum of the losses smpl_loss = 0.0 pbar_desc = "Body Fitting -- " for k in ["normal", "silhouette", "joint"]: per_loop_loss = ( losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device) ).mean() pbar_desc += f"{k}: {per_loop_loss:.3f} | " smpl_loss += per_loop_loss pbar_desc += f"Total: {smpl_loss:.3f}" loose_str = ''.join([str(j) for j in cloth_overlap_flag.int().tolist()]) occlude_str = ''.join([str(j) for j in body_overlap_flag.int().tolist()]) pbar_desc += colored(f"| loose:{loose_str}, occluded:{occlude_str}", "yellow") loop_smpl.set_description(pbar_desc) # save intermediate results / vis_freq and final_step if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1): per_loop_lst.extend( [ in_tensor["image"], in_tensor["T_normal_F"], in_tensor["normal_F"], diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1), ] ) per_loop_lst.extend( [ in_tensor["image"], in_tensor["T_normal_B"], in_tensor["normal_B"], diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1), ] ) per_data_lst.append( get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl") ) smpl_loss.backward() optimizer_smpl.step() scheduler_smpl.step(smpl_loss) in_tensor["smpl_verts"] = smpl_verts * torch.tensor([1.0, 1.0, -1.0]).to(device) in_tensor["smpl_faces"] = in_tensor["smpl_faces"][:, :, [0, 2, 1]] per_data_lst[-1].save(osp.join(args.out_dir, cfg.name, f"png/{data['name']}_smpl.png")) img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png") torchvision.utils.save_image( torch.cat( [ data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5 ], dim=3 ), img_crop_path ) rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data) rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data) img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png") torchvision.utils.save_image( torch.cat([data["img_raw"], rgb_norm_F, rgb_norm_B], dim=-1) / 255., img_overlap_path ) smpl_obj_lst = [] for idx in range(N_body): smpl_obj = trimesh.Trimesh( in_tensor["smpl_verts"].detach().cpu()[idx] * torch.tensor([1.0, -1.0, 1.0]), in_tensor["smpl_faces"].detach().cpu()[0][:, [0, 2, 1]], process=False, maintains_order=True, ) smpl_obj_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj" if not osp.exists(smpl_obj_path): smpl_obj.export(smpl_obj_path) smpl_info = { "betas": optimed_betas[idx].detach().cpu().unsqueeze(0), "body_pose": rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach() ).cpu().unsqueeze(0), "global_orient": rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach() ).cpu().unsqueeze(0), "transl": optimed_trans[idx].detach().cpu(), "expression": data["exp"][idx].cpu().unsqueeze(0), "jaw_pose": rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0), "left_hand_pose": rotation_matrix_to_angle_axis(data["left_hand_pose"][idx] ).cpu().unsqueeze(0), "right_hand_pose": rotation_matrix_to_angle_axis(data["right_hand_pose"][idx] ).cpu().unsqueeze(0), "scale": data["scale"][idx].cpu(), } np.save( smpl_obj_path.replace(".obj", ".npy"), smpl_info, allow_pickle=True, ) smpl_obj_lst.append(smpl_obj) del optimizer_smpl del optimed_betas del optimed_orient del optimed_pose del optimed_trans torch.cuda.empty_cache() # ------------------------------------------------------------------------------------------------------------------ # clothing refinement per_data_lst = [] batch_smpl_verts = in_tensor["smpl_verts"].detach( ) * torch.tensor([1.0, -1.0, 1.0], device=device) batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]] in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth( batch_smpl_verts, batch_smpl_faces ) per_loop_lst = [] in_tensor["BNI_verts"] = [] in_tensor["BNI_faces"] = [] in_tensor["body_verts"] = [] in_tensor["body_faces"] = [] for idx in range(N_body): final_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_full.obj" side_mesh = smpl_obj_lst[idx].copy() face_mesh = smpl_obj_lst[idx].copy() hand_mesh = smpl_obj_lst[idx].copy() smplx_mesh = smpl_obj_lst[idx].copy() # save normals, depths and masks BNI_dict = save_normal_tensor( in_tensor, idx, osp.join(args.out_dir, cfg.name, f"BNI/{data['name']}_{idx}"), cfg.bni.thickness, ) # BNI process BNI_object = BNI( dir_path=osp.join(args.out_dir, cfg.name, "BNI"), name=data["name"], BNI_dict=BNI_dict, cfg=cfg.bni, device=device ) BNI_object.extract_surface(False) in_tensor["body_verts"].append(torch.tensor(smpl_obj_lst[idx].vertices).float()) in_tensor["body_faces"].append(torch.tensor(smpl_obj_lst[idx].faces).long()) # requires shape completion when low overlap # replace SMPL by completed mesh as side_mesh if cfg.bni.use_ifnet: side_mesh_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_IF.obj" side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask) # mesh completion via IF-net in_tensor.update( dataset.depth_to_voxel( { "depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B": BNI_object.B_depth.unsqueeze(0) } ) ) occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[ 0, ] * 3, scale=2.0).data.transpose(2, 1, 0) occupancies = np.flip(occupancies, axis=1) in_tensor["body_voxels"] = torch.tensor(occupancies.copy() ).float().unsqueeze(0).to(device) with torch.no_grad(): sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor) verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf) if ifnet.clean_mesh_flag: verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF) side_mesh = trimesh.Trimesh(verts_IF, faces_IF) side_mesh = remesh_laplacian(side_mesh, side_mesh_path) else: side_mesh = apply_vertex_mask( side_mesh, ( SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + SMPLX_object.eyeball_vertex_mask ).eq(0).float(), ) #register side_mesh to BNI surfaces side_mesh = Meshes( verts=[torch.tensor(side_mesh.vertices).float()], faces=[torch.tensor(side_mesh.faces).long()], ).to(device) sm = SubdivideMeshes(side_mesh) side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device) side_verts = torch.tensor(side_mesh.vertices).float().to(device) side_faces = torch.tensor(side_mesh.faces).long().to(device) # Possion Fusion between SMPLX and BNI # 1. keep the faces invisible to front+back cameras # 2. keep the front-FLAME+MANO faces # 3. remove eyeball faces # export intermediate meshes BNI_object.F_B_trimesh.export( f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj" ) full_lst = [] if "face" in cfg.bni.use_smpl: # only face face_mesh = apply_vertex_mask(face_mesh, SMPLX_object.front_flame_vertex_mask) face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness]) # remove face neighbor triangles BNI_object.F_B_trimesh = part_removal( BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face" ) side_mesh = part_removal( side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face" ) face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj") full_lst += [face_mesh] if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]): hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0], ) if data['hands_visibility'][idx][0]: hand_mask.index_fill_( 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0 ) if data['hands_visibility'][idx][1]: hand_mask.index_fill_( 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0 ) # only hands hand_mesh = apply_vertex_mask(hand_mesh, hand_mask) # remove hand neighbor triangles BNI_object.F_B_trimesh = part_removal( BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand" ) side_mesh = part_removal( side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand" ) hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj") full_lst += [hand_mesh] full_lst += [BNI_object.F_B_trimesh] # initial side_mesh could be SMPLX or IF-net side_mesh = part_removal( side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False ) full_lst += [side_mesh] # # export intermediate meshes BNI_object.F_B_trimesh.export( f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj" ) side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj") if cfg.bni.use_poisson: final_mesh = poisson( sum(full_lst), final_path, cfg.bni.poisson_depth, ) else: final_mesh = sum(full_lst) final_mesh.export(final_path) dataset.render.load_meshes(final_mesh.vertices, final_mesh.faces) rotate_recon_lst = dataset.render.get_image(cam_type="four") per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst) if cfg.bni.texture_src == 'image': # coloring the final mesh (front: RGB pixels, back: normal colors) final_colors = query_color( torch.tensor(final_mesh.vertices).float(), torch.tensor(final_mesh.faces).long(), in_tensor["image"][idx:idx + 1], device=device, ) final_mesh.visual.vertex_colors = final_colors final_mesh.export(final_path) elif cfg.bni.texture_src == 'SD': # !TODO: add texture from Stable Diffusion pass # for video rendering in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float()) in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long()) if len(per_loop_lst) > 0: per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=5, type="cloth")) per_data_lst[-1].save(osp.join(args.out_dir, cfg.name, f"png/{data['name']}_cloth.png")) os.makedirs(osp.join(args.out_dir, cfg.name, "vid"), exist_ok=True) in_tensor["uncrop_param"] = data["uncrop_param"] in_tensor["img_raw"] = data["img_raw"] torch.save(in_tensor, osp.join(args.out_dir, cfg.name, f"vid/{data['name']}_in_tensor.pt"))