File size: 7,265 Bytes
ec9a6bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import numpy as np
import torch
import torch.nn.functional as F
import cv2 as cv
import trimesh
import yaml
import tqdm
import smplx
from network.volume import CanoBlendWeightVolume
from utils.renderer import Renderer
import config
def save_pos_map(pos_map, path):
mask = np.linalg.norm(pos_map, axis = -1) > 0.
positions = pos_map[mask]
print('Point nums %d' % positions.shape[0])
pc = trimesh.PointCloud(positions)
def interpolate_lbs(pts, vertices, faces, vertex_lbs):
from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d
from utils.geo_util import barycentric_interpolate
dists, indices, bc_coords = nearest_face_pytorch3d(
# print(dists.mean())
lbs = barycentric_interpolate(
vert_attris = vertex_lbs[None].to(torch.float32).cuda(),
faces = torch.from_numpy(faces).to(torch.int64).cuda()[None],
face_ids = indices,
bc_coords = bc_coords
return lbs[0].cpu().numpy()
map_size = 1024
if __name__ == '__main__':
from argparse import ArgumentParser
import importlib
arg_parser = ArgumentParser()
arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.')
args = arg_parser.parse_args()
opt = yaml.load(open(args.config_path, encoding = 'UTF-8'), Loader = yaml.FullLoader)
dataset_module = opt['train'].get('dataset', 'MvRgbDatasetAvatarReX')
MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module)
dataset = MvRgbDataset(**opt['train']['data'])
data_dir, frame_list = dataset.data_dir, dataset.pose_list
os.makedirs(data_dir + '/smpl_pos_map', exist_ok = True)
cano_renderer = Renderer(map_size, map_size, shader_name = 'vertex_attribute')
smpl_model = smplx.SMPLX(config.PROJ_DIR + '/smpl_files/smplx', gender = 'neutral', use_pca = False, num_pca_comps = 45, flat_hand_mean = True, batch_size = 1)
smpl_data = np.load(data_dir + '/smpl_params.npz')
smpl_data = {k: torch.from_numpy(v.astype(np.float32)) for k, v in smpl_data.items()}
with torch.no_grad():
cano_smpl = smpl_model.forward(
betas = smpl_data['betas'],
global_orient = config.cano_smpl_global_orient[None],
transl = config.cano_smpl_transl[None],
body_pose = config.cano_smpl_body_pose[None]
cano_smpl_v = cano_smpl.vertices[0].cpu().numpy()
cano_center = 0.5 * (cano_smpl_v.min(0) + cano_smpl_v.max(0))
cano_smpl_v_min = cano_smpl_v.min()
smpl_faces = smpl_model.faces.astype(np.int64)
if os.path.exists(data_dir + '/template.ply'):
print('# Loading template from %s' % (data_dir + '/template.ply'))
template = trimesh.load(data_dir + '/template.ply', process = False)
using_template = True
print(f'# Cannot find template.ply from {data_dir}, using SMPL-X as template')
template = trimesh.Trimesh(cano_smpl_v, smpl_faces, process = False)
using_template = False
cano_smpl_v = template.vertices.astype(np.float32)
smpl_faces = template.faces.astype(np.int64)
cano_smpl_v_dup = cano_smpl_v[smpl_faces.reshape(-1)]
cano_smpl_n_dup = template.vertex_normals.astype(np.float32)[smpl_faces.reshape(-1)]
# define front & back view matrices
front_mv = np.identity(4, np.float32)
front_mv[:3, 3] = -cano_center + np.array([0, 0, -10], np.float32)
front_mv[1:3] *= -1
back_mv = np.identity(4, np.float32)
rot_y = cv.Rodrigues(np.array([0, np.pi, 0], np.float32))[0]
back_mv[:3, :3] = rot_y
back_mv[:3, 3] = -rot_y @ cano_center + np.array([0, 0, -10], np.float32)
back_mv[1:3] *= -1
# render canonical smpl position maps
cano_renderer.set_model(cano_smpl_v_dup, cano_smpl_v_dup)
front_cano_pos_map = cano_renderer.render()[:, :, :3]
back_cano_pos_map = cano_renderer.render()[:, :, :3]
back_cano_pos_map = cv.flip(back_cano_pos_map, 1)
cano_pos_map = np.concatenate([front_cano_pos_map, back_cano_pos_map], 1)
cv.imwrite(data_dir + '/smpl_pos_map/cano_smpl_pos_map.exr', cano_pos_map)
# render canonical smpl normal maps
cano_renderer.set_model(cano_smpl_v_dup, cano_smpl_n_dup)
front_cano_nml_map = cano_renderer.render()[:, :, :3]
back_cano_nml_map = cano_renderer.render()[:, :, :3]
back_cano_nml_map = cv.flip(back_cano_nml_map, 1)
cano_nml_map = np.concatenate([front_cano_nml_map, back_cano_nml_map], 1)
cv.imwrite(data_dir + '/smpl_pos_map/cano_smpl_nml_map.exr', cano_nml_map)
body_mask = np.linalg.norm(cano_pos_map, axis = -1) > 0.
cano_pts = cano_pos_map[body_mask]
if using_template:
weight_volume = CanoBlendWeightVolume(data_dir + '/cano_weight_volume.npz')
pts_lbs = weight_volume.forward_weight(torch.from_numpy(cano_pts)[None].cuda())[0]
pts_lbs = interpolate_lbs(cano_pts, cano_smpl_v, smpl_faces, smpl_model.lbs_weights)
pts_lbs = torch.from_numpy(pts_lbs).cuda() + '/smpl_pos_map/init_pts_lbs.npy', pts_lbs.cpu().numpy())
inv_cano_smpl_A = torch.linalg.inv(cano_smpl.A).cuda()
body_mask = torch.from_numpy(body_mask).cuda()
cano_pts = torch.from_numpy(cano_pts).cuda()
pts_lbs = pts_lbs.cuda()
for pose_idx in tqdm.tqdm(frame_list, desc = 'Generating positional maps...'):
with torch.no_grad():
live_smpl_woRoot = smpl_model.forward(
betas = smpl_data['betas'],
# global_orient = smpl_data['global_orient'][pose_idx][None],
# transl = smpl_data['transl'][pose_idx][None],
body_pose = smpl_data['body_pose'][pose_idx][None],
jaw_pose = smpl_data['jaw_pose'][pose_idx][None],
expression = smpl_data['expression'][pose_idx][None],
# left_hand_pose = smpl_data['left_hand_pose'][pose_idx][None],
# right_hand_pose = smpl_data['right_hand_pose'][pose_idx][None]
cano2live_jnt_mats_woRoot = torch.matmul(live_smpl_woRoot.A.cuda(), inv_cano_smpl_A)[0]
pt_mats = torch.einsum('nj,jxy->nxy', pts_lbs, cano2live_jnt_mats_woRoot)
live_pts = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], cano_pts) + pt_mats[..., :3, 3]
live_pos_map = torch.zeros((map_size, 2 * map_size, 3)).to(live_pts)
live_pos_map[body_mask] = live_pts
live_pos_map = F.interpolate(live_pos_map.permute(2, 0, 1)[None], None, [0.5, 0.5], mode = 'nearest')[0]
live_pos_map = live_pos_map.permute(1, 2, 0).cpu().numpy()
cv.imwrite(data_dir + '/smpl_pos_map/%08d.exr' % pose_idx, live_pos_map)