pengc02's picture
all
ec9a6bc
raw
history blame
19.9 kB
import torch
import torch.nn as nn
import numpy as np
import pytorch3d.ops
import pytorch3d.transforms
import trimesh
import config
from network.mlp import MLPLinear, SdfMLP
from network.density import LaplaceDensity
from network.volume import CanoBlendWeightVolume
from network.hand_avatar import HandAvatar
from utils.embedder import get_embedder
import utils.nerf_util as nerf_util
import utils.smpl_util as smpl_util
import utils.geo_util as geo_util
from utils.posevocab_custom_ops.near_far_smpl import near_far_smpl
from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d
from utils.knn import knn_gather
import root_finding
class TemplateNet(nn.Module):
def __init__(self, opt):
super(TemplateNet, self).__init__()
self.opt = opt
self.pos_embedder, self.pos_dim = get_embedder(opt['multires'], 3)
# canonical blend weight volume
self.cano_weight_volume = CanoBlendWeightVolume(config.opt['train']['data']['data_dir'] + '/cano_weight_volume.npz')
self.pose_feat_dim = 0
""" geometry networks """
geo_mlp_opt = {
'in_channels': self.pos_dim + self.pose_feat_dim,
'out_channels': 256 + 1,
'inter_channels': [512, 256, 256, 256, 256, 256],
'nlactv': nn.Softplus(beta = 100),
'res_layers': [4],
'geometric_init': True,
'bias': 0.7,
'weight_norm': True
}
self.geo_mlp = SdfMLP(**geo_mlp_opt)
""" texture networks """
if self.opt['use_viewdir']:
self.viewdir_embedder, self.viewdir_dim = get_embedder(self.opt['multires_viewdir'], 3)
else:
self.viewdir_embedder, self.viewdir_dim = None, 0
tex_mlp_opt = {
'in_channels': 256 + self.viewdir_dim,
'out_channels': 3,
'inter_channels': [256, 256, 256],
'nlactv': nn.ReLU(),
'last_op': nn.Sigmoid()
}
self.tex_mlp = MLPLinear(**tex_mlp_opt)
print('# MLPs: ')
print(self.geo_mlp)
print(self.tex_mlp)
# sdf2density
self.density_func = LaplaceDensity(params_init = {'beta': 0.01})
# hand avatars
self.with_hand = self.opt.get('with_hand', False)
self.left_hand = HandAvatar()
self.right_hand = HandAvatar()
# for root finding
from network.volume import compute_gradient_volume
if self.opt.get('volume_type', 'diff') == 'diff':
self.weight_volume = self.cano_weight_volume.diff_weight_volume[0].permute(1, 2, 3, 0).contiguous()
else:
self.weight_volume = self.cano_weight_volume.ori_weight_volume[0].permute(1, 2, 3, 0).contiguous()
self.grad_volume = compute_gradient_volume(self.weight_volume.permute(3, 0, 1, 2), self.cano_weight_volume.voxel_size).permute(2, 3, 4, 0, 1)\
.reshape(self.cano_weight_volume.res_x, self.cano_weight_volume.res_y, self.cano_weight_volume.res_z, -1).contiguous()
self.res = torch.tensor([self.cano_weight_volume.res_x, self.cano_weight_volume.res_y, self.cano_weight_volume.res_z], dtype = torch.int32, device = config.device)
self._initialize_hands()
def _initialize_hands(self):
smplx_lhand_to_mano_rhand_data = np.load(config.PROJ_DIR + '/smpl_files/mano/smplx_lhand_to_mano_rhand.npz', allow_pickle = True)
smplx_rhand_to_mano_rhand_data = np.load(config.PROJ_DIR + '/smpl_files/mano/smplx_rhand_to_mano_rhand.npz', allow_pickle = True)
smpl_lhand_vert_id = np.copy(smplx_lhand_to_mano_rhand_data['smpl_vert_id_to_mano'])
smpl_rhand_vert_id = np.copy(smplx_rhand_to_mano_rhand_data['smpl_vert_id_to_mano'])
self.smpl_lhand_vert_id = torch.from_numpy(smpl_lhand_vert_id).to(config.device)
self.smpl_rhand_vert_id = torch.from_numpy(smpl_rhand_vert_id).to(config.device)
self.smpl_hands_vert_id = torch.cat([self.smpl_lhand_vert_id, self.smpl_rhand_vert_id], 0)
mano_face_closed = np.loadtxt(config.PROJ_DIR + '/smpl_files/mano/mano_face_close.txt').astype(np.int64)
self.mano_face_closed = torch.from_numpy(mano_face_closed).to(config.device)
self.mano_face_closed_2hand = torch.cat([self.mano_face_closed[:, [2, 1, 0]], self.mano_face_closed + self.smpl_lhand_vert_id.shape[0]], 0)
def forward_cano_body_nerf(self, xyz, viewdirs, pose, compute_grad = False):
"""
:param xyz: (B, N, 3)
:param viewdirs: (B, N, 3)
:param pose: (B, pose_dim)
:param compute_grad: whether computing gradient w.r.t xyz
:return:
"""
if compute_grad:
xyz.requires_grad_()
# pose_feat = self.pose_feat[None, None].expand(xyz.shape[0], xyz.shape[1], -1)
# pose_feat = torch.cat([self.pos_embedder(xyz), pose_feat], -1)
pose_feat = self.pos_embedder(xyz)
geo_feat = self.geo_mlp(pose_feat)
sdf, geo_feat = torch.split(geo_feat, [1, geo_feat.shape[-1] - 1], -1)
if self.viewdir_embedder is not None:
if viewdirs is None:
viewdirs = torch.zeros_like(xyz)
geo_feat = torch.cat([geo_feat, self.viewdir_embedder(viewdirs)], -1)
color = self.tex_mlp(geo_feat)
density = self.density_func(sdf)
ret = {
'sdf': -sdf, # assume outside is negative, inside is positive
'density': density,
'color': color,
'cano_xyz': xyz.detach()
}
if compute_grad:
d_output = torch.ones_like(sdf, requires_grad = False, device = sdf.device)
normal = torch.autograd.grad(outputs = sdf,
inputs = xyz,
grad_outputs = d_output,
create_graph = self.training,
retain_graph = self.training,
only_inputs = True)[0]
ret.update({
'normal': normal
})
return ret
def forward_cano_hand_nerf(self, xyz, sdf, viewdirs, hand_pose, module = 'left_hand'):
net = self.__getattr__(module)
return net(xyz, sdf, viewdirs, hand_pose)
def fuse_hands(self, body_ret, posed_xyz, view_dirs, batch, space = 'live'):
# get hand correspondences
batch_size, n_pts = posed_xyz.shape[:2]
def process_one_hand(side = 'left'):
hand_v = batch['%s_live_mano_v' % side] if space == 'live' else batch['%s_cano_mano_v' % side]
hand_n = batch['%s_live_mano_n' % side] if space == 'live' else batch['%s_cano_mano_n' % side]
hand_f = self.mano_face_closed[:, [2, 1, 0]] if side == 'left' else self.mano_face_closed
dists, face_indices, bc_coords = nearest_face_pytorch3d(posed_xyz, hand_v, hand_f)
face_vertex_ids = torch.gather(hand_f[None].expand(batch_size, -1, -1), 1, face_indices[:, :, None].long().expand(-1, -1, 3)) # (B, N, 3)
cano_hand_v = geo_util.normalize_vert_bbox(batch['%s_cano_mano_v' % side], dim = 1, per_axis = True)
face_cano_mano_v = knn_gather(cano_hand_v, face_vertex_ids)
pts_cano_mano_v = (bc_coords[..., None] * face_cano_mano_v).sum(2)
face_live_mano_v = knn_gather(hand_v, face_vertex_ids)
pts_live_mano_v = (bc_coords[..., None] * face_live_mano_v).sum(2)
# face_normal = torch.cross(face_live_smpl_v[:, :, 1] - face_live_smpl_v[:, :, 0], face_live_smpl_v[:, :, 2] - face_live_smpl_v[:, :, 0])
face_live_mano_n = knn_gather(hand_n, face_vertex_ids)
pts_live_mano_n = (bc_coords[..., None] * face_live_mano_n).sum(2)
pts_smpl_sdf = -torch.sign(torch.einsum('bni,bni->bn', pts_live_mano_n, posed_xyz - pts_live_mano_v)) * dists
return pts_cano_mano_v, pts_smpl_sdf.unsqueeze(-1)
left_cano_mano_v, left_mano_sdf = process_one_hand('left')
right_cano_mano_v, right_mano_sdf = process_one_hand('right')
# fuse
zero_hand_pose = torch.zeros((1, 15*3)).to(left_cano_mano_v)
color_lhand = self.forward_cano_hand_nerf(left_cano_mano_v, left_mano_sdf, view_dirs, zero_hand_pose, module = 'left_hand')
color_rhand = self.forward_cano_hand_nerf(right_cano_mano_v, right_mano_sdf, view_dirs, zero_hand_pose, module = 'right_hand')
# calculate the blending weights for blending the outputs of body network and hand networks
# wl = torch.sigmoid(1000 * (left_mano_sdf + 0.1)) * torch.sigmoid(25 * (left_cano_mano_v[..., 0:1] + 0.8))
# wr = torch.sigmoid(1000 * (right_mano_sdf + 0.1)) * torch.sigmoid(-25 * (right_cano_mano_v[..., 0:1] - 0.8))
cano_xyz = body_ret['cano_xyz']
wl = torch.sigmoid(25 * (geo_util.normalize_vert_bbox(batch['left_cano_mano_v'], attris = cano_xyz, dim = 1, per_axis = True)[..., 0:1] + 0.8))
wr = torch.sigmoid(-25 * (geo_util.normalize_vert_bbox(batch['right_cano_mano_v'], attris = cano_xyz, dim = 1, per_axis = True)[..., 0:1] - 0.8))
wl[cano_xyz[..., 1] < batch['cano_smpl_center'][0, 1]] = 0.
wr[cano_xyz[..., 1] < batch['cano_smpl_center'][0, 1]] = 0.
s = torch.maximum(wl + wr, torch.ones_like(wl))
wl, wr = wl / s, wr / s
# blend the outputs of body network and hand networks
w = wl + wr
# factor = 10
# left_mano_sdf *= factor
# right_mano_sdf *= factor
body_ret['sdf'] = wl * left_mano_sdf + wr * right_mano_sdf + (1.0 - w) * body_ret['sdf']
body_ret['color'] = wl * color_lhand + wr * color_rhand + (1.0 - w) * body_ret['color']
body_ret['density'] = self.density_func(-body_ret['sdf'])
def forward_cano_radiance_field(self, xyz, view_dirs, batch):
body_ret = self.forward_cano_body_nerf(xyz, view_dirs, None, compute_grad = self.training)
return body_ret
def transform_cano2live(self, cano_pts, batch, normals = None, near_thres = 0.08):
cano2live_jnt_mats = batch['cano2live_jnt_mats'].clone()
if not self.with_hand:
# make sure the hand transformation is totally rigid
cano2live_jnt_mats[:, 25: 40] = cano2live_jnt_mats[:, 20: 21]
cano2live_jnt_mats[:, 40: 55] = cano2live_jnt_mats[:, 21: 22]
pts_w = self.cano_weight_volume.forward_weight(cano_pts)
pt_mats = torch.einsum('bnj,bjxy->bnxy', pts_w, cano2live_jnt_mats)
posed_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], cano_pts) + pt_mats[..., :3, 3]
if normals is None:
return posed_pts
else:
posed_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals)
return posed_pts, posed_normals
def transform_live2cano(self, posed_pts, batch, normals = None, near_thres = 0.08):
cano2live_jnt_mats = batch['cano2live_jnt_mats'].clone()
if not self.with_hand:
cano2live_jnt_mats[:, 25: 40] = cano2live_jnt_mats[:, 20: 21]
cano2live_jnt_mats[:, 40: 55] = cano2live_jnt_mats[:, 21: 22]
""" live_pts -> cano_pts """
batch_size, n_pts = posed_pts.shape[:2]
with torch.no_grad():
if 'live_mesh_v' in batch:
# if False:
tar_v = batch['live_mesh_v']
tar_f = batch['live_mesh_f']
tar_lbs = batch['live_mesh_lbs']
pts_w, near_flag = smpl_util.calc_blending_weight(posed_pts, tar_v, tar_f, tar_lbs, near_thres, method = 'NN')
else:
tar_v = batch['live_smpl_v']
tar_f = batch['smpl_faces']
tar_lbs = None
pts_w, near_flag = smpl_util.calc_blending_weight(posed_pts, tar_v, tar_f, tar_lbs, near_thres, method = 'barycentric')
pt_mats = torch.einsum('bnj,bjxy->bnxy', pts_w, cano2live_jnt_mats)
pt_mats = torch.linalg.inv(pt_mats)
cano_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], posed_pts) + pt_mats[..., :3, 3]
# cano_pts_bk = cano_pts.detach().clone()
if normals is not None:
cano_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals)
if self.opt['use_root_finding']:
argmax_lbs = torch.argmax(pts_w, -1)
nonopt_bone_ids = [7, 8, 10, 11]
nonopt_pts_flag = torch.zeros((batch_size, n_pts), dtype = torch.bool).to(argmax_lbs.device)
for i in nonopt_bone_ids:
nonopt_pts_flag = torch.logical_or(nonopt_pts_flag, argmax_lbs == i)
root_finding_flag = torch.logical_not(nonopt_pts_flag)
if root_finding_flag.any():
cano_pts_ = cano_pts[root_finding_flag].unsqueeze(0)
posed_pts_ = posed_pts[root_finding_flag].unsqueeze(0)
if not cano_pts_.is_contiguous():
cano_pts_ = cano_pts_.contiguous()
if not posed_pts_.is_contiguous():
posed_pts_ = posed_pts_.contiguous()
root_finding.root_finding(
self.weight_volume,
self.grad_volume,
posed_pts_,
cano_pts_,
cano2live_jnt_mats,
self.cano_weight_volume.volume_bounds,
self.res,
cano_pts_,
0.1,
10
)
cano_pts[root_finding_flag] = cano_pts_[0]
if normals is None:
return cano_pts, near_flag
else:
return cano_pts, cano_normals, near_flag
def render(self, batch, chunk_size = 2048, depth_guided_sampling = None, space = 'live', white_bkgd = False):
ray_o = batch['ray_o']
ray_d = batch['ray_d']
near = batch['near']
far = batch['far']
if depth_guided_sampling['flag']:
print('# depth-guided sampling')
valid_dist_flag = batch['dist'] > 1e-6
dist = batch['dist'][valid_dist_flag]
near_dist = depth_guided_sampling['near_sur_dist']
far_dist = depth_guided_sampling['near_sur_dist']
near[valid_dist_flag] = dist - near_dist
far[valid_dist_flag] = dist + far_dist
N_ray_samples = depth_guided_sampling['N_ray_samples']
else:
if depth_guided_sampling.get('type', 'smpl') == 'smpl':
print('# smpl-guided sampling')
valid_dist_flag = torch.ones_like(near, dtype = bool)
near, far, intersect_flag = near_far_smpl(batch['live_smpl_v'][0], ray_o[0], ray_d[0])
near[~intersect_flag] = batch['near'][0][~intersect_flag]
far[~intersect_flag] = batch['far'][0][~intersect_flag]
near = near.unsqueeze(0)
far = far.unsqueeze(0)
N_ray_samples = 64
elif depth_guided_sampling.get('type', 'smpl') == 'uniform':
print('# uniform sampling')
valid_dist_flag = torch.ones_like(near, dtype = bool)
N_ray_samples = 64
if self.training:
chunk_size = batch['ray_o'].shape[1]
batch_size, n_pixels = ray_o.shape[:2]
output_list = []
for i in range(0, n_pixels, chunk_size):
near_chunk = near[:, i: i + chunk_size]
far_chunk = far[:, i: i + chunk_size]
ray_o_chunk = ray_o[:, i: i + chunk_size]
ray_d_chunk = ray_d[:, i: i + chunk_size]
valid_dist_flag_chunk = valid_dist_flag[:, i: i + chunk_size]
# sample points on each ray
pts, z_vals = nerf_util.sample_pts_on_rays(ray_o_chunk, ray_d_chunk, near_chunk, far_chunk,
N_samples = N_ray_samples,
perturb = self.training,
depth_guided_mask = valid_dist_flag_chunk)
# # debug: visualize pts
# import trimesh
# pts_trimesh = trimesh.PointCloud(pts[0].cpu().numpy().reshape(-1, 3))
# pts_trimesh.export('./debug/sampled_pts_%s.obj' % 'training' if self.training else 'testing')
# exit(1)
# flat
_, n_pixels_chunk, n_samples = pts.shape[:3]
pts = pts.view(batch_size, n_pixels_chunk * n_samples, -1)
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, dists[..., -1:]], -1)
# query
if space == 'live':
cano_pts, near_flag = self.transform_live2cano(pts, batch)
elif space == 'cano':
cano_pts = pts
else:
raise ValueError('Invalid rendering space!')
viewdirs = ray_d_chunk / torch.norm(ray_d_chunk, dim = -1, keepdim = True)
viewdirs = viewdirs[:, :, None, :].expand(-1, -1, n_samples, -1).reshape(batch_size, n_pixels_chunk * n_samples, -1)
# apply gaussian noise to avoid overfitting
if self.training:
with torch.no_grad():
noise = torch.randn_like(viewdirs) * 0.1
viewdirs = viewdirs + noise
viewdirs = viewdirs / torch.norm(viewdirs, dim = -1, keepdim = True)
ret = self.forward_cano_radiance_field(cano_pts, viewdirs, batch)
if self.with_hand:
self.fuse_hands(ret, pts, viewdirs, batch, space)
ret['color'] = ret['color'].view(batch_size, n_pixels_chunk, n_samples, -1)
ret['density'] = ret['density'].view(batch_size, n_pixels_chunk, n_samples, -1)
# integration
alpha = 1. - torch.exp(-ret['density'] * dists[..., None])
raw = torch.cat([ret['color'], alpha], dim = -1)
rgb_map, disp_map, acc_map, weights, depth_map = nerf_util.raw2outputs(raw, z_vals, white_bkgd = white_bkgd)
output_chunk = {
'rgb_map': rgb_map, # (batch_size, n_pixel_chunk, 3)
'acc_map': acc_map
}
if 'normal' in ret:
output_chunk.update({
'normal': ret['normal'].view(batch_size, n_pixels_chunk, -1, 3)
})
if 'tv_loss' in ret:
output_chunk.update({
'tv_loss': ret['tv_loss'].view(1, 1, -1)
})
output_list.append(output_chunk)
keys = output_list[0].keys()
output_list = {k: torch.cat([r[k] for r in output_list], dim = 1) for k in keys}
# processing for patch-based ray sampling
if 'mask_within_patch' in batch:
_, ray_num = batch['mask_within_patch'].shape
rgb_map = torch.zeros((batch_size, ray_num, 3), dtype = torch.float32, device = config.device)
acc_map = torch.zeros((batch_size, ray_num), dtype = torch.float32, device = config.device)
rgb_map[batch['mask_within_patch']] = output_list['rgb_map'].reshape(-1, 3)
acc_map[batch['mask_within_patch']] = output_list['acc_map'].reshape(-1)
batch['color_gt'][~batch['mask_within_patch']] = 0.
batch['mask_gt'][~batch['mask_within_patch']] = 0.
output_list['rgb_map'] = rgb_map
output_list['acc_map'] = acc_map
return output_list