ECON / apps /infer.py
Yuliang's picture
upgrade to Gradio 4.14.0
e0ba903
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("trimesh").setLevel(logging.ERROR)
import os
import numpy as np
import torch
import torchvision
import trimesh
from pytorch3d.ops import SubdivideMeshes
from huggingface_hub import hf_hub_download
from termcolor import colored
from tqdm import tqdm
from apps.IFGeo import IFGeo
from apps.Normal import Normal
from lib.common.BNI import BNI
from lib.common.BNI_utils import save_normal_tensor
from lib.common.config import cfg
from lib.common.imutils import blend_rgb_norm
from lib.common.local_affine import register
from lib.common.render import query_color, Render
from lib.common.train_util import Format, init_loss
from lib.common.voxelize import VoxelGrid
from lib.dataset.mesh_util import *
from lib.dataset.TestDataset import TestDataset
from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
torch.backends.cudnn.benchmark = True
def generate_video(vis_tensor_path):
in_tensor = torch.load(vis_tensor_path)
render = Render(size=512, device=torch.device("cuda:0"))
# visualize the final results in self-rotation mode
verts_lst = in_tensor["body_verts"] + in_tensor["BNI_verts"]
faces_lst = in_tensor["body_faces"] + in_tensor["BNI_faces"]
# self-rotated video
tmp_path = vis_tensor_path.replace("_in_tensor.pt", "_tmp.mp4")
out_path = vis_tensor_path.replace("_in_tensor.pt", ".mp4")
render.load_meshes(verts_lst, faces_lst)
render.get_rendered_video_multi(in_tensor, tmp_path)
os.system(f"ffmpeg -y -loglevel quiet -stats -i {tmp_path} -vcodec libx264 {out_path}")
return out_path
import sys
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
def generate_model(in_path, fitting_step=50):
sys.stdout = Logger("./output.log")
out_dir = "./results"
# cfg read and merge
cfg.merge_from_file("./configs/econ.yaml")
cfg.merge_from_file("./lib/pymafx/configs/pymafx_config.yaml")
device = torch.device(f"cuda:0")
# setting for testing on in-the-wild images
cfg_show_list = [
"test_gpus", [0], "mcube_res", 512, "clean_mesh", True, "test_mode", True, "batch_size", 1
]
cfg.merge_from_list(cfg_show_list)
cfg.freeze()
# load normal model
normal_net = Normal.load_from_checkpoint(
cfg=cfg,
checkpoint_path=hf_hub_download(
repo_id="Yuliang/ICON", use_auth_token=os.environ["ICON"], filename=cfg.normal_path
),
map_location=device,
strict=False
)
normal_net = normal_net.to(device)
normal_net.netG.eval()
print(
colored(
f"Resume Normal Estimator from : {cfg.normal_path} ", "green"
)
)
# SMPLX object
SMPLX_object = SMPLX()
dataset_param = {
"image_path": in_path,
"use_seg": True, # w/ or w/o segmentation
"hps_type": cfg.bni.hps_type, # pymafx/pixie
"vol_res": cfg.vol_res,
"single": True,
}
if cfg.bni.use_ifnet:
# load IFGeo model
ifnet = IFGeo.load_from_checkpoint(
cfg=cfg,
checkpoint_path=hf_hub_download(
repo_id="Yuliang/ICON", use_auth_token=os.environ["ICON"], filename=cfg.ifnet_path
),
map_location=device,
strict=False
)
ifnet = ifnet.to(device)
ifnet.netG.eval()
print(colored(f"Resume IF-Net+ from : {cfg.ifnet_path} ", "green"))
print(colored(f"Complete with : IF-Nets+ (Implicit) ", "green"))
else:
print(colored(f"Complete with : SMPL-X (Explicit) ", "green"))
dataset = TestDataset(dataset_param, device)
print(colored(f"Dataset Size: {len(dataset)}", "green"))
data = dataset[0]
losses = init_loss()
print(f"Subject name: {data['name']}")
# final results rendered as image (PNG)
# 1. Render the final fitted SMPL (xxx_smpl.png)
# 2. Render the final reconstructed clothed human (xxx_cloth.png)
# 3. Blend the original image with predicted cloth normal (xxx_overlap.png)
# 4. Blend the cropped image with predicted cloth normal (xxx_crop.png)
os.makedirs(osp.join(out_dir, cfg.name, "png"), exist_ok=True)
# final reconstruction meshes (OBJ)
# 1. SMPL mesh (xxx_smpl_xx.obj)
# 2. SMPL params (xxx_smpl.npy)
# 3. d-BiNI surfaces (xxx_BNI.obj)
# 4. seperate face/hand mesh (xxx_hand/face.obj)
# 5. full shape impainted by IF-Nets+ after remeshing (xxx_IF.obj)
# 6. sideded or occluded parts (xxx_side.obj)
# 7. final reconstructed clothed human (xxx_full.obj)
os.makedirs(osp.join(out_dir, cfg.name, "obj"), exist_ok=True)
in_tensor = {
"smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask":
data["img_mask"].to(device)
}
# The optimizer and variables
optimed_pose = data["body_pose"].requires_grad_(True)
optimed_trans = data["trans"].requires_grad_(True)
optimed_betas = data["betas"].requires_grad_(True)
optimed_orient = data["global_orient"].requires_grad_(True)
optimizer_smpl = torch.optim.Adam([optimed_pose, optimed_trans, optimed_betas, optimed_orient],
lr=1e-2,
amsgrad=True)
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_smpl,
mode="min",
factor=0.5,
verbose=0,
min_lr=1e-5,
patience=5,
)
# [result_loop_1, result_loop_2, ...]
per_data_lst = []
N_body, N_pose = optimed_pose.shape[:2]
smpl_path = f"{out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
# remove this line if you change the loop_smpl and obtain different SMPL-X fits
if osp.exists(smpl_path):
smpl_verts_lst = []
smpl_faces_lst = []
for idx in range(N_body):
smpl_obj = f"{out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
smpl_mesh = trimesh.load(smpl_obj)
smpl_verts = torch.tensor(smpl_mesh.vertices).to(device).float()
smpl_faces = torch.tensor(smpl_mesh.faces).to(device).long()
smpl_verts_lst.append(smpl_verts)
smpl_faces_lst.append(smpl_faces)
batch_smpl_verts = torch.stack(smpl_verts_lst)
batch_smpl_faces = torch.stack(smpl_faces_lst)
# render optimized mesh as normal [-1,1]
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
batch_smpl_verts, batch_smpl_faces
)
with torch.no_grad():
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
else:
# smpl optimization
loop_smpl = tqdm(range(fitting_step))
for i in loop_smpl:
per_loop_lst = []
optimizer_smpl.zero_grad()
N_body, N_pose = optimed_pose.shape[:2]
# 6d_rot to rot_mat
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).view(N_body, 1, 3, 3)
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).view(N_body, N_pose, 3, 3)
smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model(
shape_params=optimed_betas,
expression_params=tensor2variable(data["exp"], device),
body_pose=optimed_pose_mat,
global_pose=optimed_orient_mat,
jaw_pose=tensor2variable(data["jaw_pose"], device),
left_hand_pose=tensor2variable(data["left_hand_pose"], device),
right_hand_pose=tensor2variable(data["right_hand_pose"], device),
)
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor([
1.0, 1.0, -1.0
]).to(device)
# landmark errors
smpl_joints_3d = (
smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0
) * 0.5
in_tensor["smpl_joint"] = smpl_joints[:, dataset.smpl_data.smpl_joint_ids_24_pixie, :]
ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device)
ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device)
smpl_lmks = smpl_joints_3d[:, SMPLX_object.ghum_smpl_pairs[:, 1], :2].to(device)
# render optimized mesh as normal [-1,1]
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
smpl_verts * torch.tensor([1.0, -1.0, -1.0]).to(device),
in_tensor["smpl_faces"],
)
T_mask_F, T_mask_B = dataset.render.get_image(type="mask")
with torch.no_grad():
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"])
diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"])
# silhouette loss
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)
gt_arr = in_tensor["mask"].repeat(1, 1, 2)
diff_S = torch.abs(smpl_arr - gt_arr)
losses["silhouette"]["value"] = diff_S.mean()
# large cloth_overlap --> big difference between body and cloth mask
# for loose clothing, reply more on landmarks instead of silhouette+normal loss
cloth_overlap = diff_S.sum(dim=[1, 2]) / gt_arr.sum(dim=[1, 2])
cloth_overlap_flag = cloth_overlap > cfg.cloth_overlap_thres
losses["joint"]["weight"] = [50.0 if flag else 5.0 for flag in cloth_overlap_flag]
# small body_overlap --> large occlusion or out-of-frame
# for highly occluded body, reply only on high-confidence landmarks, no silhouette+normal loss
# BUG: PyTorch3D silhouette renderer generates dilated mask
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0].to(device)
smpl_arr_fake = torch.cat([
in_tensor["T_normal_F"][:, 0].ne(bg_value).float(),
in_tensor["T_normal_B"][:, 0].ne(bg_value).float()
],
dim=-1)
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)
).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
body_overlap_flag = body_overlap < cfg.body_overlap_thres
losses["normal"]["value"] = (
diff_F_smpl * body_overlap_mask[..., :512] +
diff_B_smpl * body_overlap_mask[..., 512:]
).mean() / 2.0
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
occluded_idx = torch.where(body_overlap_flag)[0]
ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95
losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) *
ghum_conf).mean(dim=1)
# Weighted sum of the losses
smpl_loss = 0.0
pbar_desc = "Body Fitting -- "
for k in ["normal", "silhouette", "joint"]:
per_loop_loss = (losses[k]["value"] *
torch.tensor(losses[k]["weight"]).to(device)).mean()
pbar_desc += f"{k}: {per_loop_loss:.3f} | "
smpl_loss += per_loop_loss
pbar_desc += f"Total: {smpl_loss:.3f}"
loose_str = ''.join([str(j) for j in cloth_overlap_flag.int().tolist()])
occlude_str = ''.join([str(j) for j in body_overlap_flag.int().tolist()])
pbar_desc += colored(f"| loose:{loose_str}, occluded:{occlude_str}", "yellow")
loop_smpl.set_description(pbar_desc)
print(pbar_desc)
# save intermediate results
if (i == fitting_step - 1):
per_loop_lst.extend([
in_tensor["image"],
in_tensor["T_normal_F"],
in_tensor["normal_F"],
diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
])
per_loop_lst.extend([
in_tensor["image"],
in_tensor["T_normal_B"],
in_tensor["normal_B"],
diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
])
per_data_lst.append(
get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl")
)
smpl_loss.backward()
optimizer_smpl.step()
scheduler_smpl.step(smpl_loss)
in_tensor["smpl_verts"] = smpl_verts * torch.tensor([1.0, 1.0, -1.0]).to(device)
in_tensor["smpl_faces"] = in_tensor["smpl_faces"][:, :, [0, 2, 1]]
per_data_lst[-1].save(osp.join(out_dir, cfg.name, f"png/{data['name']}_smpl.png"))
img_crop_path = osp.join(out_dir, cfg.name, "png", f"{data['name']}_crop.png")
torchvision.utils.save_image(
torch.cat([
data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
(in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
],
dim=3), img_crop_path
)
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
img_overlap_path = osp.join(out_dir, cfg.name, f"png/{data['name']}_overlap.png")
torchvision.utils.save_image(
torch.cat([data["img_raw"], rgb_norm_F, rgb_norm_B], dim=-1) / 255., img_overlap_path
)
smpl_obj_lst = []
for idx in range(N_body):
smpl_obj = trimesh.Trimesh(
in_tensor["smpl_verts"].detach().cpu()[idx] * torch.tensor([1.0, -1.0, 1.0]),
in_tensor["smpl_faces"].detach().cpu()[0][:, [0, 2, 1]],
process=False,
maintains_order=True,
)
smpl_obj_path = f"{out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
if not osp.exists(smpl_obj_path):
smpl_obj.export(smpl_obj_path)
smpl_obj.export(smpl_obj_path.replace(".obj", ".glb"))
smpl_info = {
"betas":
optimed_betas[idx].detach().cpu().unsqueeze(0),
"body_pose":
rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()).cpu().unsqueeze(0),
"global_orient":
rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()).cpu().unsqueeze(0),
"transl":
optimed_trans[idx].detach().cpu(),
"expression":
data["exp"][idx].cpu().unsqueeze(0),
"jaw_pose":
rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
"left_hand_pose":
rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]).cpu().unsqueeze(0),
"right_hand_pose":
rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]).cpu().unsqueeze(0),
"scale":
data["scale"][idx].cpu(),
}
np.save(
smpl_obj_path.replace(".obj", ".npy"),
smpl_info,
allow_pickle=True,
)
smpl_obj_lst.append(smpl_obj)
del optimizer_smpl
del optimed_betas
del optimed_orient
del optimed_pose
del optimed_trans
torch.cuda.empty_cache()
# ------------------------------------------------------------------------------------------------------------------
# clothing refinement
per_data_lst = []
batch_smpl_verts = in_tensor["smpl_verts"].detach() * torch.tensor([1.0, -1.0, 1.0],
device=device)
batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
batch_smpl_verts, batch_smpl_faces
)
per_loop_lst = []
in_tensor["BNI_verts"] = []
in_tensor["BNI_faces"] = []
in_tensor["body_verts"] = []
in_tensor["body_faces"] = []
for idx in range(N_body):
final_path = f"{out_dir}/{cfg.name}/obj/{data['name']}_{idx}_full.obj"
side_mesh = smpl_obj_lst[idx].copy()
face_mesh = smpl_obj_lst[idx].copy()
hand_mesh = smpl_obj_lst[idx].copy()
smplx_mesh = smpl_obj_lst[idx].copy()
# save normals, depths and masks
BNI_dict = save_normal_tensor(
in_tensor,
idx,
osp.join(out_dir, cfg.name, f"BNI/{data['name']}_{idx}"),
cfg.bni.thickness,
)
# BNI process
BNI_object = BNI(
dir_path=osp.join(out_dir, cfg.name, "BNI"),
name=data["name"],
BNI_dict=BNI_dict,
cfg=cfg.bni,
device=device
)
BNI_object.extract_surface(False)
in_tensor["body_verts"].append(torch.tensor(smpl_obj_lst[idx].vertices).float())
in_tensor["body_faces"].append(torch.tensor(smpl_obj_lst[idx].faces).long())
# requires shape completion when low overlap
# replace SMPL by completed mesh as side_mesh
if cfg.bni.use_ifnet:
side_mesh_path = f"{out_dir}/{cfg.name}/obj/{data['name']}_{idx}_IF.obj"
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
# mesh completion via IF-net
in_tensor.update(
dataset.depth_to_voxel({
"depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B":
BNI_object.B_depth.unsqueeze(0)
})
)
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
0,
] * 3, scale=2.0).data.transpose(2, 1, 0)
occupancies = np.flip(occupancies, axis=1)
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()
).float().unsqueeze(0).to(device)
with torch.no_grad():
sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor)
verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf)
if ifnet.clean_mesh_flag:
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
side_mesh = trimesh.Trimesh(verts_IF, faces_IF)
side_mesh = remesh_laplacian(side_mesh, side_mesh_path)
else:
side_mesh = apply_vertex_mask(
side_mesh,
(
SMPLX_object.front_flame_vertex_mask + SMPLX_object.smplx_mano_vertex_mask +
SMPLX_object.eyeball_vertex_mask
).eq(0).float(),
)
#register side_mesh to BNI surfaces
side_mesh = Meshes(
verts=[torch.tensor(side_mesh.vertices).float()],
faces=[torch.tensor(side_mesh.faces).long()],
).to(device)
sm = SubdivideMeshes(side_mesh)
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
side_faces = torch.tensor(side_mesh.faces).long().to(device)
# Possion Fusion between SMPLX and BNI
# 1. keep the faces invisible to front+back cameras
# 2. keep the front-FLAME+MANO faces
# 3. remove eyeball faces
# export intermediate meshes
BNI_object.F_B_trimesh.export(f"{out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
full_lst = []
if "face" in cfg.bni.use_smpl:
# only face
face_mesh = apply_vertex_mask(face_mesh, SMPLX_object.front_flame_vertex_mask)
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
# remove face neighbor triangles
BNI_object.F_B_trimesh = part_removal(
BNI_object.F_B_trimesh,
face_mesh,
cfg.bni.face_thres,
device,
smplx_mesh,
region="face"
)
side_mesh = part_removal(
side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face"
)
face_mesh.export(f"{out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
full_lst += [face_mesh]
if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]):
hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0], )
if data['hands_visibility'][idx][0]:
hand_mask.index_fill_(
0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0
)
if data['hands_visibility'][idx][1]:
hand_mask.index_fill_(
0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0
)
# only hands
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
# remove hand neighbor triangles
BNI_object.F_B_trimesh = part_removal(
BNI_object.F_B_trimesh,
hand_mesh,
cfg.bni.hand_thres,
device,
smplx_mesh,
region="hand"
)
side_mesh = part_removal(
side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand"
)
hand_mesh.export(f"{out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
full_lst += [hand_mesh]
full_lst += [BNI_object.F_B_trimesh]
# initial side_mesh could be SMPLX or IF-net
side_mesh = part_removal(
side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False
)
full_lst += [side_mesh]
# # export intermediate meshes
BNI_object.F_B_trimesh.export(f"{out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
side_mesh.export(f"{out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj")
final_mesh = poisson(
sum(full_lst),
final_path,
cfg.bni.poisson_depth,
)
print(
colored(f"Poisson completion to : {final_path} ", "yellow")
)
dataset.render.load_meshes(final_mesh.vertices, final_mesh.faces)
rotate_recon_lst = dataset.render.get_image(cam_type="four")
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
if cfg.bni.texture_src == 'image':
# coloring the final mesh (front: RGB pixels, back: normal colors)
final_colors = query_color(
torch.tensor(final_mesh.vertices).float(),
torch.tensor(final_mesh.faces).long(),
in_tensor["image"][idx:idx + 1],
device=device,
)
final_mesh.visual.vertex_colors = final_colors
final_mesh.export(final_path)
final_mesh.export(final_path.replace(".obj", ".glb"))
elif cfg.bni.texture_src == 'SD':
# !TODO: add texture from Stable Diffusion
pass
if len(per_loop_lst) > 0:
per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=5, type="cloth"))
per_data_lst[-1].save(osp.join(out_dir, cfg.name, f"png/{data['name']}_cloth.png"))
# for video rendering
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())
os.makedirs(osp.join(out_dir, cfg.name, "vid"), exist_ok=True)
in_tensor["uncrop_param"] = data["uncrop_param"]
in_tensor["img_raw"] = data["img_raw"]
torch.save(in_tensor, osp.join(out_dir, cfg.name, f"vid/{data['name']}_in_tensor.pt"))
smpl_glb_path = smpl_obj_path.replace(".obj", ".glb")
# smpl_npy_path = smpl_obj_path.replace(".obj", ".npy")
# refine_obj_path = final_path
refine_glb_path = final_path.replace(".obj", ".glb")
overlap_path = img_overlap_path
vis_tensor_path = osp.join(out_dir, cfg.name, f"vid/{data['name']}_in_tensor.pt")
# clean all the variables
for element in dir():
if 'path' not in element:
del locals()[element]
import gc
gc.collect()
torch.cuda.empty_cache()
return [smpl_glb_path, refine_glb_path, overlap_path, vis_tensor_path]