Chao Xu
sparseneus and elev est
854f0d0
# The codes are partly from IBRNet
import torch
import torch.nn.functional as F
from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume
def safe_l2_normalize(x, dim=None, eps=1e-6):
return F.normalize(x, p=2, dim=dim, eps=eps)
class Projector():
"""
Obtain features from geometryVolume and rendering_feature_maps for generalized rendering
"""
def compute_angle(self, xyz, query_c2w, supporting_c2ws):
"""
:param xyz: [N_rays, n_samples,3 ]
:param query_c2w: [1,4,4]
:param supporting_c2ws: [n,4,4]
:return:
"""
N_rays, n_samples, _ = xyz.shape
num_views = supporting_c2ws.shape[0]
xyz = xyz.reshape(-1, 3)
ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6)
ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6)
ray_diff = ray2tar_pose - ray2support_pose
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True)
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product
return ray_diff.detach()
def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws):
"""
:param xyz: [N_rays, n_samples,3 ]
:param surface_normals: [N_rays, n_samples,3 ]
:param supporting_c2ws: [n,4,4]
:return:
"""
N_rays, n_samples, _ = xyz.shape
num_views = supporting_c2ws.shape[0]
xyz = xyz.reshape(-1, 3)
ray2tar_pose = surface_normals
ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6)
ray_diff = ray2tar_pose - ray2support_pose
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True)
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product,
# and the first three dimensions is the normalized ray diff vector
return ray_diff.detach()
@torch.no_grad()
def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values):
"""
compute the depth difference of query pts projected on the image and the predicted depth values of the image
:param xyz: [N_rays, n_samples,3 ]
:param w2cs: [N_views, 4, 4]
:param intrinsics: [N_views, 3, 3]
:param pred_depth_values: [N_views, N_rays, n_samples,1 ]
:param pred_depth_masks: [N_views, N_rays, n_samples]
:return:
"""
device = xyz.device
N_views = w2cs.shape[0]
N_rays, n_samples, _ = xyz.shape
proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :])
proj_rot = proj_matrix[:, :3, :3]
proj_trans = proj_matrix[:, :3, 3:]
batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1)
proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans
# X = proj_xyz[:, 0]
# Y = proj_xyz[:, 1]
Z = proj_xyz[:, 2].clamp(min=1e-3) # [N_views, N_rays*n_samples]
proj_z = Z.view(N_views, N_rays, n_samples, 1)
z_diff = proj_z - pred_depth_values # [N_views, N_rays, n_samples,1 ]
return z_diff
def compute(self,
pts,
# * 3d geometry feature volumes
geometryVolume=None,
geometryVolumeMask=None,
vol_dims=None,
partial_vol_origin=None,
vol_size=None,
# * 2d rendering feature maps
rendering_feature_maps=None,
color_maps=None,
w2cs=None,
intrinsics=None,
img_wh=None,
query_img_idx=0, # the index of the N_views dim for rendering
query_c2w=None,
pred_depth_maps=None, # no use here
pred_depth_masks=None # no use here
):
"""
extract features of pts for rendering
:param pts:
:param geometryVolume:
:param vol_dims:
:param partial_vol_origin:
:param vol_size:
:param rendering_feature_maps:
:param color_maps:
:param w2cs:
:param intrinsics:
:param img_wh:
:param rendering_img_idx: by default, we render the first view of w2cs
:return:
"""
device = pts.device
c2ws = torch.inverse(w2cs)
if len(pts.shape) == 2:
pts = pts[None, :, :]
N_rays, n_samples, _ = pts.shape
N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W)
supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device)
query_img_idx = torch.LongTensor([query_img_idx]).to(device)
if query_c2w is None and query_img_idx > -1:
query_c2w = torch.index_select(c2ws, 0, query_img_idx)
supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs)
supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs)
supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs)
supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs)
supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs)
if pred_depth_maps is not None:
supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs)
supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs)
# print("N_supporting_views: ", N_views - 1)
N_supporting_views = N_views - 1
else:
supporting_c2ws = c2ws
supporting_w2cs = w2cs
supporting_rendering_feature_maps = rendering_feature_maps
supporting_color_maps = color_maps
supporting_intrinsics = intrinsics
supporting_depth_maps = pred_depth_masks
supporting_depth_masks = pred_depth_masks
# print("N_supporting_views: ", N_views)
N_supporting_views = N_views
# import ipdb; ipdb.set_trace()
if geometryVolume is not None:
# * sample feature of pts from 3D feature volume
pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume(
pts, geometryVolume, vol_dims,
partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples]
if len(geometryVolumeMask.shape) == 3:
geometryVolumeMask = geometryVolumeMask[None, :, :, :]
pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume(
pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims,
partial_vol_origin, vol_size) # [N_rays, n_samples, C]
pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0)
else:
pts_geometry_feature = None
pts_geometry_masks = None
# * sample feature of pts from 2D feature maps
pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps(
pts, supporting_rendering_feature_maps, supporting_w2cs,
supporting_intrinsics, img_wh,
return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples]
# import ipdb; ipdb.set_trace()
# * size (N_views, N_rays*n_samples, c)
pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous()
pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs,
supporting_intrinsics, img_wh)
# * size (N_views, N_rays*n_samples, c)
pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous()
rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c]
ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4]
# import ipdb; ipdb.set_trace()
if pts_geometry_masks is not None:
final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \
pts_rendering_mask # [N_views, N_rays, n_samples]
else:
final_mask = pts_rendering_mask
# import ipdb; ipdb.set_trace()
z_diff, pts_pred_depth_masks = None, None
if pred_depth_maps is not None:
pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs,
supporting_intrinsics, img_wh)
pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3,
1).contiguous() # (N_views, N_rays*n_samples, 1)
# - pts_pred_depth_masks are critical than final_mask,
# - the ray containing few invalid pts will be treated invalid
pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(),
supporting_w2cs,
supporting_intrinsics, img_wh)
pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :,
0] # (N_views, N_rays*n_samples)
z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values)
# import ipdb; ipdb.set_trace()
return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks
def compute_view_independent(
self,
pts,
# * 3d geometry feature volumes
geometryVolume=None,
geometryVolumeMask=None,
sdf_network=None,
lod=0,
vol_dims=None,
partial_vol_origin=None,
vol_size=None,
# * 2d rendering feature maps
rendering_feature_maps=None,
color_maps=None,
w2cs=None,
target_candidate_w2cs=None,
intrinsics=None,
img_wh=None,
query_img_idx=0, # the index of the N_views dim for rendering
query_c2w=None,
pred_depth_maps=None, # no use here
pred_depth_masks=None # no use here
):
"""
extract features of pts for rendering
:param pts:
:param geometryVolume:
:param vol_dims:
:param partial_vol_origin:
:param vol_size:
:param rendering_feature_maps:
:param color_maps:
:param w2cs:
:param intrinsics:
:param img_wh:
:param rendering_img_idx: by default, we render the first view of w2cs
:return:
"""
device = pts.device
c2ws = torch.inverse(w2cs)
if len(pts.shape) == 2:
pts = pts[None, :, :]
N_rays, n_samples, _ = pts.shape
N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W)
supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device)
query_img_idx = torch.LongTensor([query_img_idx]).to(device)
if query_c2w is None and query_img_idx > -1:
query_c2w = torch.index_select(c2ws, 0, query_img_idx)
supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs)
supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs)
supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs)
supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs)
supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs)
if pred_depth_maps is not None:
supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs)
supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs)
# print("N_supporting_views: ", N_views - 1)
N_supporting_views = N_views - 1
else:
supporting_c2ws = c2ws
supporting_w2cs = w2cs
supporting_rendering_feature_maps = rendering_feature_maps
supporting_color_maps = color_maps
supporting_intrinsics = intrinsics
supporting_depth_maps = pred_depth_masks
supporting_depth_masks = pred_depth_masks
# print("N_supporting_views: ", N_views)
N_supporting_views = N_views
# import ipdb; ipdb.set_trace()
if geometryVolume is not None:
# * sample feature of pts from 3D feature volume
pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume(
pts, geometryVolume, vol_dims,
partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples]
if len(geometryVolumeMask.shape) == 3:
geometryVolumeMask = geometryVolumeMask[None, :, :, :]
pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume(
pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims,
partial_vol_origin, vol_size) # [N_rays, n_samples, C]
pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0)
else:
pts_geometry_feature = None
pts_geometry_masks = None
# * sample feature of pts from 2D feature maps
pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps(
pts, supporting_rendering_feature_maps, supporting_w2cs,
supporting_intrinsics, img_wh,
return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples]
# * size (N_views, N_rays*n_samples, c)
pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous()
pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs,
supporting_intrinsics, img_wh)
# * size (N_views, N_rays*n_samples, c)
pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous()
rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c]
# import ipdb; ipdb.set_trace()
gradients = sdf_network.gradient(
pts.reshape(-1, 3), # pts.squeeze(0),
geometryVolume.unsqueeze(0),
lod=lod
).squeeze()
surface_normals = safe_l2_normalize(gradients, dim=-1) # [npts, 3]
# input normals
ren_ray_diff = self.compute_angle_view_independent(
xyz=pts,
surface_normals=surface_normals,
supporting_c2ws=supporting_c2ws
)
# # choose closest target view direction from 32 candidate views
# # choose the closest source view as view direction instead of the normals vectors
# pts2src_centers = safe_l2_normalize((supporting_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3]
# cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1]
# # choose the largest cosine distance as the view direction
# max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1]
# chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3]
# ren_ray_diff = self.compute_angle_view_independent(
# xyz=pts,
# surface_normals=chosen_view_direction,
# supporting_c2ws=supporting_c2ws
# )
# # choose closest target view direction from 8 candidate views
# # choose the closest source view as view direction instead of the normals vectors
# target_candidate_c2ws = torch.inverse(target_candidate_w2cs)
# pts2src_centers = safe_l2_normalize((target_candidate_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3]
# cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1]
# # choose the largest cosine distance as the view direction
# max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1]
# chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3]
# ren_ray_diff = self.compute_angle_view_independent(
# xyz=pts,
# surface_normals=chosen_view_direction,
# supporting_c2ws=supporting_c2ws
# )
# ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4]
# import ipdb; ipdb.set_trace()
# input_directions = safe_l2_normalize(pts)
# ren_ray_diff = self.compute_angle_view_independent(
# xyz=pts,
# surface_normals=input_directions,
# supporting_c2ws=supporting_c2ws
# )
if pts_geometry_masks is not None:
final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \
pts_rendering_mask # [N_views, N_rays, n_samples]
else:
final_mask = pts_rendering_mask
# import ipdb; ipdb.set_trace()
z_diff, pts_pred_depth_masks = None, None
if pred_depth_maps is not None:
pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs,
supporting_intrinsics, img_wh)
pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3,
1).contiguous() # (N_views, N_rays*n_samples, 1)
# - pts_pred_depth_masks are critical than final_mask,
# - the ray containing few invalid pts will be treated invalid
pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(),
supporting_w2cs,
supporting_intrinsics, img_wh)
pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :,
0] # (N_views, N_rays*n_samples)
z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values)
# import ipdb; ipdb.set_trace()
return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks