Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: ps-license@tuebingen.mpg.de | |
import warnings | |
import logging | |
warnings.filterwarnings("ignore") | |
logging.getLogger("lightning").setLevel(logging.ERROR) | |
logging.getLogger("trimesh").setLevel(logging.ERROR) | |
import torch, torchvision | |
import trimesh | |
import numpy as np | |
import argparse | |
import os | |
from termcolor import colored | |
from tqdm.auto import tqdm | |
from apps.Normal import Normal | |
from apps.IFGeo import IFGeo | |
from pytorch3d.ops import SubdivideMeshes | |
from lib.common.config import cfg | |
from lib.common.render import query_color | |
from lib.common.train_util import init_loss, Format | |
from lib.common.imutils import blend_rgb_norm | |
from lib.common.BNI import BNI | |
from lib.common.BNI_utils import save_normal_tensor | |
from lib.dataset.TestDataset import TestDataset | |
from lib.common.local_affine import register | |
from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis | |
from lib.dataset.mesh_util import * | |
from lib.common.voxelize import VoxelGrid | |
torch.backends.cudnn.benchmark = True | |
if __name__ == "__main__": | |
# loading cfg file | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-gpu", "--gpu_device", type=int, default=0) | |
parser.add_argument("-loop_smpl", "--loop_smpl", type=int, default=50) | |
parser.add_argument("-patience", "--patience", type=int, default=5) | |
parser.add_argument("-in_dir", "--in_dir", type=str, default="./examples") | |
parser.add_argument("-out_dir", "--out_dir", type=str, default="./results") | |
parser.add_argument("-seg_dir", "--seg_dir", type=str, default=None) | |
parser.add_argument("-cfg", "--config", type=str, default="./configs/econ.yaml") | |
parser.add_argument("-multi", action="store_false") | |
parser.add_argument("-novis", action="store_true") | |
args = parser.parse_args() | |
# cfg read and merge | |
cfg.merge_from_file(args.config) | |
cfg.merge_from_file("./lib/pymafx/configs/pymafx_config.yaml") | |
device = torch.device(f"cuda:{args.gpu_device}") | |
# setting for testing on in-the-wild images | |
cfg_show_list = [ | |
"test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True, | |
"batch_size", 1 | |
] | |
cfg.merge_from_list(cfg_show_list) | |
cfg.freeze() | |
# load normal model | |
normal_net = Normal.load_from_checkpoint( | |
cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False | |
) | |
normal_net = normal_net.to(device) | |
normal_net.netG.eval() | |
print( | |
colored( | |
f"Resume Normal Estimator from {Format.start} {cfg.normal_path} {Format.end}", "green" | |
) | |
) | |
# SMPLX object | |
SMPLX_object = SMPLX() | |
dataset_param = { | |
"image_dir": args.in_dir, | |
"seg_dir": args.seg_dir, | |
"use_seg": True, # w/ or w/o segmentation | |
"hps_type": cfg.bni.hps_type, # pymafx/pixie | |
"vol_res": cfg.vol_res, | |
"single": args.multi, | |
} | |
if cfg.bni.use_ifnet: | |
# load IFGeo model | |
ifnet = IFGeo.load_from_checkpoint( | |
cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False | |
) | |
ifnet = ifnet.to(device) | |
ifnet.netG.eval() | |
print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green")) | |
print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green")) | |
else: | |
print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green")) | |
dataset = TestDataset(dataset_param, device) | |
print(colored(f"Dataset Size: {len(dataset)}", "green")) | |
pbar = tqdm(dataset) | |
for data in pbar: | |
losses = init_loss() | |
pbar.set_description(f"{data['name']}") | |
# final results rendered as image (PNG) | |
# 1. Render the final fitted SMPL (xxx_smpl.png) | |
# 2. Render the final reconstructed clothed human (xxx_cloth.png) | |
# 3. Blend the original image with predicted cloth normal (xxx_overlap.png) | |
# 4. Blend the cropped image with predicted cloth normal (xxx_crop.png) | |
os.makedirs(osp.join(args.out_dir, cfg.name, "png"), exist_ok=True) | |
# final reconstruction meshes (OBJ) | |
# 1. SMPL mesh (xxx_smpl_xx.obj) | |
# 2. SMPL params (xxx_smpl.npy) | |
# 3. d-BiNI surfaces (xxx_BNI.obj) | |
# 4. seperate face/hand mesh (xxx_hand/face.obj) | |
# 5. full shape impainted by IF-Nets+ after remeshing (xxx_IF.obj) | |
# 6. sideded or occluded parts (xxx_side.obj) | |
# 7. final reconstructed clothed human (xxx_full.obj) | |
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True) | |
in_tensor = { | |
"smpl_faces": data["smpl_faces"], | |
"image": data["img_icon"].to(device), | |
"mask": data["img_mask"].to(device) | |
} | |
# The optimizer and variables | |
optimed_pose = data["body_pose"].requires_grad_(True) | |
optimed_trans = data["trans"].requires_grad_(True) | |
optimed_betas = data["betas"].requires_grad_(True) | |
optimed_orient = data["global_orient"].requires_grad_(True) | |
optimizer_smpl = torch.optim.Adam( | |
[optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True | |
) | |
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer_smpl, | |
mode="min", | |
factor=0.5, | |
verbose=0, | |
min_lr=1e-5, | |
patience=args.patience, | |
) | |
# [result_loop_1, result_loop_2, ...] | |
per_data_lst = [] | |
N_body, N_pose = optimed_pose.shape[:2] | |
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj" | |
# remove this line if you change the loop_smpl and obtain different SMPL-X fits | |
if osp.exists(smpl_path): | |
smpl_verts_lst = [] | |
smpl_faces_lst = [] | |
for idx in range(N_body): | |
smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj" | |
smpl_mesh = trimesh.load(smpl_obj) | |
smpl_verts = torch.tensor(smpl_mesh.vertices).to(device).float() | |
smpl_faces = torch.tensor(smpl_mesh.faces).to(device).long() | |
smpl_verts_lst.append(smpl_verts) | |
smpl_faces_lst.append(smpl_faces) | |
batch_smpl_verts = torch.stack(smpl_verts_lst) | |
batch_smpl_faces = torch.stack(smpl_faces_lst) | |
# render optimized mesh as normal [-1,1] | |
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( | |
batch_smpl_verts, batch_smpl_faces | |
) | |
with torch.no_grad(): | |
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor) | |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device) | |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]] | |
else: | |
# smpl optimization | |
loop_smpl = tqdm(range(args.loop_smpl)) | |
for i in loop_smpl: | |
per_loop_lst = [] | |
optimizer_smpl.zero_grad() | |
N_body, N_pose = optimed_pose.shape[:2] | |
# 6d_rot to rot_mat | |
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, | |
6)).view(N_body, 1, 3, 3) | |
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, | |
6)).view(N_body, N_pose, 3, 3) | |
smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model( | |
shape_params=optimed_betas, | |
expression_params=tensor2variable(data["exp"], device), | |
body_pose=optimed_pose_mat, | |
global_pose=optimed_orient_mat, | |
jaw_pose=tensor2variable(data["jaw_pose"], device), | |
left_hand_pose=tensor2variable(data["left_hand_pose"], device), | |
right_hand_pose=tensor2variable(data["right_hand_pose"], device), | |
) | |
smpl_verts = (smpl_verts + optimed_trans) * data["scale"] | |
smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor( | |
[1.0, 1.0, -1.0] | |
).to(device) | |
# landmark errors | |
smpl_joints_3d = ( | |
smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0 | |
) * 0.5 | |
in_tensor["smpl_joint"] = smpl_joints[:, | |
dataset.smpl_data.smpl_joint_ids_24_pixie, :] | |
ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device) | |
ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device) | |
smpl_lmks = smpl_joints_3d[:, SMPLX_object.ghum_smpl_pairs[:, 1], :2] | |
# render optimized mesh as normal [-1,1] | |
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( | |
smpl_verts * torch.tensor([1.0, -1.0, -1.0]).to(device), | |
in_tensor["smpl_faces"], | |
) | |
T_mask_F, T_mask_B = dataset.render.get_image(type="mask") | |
with torch.no_grad(): | |
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor) | |
diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"]) | |
diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"]) | |
# silhouette loss | |
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1) | |
gt_arr = in_tensor["mask"].repeat(1, 1, 2) | |
diff_S = torch.abs(smpl_arr - gt_arr) | |
losses["silhouette"]["value"] = diff_S.mean() | |
# large cloth_overlap --> big difference between body and cloth mask | |
# for loose clothing, reply more on landmarks instead of silhouette+normal loss | |
cloth_overlap = diff_S.sum(dim=[1, 2]) / gt_arr.sum(dim=[1, 2]) | |
cloth_overlap_flag = cloth_overlap > cfg.cloth_overlap_thres | |
losses["joint"]["weight"] = [50.0 if flag else 5.0 for flag in cloth_overlap_flag] | |
# small body_overlap --> large occlusion or out-of-frame | |
# for highly occluded body, reply only on high-confidence landmarks, no silhouette+normal loss | |
# BUG: PyTorch3D silhouette renderer generates dilated mask | |
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0] | |
smpl_arr_fake = torch.cat( | |
[ | |
in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), | |
in_tensor["T_normal_B"][:, 0].ne(bg_value).float() | |
], | |
dim=-1 | |
) | |
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0) | |
).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2]) | |
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1) | |
body_overlap_flag = body_overlap < cfg.body_overlap_thres | |
losses["normal"]["value"] = ( | |
diff_F_smpl * body_overlap_mask[..., :512] + | |
diff_B_smpl * body_overlap_mask[..., 512:] | |
).mean() / 2.0 | |
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag] | |
occluded_idx = torch.where(body_overlap_flag)[0] | |
ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95 | |
losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) * | |
ghum_conf).mean(dim=1) | |
# Weighted sum of the losses | |
smpl_loss = 0.0 | |
pbar_desc = "Body Fitting -- " | |
for k in ["normal", "silhouette", "joint"]: | |
per_loop_loss = ( | |
losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device) | |
).mean() | |
pbar_desc += f"{k}: {per_loop_loss:.3f} | " | |
smpl_loss += per_loop_loss | |
pbar_desc += f"Total: {smpl_loss:.3f}" | |
loose_str = ''.join([str(j) for j in cloth_overlap_flag.int().tolist()]) | |
occlude_str = ''.join([str(j) for j in body_overlap_flag.int().tolist()]) | |
pbar_desc += colored(f"| loose:{loose_str}, occluded:{occlude_str}", "yellow") | |
loop_smpl.set_description(pbar_desc) | |
# save intermediate results | |
if (i == args.loop_smpl - 1) and (not args.novis): | |
per_loop_lst.extend( | |
[ | |
in_tensor["image"], | |
in_tensor["T_normal_F"], | |
in_tensor["normal_F"], | |
diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1), | |
] | |
) | |
per_loop_lst.extend( | |
[ | |
in_tensor["image"], | |
in_tensor["T_normal_B"], | |
in_tensor["normal_B"], | |
diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1), | |
] | |
) | |
per_data_lst.append( | |
get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl") | |
) | |
smpl_loss.backward() | |
optimizer_smpl.step() | |
scheduler_smpl.step(smpl_loss) | |
in_tensor["smpl_verts"] = smpl_verts * torch.tensor([1.0, 1.0, -1.0]).to(device) | |
in_tensor["smpl_faces"] = in_tensor["smpl_faces"][:, :, [0, 2, 1]] | |
if not args.novis: | |
per_data_lst[-1].save( | |
osp.join(args.out_dir, cfg.name, f"png/{data['name']}_smpl.png") | |
) | |
if not args.novis: | |
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png") | |
torchvision.utils.save_image( | |
torch.cat( | |
[ | |
data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, | |
(in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5 | |
], | |
dim=3 | |
), img_crop_path | |
) | |
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data) | |
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data) | |
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png") | |
torchvision.utils.save_image( | |
torch.cat([data["img_raw"], rgb_norm_F, rgb_norm_B], dim=-1) / 255., | |
img_overlap_path | |
) | |
smpl_obj_lst = [] | |
for idx in range(N_body): | |
smpl_obj = trimesh.Trimesh( | |
in_tensor["smpl_verts"].detach().cpu()[idx] * torch.tensor([1.0, -1.0, 1.0]), | |
in_tensor["smpl_faces"].detach().cpu()[0][:, [0, 2, 1]], | |
process=False, | |
maintains_order=True, | |
) | |
smpl_obj_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj" | |
if not osp.exists(smpl_obj_path): | |
smpl_obj.export(smpl_obj_path) | |
smpl_info = { | |
"betas": | |
optimed_betas[idx].detach().cpu().unsqueeze(0), | |
"body_pose": | |
rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach() | |
).cpu().unsqueeze(0), | |
"global_orient": | |
rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach() | |
).cpu().unsqueeze(0), | |
"transl": | |
optimed_trans[idx].detach().cpu(), | |
"expression": | |
data["exp"][idx].cpu().unsqueeze(0), | |
"jaw_pose": | |
rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0), | |
"left_hand_pose": | |
rotation_matrix_to_angle_axis(data["left_hand_pose"][idx] | |
).cpu().unsqueeze(0), | |
"right_hand_pose": | |
rotation_matrix_to_angle_axis(data["right_hand_pose"][idx] | |
).cpu().unsqueeze(0), | |
"scale": | |
data["scale"][idx].cpu(), | |
} | |
np.save( | |
smpl_obj_path.replace(".obj", ".npy"), | |
smpl_info, | |
allow_pickle=True, | |
) | |
smpl_obj_lst.append(smpl_obj) | |
del optimizer_smpl | |
del optimed_betas | |
del optimed_orient | |
del optimed_pose | |
del optimed_trans | |
torch.cuda.empty_cache() | |
# ------------------------------------------------------------------------------------------------------------------ | |
# clothing refinement | |
per_data_lst = [] | |
batch_smpl_verts = in_tensor["smpl_verts"].detach( | |
) * torch.tensor([1.0, -1.0, 1.0], device=device) | |
batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]] | |
in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth( | |
batch_smpl_verts, batch_smpl_faces | |
) | |
per_loop_lst = [] | |
in_tensor["BNI_verts"] = [] | |
in_tensor["BNI_faces"] = [] | |
in_tensor["body_verts"] = [] | |
in_tensor["body_faces"] = [] | |
for idx in range(N_body): | |
final_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_full.obj" | |
side_mesh = smpl_obj_lst[idx].copy() | |
face_mesh = smpl_obj_lst[idx].copy() | |
hand_mesh = smpl_obj_lst[idx].copy() | |
smplx_mesh = smpl_obj_lst[idx].copy() | |
# save normals, depths and masks | |
BNI_dict = save_normal_tensor( | |
in_tensor, | |
idx, | |
osp.join(args.out_dir, cfg.name, f"BNI/{data['name']}_{idx}"), | |
cfg.bni.thickness, | |
) | |
# BNI process | |
BNI_object = BNI( | |
dir_path=osp.join(args.out_dir, cfg.name, "BNI"), | |
name=data["name"], | |
BNI_dict=BNI_dict, | |
cfg=cfg.bni, | |
device=device | |
) | |
BNI_object.extract_surface(False) | |
in_tensor["body_verts"].append(torch.tensor(smpl_obj_lst[idx].vertices).float()) | |
in_tensor["body_faces"].append(torch.tensor(smpl_obj_lst[idx].faces).long()) | |
# requires shape completion when low overlap | |
# replace SMPL by completed mesh as side_mesh | |
if cfg.bni.use_ifnet: | |
side_mesh_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_IF.obj" | |
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask) | |
# mesh completion via IF-net | |
in_tensor.update( | |
dataset.depth_to_voxel( | |
{ | |
"depth_F": BNI_object.F_depth.unsqueeze(0), | |
"depth_B": BNI_object.B_depth.unsqueeze(0) | |
} | |
) | |
) | |
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[ | |
0, | |
] * 3, scale=2.0).data.transpose(2, 1, 0) | |
occupancies = np.flip(occupancies, axis=1) | |
in_tensor["body_voxels"] = torch.tensor(occupancies.copy() | |
).float().unsqueeze(0).to(device) | |
with torch.no_grad(): | |
sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor) | |
verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf) | |
if ifnet.clean_mesh_flag: | |
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF) | |
side_mesh = trimesh.Trimesh(verts_IF, faces_IF) | |
side_mesh = remesh_laplacian(side_mesh, side_mesh_path) | |
else: | |
side_mesh = apply_vertex_mask( | |
side_mesh, | |
( | |
SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + | |
SMPLX_object.eyeball_vertex_mask | |
).eq(0).float(), | |
) | |
#register side_mesh to BNI surfaces | |
side_mesh = Meshes( | |
verts=[torch.tensor(side_mesh.vertices).float()], | |
faces=[torch.tensor(side_mesh.faces).long()], | |
).to(device) | |
sm = SubdivideMeshes(side_mesh) | |
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device) | |
side_verts = torch.tensor(side_mesh.vertices).float().to(device) | |
side_faces = torch.tensor(side_mesh.faces).long().to(device) | |
# Possion Fusion between SMPLX and BNI | |
# 1. keep the faces invisible to front+back cameras | |
# 2. keep the front-FLAME+MANO faces | |
# 3. remove eyeball faces | |
# export intermediate meshes | |
BNI_object.F_B_trimesh.export( | |
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj" | |
) | |
full_lst = [] | |
if "face" in cfg.bni.use_smpl: | |
# only face | |
face_mesh = apply_vertex_mask(face_mesh, SMPLX_object.front_flame_vertex_mask) | |
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness]) | |
# remove face neighbor triangles | |
BNI_object.F_B_trimesh = part_removal( | |
BNI_object.F_B_trimesh, | |
face_mesh, | |
cfg.bni.face_thres, | |
device, | |
smplx_mesh, | |
region="face" | |
) | |
side_mesh = part_removal( | |
side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face" | |
) | |
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj") | |
full_lst += [face_mesh] | |
if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]): | |
hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0], ) | |
if data['hands_visibility'][idx][0]: | |
hand_mask.index_fill_( | |
0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0 | |
) | |
if data['hands_visibility'][idx][1]: | |
hand_mask.index_fill_( | |
0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0 | |
) | |
# only hands | |
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask) | |
# remove hand neighbor triangles | |
BNI_object.F_B_trimesh = part_removal( | |
BNI_object.F_B_trimesh, | |
hand_mesh, | |
cfg.bni.hand_thres, | |
device, | |
smplx_mesh, | |
region="hand" | |
) | |
side_mesh = part_removal( | |
side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand" | |
) | |
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj") | |
full_lst += [hand_mesh] | |
full_lst += [BNI_object.F_B_trimesh] | |
# initial side_mesh could be SMPLX or IF-net | |
side_mesh = part_removal( | |
side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False | |
) | |
full_lst += [side_mesh] | |
# # export intermediate meshes | |
BNI_object.F_B_trimesh.export( | |
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj" | |
) | |
side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj") | |
if cfg.bni.use_poisson: | |
final_mesh = poisson( | |
sum(full_lst), | |
final_path, | |
cfg.bni.poisson_depth, | |
) | |
else: | |
final_mesh = sum(full_lst) | |
final_mesh.export(final_path) | |
if not args.novis: | |
dataset.render.load_meshes(final_mesh.vertices, final_mesh.faces) | |
rotate_recon_lst = dataset.render.get_image(cam_type="four") | |
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst) | |
if cfg.bni.texture_src == 'image': | |
# coloring the final mesh (front: RGB pixels, back: normal colors) | |
final_colors = query_color( | |
torch.tensor(final_mesh.vertices).float(), | |
torch.tensor(final_mesh.faces).long(), | |
in_tensor["image"][idx:idx + 1], | |
device=device, | |
) | |
final_mesh.visual.vertex_colors = final_colors | |
final_mesh.export(final_path) | |
elif cfg.bni.texture_src == 'SD': | |
# !TODO: add texture from Stable Diffusion | |
pass | |
if len(per_loop_lst) > 0 and (not args.novis): | |
per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=5, type="cloth")) | |
per_data_lst[-1].save(osp.join(args.out_dir, cfg.name, f"png/{data['name']}_cloth.png")) | |
# for video rendering | |
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float()) | |
in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long()) | |
os.makedirs(osp.join(args.out_dir, cfg.name, "vid"), exist_ok=True) | |
in_tensor["uncrop_param"] = data["uncrop_param"] | |
in_tensor["img_raw"] = data["img_raw"] | |
torch.save( | |
in_tensor, osp.join(args.out_dir, cfg.name, f"vid/{data['name']}_in_tensor.pt") | |
) | |