|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume |
|
|
|
def safe_l2_normalize(x, dim=None, eps=1e-6): |
|
return F.normalize(x, p=2, dim=dim, eps=eps) |
|
|
|
class Projector(): |
|
""" |
|
Obtain features from geometryVolume and rendering_feature_maps for generalized rendering |
|
""" |
|
|
|
def compute_angle(self, xyz, query_c2w, supporting_c2ws): |
|
""" |
|
|
|
:param xyz: [N_rays, n_samples,3 ] |
|
:param query_c2w: [1,4,4] |
|
:param supporting_c2ws: [n,4,4] |
|
:return: |
|
""" |
|
N_rays, n_samples, _ = xyz.shape |
|
num_views = supporting_c2ws.shape[0] |
|
xyz = xyz.reshape(-1, 3) |
|
|
|
ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
|
ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6) |
|
ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
|
ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) |
|
ray_diff = ray2tar_pose - ray2support_pose |
|
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) |
|
ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) |
|
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) |
|
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) |
|
ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) |
|
return ray_diff.detach() |
|
|
|
|
|
def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws): |
|
""" |
|
|
|
:param xyz: [N_rays, n_samples,3 ] |
|
:param surface_normals: [N_rays, n_samples,3 ] |
|
:param supporting_c2ws: [n,4,4] |
|
:return: |
|
""" |
|
N_rays, n_samples, _ = xyz.shape |
|
num_views = supporting_c2ws.shape[0] |
|
xyz = xyz.reshape(-1, 3) |
|
|
|
ray2tar_pose = surface_normals |
|
ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
|
ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) |
|
ray_diff = ray2tar_pose - ray2support_pose |
|
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) |
|
ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) |
|
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) |
|
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) |
|
ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) |
|
|
|
return ray_diff.detach() |
|
|
|
@torch.no_grad() |
|
def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values): |
|
""" |
|
compute the depth difference of query pts projected on the image and the predicted depth values of the image |
|
:param xyz: [N_rays, n_samples,3 ] |
|
:param w2cs: [N_views, 4, 4] |
|
:param intrinsics: [N_views, 3, 3] |
|
:param pred_depth_values: [N_views, N_rays, n_samples,1 ] |
|
:param pred_depth_masks: [N_views, N_rays, n_samples] |
|
:return: |
|
""" |
|
device = xyz.device |
|
N_views = w2cs.shape[0] |
|
N_rays, n_samples, _ = xyz.shape |
|
proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :]) |
|
|
|
proj_rot = proj_matrix[:, :3, :3] |
|
proj_trans = proj_matrix[:, :3, 3:] |
|
|
|
batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1) |
|
|
|
proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans |
|
|
|
|
|
|
|
Z = proj_xyz[:, 2].clamp(min=1e-3) |
|
proj_z = Z.view(N_views, N_rays, n_samples, 1) |
|
|
|
z_diff = proj_z - pred_depth_values |
|
|
|
return z_diff |
|
|
|
def compute(self, |
|
pts, |
|
|
|
geometryVolume=None, |
|
geometryVolumeMask=None, |
|
vol_dims=None, |
|
partial_vol_origin=None, |
|
vol_size=None, |
|
|
|
rendering_feature_maps=None, |
|
color_maps=None, |
|
w2cs=None, |
|
intrinsics=None, |
|
img_wh=None, |
|
query_img_idx=0, |
|
query_c2w=None, |
|
pred_depth_maps=None, |
|
pred_depth_masks=None |
|
): |
|
""" |
|
extract features of pts for rendering |
|
:param pts: |
|
:param geometryVolume: |
|
:param vol_dims: |
|
:param partial_vol_origin: |
|
:param vol_size: |
|
:param rendering_feature_maps: |
|
:param color_maps: |
|
:param w2cs: |
|
:param intrinsics: |
|
:param img_wh: |
|
:param rendering_img_idx: by default, we render the first view of w2cs |
|
:return: |
|
""" |
|
device = pts.device |
|
c2ws = torch.inverse(w2cs) |
|
|
|
if len(pts.shape) == 2: |
|
pts = pts[None, :, :] |
|
|
|
N_rays, n_samples, _ = pts.shape |
|
N_views = rendering_feature_maps.shape[0] |
|
|
|
supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) |
|
query_img_idx = torch.LongTensor([query_img_idx]).to(device) |
|
|
|
if query_c2w is None and query_img_idx > -1: |
|
query_c2w = torch.index_select(c2ws, 0, query_img_idx) |
|
supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) |
|
supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) |
|
supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) |
|
supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) |
|
supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) |
|
|
|
if pred_depth_maps is not None: |
|
supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) |
|
supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) |
|
|
|
N_supporting_views = N_views - 1 |
|
else: |
|
supporting_c2ws = c2ws |
|
supporting_w2cs = w2cs |
|
supporting_rendering_feature_maps = rendering_feature_maps |
|
supporting_color_maps = color_maps |
|
supporting_intrinsics = intrinsics |
|
supporting_depth_maps = pred_depth_masks |
|
supporting_depth_masks = pred_depth_masks |
|
|
|
N_supporting_views = N_views |
|
|
|
if geometryVolume is not None: |
|
|
|
pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( |
|
pts, geometryVolume, vol_dims, |
|
partial_vol_origin, vol_size) |
|
|
|
if len(geometryVolumeMask.shape) == 3: |
|
geometryVolumeMask = geometryVolumeMask[None, :, :, :] |
|
|
|
pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( |
|
pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, |
|
partial_vol_origin, vol_size) |
|
|
|
pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) |
|
else: |
|
pts_geometry_feature = None |
|
pts_geometry_masks = None |
|
|
|
|
|
pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( |
|
pts, supporting_rendering_feature_maps, supporting_w2cs, |
|
supporting_intrinsics, img_wh, |
|
return_mask=True) |
|
|
|
|
|
pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() |
|
|
|
pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, |
|
supporting_intrinsics, img_wh) |
|
|
|
pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() |
|
|
|
rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) |
|
|
|
|
|
ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) |
|
|
|
if pts_geometry_masks is not None: |
|
final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ |
|
pts_rendering_mask |
|
else: |
|
final_mask = pts_rendering_mask |
|
|
|
z_diff, pts_pred_depth_masks = None, None |
|
|
|
if pred_depth_maps is not None: |
|
pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, |
|
supporting_intrinsics, img_wh) |
|
pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, |
|
1).contiguous() |
|
|
|
|
|
|
|
pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), |
|
supporting_w2cs, |
|
supporting_intrinsics, img_wh) |
|
|
|
pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, |
|
0] |
|
|
|
z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) |
|
|
|
return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks |
|
|
|
|
|
def compute_view_independent( |
|
self, |
|
pts, |
|
|
|
geometryVolume=None, |
|
geometryVolumeMask=None, |
|
sdf_network=None, |
|
lod=0, |
|
vol_dims=None, |
|
partial_vol_origin=None, |
|
vol_size=None, |
|
|
|
rendering_feature_maps=None, |
|
color_maps=None, |
|
w2cs=None, |
|
target_candidate_w2cs=None, |
|
intrinsics=None, |
|
img_wh=None, |
|
query_img_idx=0, |
|
query_c2w=None, |
|
pred_depth_maps=None, |
|
pred_depth_masks=None |
|
): |
|
""" |
|
extract features of pts for rendering |
|
:param pts: |
|
:param geometryVolume: |
|
:param vol_dims: |
|
:param partial_vol_origin: |
|
:param vol_size: |
|
:param rendering_feature_maps: |
|
:param color_maps: |
|
:param w2cs: |
|
:param intrinsics: |
|
:param img_wh: |
|
:param rendering_img_idx: by default, we render the first view of w2cs |
|
:return: |
|
""" |
|
device = pts.device |
|
c2ws = torch.inverse(w2cs) |
|
|
|
if len(pts.shape) == 2: |
|
pts = pts[None, :, :] |
|
|
|
N_rays, n_samples, _ = pts.shape |
|
N_views = rendering_feature_maps.shape[0] |
|
|
|
supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) |
|
query_img_idx = torch.LongTensor([query_img_idx]).to(device) |
|
|
|
if query_c2w is None and query_img_idx > -1: |
|
query_c2w = torch.index_select(c2ws, 0, query_img_idx) |
|
supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) |
|
supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) |
|
supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) |
|
supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) |
|
supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) |
|
|
|
if pred_depth_maps is not None: |
|
supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) |
|
supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) |
|
|
|
N_supporting_views = N_views - 1 |
|
else: |
|
supporting_c2ws = c2ws |
|
supporting_w2cs = w2cs |
|
supporting_rendering_feature_maps = rendering_feature_maps |
|
supporting_color_maps = color_maps |
|
supporting_intrinsics = intrinsics |
|
supporting_depth_maps = pred_depth_masks |
|
supporting_depth_masks = pred_depth_masks |
|
|
|
N_supporting_views = N_views |
|
|
|
if geometryVolume is not None: |
|
|
|
pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( |
|
pts, geometryVolume, vol_dims, |
|
partial_vol_origin, vol_size) |
|
|
|
if len(geometryVolumeMask.shape) == 3: |
|
geometryVolumeMask = geometryVolumeMask[None, :, :, :] |
|
|
|
pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( |
|
pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, |
|
partial_vol_origin, vol_size) |
|
|
|
pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) |
|
else: |
|
pts_geometry_feature = None |
|
pts_geometry_masks = None |
|
|
|
|
|
pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( |
|
pts, supporting_rendering_feature_maps, supporting_w2cs, |
|
supporting_intrinsics, img_wh, |
|
return_mask=True) |
|
|
|
|
|
pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() |
|
|
|
pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, |
|
supporting_intrinsics, img_wh) |
|
|
|
pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() |
|
|
|
rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) |
|
|
|
|
|
|
|
gradients = sdf_network.gradient( |
|
pts.reshape(-1, 3), |
|
geometryVolume.unsqueeze(0), |
|
lod=lod |
|
).squeeze() |
|
|
|
surface_normals = safe_l2_normalize(gradients, dim=-1) |
|
|
|
ren_ray_diff = self.compute_angle_view_independent( |
|
xyz=pts, |
|
surface_normals=surface_normals, |
|
supporting_c2ws=supporting_c2ws |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pts_geometry_masks is not None: |
|
final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ |
|
pts_rendering_mask |
|
else: |
|
final_mask = pts_rendering_mask |
|
|
|
z_diff, pts_pred_depth_masks = None, None |
|
|
|
if pred_depth_maps is not None: |
|
pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, |
|
supporting_intrinsics, img_wh) |
|
pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, |
|
1).contiguous() |
|
|
|
|
|
|
|
pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), |
|
supporting_w2cs, |
|
supporting_intrinsics, img_wh) |
|
|
|
pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, |
|
0] |
|
|
|
z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) |
|
|
|
return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks |
|
|