Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
import pytorch3d.ops | |
import pytorch3d.transforms | |
import trimesh | |
import config | |
from network.mlp import MLPLinear, SdfMLP | |
from network.density import LaplaceDensity | |
from network.volume import CanoBlendWeightVolume | |
from network.hand_avatar import HandAvatar | |
from utils.embedder import get_embedder | |
import utils.nerf_util as nerf_util | |
import utils.smpl_util as smpl_util | |
import utils.geo_util as geo_util | |
from utils.posevocab_custom_ops.near_far_smpl import near_far_smpl | |
from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d | |
from utils.knn import knn_gather | |
import root_finding | |
class TemplateNet(nn.Module): | |
def __init__(self, opt): | |
super(TemplateNet, self).__init__() | |
self.opt = opt | |
self.pos_embedder, self.pos_dim = get_embedder(opt['multires'], 3) | |
# canonical blend weight volume | |
self.cano_weight_volume = CanoBlendWeightVolume(config.opt['train']['data']['data_dir'] + '/cano_weight_volume.npz') | |
self.pose_feat_dim = 0 | |
""" geometry networks """ | |
geo_mlp_opt = { | |
'in_channels': self.pos_dim + self.pose_feat_dim, | |
'out_channels': 256 + 1, | |
'inter_channels': [512, 256, 256, 256, 256, 256], | |
'nlactv': nn.Softplus(beta = 100), | |
'res_layers': [4], | |
'geometric_init': True, | |
'bias': 0.7, | |
'weight_norm': True | |
} | |
self.geo_mlp = SdfMLP(**geo_mlp_opt) | |
""" texture networks """ | |
if self.opt['use_viewdir']: | |
self.viewdir_embedder, self.viewdir_dim = get_embedder(self.opt['multires_viewdir'], 3) | |
else: | |
self.viewdir_embedder, self.viewdir_dim = None, 0 | |
tex_mlp_opt = { | |
'in_channels': 256 + self.viewdir_dim, | |
'out_channels': 3, | |
'inter_channels': [256, 256, 256], | |
'nlactv': nn.ReLU(), | |
'last_op': nn.Sigmoid() | |
} | |
self.tex_mlp = MLPLinear(**tex_mlp_opt) | |
print('# MLPs: ') | |
print(self.geo_mlp) | |
print(self.tex_mlp) | |
# sdf2density | |
self.density_func = LaplaceDensity(params_init = {'beta': 0.01}) | |
# hand avatars | |
self.with_hand = self.opt.get('with_hand', False) | |
self.left_hand = HandAvatar() | |
self.right_hand = HandAvatar() | |
# for root finding | |
from network.volume import compute_gradient_volume | |
if self.opt.get('volume_type', 'diff') == 'diff': | |
self.weight_volume = self.cano_weight_volume.diff_weight_volume[0].permute(1, 2, 3, 0).contiguous() | |
else: | |
self.weight_volume = self.cano_weight_volume.ori_weight_volume[0].permute(1, 2, 3, 0).contiguous() | |
self.grad_volume = compute_gradient_volume(self.weight_volume.permute(3, 0, 1, 2), self.cano_weight_volume.voxel_size).permute(2, 3, 4, 0, 1)\ | |
.reshape(self.cano_weight_volume.res_x, self.cano_weight_volume.res_y, self.cano_weight_volume.res_z, -1).contiguous() | |
self.res = torch.tensor([self.cano_weight_volume.res_x, self.cano_weight_volume.res_y, self.cano_weight_volume.res_z], dtype = torch.int32, device = config.device) | |
self._initialize_hands() | |
def _initialize_hands(self): | |
smplx_lhand_to_mano_rhand_data = np.load(config.PROJ_DIR + '/smpl_files/mano/smplx_lhand_to_mano_rhand.npz', allow_pickle = True) | |
smplx_rhand_to_mano_rhand_data = np.load(config.PROJ_DIR + '/smpl_files/mano/smplx_rhand_to_mano_rhand.npz', allow_pickle = True) | |
smpl_lhand_vert_id = np.copy(smplx_lhand_to_mano_rhand_data['smpl_vert_id_to_mano']) | |
smpl_rhand_vert_id = np.copy(smplx_rhand_to_mano_rhand_data['smpl_vert_id_to_mano']) | |
self.smpl_lhand_vert_id = torch.from_numpy(smpl_lhand_vert_id).to(config.device) | |
self.smpl_rhand_vert_id = torch.from_numpy(smpl_rhand_vert_id).to(config.device) | |
self.smpl_hands_vert_id = torch.cat([self.smpl_lhand_vert_id, self.smpl_rhand_vert_id], 0) | |
mano_face_closed = np.loadtxt(config.PROJ_DIR + '/smpl_files/mano/mano_face_close.txt').astype(np.int64) | |
self.mano_face_closed = torch.from_numpy(mano_face_closed).to(config.device) | |
self.mano_face_closed_2hand = torch.cat([self.mano_face_closed[:, [2, 1, 0]], self.mano_face_closed + self.smpl_lhand_vert_id.shape[0]], 0) | |
def forward_cano_body_nerf(self, xyz, viewdirs, pose, compute_grad = False): | |
""" | |
:param xyz: (B, N, 3) | |
:param viewdirs: (B, N, 3) | |
:param pose: (B, pose_dim) | |
:param compute_grad: whether computing gradient w.r.t xyz | |
:return: | |
""" | |
if compute_grad: | |
xyz.requires_grad_() | |
# pose_feat = self.pose_feat[None, None].expand(xyz.shape[0], xyz.shape[1], -1) | |
# pose_feat = torch.cat([self.pos_embedder(xyz), pose_feat], -1) | |
pose_feat = self.pos_embedder(xyz) | |
geo_feat = self.geo_mlp(pose_feat) | |
sdf, geo_feat = torch.split(geo_feat, [1, geo_feat.shape[-1] - 1], -1) | |
if self.viewdir_embedder is not None: | |
if viewdirs is None: | |
viewdirs = torch.zeros_like(xyz) | |
geo_feat = torch.cat([geo_feat, self.viewdir_embedder(viewdirs)], -1) | |
color = self.tex_mlp(geo_feat) | |
density = self.density_func(sdf) | |
ret = { | |
'sdf': -sdf, # assume outside is negative, inside is positive | |
'density': density, | |
'color': color, | |
'cano_xyz': xyz.detach() | |
} | |
if compute_grad: | |
d_output = torch.ones_like(sdf, requires_grad = False, device = sdf.device) | |
normal = torch.autograd.grad(outputs = sdf, | |
inputs = xyz, | |
grad_outputs = d_output, | |
create_graph = self.training, | |
retain_graph = self.training, | |
only_inputs = True)[0] | |
ret.update({ | |
'normal': normal | |
}) | |
return ret | |
def forward_cano_hand_nerf(self, xyz, sdf, viewdirs, hand_pose, module = 'left_hand'): | |
net = self.__getattr__(module) | |
return net(xyz, sdf, viewdirs, hand_pose) | |
def fuse_hands(self, body_ret, posed_xyz, view_dirs, batch, space = 'live'): | |
# get hand correspondences | |
batch_size, n_pts = posed_xyz.shape[:2] | |
def process_one_hand(side = 'left'): | |
hand_v = batch['%s_live_mano_v' % side] if space == 'live' else batch['%s_cano_mano_v' % side] | |
hand_n = batch['%s_live_mano_n' % side] if space == 'live' else batch['%s_cano_mano_n' % side] | |
hand_f = self.mano_face_closed[:, [2, 1, 0]] if side == 'left' else self.mano_face_closed | |
dists, face_indices, bc_coords = nearest_face_pytorch3d(posed_xyz, hand_v, hand_f) | |
face_vertex_ids = torch.gather(hand_f[None].expand(batch_size, -1, -1), 1, face_indices[:, :, None].long().expand(-1, -1, 3)) # (B, N, 3) | |
cano_hand_v = geo_util.normalize_vert_bbox(batch['%s_cano_mano_v' % side], dim = 1, per_axis = True) | |
face_cano_mano_v = knn_gather(cano_hand_v, face_vertex_ids) | |
pts_cano_mano_v = (bc_coords[..., None] * face_cano_mano_v).sum(2) | |
face_live_mano_v = knn_gather(hand_v, face_vertex_ids) | |
pts_live_mano_v = (bc_coords[..., None] * face_live_mano_v).sum(2) | |
# face_normal = torch.cross(face_live_smpl_v[:, :, 1] - face_live_smpl_v[:, :, 0], face_live_smpl_v[:, :, 2] - face_live_smpl_v[:, :, 0]) | |
face_live_mano_n = knn_gather(hand_n, face_vertex_ids) | |
pts_live_mano_n = (bc_coords[..., None] * face_live_mano_n).sum(2) | |
pts_smpl_sdf = -torch.sign(torch.einsum('bni,bni->bn', pts_live_mano_n, posed_xyz - pts_live_mano_v)) * dists | |
return pts_cano_mano_v, pts_smpl_sdf.unsqueeze(-1) | |
left_cano_mano_v, left_mano_sdf = process_one_hand('left') | |
right_cano_mano_v, right_mano_sdf = process_one_hand('right') | |
# fuse | |
zero_hand_pose = torch.zeros((1, 15*3)).to(left_cano_mano_v) | |
color_lhand = self.forward_cano_hand_nerf(left_cano_mano_v, left_mano_sdf, view_dirs, zero_hand_pose, module = 'left_hand') | |
color_rhand = self.forward_cano_hand_nerf(right_cano_mano_v, right_mano_sdf, view_dirs, zero_hand_pose, module = 'right_hand') | |
# calculate the blending weights for blending the outputs of body network and hand networks | |
# wl = torch.sigmoid(1000 * (left_mano_sdf + 0.1)) * torch.sigmoid(25 * (left_cano_mano_v[..., 0:1] + 0.8)) | |
# wr = torch.sigmoid(1000 * (right_mano_sdf + 0.1)) * torch.sigmoid(-25 * (right_cano_mano_v[..., 0:1] - 0.8)) | |
cano_xyz = body_ret['cano_xyz'] | |
wl = torch.sigmoid(25 * (geo_util.normalize_vert_bbox(batch['left_cano_mano_v'], attris = cano_xyz, dim = 1, per_axis = True)[..., 0:1] + 0.8)) | |
wr = torch.sigmoid(-25 * (geo_util.normalize_vert_bbox(batch['right_cano_mano_v'], attris = cano_xyz, dim = 1, per_axis = True)[..., 0:1] - 0.8)) | |
wl[cano_xyz[..., 1] < batch['cano_smpl_center'][0, 1]] = 0. | |
wr[cano_xyz[..., 1] < batch['cano_smpl_center'][0, 1]] = 0. | |
s = torch.maximum(wl + wr, torch.ones_like(wl)) | |
wl, wr = wl / s, wr / s | |
# blend the outputs of body network and hand networks | |
w = wl + wr | |
# factor = 10 | |
# left_mano_sdf *= factor | |
# right_mano_sdf *= factor | |
body_ret['sdf'] = wl * left_mano_sdf + wr * right_mano_sdf + (1.0 - w) * body_ret['sdf'] | |
body_ret['color'] = wl * color_lhand + wr * color_rhand + (1.0 - w) * body_ret['color'] | |
body_ret['density'] = self.density_func(-body_ret['sdf']) | |
def forward_cano_radiance_field(self, xyz, view_dirs, batch): | |
body_ret = self.forward_cano_body_nerf(xyz, view_dirs, None, compute_grad = self.training) | |
return body_ret | |
def transform_cano2live(self, cano_pts, batch, normals = None, near_thres = 0.08): | |
cano2live_jnt_mats = batch['cano2live_jnt_mats'].clone() | |
if not self.with_hand: | |
# make sure the hand transformation is totally rigid | |
cano2live_jnt_mats[:, 25: 40] = cano2live_jnt_mats[:, 20: 21] | |
cano2live_jnt_mats[:, 40: 55] = cano2live_jnt_mats[:, 21: 22] | |
pts_w = self.cano_weight_volume.forward_weight(cano_pts) | |
pt_mats = torch.einsum('bnj,bjxy->bnxy', pts_w, cano2live_jnt_mats) | |
posed_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], cano_pts) + pt_mats[..., :3, 3] | |
if normals is None: | |
return posed_pts | |
else: | |
posed_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals) | |
return posed_pts, posed_normals | |
def transform_live2cano(self, posed_pts, batch, normals = None, near_thres = 0.08): | |
cano2live_jnt_mats = batch['cano2live_jnt_mats'].clone() | |
if not self.with_hand: | |
cano2live_jnt_mats[:, 25: 40] = cano2live_jnt_mats[:, 20: 21] | |
cano2live_jnt_mats[:, 40: 55] = cano2live_jnt_mats[:, 21: 22] | |
""" live_pts -> cano_pts """ | |
batch_size, n_pts = posed_pts.shape[:2] | |
with torch.no_grad(): | |
if 'live_mesh_v' in batch: | |
# if False: | |
tar_v = batch['live_mesh_v'] | |
tar_f = batch['live_mesh_f'] | |
tar_lbs = batch['live_mesh_lbs'] | |
pts_w, near_flag = smpl_util.calc_blending_weight(posed_pts, tar_v, tar_f, tar_lbs, near_thres, method = 'NN') | |
else: | |
tar_v = batch['live_smpl_v'] | |
tar_f = batch['smpl_faces'] | |
tar_lbs = None | |
pts_w, near_flag = smpl_util.calc_blending_weight(posed_pts, tar_v, tar_f, tar_lbs, near_thres, method = 'barycentric') | |
pt_mats = torch.einsum('bnj,bjxy->bnxy', pts_w, cano2live_jnt_mats) | |
pt_mats = torch.linalg.inv(pt_mats) | |
cano_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], posed_pts) + pt_mats[..., :3, 3] | |
# cano_pts_bk = cano_pts.detach().clone() | |
if normals is not None: | |
cano_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals) | |
if self.opt['use_root_finding']: | |
argmax_lbs = torch.argmax(pts_w, -1) | |
nonopt_bone_ids = [7, 8, 10, 11] | |
nonopt_pts_flag = torch.zeros((batch_size, n_pts), dtype = torch.bool).to(argmax_lbs.device) | |
for i in nonopt_bone_ids: | |
nonopt_pts_flag = torch.logical_or(nonopt_pts_flag, argmax_lbs == i) | |
root_finding_flag = torch.logical_not(nonopt_pts_flag) | |
if root_finding_flag.any(): | |
cano_pts_ = cano_pts[root_finding_flag].unsqueeze(0) | |
posed_pts_ = posed_pts[root_finding_flag].unsqueeze(0) | |
if not cano_pts_.is_contiguous(): | |
cano_pts_ = cano_pts_.contiguous() | |
if not posed_pts_.is_contiguous(): | |
posed_pts_ = posed_pts_.contiguous() | |
root_finding.root_finding( | |
self.weight_volume, | |
self.grad_volume, | |
posed_pts_, | |
cano_pts_, | |
cano2live_jnt_mats, | |
self.cano_weight_volume.volume_bounds, | |
self.res, | |
cano_pts_, | |
0.1, | |
10 | |
) | |
cano_pts[root_finding_flag] = cano_pts_[0] | |
if normals is None: | |
return cano_pts, near_flag | |
else: | |
return cano_pts, cano_normals, near_flag | |
def render(self, batch, chunk_size = 2048, depth_guided_sampling = None, space = 'live', white_bkgd = False): | |
ray_o = batch['ray_o'] | |
ray_d = batch['ray_d'] | |
near = batch['near'] | |
far = batch['far'] | |
if depth_guided_sampling['flag']: | |
print('# depth-guided sampling') | |
valid_dist_flag = batch['dist'] > 1e-6 | |
dist = batch['dist'][valid_dist_flag] | |
near_dist = depth_guided_sampling['near_sur_dist'] | |
far_dist = depth_guided_sampling['near_sur_dist'] | |
near[valid_dist_flag] = dist - near_dist | |
far[valid_dist_flag] = dist + far_dist | |
N_ray_samples = depth_guided_sampling['N_ray_samples'] | |
else: | |
if depth_guided_sampling.get('type', 'smpl') == 'smpl': | |
print('# smpl-guided sampling') | |
valid_dist_flag = torch.ones_like(near, dtype = bool) | |
near, far, intersect_flag = near_far_smpl(batch['live_smpl_v'][0], ray_o[0], ray_d[0]) | |
near[~intersect_flag] = batch['near'][0][~intersect_flag] | |
far[~intersect_flag] = batch['far'][0][~intersect_flag] | |
near = near.unsqueeze(0) | |
far = far.unsqueeze(0) | |
N_ray_samples = 64 | |
elif depth_guided_sampling.get('type', 'smpl') == 'uniform': | |
print('# uniform sampling') | |
valid_dist_flag = torch.ones_like(near, dtype = bool) | |
N_ray_samples = 64 | |
if self.training: | |
chunk_size = batch['ray_o'].shape[1] | |
batch_size, n_pixels = ray_o.shape[:2] | |
output_list = [] | |
for i in range(0, n_pixels, chunk_size): | |
near_chunk = near[:, i: i + chunk_size] | |
far_chunk = far[:, i: i + chunk_size] | |
ray_o_chunk = ray_o[:, i: i + chunk_size] | |
ray_d_chunk = ray_d[:, i: i + chunk_size] | |
valid_dist_flag_chunk = valid_dist_flag[:, i: i + chunk_size] | |
# sample points on each ray | |
pts, z_vals = nerf_util.sample_pts_on_rays(ray_o_chunk, ray_d_chunk, near_chunk, far_chunk, | |
N_samples = N_ray_samples, | |
perturb = self.training, | |
depth_guided_mask = valid_dist_flag_chunk) | |
# # debug: visualize pts | |
# import trimesh | |
# pts_trimesh = trimesh.PointCloud(pts[0].cpu().numpy().reshape(-1, 3)) | |
# pts_trimesh.export('./debug/sampled_pts_%s.obj' % 'training' if self.training else 'testing') | |
# exit(1) | |
# flat | |
_, n_pixels_chunk, n_samples = pts.shape[:3] | |
pts = pts.view(batch_size, n_pixels_chunk * n_samples, -1) | |
dists = z_vals[..., 1:] - z_vals[..., :-1] | |
dists = torch.cat([dists, dists[..., -1:]], -1) | |
# query | |
if space == 'live': | |
cano_pts, near_flag = self.transform_live2cano(pts, batch) | |
elif space == 'cano': | |
cano_pts = pts | |
else: | |
raise ValueError('Invalid rendering space!') | |
viewdirs = ray_d_chunk / torch.norm(ray_d_chunk, dim = -1, keepdim = True) | |
viewdirs = viewdirs[:, :, None, :].expand(-1, -1, n_samples, -1).reshape(batch_size, n_pixels_chunk * n_samples, -1) | |
# apply gaussian noise to avoid overfitting | |
if self.training: | |
with torch.no_grad(): | |
noise = torch.randn_like(viewdirs) * 0.1 | |
viewdirs = viewdirs + noise | |
viewdirs = viewdirs / torch.norm(viewdirs, dim = -1, keepdim = True) | |
ret = self.forward_cano_radiance_field(cano_pts, viewdirs, batch) | |
if self.with_hand: | |
self.fuse_hands(ret, pts, viewdirs, batch, space) | |
ret['color'] = ret['color'].view(batch_size, n_pixels_chunk, n_samples, -1) | |
ret['density'] = ret['density'].view(batch_size, n_pixels_chunk, n_samples, -1) | |
# integration | |
alpha = 1. - torch.exp(-ret['density'] * dists[..., None]) | |
raw = torch.cat([ret['color'], alpha], dim = -1) | |
rgb_map, disp_map, acc_map, weights, depth_map = nerf_util.raw2outputs(raw, z_vals, white_bkgd = white_bkgd) | |
output_chunk = { | |
'rgb_map': rgb_map, # (batch_size, n_pixel_chunk, 3) | |
'acc_map': acc_map | |
} | |
if 'normal' in ret: | |
output_chunk.update({ | |
'normal': ret['normal'].view(batch_size, n_pixels_chunk, -1, 3) | |
}) | |
if 'tv_loss' in ret: | |
output_chunk.update({ | |
'tv_loss': ret['tv_loss'].view(1, 1, -1) | |
}) | |
output_list.append(output_chunk) | |
keys = output_list[0].keys() | |
output_list = {k: torch.cat([r[k] for r in output_list], dim = 1) for k in keys} | |
# processing for patch-based ray sampling | |
if 'mask_within_patch' in batch: | |
_, ray_num = batch['mask_within_patch'].shape | |
rgb_map = torch.zeros((batch_size, ray_num, 3), dtype = torch.float32, device = config.device) | |
acc_map = torch.zeros((batch_size, ray_num), dtype = torch.float32, device = config.device) | |
rgb_map[batch['mask_within_patch']] = output_list['rgb_map'].reshape(-1, 3) | |
acc_map[batch['mask_within_patch']] = output_list['acc_map'].reshape(-1) | |
batch['color_gt'][~batch['mask_within_patch']] = 0. | |
batch['mask_gt'][~batch['mask_within_patch']] = 0. | |
output_list['rgb_map'] = rgb_map | |
output_list['acc_map'] = acc_map | |
return output_list | |