Spaces:
Running
on
L40S
Running
on
L40S
from core.remesh import calc_vertex_normals | |
from core.opt import MeshOptimizer | |
from utils.func import make_sparse_camera, make_round_views | |
from utils.render import NormalsRenderer | |
import torch.optim as optim | |
from tqdm import tqdm | |
from utils.video_utils import write_video | |
from omegaconf import OmegaConf | |
import numpy as np | |
import os | |
from PIL import Image | |
import kornia | |
import torch | |
import torch.nn as nn | |
import trimesh | |
from icecream import ic | |
from utils.project_mesh import multiview_color_projection, get_cameras_list | |
from utils.mesh_utils import to_py3d_mesh, rot6d_to_rotmat, tensor2variable | |
from utils.project_mesh import project_color, get_cameras_list | |
from utils.smpl_util import SMPLX | |
from lib.dataset.mesh_util import apply_vertex_mask, part_removal, poisson, keep_largest | |
from scipy.spatial.transform import Rotation as R | |
from scipy.spatial import KDTree | |
import argparse | |
#### ------------------- config---------------------- | |
bg_color = np.array([1,1,1]) | |
class colorModel(nn.Module): | |
def __init__(self, renderer, v, f, c): | |
super().__init__() | |
self.renderer = renderer | |
self.v = v | |
self.f = f | |
self.colors = nn.Parameter(c, requires_grad=True) | |
self.bg_color = torch.from_numpy(bg_color).float().to(self.colors.device) | |
def forward(self, return_mask=False): | |
rgba = self.renderer.render(self.v, self.f, colors=self.colors) | |
if return_mask: | |
return rgba | |
else: | |
mask = rgba[..., 3:] | |
return rgba[..., :3] * mask + self.bg_color * (1 - mask) | |
def scale_mesh(vert): | |
min_bbox, max_bbox = vert.min(0)[0], vert.max(0)[0] | |
center = (min_bbox + max_bbox) / 2 | |
offset = -center | |
vert = vert + offset | |
max_dist = torch.max(torch.sqrt(torch.sum(vert**2, dim=1))) | |
scale = 1.0 / max_dist | |
return scale, offset | |
def save_mesh(save_name, vertices, faces, color=None): | |
trimesh.Trimesh( | |
vertices.detach().cpu().numpy(), | |
faces.detach().cpu().numpy(), | |
vertex_colors=(color.detach().cpu().numpy() * 255).astype(np.uint8) if color is not None else None) \ | |
.export(save_name) | |
class ReMesh: | |
def __init__(self, opt, econ_dataset): | |
self.opt = opt | |
self.device = torch.device(f"cuda:{opt.gpu_id}" if torch.cuda.is_available() else "cpu") | |
self.num_view = opt.num_view | |
self.out_path = opt.res_path | |
os.makedirs(self.out_path, exist_ok=True) | |
self.resolution = opt.resolution | |
self.views = ['front_face', 'front_right', 'right', 'back', 'left', 'front_left' ] | |
self.weights = torch.Tensor([1., 0.4, 0.8, 1.0, 0.8, 0.4]).view(6,1,1,1).to(self.device) | |
self.renderer = self.prepare_render() | |
# pose prediction | |
self.econ_dataset = econ_dataset | |
self.smplx_face = torch.Tensor(econ_dataset.faces.astype(np.int64)).long().to(self.device) | |
def prepare_render(self): | |
### ------------------- prepare camera and renderer---------------------- | |
mv, proj = make_sparse_camera(self.opt.cam_path, self.opt.scale, views=[0,1,2,4,6,7], device=self.device) | |
renderer = NormalsRenderer(mv, proj, [self.resolution, self.resolution], device=self.device) | |
return renderer | |
def proj_texture(self, fused_images, vertices, faces): | |
mesh = to_py3d_mesh(vertices, faces) | |
mesh = mesh.to(self.device) | |
camera_focal = 1/2 | |
cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) | |
mesh = multiview_color_projection(mesh, fused_images, camera_focal=camera_focal, resolution=self.resolution, weights=self.weights.squeeze().cpu().numpy(), | |
device=self.device, complete_unseen=True, confidence_threshold=0.2, cameras_list=cameras_list) | |
return mesh | |
def get_invisible_idx(self, imgs, vertices, faces): | |
mesh = to_py3d_mesh(vertices, faces) | |
mesh = mesh.to(self.device) | |
camera_focal = 1/2 | |
if self.num_view == 6: | |
cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) | |
elif self.num_view == 4: | |
cameras_list = get_cameras_list([0, 90, 180, 270], device=self.device, focal=camera_focal) | |
valid_vert_id = [] | |
vertices_colors = torch.zeros((vertices.shape[0], 3)).float().to(self.device) | |
valid_cnt = torch.zeros((vertices.shape[0])).to(self.device) | |
for cam, img, weight in zip(cameras_list, imgs, self.weights.squeeze()): | |
ret = project_color(mesh, cam, img, eps=0.01, resolution=self.resolution, device=self.device) | |
# print(ret['valid_colors'].shape) | |
valid_cnt[ret['valid_verts']] += weight | |
vertices_colors[ret['valid_verts']] += ret['valid_colors']*weight | |
valid_mask = valid_cnt > 1 | |
invalid_mask = valid_cnt < 1 | |
vertices_colors[valid_mask] /= valid_cnt[valid_mask][:, None] | |
# visibility | |
invisible_vert = valid_cnt < 1 | |
invisible_vert_indices = torch.nonzero(invisible_vert).squeeze() | |
# vertices_colors[invalid_vert] = torch.tensor([1.0, 0.0, 0.0]).float().to("cuda") | |
return vertices_colors, invisible_vert_indices | |
def inpaint_missed_colors(self, all_vertices, all_colors, missing_indices): | |
all_vertices = all_vertices.detach().cpu().numpy() | |
all_colors = all_colors.detach().cpu().numpy() | |
missing_indices = missing_indices.detach().cpu().numpy() | |
non_missing_indices = np.setdiff1d(np.arange(len(all_vertices)), missing_indices) | |
kdtree = KDTree(all_vertices[non_missing_indices]) | |
for missing_index in missing_indices: | |
missing_vertex = all_vertices[missing_index] | |
_, nearest_index = kdtree.query(missing_vertex.reshape(1, -1)) | |
interpolated_color = all_colors[non_missing_indices[nearest_index]] | |
all_colors[missing_index] = interpolated_color | |
return torch.from_numpy(all_colors).to(self.device) | |
def load_training_data(self, case): | |
###------------------ load target images ------------------------------- | |
kernal = torch.ones(3, 3) | |
erode_iters = 2 | |
normals = [] | |
masks = [] | |
colors = [] | |
for idx, view in enumerate(self.views): | |
# for idx in [0,2,3,4]: | |
normal = Image.open(f'{self.opt.mv_path}/{case}/normals_{view}_masked.png') | |
# normal = Image.open(f'{data_path}/{case}/normals/{idx:02d}_rgba.png') | |
normal = normal.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) | |
normal = np.array(normal).astype(np.float32) / 255. | |
mask = normal[..., 3:] # alpha | |
mask_troch = torch.from_numpy(mask).unsqueeze(0) | |
for _ in range(erode_iters): | |
mask_torch = kornia.morphology.erosion(mask_troch, kernal) | |
mask_erode = mask_torch.squeeze(0).numpy() | |
masks.append(mask_erode) | |
normal = normal[..., :3] * mask_erode | |
normals.append(normal) | |
color = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png') | |
color = color.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) | |
color = np.array(color).astype(np.float32) / 255. | |
color_mask = color[..., 3:] # alpha | |
# color_dilate = color[..., :3] * color_mask + bg_color * (1 - color_mask) | |
color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) | |
colors.append(color_dilate) | |
masks = np.stack(masks, 0) | |
masks = torch.from_numpy(masks).to(self.device) | |
normals = np.stack(normals, 0) | |
target_normals = torch.from_numpy(normals).to(self.device) | |
colors = np.stack(colors, 0) | |
target_colors = torch.from_numpy(colors).to(self.device) | |
return masks, target_colors, target_normals | |
def preprocess(self, color_pils, normal_pils): | |
###------------------ load target images ------------------------------- | |
kernal = torch.ones(3, 3) | |
erode_iters = 2 | |
normals = [] | |
masks = [] | |
colors = [] | |
for normal, color in zip(normal_pils, color_pils): | |
normal = normal.resize((self.resolution, self.resolution), Image.BILINEAR) | |
normal = np.array(normal).astype(np.float32) / 255. | |
mask = normal[..., 3:] # alpha | |
mask_troch = torch.from_numpy(mask).unsqueeze(0) | |
for _ in range(erode_iters): | |
mask_torch = kornia.morphology.erosion(mask_troch, kernal) | |
mask_erode = mask_torch.squeeze(0).numpy() | |
masks.append(mask_erode) | |
normal = normal[..., :3] * mask_erode | |
normals.append(normal) | |
color = color.resize((self.resolution, self.resolution), Image.BILINEAR) | |
color = np.array(color).astype(np.float32) / 255. | |
color_mask = color[..., 3:] # alpha | |
# color_dilate = color[..., :3] * color_mask + bg_color * (1 - color_mask) | |
color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) | |
colors.append(color_dilate) | |
masks = np.stack(masks, 0) | |
masks = torch.from_numpy(masks).to(self.device) | |
normals = np.stack(normals, 0) | |
target_normals = torch.from_numpy(normals).to(self.device) | |
colors = np.stack(colors, 0) | |
target_colors = torch.from_numpy(colors).to(self.device) | |
return masks, target_colors, target_normals | |
def optimize_case(self, case, pose, clr_img, nrm_img, opti_texture=True): | |
case_path = f'{self.out_path}/{case}' | |
os.makedirs(case_path, exist_ok=True) | |
if clr_img is not None: | |
masks, target_colors, target_normals = self.preprocess(clr_img, nrm_img) | |
else: | |
masks, target_colors, target_normals = self.load_training_data(case) | |
# rotation | |
rz = R.from_euler('z', 180, degrees=True).as_matrix() | |
ry = R.from_euler('y', 180, degrees=True).as_matrix() | |
rz = torch.from_numpy(rz).float().to(self.device) | |
ry = torch.from_numpy(ry).float().to(self.device) | |
scale, offset = None, None | |
global_orient = pose["global_orient"] # pymaf_res[idx]['smplx_params']['body_pose'][:, :1, :, :2].to(device).reshape(1, 1, -1) # data["global_orient"] | |
body_pose = pose["body_pose"] # pymaf_res[idx]['smplx_params']['body_pose'][:, 1:22, :, :2].to(device).reshape(1, 21, -1) # data["body_pose"] | |
left_hand_pose = pose["left_hand_pose"] # pymaf_res[idx]['smplx_params']['left_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1) | |
right_hand_pose = pose["right_hand_pose"] # pymaf_res[idx]['smplx_params']['right_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1) | |
beta = pose["betas"] | |
# The optimizer and variables | |
optimed_pose = torch.tensor(body_pose, | |
device=self.device, | |
requires_grad=True) # [1,23,3,3] | |
optimed_trans = torch.tensor(pose["trans"], | |
device=self.device, | |
requires_grad=True) # [3] | |
optimed_betas = torch.tensor(beta, | |
device=self.device, | |
requires_grad=True) # [1,200] | |
optimed_orient = torch.tensor(global_orient, | |
device=self.device, | |
requires_grad=True) # [1,1,3,3] | |
optimed_rhand = torch.tensor(right_hand_pose, | |
device=self.device, | |
requires_grad=True) | |
optimed_lhand = torch.tensor(left_hand_pose, | |
device=self.device, | |
requires_grad=True) | |
optimed_params = [ | |
{'params': [optimed_lhand, optimed_rhand], 'lr': 1e-3}, | |
{'params': [optimed_betas, optimed_trans, optimed_orient, optimed_pose], 'lr': 3e-3}, | |
] | |
optimizer_smpl = torch.optim.Adam( | |
optimed_params, | |
amsgrad=True, | |
) | |
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer_smpl, | |
mode="min", | |
factor=0.5, | |
verbose=0, | |
min_lr=1e-5, | |
patience=5, | |
) | |
smpl_steps = 100 | |
for i in tqdm(range(smpl_steps)): | |
optimizer_smpl.zero_grad() | |
# 6d_rot to rot_mat | |
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view( | |
-1, 6)).unsqueeze(0) | |
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view( | |
-1, 6)).unsqueeze(0) | |
smpl_verts, smpl_landmarks, smpl_joints = self.econ_dataset.smpl_model( | |
shape_params=optimed_betas, | |
expression_params=tensor2variable(pose["exp"], self.device), | |
body_pose=optimed_pose_mat, | |
global_pose=optimed_orient_mat, | |
jaw_pose=tensor2variable(pose["jaw_pose"], self.device), | |
left_hand_pose=optimed_lhand, | |
right_hand_pose=optimed_rhand, | |
) | |
smpl_verts = smpl_verts + optimed_trans | |
v_smpl = torch.matmul(torch.matmul(smpl_verts.squeeze(0), rz.T), ry.T) | |
if scale is None: | |
scale, offset = scale_mesh(v_smpl.detach()) | |
v_smpl = (v_smpl + offset) * scale * 2 | |
# if i == 0: | |
# save_mesh(f'{case_path}/{case}_init_smpl.obj', v_smpl, self.smplx_face) | |
# exit() | |
normals = calc_vertex_normals(v_smpl, self.smplx_face) | |
nrm = self.renderer.render(v_smpl, self.smplx_face, normals=normals) | |
masks_ = nrm[..., 3:] | |
smpl_mask_loss = ((masks_ - masks) * self.weights).abs().mean() | |
smpl_nrm_loss = ((nrm[..., :3] - target_normals) * self.weights).abs().mean() | |
smpl_loss = smpl_mask_loss + smpl_nrm_loss | |
# smpl_loss = smpl_mask_loss | |
smpl_loss.backward() | |
optimizer_smpl.step() | |
scheduler_smpl.step(smpl_loss) | |
mesh_smpl = trimesh.Trimesh(vertices=v_smpl.detach().cpu().numpy(), faces=self.smplx_face.detach().cpu().numpy()) | |
nrm_opt = MeshOptimizer(v_smpl.detach(), self.smplx_face.detach(), edge_len_lims=[0.01, 0.1]) | |
vertices, faces = nrm_opt.vertices, nrm_opt.faces | |
# ###----------------------- optimization iterations------------------------------------- | |
for i in tqdm(range(self.opt.iters)): | |
nrm_opt.zero_grad() | |
normals = calc_vertex_normals(vertices,faces) | |
nrm = self.renderer.render(vertices,faces, normals=normals) | |
normals = nrm[..., :3] | |
# if i < 800: | |
loss = ((normals-target_normals) * self.weights).abs().mean() | |
# else: | |
# loss = ((normals-target_images) * masks).abs().mean() | |
alpha = nrm[..., 3:] | |
loss += ((alpha - masks) * self.weights).abs().mean() | |
loss.backward() | |
nrm_opt.step() | |
vertices,faces = nrm_opt.remesh() | |
if self.opt.debug and i % self.opt.snapshot_step == 0: | |
import imageio | |
os.makedirs(f'{case_path}/normals', exist_ok=True) | |
imageio.imwrite(f'{case_path}/normals/{i:02d}.png',(nrm.detach()[0,:,:,:3]*255).clamp(max=255).type(torch.uint8).cpu().numpy()) | |
# mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy()) | |
# mesh_remeshed.export(f'{case_path}/{case}_remeshed_step{i}.obj') | |
torch.cuda.empty_cache() | |
mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy()) | |
mesh_remeshed.export(f'{case_path}/{case}_remeshed.obj') | |
# save_mesh(case, vertices, faces) | |
vertices = vertices.detach() | |
faces = faces.detach() | |
#### replace hand | |
smpl_data = SMPLX() | |
if self.opt.replace_hand and True in pose['hands_visibility'][0]: | |
hand_mask = torch.zeros(smpl_data.smplx_verts.shape[0], ) | |
if pose['hands_visibility'][0][0]: | |
hand_mask.index_fill_( | |
0, torch.tensor(smpl_data.smplx_mano_vid_dict["left_hand"]), 1.0 | |
) | |
if pose['hands_visibility'][0][1]: | |
hand_mask.index_fill_( | |
0, torch.tensor(smpl_data.smplx_mano_vid_dict["right_hand"]), 1.0 | |
) | |
hand_mesh = apply_vertex_mask(mesh_smpl.copy(), hand_mask) | |
body_mesh = part_removal( | |
mesh_remeshed.copy(), | |
hand_mesh, | |
0.08, | |
self.device, | |
mesh_smpl.copy(), | |
region="hand" | |
) | |
final = poisson(sum([hand_mesh, body_mesh]), f'{case_path}/{case}_final.obj', 10, False) | |
else: | |
final = poisson(mesh_remeshed, f'{case_path}/{case}_final.obj', 10, False) | |
vertices = torch.from_numpy(final.vertices).float().to(self.device) | |
faces = torch.from_numpy(final.faces).long().to(self.device) | |
# Differing from paper, we use the texturing method in Unique3D | |
masked_color = [] | |
for tmp in clr_img: | |
# tmp = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png') | |
tmp = tmp.resize((self.resolution, self.resolution), Image.BILINEAR) | |
tmp = np.array(tmp).astype(np.float32) / 255. | |
masked_color.append(torch.from_numpy(tmp).permute(2, 0, 1).to(self.device)) | |
meshes = self.proj_texture(masked_color, vertices, faces) | |
vertices = meshes.verts_packed().float() | |
faces = meshes.faces_packed().long() | |
colors = meshes.textures.verts_features_packed().float() | |
save_mesh(f'./{case_path}/result_clr_scale{self.opt.scale}_{case}.obj', vertices, faces, colors) | |
self.evaluate(vertices, colors, faces, save_path=f'{case_path}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) | |
def evaluate(self, target_vertices, target_colors, target_faces, save_path=None, save_nrm=False): | |
mv, proj = make_round_views(60, self.opt.scale, device=self.device) | |
renderer = NormalsRenderer(mv, proj, [512, 512], device=self.device) | |
target_images = renderer.render(target_vertices,target_faces, colors=target_colors) | |
target_images = target_images.detach().cpu().numpy() | |
target_images = target_images[..., :3] * target_images[..., 3:4] + bg_color * (1 - target_images[..., 3:4]) | |
target_images = (target_images.clip(0, 1) * 255).astype(np.uint8) | |
if save_nrm: | |
target_normals = calc_vertex_normals(target_vertices, target_faces) | |
# target_normals[:, 2] *= -1 | |
target_normals = renderer.render(target_vertices, target_faces, normals=target_normals) | |
target_normals = target_normals.detach().cpu().numpy() | |
target_normals = target_normals[..., :3] * target_normals[..., 3:4] + bg_color * (1 - target_normals[..., 3:4]) | |
target_normals = (target_normals.clip(0, 1) * 255).astype(np.uint8) | |
frames = [np.concatenate([img, nrm], 1) for img, nrm in zip(target_images, target_normals)] | |
else: | |
frames = [img for img in target_images] | |
if save_path is not None: | |
write_video(frames, fps=25, save_path=save_path) | |
return frames | |
def run(self): | |
cases = sorted(os.listdir(self.opt.imgs_path)) | |
for idx in range(len(cases)): | |
case = cases[idx].split('.')[0] | |
print(f'Processing {case}') | |
pose = self.econ_dataset.__getitem__(idx) | |
v, f, c = self.optimize_case(case, pose, None, None, opti_texture=True) | |
self.evaluate(v, c, f, save_path=f'{self.opt.res_path}/{case}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", help="path to the yaml configs file", default='config.yaml') | |
args, extras = parser.parse_known_args() | |
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) | |
from econdataset import SMPLDataset | |
dataset_param = {'image_dir': opt.imgs_path, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'} | |
econdata = SMPLDataset(dataset_param, device='cuda') | |
EHuman = ReMesh(opt, econdata) | |
EHuman.run() | |