|
""" |
|
The codes are heavily borrowed from NeuS |
|
""" |
|
|
|
import os |
|
import cv2 as cv |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import logging |
|
import mcubes |
|
from icecream import ic |
|
from models.render_utils import sample_pdf |
|
|
|
from models.projector import Projector |
|
from tsparse.torchsparse_utils import sparse_to_dense_channel |
|
|
|
from models.fast_renderer import FastRenderer |
|
|
|
from models.patch_projector import PatchProjector |
|
|
|
|
|
class SparseNeuSRenderer(nn.Module): |
|
""" |
|
conditional neus render; |
|
optimize on normalized world space; |
|
warped by nn.Module to support DataParallel traning |
|
""" |
|
|
|
def __init__(self, |
|
rendering_network_outside, |
|
sdf_network, |
|
variance_network, |
|
rendering_network, |
|
n_samples, |
|
n_importance, |
|
n_outside, |
|
perturb, |
|
alpha_type='div', |
|
conf=None |
|
): |
|
super(SparseNeuSRenderer, self).__init__() |
|
|
|
self.conf = conf |
|
self.base_exp_dir = conf['general.base_exp_dir'] |
|
|
|
|
|
self.rendering_network_outside = rendering_network_outside |
|
self.sdf_network = sdf_network |
|
self.variance_network = variance_network |
|
self.rendering_network = rendering_network |
|
|
|
self.n_samples = n_samples |
|
self.n_importance = n_importance |
|
self.n_outside = n_outside |
|
self.perturb = perturb |
|
self.alpha_type = alpha_type |
|
|
|
self.rendering_projector = Projector() |
|
|
|
self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3) |
|
self.patch_projector = PatchProjector(self.h_patch_size) |
|
|
|
self.ray_tracer = FastRenderer() |
|
|
|
|
|
try: |
|
self.if_fitted_rendering = self.sdf_network.if_fitted_rendering |
|
except: |
|
self.if_fitted_rendering = False |
|
|
|
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance, |
|
conditional_valid_mask_volume=None): |
|
device = rays_o.device |
|
batch_size, n_samples = z_vals.shape |
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
|
|
if conditional_valid_mask_volume is not None: |
|
pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) |
|
pts_mask = pts_mask.reshape(batch_size, n_samples) |
|
pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] |
|
else: |
|
pts_mask = torch.ones([batch_size, n_samples]).to(pts.device) |
|
|
|
sdf = sdf.reshape(batch_size, n_samples) |
|
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] |
|
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] |
|
mid_sdf = (prev_sdf + next_sdf) * 0.5 |
|
dot_val = None |
|
if self.alpha_type == 'uniform': |
|
dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0 |
|
else: |
|
dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) |
|
prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1) |
|
dot_val = torch.stack([prev_dot_val, dot_val], dim=-1) |
|
dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False) |
|
dot_val = dot_val.clip(-10.0, 0.0) * pts_mask |
|
dist = (next_z_vals - prev_z_vals) |
|
prev_esti_sdf = mid_sdf - dot_val * dist * 0.5 |
|
next_esti_sdf = mid_sdf + dot_val * dist * 0.5 |
|
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance) |
|
next_cdf = torch.sigmoid(next_esti_sdf * inv_variance) |
|
alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) |
|
|
|
alpha = alpha_sdf |
|
|
|
|
|
alpha = pts_mask * alpha |
|
|
|
weights = alpha * torch.cumprod( |
|
torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] |
|
|
|
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() |
|
return z_samples |
|
|
|
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod, |
|
sdf_network, gru_fusion, |
|
|
|
conditional_volume=None, |
|
conditional_valid_mask_volume=None |
|
): |
|
device = rays_o.device |
|
batch_size, n_samples = z_vals.shape |
|
_, n_importance = new_z_vals.shape |
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] |
|
|
|
if conditional_valid_mask_volume is not None: |
|
pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) |
|
pts_mask = pts_mask.reshape(batch_size, n_importance) |
|
pts_mask_bool = (pts_mask > 0).view(-1) |
|
else: |
|
pts_mask = torch.ones([batch_size, n_importance]).to(pts.device) |
|
|
|
new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100 |
|
|
|
if torch.sum(pts_mask) > 1: |
|
new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod) |
|
new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] |
|
|
|
new_sdf = new_sdf.view(batch_size, n_importance) |
|
|
|
z_vals = torch.cat([z_vals, new_z_vals], dim=-1) |
|
sdf = torch.cat([sdf, new_sdf], dim=-1) |
|
|
|
z_vals, index = torch.sort(z_vals, dim=-1) |
|
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) |
|
index = index.reshape(-1) |
|
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) |
|
|
|
return z_vals, sdf |
|
|
|
@torch.no_grad() |
|
def get_pts_mask_for_conditional_volume(self, pts, mask_volume): |
|
""" |
|
|
|
:param pts: [N, 3] |
|
:param mask_volume: [1, 1, X, Y, Z] |
|
:return: |
|
""" |
|
num_pts = pts.shape[0] |
|
pts = pts.view(1, 1, 1, num_pts, 3) |
|
|
|
pts = torch.flip(pts, dims=[-1]) |
|
|
|
pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') |
|
pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() |
|
|
|
return pts_mask |
|
|
|
def render_core(self, |
|
rays_o, |
|
rays_d, |
|
z_vals, |
|
sample_dist, |
|
lod, |
|
sdf_network, |
|
rendering_network, |
|
background_alpha=None, |
|
background_sampled_color=None, |
|
background_rgb=None, |
|
alpha_inter_ratio=0.0, |
|
|
|
conditional_volume=None, |
|
conditional_valid_mask_volume=None, |
|
|
|
feature_maps=None, |
|
color_maps=None, |
|
w2cs=None, |
|
intrinsics=None, |
|
img_wh=None, |
|
query_c2w=None, |
|
if_general_rendering=True, |
|
if_render_with_grad=True, |
|
|
|
img_index=None, |
|
rays_uv=None, |
|
|
|
bg_num=0 |
|
): |
|
device = rays_o.device |
|
N_rays = rays_o.shape[0] |
|
_, n_samples = z_vals.shape |
|
dists = z_vals[..., 1:] - z_vals[..., :-1] |
|
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1) |
|
|
|
mid_z_vals = z_vals + dists * 0.5 |
|
mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1] |
|
|
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
|
dirs = rays_d[:, None, :].expand(pts.shape) |
|
|
|
pts = pts.reshape(-1, 3) |
|
dirs = dirs.reshape(-1, 3) |
|
|
|
|
|
if conditional_valid_mask_volume is not None: |
|
pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume) |
|
pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach() |
|
pts_mask_bool = (pts_mask > 0).view(-1) |
|
|
|
if torch.sum(pts_mask_bool.float()) < 1: |
|
pts_mask_bool[:100] = True |
|
|
|
else: |
|
pts_mask = torch.ones([N_rays, n_samples]).to(pts.device) |
|
|
|
|
|
sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod) |
|
|
|
sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100 |
|
sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] |
|
feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod] |
|
feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device) |
|
feature_vector[pts_mask_bool] = feature_vector_valid |
|
|
|
|
|
gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) |
|
|
|
gradients[pts_mask_bool] = sdf_network.gradient( |
|
pts[pts_mask_bool], conditional_volume, lod=lod).squeeze() |
|
|
|
sampled_color_mlp = None |
|
rendering_valid_mask_mlp = None |
|
sampled_color_patch = None |
|
rendering_patch_mask = None |
|
|
|
if self.if_fitted_rendering: |
|
position_latent = sdf_nn_output['sampled_latent_scale%d' % lod] |
|
sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) |
|
sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device) |
|
|
|
|
|
pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp( |
|
pts[pts_mask_bool][:, None, :], color_maps, intrinsics, |
|
w2cs, img_wh=None) |
|
pts_pixel_color = pts_pixel_color[:, 0, :, :] |
|
pts_pixel_mask = pts_pixel_mask[:, 0, :] |
|
|
|
|
|
if_patch_blending = False if rays_uv is None else True |
|
pts_patch_color, pts_patch_mask = None, None |
|
if if_patch_blending: |
|
pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp( |
|
pts.reshape([N_rays, n_samples, 3]), |
|
rays_uv, gradients.reshape([N_rays, n_samples, 3]), |
|
color_maps, |
|
intrinsics[0], intrinsics, |
|
query_c2w[0], torch.inverse(w2cs), img_wh=None |
|
) |
|
N_src, Npx = pts_patch_mask.shape[2:] |
|
pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool] |
|
pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool] |
|
|
|
sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device) |
|
sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device) |
|
|
|
sampled_color_mlp_, sampled_color_mlp_mask_, \ |
|
sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend( |
|
pts[pts_mask_bool], |
|
position_latent, |
|
gradients[pts_mask_bool], |
|
dirs[pts_mask_bool], |
|
feature_vector[pts_mask_bool], |
|
img_index=img_index, |
|
pts_pixel_color=pts_pixel_color, |
|
pts_pixel_mask=pts_pixel_mask, |
|
pts_patch_color=pts_patch_color, |
|
pts_patch_mask=pts_patch_mask |
|
|
|
) |
|
sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_ |
|
sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float() |
|
sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3) |
|
sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples) |
|
rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5 |
|
|
|
|
|
if if_patch_blending: |
|
sampled_color_patch[pts_mask_bool] = sampled_color_patch_ |
|
sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float() |
|
sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3) |
|
sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples) |
|
rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1, |
|
keepdim=True) > 0.5 |
|
else: |
|
sampled_color_patch, rendering_patch_mask = None, None |
|
|
|
if if_general_rendering: |
|
|
|
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute( |
|
pts.view(N_rays, n_samples, 3), |
|
|
|
geometryVolume=conditional_volume[0], |
|
geometryVolumeMask=conditional_valid_mask_volume[0], |
|
|
|
rendering_feature_maps=feature_maps, |
|
color_maps=color_maps, |
|
w2cs=w2cs, |
|
intrinsics=intrinsics, |
|
img_wh=img_wh, |
|
query_img_idx=0, |
|
query_c2w=query_c2w, |
|
) |
|
|
|
|
|
if if_render_with_grad: |
|
|
|
|
|
sampled_color, rendering_valid_mask = rendering_network( |
|
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) |
|
|
|
else: |
|
with torch.no_grad(): |
|
sampled_color, rendering_valid_mask = rendering_network( |
|
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) |
|
else: |
|
sampled_color, rendering_valid_mask = None, None |
|
|
|
inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6) |
|
|
|
true_dot_val = (dirs * gradients).sum(-1, keepdim=True) |
|
|
|
iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu( |
|
-true_dot_val) * alpha_inter_ratio) |
|
|
|
iter_cos = iter_cos * pts_mask.view(-1, 1) |
|
|
|
true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 |
|
true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 |
|
|
|
prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance) |
|
next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance) |
|
|
|
p = prev_cdf - next_cdf |
|
c = prev_cdf |
|
|
|
if self.alpha_type == 'div': |
|
alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0) |
|
elif self.alpha_type == 'uniform': |
|
uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5 |
|
uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5 |
|
uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance) |
|
uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance) |
|
uniform_alpha = F.relu( |
|
(uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape( |
|
N_rays, n_samples).clip(0.0, 1.0) |
|
alpha_sdf = uniform_alpha |
|
else: |
|
assert False |
|
|
|
alpha = alpha_sdf |
|
|
|
|
|
alpha = alpha * pts_mask |
|
|
|
|
|
|
|
|
|
inside_sphere = pts_mask |
|
relax_inside_sphere = pts_mask |
|
|
|
weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, |
|
:-1] |
|
weights_sum = weights.sum(dim=-1, keepdim=True) |
|
alpha_sum = alpha.sum(dim=-1, keepdim=True) |
|
|
|
if bg_num > 0: |
|
weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True) |
|
else: |
|
weights_sum_fg = weights_sum |
|
|
|
if sampled_color is not None: |
|
color = (sampled_color * weights[:, :, None]).sum(dim=1) |
|
else: |
|
color = None |
|
|
|
|
|
if background_rgb is not None and color is not None: |
|
color = color + background_rgb * (1.0 - weights_sum) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
color_mlp = None |
|
|
|
if sampled_color_mlp is not None: |
|
color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1) |
|
|
|
if background_rgb is not None and color_mlp is not None: |
|
color_mlp = color_mlp + background_rgb * (1.0 - weights_sum) |
|
|
|
|
|
blended_color_patch = None |
|
if sampled_color_patch is not None: |
|
blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) |
|
|
|
|
|
|
|
gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2, |
|
dim=-1) - 1.0) ** 2 |
|
|
|
gradient_error = (pts_mask * gradient_error).sum() / ( |
|
(pts_mask).sum() + 1e-5) |
|
|
|
depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
|
return { |
|
'color': color, |
|
'color_mask': rendering_valid_mask, |
|
'color_mlp': color_mlp, |
|
'color_mlp_mask': rendering_valid_mask_mlp, |
|
'sdf': sdf, |
|
'depth': depth, |
|
'dists': dists, |
|
'gradients': gradients.reshape(N_rays, n_samples, 3), |
|
'variance': 1.0 / inv_variance, |
|
'mid_z_vals': mid_z_vals, |
|
'weights': weights, |
|
'weights_sum': weights_sum, |
|
'alpha_sum': alpha_sum, |
|
'alpha_mean': alpha.mean(), |
|
'cdf': c.reshape(N_rays, n_samples), |
|
'gradient_error': gradient_error, |
|
'inside_sphere': inside_sphere, |
|
'blended_color_patch': blended_color_patch, |
|
'blended_color_patch_mask': rendering_patch_mask, |
|
'weights_sum_fg': weights_sum_fg |
|
} |
|
|
|
def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network, |
|
perturb_overwrite=-1, |
|
background_rgb=None, |
|
alpha_inter_ratio=0.0, |
|
|
|
lod=None, |
|
conditional_volume=None, |
|
conditional_valid_mask_volume=None, |
|
|
|
feature_maps=None, |
|
color_maps=None, |
|
w2cs=None, |
|
intrinsics=None, |
|
img_wh=None, |
|
query_c2w=None, |
|
if_general_rendering=True, |
|
if_render_with_grad=True, |
|
|
|
img_index=None, |
|
rays_uv=None, |
|
|
|
pre_sample=False, |
|
|
|
bg_ratio=0.0 |
|
): |
|
device = rays_o.device |
|
N_rays = len(rays_o) |
|
|
|
sample_dist = ((far - near) / self.n_samples).mean().item() |
|
z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device) |
|
z_vals = near + (far - near) * z_vals[None, :] |
|
|
|
bg_num = int(self.n_samples * bg_ratio) |
|
|
|
if z_vals.shape[0] == 1: |
|
z_vals = z_vals.repeat(N_rays, 1) |
|
|
|
if bg_num > 0: |
|
z_vals_bg = z_vals[:, self.n_samples - bg_num:] |
|
z_vals = z_vals[:, :self.n_samples - bg_num] |
|
|
|
n_samples = self.n_samples - bg_num |
|
perturb = self.perturb |
|
|
|
|
|
if pre_sample: |
|
z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far, |
|
conditional_valid_mask_volume) |
|
|
|
if perturb_overwrite >= 0: |
|
perturb = perturb_overwrite |
|
if perturb > 0: |
|
|
|
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) |
|
upper = torch.cat([mids, z_vals[..., -1:]], -1) |
|
lower = torch.cat([z_vals[..., :1], mids], -1) |
|
|
|
t_rand = torch.rand(z_vals.shape).to(device) |
|
z_vals = lower + (upper - lower) * t_rand |
|
|
|
background_alpha = None |
|
background_sampled_color = None |
|
z_val_before = z_vals.clone() |
|
|
|
if self.n_importance > 0: |
|
with torch.no_grad(): |
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
|
|
sdf_outputs = sdf_network.sdf( |
|
pts.reshape(-1, 3), conditional_volume, lod=lod) |
|
|
|
sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num) |
|
|
|
n_steps = 4 |
|
for i in range(n_steps): |
|
new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps, |
|
64 * 2 ** i, |
|
conditional_valid_mask_volume=conditional_valid_mask_volume, |
|
) |
|
|
|
|
|
|
|
|
|
z_vals, sdf = self.cat_z_vals( |
|
rays_o, rays_d, z_vals, new_z_vals, sdf, lod, |
|
sdf_network, gru_fusion=False, |
|
conditional_volume=conditional_volume, |
|
conditional_valid_mask_volume=conditional_valid_mask_volume, |
|
) |
|
|
|
del sdf |
|
|
|
n_samples = self.n_samples + self.n_importance |
|
|
|
|
|
ret_outside = None |
|
|
|
|
|
if bg_num > 0: |
|
z_vals = torch.cat([z_vals, z_vals_bg], dim=1) |
|
|
|
|
|
ret_fine = self.render_core(rays_o, |
|
rays_d, |
|
z_vals, |
|
sample_dist, |
|
lod, |
|
sdf_network, |
|
rendering_network, |
|
background_rgb=background_rgb, |
|
background_alpha=background_alpha, |
|
background_sampled_color=background_sampled_color, |
|
alpha_inter_ratio=alpha_inter_ratio, |
|
|
|
conditional_volume=conditional_volume, |
|
conditional_valid_mask_volume=conditional_valid_mask_volume, |
|
|
|
feature_maps=feature_maps, |
|
color_maps=color_maps, |
|
w2cs=w2cs, |
|
intrinsics=intrinsics, |
|
img_wh=img_wh, |
|
query_c2w=query_c2w, |
|
if_general_rendering=if_general_rendering, |
|
if_render_with_grad=if_render_with_grad, |
|
|
|
img_index=img_index, |
|
rays_uv=rays_uv |
|
) |
|
|
|
color_fine = ret_fine['color'] |
|
|
|
if self.n_outside > 0: |
|
color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask']) |
|
else: |
|
color_fine_mask = ret_fine['color_mask'] |
|
|
|
weights = ret_fine['weights'] |
|
weights_sum = ret_fine['weights_sum'] |
|
|
|
gradients = ret_fine['gradients'] |
|
mid_z_vals = ret_fine['mid_z_vals'] |
|
|
|
|
|
depth = ret_fine['depth'] |
|
depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True) |
|
variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True) |
|
|
|
|
|
pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 |
|
sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod] |
|
|
|
result = { |
|
'depth': depth, |
|
'color_fine': color_fine, |
|
'color_fine_mask': color_fine_mask, |
|
'color_outside': ret_outside['color'] if ret_outside is not None else None, |
|
'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None, |
|
'color_mlp': ret_fine['color_mlp'], |
|
'color_mlp_mask': ret_fine['color_mlp_mask'], |
|
'variance': variance.mean(), |
|
'cdf_fine': ret_fine['cdf'], |
|
'depth_variance': depth_varaince, |
|
'weights_sum': weights_sum, |
|
'weights_max': torch.max(weights, dim=-1, keepdim=True)[0], |
|
'alpha_sum': ret_fine['alpha_sum'].mean(), |
|
'alpha_mean': ret_fine['alpha_mean'], |
|
'gradients': gradients, |
|
'weights': weights, |
|
'gradient_error_fine': ret_fine['gradient_error'], |
|
'inside_sphere': ret_fine['inside_sphere'], |
|
'sdf': ret_fine['sdf'], |
|
'sdf_random': sdf_random, |
|
'blended_color_patch': ret_fine['blended_color_patch'], |
|
'blended_color_patch_mask': ret_fine['blended_color_patch_mask'], |
|
'weights_sum_fg': ret_fine['weights_sum_fg'] |
|
} |
|
|
|
return result |
|
|
|
@torch.no_grad() |
|
def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume): |
|
|
|
device = rays_o.device |
|
N_rays = len(rays_o) |
|
n_samples = self.n_samples * 2 |
|
|
|
z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) |
|
z_vals = near + (far - near) * z_vals[None, :] |
|
|
|
if z_vals.shape[0] == 1: |
|
z_vals = z_vals.repeat(N_rays, 1) |
|
|
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
|
|
sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples]) |
|
|
|
new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples, |
|
200, |
|
conditional_valid_mask_volume=mask_volume, |
|
) |
|
return new_z_vals |
|
|
|
@torch.no_grad() |
|
def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): |
|
device = rays_o.device |
|
N_rays = len(rays_o) |
|
n_samples = self.n_samples * 2 |
|
|
|
z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) |
|
z_vals = near + (far - near) * z_vals[None, :] |
|
|
|
if z_vals.shape[0] == 1: |
|
z_vals = z_vals.repeat(N_rays, 1) |
|
|
|
mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5 |
|
|
|
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
|
|
|
pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape( |
|
[N_rays, n_samples - 1]) |
|
|
|
|
|
weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device), |
|
0.1 * torch.ones_like(pts_mask).to(device)) |
|
|
|
|
|
z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach() |
|
return z_samples |
|
|
|
@torch.no_grad() |
|
def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices, |
|
partial_vol_origin, voxel_size, |
|
near, far, depth_interval, d_plane_nums): |
|
""" |
|
Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless) |
|
:param coords: [n, 3] int coords |
|
:param pred_depth_maps: [N_views, 1, h, w] |
|
:param proj_matrices: [N_views, 4, 4] |
|
:param partial_vol_origin: [3] |
|
:param voxel_size: 1 |
|
:param near: 1 |
|
:param far: 1 |
|
:param depth_interval: 1 |
|
:param d_plane_nums: 1 |
|
:return: |
|
""" |
|
device = pred_depth_maps.device |
|
n_views, _, sizeH, sizeW = pred_depth_maps.shape |
|
|
|
if len(partial_vol_origin.shape) == 1: |
|
partial_vol_origin = partial_vol_origin[None, :] |
|
pts = coords * voxel_size + partial_vol_origin |
|
|
|
rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1) |
|
rs_grid = rs_grid.permute(0, 2, 1).contiguous() |
|
nV = rs_grid.shape[-1] |
|
rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) |
|
|
|
|
|
im_p = proj_matrices @ rs_grid |
|
im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] |
|
im_x = im_x / im_z |
|
im_y = im_y / im_z |
|
|
|
im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) |
|
|
|
im_grid = im_grid.view(n_views, 1, -1, 2) |
|
sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear', |
|
padding_mode='zeros', |
|
align_corners=True)[:, 0, 0, :] |
|
sampled_depths_valid = (sampled_depths > 0.5 * near).float() |
|
valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(), |
|
far.item()) * sampled_depths_valid |
|
valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(), |
|
far.item()) * sampled_depths_valid |
|
|
|
mask = im_grid.abs() <= 1 |
|
mask = mask[:, 0] |
|
mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max) |
|
|
|
mask = mask.view(n_views, -1) |
|
mask = mask.permute(1, 0).contiguous() |
|
|
|
mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0 |
|
|
|
return mask_final |
|
|
|
@torch.no_grad() |
|
def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume, |
|
pred_depth_maps, proj_matrices, |
|
partial_vol_origin, voxel_size, |
|
near, far, depth_interval, d_plane_nums, |
|
threshold=0.02, maximum_pts=110000): |
|
""" |
|
assume batch size == 1, from the first lod to get sparse voxels |
|
:param sdf_volume: [1, X, Y, Z] |
|
:param coords_volume: [3, X, Y, Z] |
|
:param mask_volume: [1, X, Y, Z] |
|
:param feature_volume: [C, X, Y, Z] |
|
:param threshold: |
|
:return: |
|
""" |
|
device = coords_volume.device |
|
_, dX, dY, dZ = coords_volume.shape |
|
|
|
def prune(sdf_pts, coords_pts, mask_volume, threshold): |
|
occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) |
|
valid_coords = coords_pts[occupancy_mask] |
|
|
|
|
|
mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices, |
|
partial_vol_origin, voxel_size, |
|
near, far, depth_interval, d_plane_nums) |
|
valid_coords = valid_coords[mask_filtered] |
|
|
|
|
|
occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) |
|
|
|
|
|
occupancy_mask = occupancy_mask.float() |
|
occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
|
occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
|
occupancy_mask = occupancy_mask.view(-1, 1) > 0 |
|
|
|
final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] |
|
|
|
return final_mask, torch.sum(final_mask.float()) |
|
|
|
C, dX, dY, dZ = feature_volume.shape |
|
sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
|
coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) |
|
mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
|
feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) |
|
|
|
|
|
|
|
|
|
final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) |
|
|
|
while (valid_num > maximum_pts) and (threshold > 0.003): |
|
threshold = threshold - 0.002 |
|
final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) |
|
|
|
valid_coords = coords_volume[final_mask] |
|
valid_feature = feature_volume[final_mask] |
|
|
|
valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, |
|
valid_coords], dim=1) |
|
|
|
|
|
if valid_num > maximum_pts: |
|
valid_num = valid_num.long() |
|
occupancy = torch.ones([valid_num]).to(device) > 0 |
|
choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, |
|
replace=False) |
|
ind = torch.nonzero(occupancy).to(device) |
|
occupancy[ind[choice]] = False |
|
valid_coords = valid_coords[occupancy] |
|
valid_feature = valid_feature[occupancy] |
|
|
|
print(threshold, "randomly sample to save memory") |
|
|
|
return valid_coords, valid_feature |
|
|
|
@torch.no_grad() |
|
def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02, |
|
maximum_pts=110000): |
|
""" |
|
assume batch size == 1, from the first lod to get sparse voxels |
|
:param sdf_volume: [num_pts, 1] |
|
:param coords_volume: [3, X, Y, Z] |
|
:param mask_volume: [1, X, Y, Z] |
|
:param feature_volume: [C, X, Y, Z] |
|
:param threshold: |
|
:return: |
|
""" |
|
|
|
def prune(sdf_volume, mask_volume, threshold): |
|
occupancy_mask = torch.abs(sdf_volume) < threshold |
|
|
|
|
|
occupancy_mask = occupancy_mask.float() |
|
occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
|
occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
|
occupancy_mask = occupancy_mask.view(-1, 1) > 0 |
|
|
|
final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] |
|
|
|
return final_mask, torch.sum(final_mask.float()) |
|
|
|
C, dX, dY, dZ = feature_volume.shape |
|
coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) |
|
mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
|
feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) |
|
|
|
final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) |
|
|
|
while (valid_num > maximum_pts) and (threshold > 0.003): |
|
threshold = threshold - 0.002 |
|
final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) |
|
|
|
valid_coords = coords_volume[final_mask] |
|
valid_feature = feature_volume[final_mask] |
|
|
|
valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, |
|
valid_coords], dim=1) |
|
|
|
|
|
if valid_num > maximum_pts: |
|
device = sdf_volume.device |
|
valid_num = valid_num.long() |
|
occupancy = torch.ones([valid_num]).to(device) > 0 |
|
choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, |
|
replace=False) |
|
ind = torch.nonzero(occupancy).to(device) |
|
occupancy[ind[choice]] = False |
|
valid_coords = valid_coords[occupancy] |
|
valid_feature = valid_feature[occupancy] |
|
|
|
print(threshold, "randomly sample to save memory") |
|
|
|
return valid_coords, valid_feature |
|
|
|
@torch.no_grad() |
|
def extract_fields(self, bound_min, bound_max, resolution, query_func, device, |
|
|
|
**kwargs |
|
): |
|
N = 64 |
|
X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N) |
|
Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N) |
|
Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N) |
|
|
|
u = np.zeros([resolution, resolution, resolution], dtype=np.float32) |
|
with torch.no_grad(): |
|
for xi, xs in enumerate(X): |
|
for yi, ys in enumerate(Y): |
|
for zi, zs in enumerate(Z): |
|
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij") |
|
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) |
|
|
|
|
|
output = query_func(pts, **kwargs) |
|
sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys), |
|
len(zs)).detach().cpu().numpy() |
|
|
|
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf |
|
return u |
|
|
|
@torch.no_grad() |
|
def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None, |
|
|
|
**kwargs |
|
): |
|
|
|
|
|
u = self.extract_fields(bound_min, bound_max, resolution, |
|
lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs), |
|
|
|
device, |
|
|
|
**kwargs |
|
) |
|
if occupancy_mask is not None: |
|
dX, dY, dZ = occupancy_mask.shape |
|
empty_mask = 1 - occupancy_mask |
|
empty_mask = empty_mask.view(1, 1, dX, dY, dZ) |
|
|
|
|
|
empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest') |
|
empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0 |
|
u[empty_mask] = -100 |
|
del empty_mask |
|
|
|
vertices, triangles = mcubes.marching_cubes(u, threshold) |
|
b_max_np = bound_max.detach().cpu().numpy() |
|
b_min_np = bound_min.detach().cpu().numpy() |
|
|
|
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] |
|
return vertices, triangles, u |
|
|
|
@torch.no_grad() |
|
def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far): |
|
""" |
|
extract depth maps from the density volume |
|
:param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume |
|
:param c2ws: [B, 4, 4] |
|
:param H: |
|
:param W: |
|
:param near: |
|
:param far: |
|
:return: |
|
""" |
|
device = con_volume.device |
|
batch_size = intrinsics.shape[0] |
|
|
|
with torch.no_grad(): |
|
ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), |
|
torch.linspace(0, W - 1, W), indexing="ij") |
|
p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) |
|
|
|
intrinsics_inv = torch.inverse(intrinsics) |
|
|
|
p = p.view(-1, 3).float().to(device) |
|
p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() |
|
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) |
|
rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() |
|
rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) |
|
rays_d = rays_v |
|
|
|
rays_o = rays_o.contiguous().view(-1, 3) |
|
rays_d = rays_d.contiguous().view(-1, 3) |
|
|
|
|
|
depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps( |
|
rays_o, rays_d, |
|
near[None, :].repeat(rays_o.shape[0], 1), |
|
far[None, :].repeat(rays_o.shape[0], 1), |
|
sdf_network, con_volume |
|
) |
|
|
|
depth_maps = depth_maps_sphere.view(batch_size, 1, H, W) |
|
depth_masks = depth_masks_sphere.view(batch_size, 1, H, W) |
|
|
|
depth_maps = torch.where(depth_masks, depth_maps, |
|
torch.zeros_like(depth_masks.float()).to(device)) |
|
|
|
return depth_maps, depth_masks |
|
|