Spaces:
Paused
Paused
import os | |
import time | |
import functools | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_scatter import segment_coo | |
from . import grid | |
from .dvgo import Raw2Alpha, Alphas2Weights | |
from .dmpigo import create_full_step_id | |
from torch.utils.cpp_extension import load | |
parent_dir = os.path.dirname(os.path.abspath(__file__)) | |
ub360_utils_cuda = load( | |
name='ub360_utils_cuda', | |
sources=[ | |
os.path.join(parent_dir, path) | |
for path in ['cuda/ub360_utils.cpp', 'cuda/ub360_utils_kernel.cu']], | |
verbose=True) | |
#TODO ORIGINAL bg_len=0.2 | |
'''Model''' | |
class DirectContractedVoxGO(nn.Module): | |
def __init__(self, xyz_min, xyz_max, | |
num_voxels=0, num_voxels_base=0, num_objects = 1, | |
alpha_init=None, | |
mask_cache_world_size=None, | |
fast_color_thres=0, bg_len=0.2, | |
contracted_norm='inf', | |
density_type='DenseGrid', k0_type='DenseGrid', | |
density_config={}, k0_config={}, | |
rgbnet_dim=0, | |
rgbnet_depth=3, rgbnet_width=128, | |
viewbase_pe=4, | |
**kwargs): | |
super(DirectContractedVoxGO, self).__init__() | |
# xyz_min/max are the boundary that separates fg and bg scene | |
xyz_min = torch.Tensor(xyz_min) | |
xyz_max = torch.Tensor(xyz_max) | |
assert len(((xyz_max - xyz_min) * 100000).long().unique()), 'scene bbox must be a cube in DirectContractedVoxGO' | |
self.register_buffer('scene_center', (xyz_min + xyz_max) * 0.5) | |
self.register_buffer('scene_radius', (xyz_max - xyz_min) * 0.5) | |
self.register_buffer('xyz_min', torch.Tensor([-1,-1,-1]) - bg_len) | |
self.register_buffer('xyz_max', torch.Tensor([1,1,1]) + bg_len) | |
if isinstance(fast_color_thres, dict): | |
self._fast_color_thres = fast_color_thres | |
self.fast_color_thres = fast_color_thres[0] | |
else: | |
self._fast_color_thres = None | |
self.fast_color_thres = fast_color_thres | |
self.bg_len = bg_len | |
self.contracted_norm = contracted_norm | |
# determine based grid resolution | |
self.num_voxels_base = num_voxels_base | |
self.voxel_size_base = ((self.xyz_max - self.xyz_min).prod() / self.num_voxels_base).pow(1/3) | |
# determine init grid resolution | |
self._set_grid_resolution(num_voxels) | |
# determine the density bias shift | |
self.alpha_init = alpha_init | |
self.register_buffer('act_shift', torch.FloatTensor([np.log(1/(1-alpha_init) - 1)])) | |
print('dcvgo: set density bias shift to', self.act_shift) | |
# init density voxel grid | |
self.density_type = density_type | |
self.density_config = density_config | |
self.density = grid.create_grid( | |
density_type, channels=1, world_size=self.world_size, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max, | |
config=self.density_config) | |
self.mode = 'coarse' | |
self.num_objects = num_objects | |
self.seg_mask_grid = grid.create_grid( | |
density_type, channels=self.num_objects, world_size=self.world_size, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max, | |
config=self.density_config) | |
self.mask_view_counts = torch.zeros_like(self.seg_mask_grid.grid, requires_grad=False) | |
self.dual_seg_mask_grid = grid.create_grid( | |
density_type, channels=self.num_objects, world_size=self.world_size, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max, | |
config=self.density_config) | |
# init color representation | |
self.rgbnet_kwargs = { | |
'rgbnet_dim': rgbnet_dim, | |
'rgbnet_depth': rgbnet_depth, 'rgbnet_width': rgbnet_width, | |
'viewbase_pe': viewbase_pe, | |
} | |
self.k0_type = k0_type | |
self.k0_config = k0_config | |
if rgbnet_dim <= 0: | |
# color voxel grid (coarse stage) | |
self.k0_dim = 3 | |
self.k0 = grid.create_grid( | |
k0_type, channels=self.k0_dim, world_size=self.world_size, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max, | |
config=self.k0_config) | |
self.rgbnet = None | |
else: | |
# feature voxel grid + shallow MLP (fine stage) | |
self.k0_dim = rgbnet_dim | |
self.k0 = grid.create_grid( | |
k0_type, channels=self.k0_dim, world_size=self.world_size, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max, | |
config=self.k0_config) | |
self.register_buffer('viewfreq', torch.FloatTensor([(2**i) for i in range(viewbase_pe)])) | |
dim0 = (3+3*viewbase_pe*2) | |
dim0 += self.k0_dim | |
self.rgbnet = nn.Sequential( | |
nn.Linear(dim0, rgbnet_width), nn.ReLU(inplace=True), | |
*[ | |
nn.Sequential(nn.Linear(rgbnet_width, rgbnet_width), nn.ReLU(inplace=True)) | |
for _ in range(rgbnet_depth-2) | |
], | |
nn.Linear(rgbnet_width, 3), | |
) | |
nn.init.constant_(self.rgbnet[-1].bias, 0) | |
print('dcvgo: feature voxel grid', self.k0) | |
print('dcvgo: mlp', self.rgbnet) | |
# Using the coarse geometry if provided (used to determine known free space and unknown space) | |
# Re-implement as occupancy grid (2021/1/31) | |
if mask_cache_world_size is None: | |
mask_cache_world_size = self.world_size | |
mask = torch.ones(list(mask_cache_world_size), dtype=torch.bool) | |
self.mask_cache = grid.MaskGrid( | |
path=None, mask=mask, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max) | |
def _set_grid_resolution(self, num_voxels): | |
# Determine grid resolution | |
self.num_voxels = num_voxels | |
self.voxel_size = ((self.xyz_max - self.xyz_min).prod() / num_voxels).pow(1/3) | |
self.world_size = ((self.xyz_max - self.xyz_min) / self.voxel_size).long() | |
self.world_len = self.world_size[0].item() | |
self.voxel_size_ratio = self.voxel_size / self.voxel_size_base | |
print('dcvgo: voxel_size ', self.voxel_size) | |
print('dcvgo: world_size ', self.world_size) | |
print('dcvgo: voxel_size_base ', self.voxel_size_base) | |
print('dcvgo: voxel_size_ratio', self.voxel_size_ratio) | |
def get_kwargs(self): | |
return { | |
'xyz_min': self.xyz_min.cpu().numpy(), | |
'xyz_max': self.xyz_max.cpu().numpy(), | |
'num_voxels': self.num_voxels, | |
'num_voxels_base': self.num_voxels_base, | |
'alpha_init': self.alpha_init, | |
'voxel_size_ratio': self.voxel_size_ratio, | |
'mask_cache_world_size': list(self.mask_cache.mask.shape), | |
'fast_color_thres': self.fast_color_thres, | |
'contracted_norm': self.contracted_norm, | |
'density_type': self.density_type, | |
'k0_type': self.k0_type, | |
'density_config': self.density_config, | |
'k0_config': self.k0_config, | |
**self.rgbnet_kwargs, | |
} | |
def change_num_objects(self, num_obj): | |
self.num_objects = num_obj | |
device = self.seg_mask_grid.grid.device | |
self.seg_mask_grid = grid.create_grid( | |
'DenseGrid', channels=self.num_objects, world_size=self.world_size, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max, | |
config=self.density_config) | |
self.dual_seg_mask_grid = grid.create_grid( | |
'DenseGrid', channels=self.num_objects, world_size=self.world_size, | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max, | |
config=self.density_config) | |
self.seg_mask_grid.to(device) | |
self.dual_seg_mask_grid.to(device) | |
print("Reset the seg_mask_grid with num_objects =", num_obj) | |
def segmentation_to_density(self): | |
assert self.seg_mask_grid.grid.shape[1] == 1 and "multi-object seg label cannot be applied directly to the density grid" | |
mask_grid = torch.zeros_like(self.seg_mask_grid.grid) | |
mask_grid[self.seg_mask_grid.grid > 0] = 1 | |
self.density.grid *= mask_grid | |
self.density.grid[self.density.grid == 0] = -1e7 | |
def segmentation_only(self): | |
assert self.seg_mask_grid.grid.shape[1] == 1 and "multi-object seg label cannot be applied directly to the density grid" | |
pass | |
def change_to_fine_mode(self): | |
self.mode = 'fine' | |
def scale_volume_grid(self, num_voxels): | |
print('dcvgo: scale_volume_grid start') | |
ori_world_size = self.world_size | |
self._set_grid_resolution(num_voxels) | |
print('dcvgo: scale_volume_grid scale world_size from', ori_world_size.tolist(), 'to', self.world_size.tolist()) | |
self.density.scale_volume_grid(self.world_size) | |
self.seg_mask_grid.scale_volume_grid(self.world_size) | |
self.dual_seg_mask_grid.scale_volume_grid(self.world_size) | |
self.k0.scale_volume_grid(self.world_size) | |
if np.prod(self.world_size.tolist()) <= 256**3: | |
self_grid_xyz = torch.stack(torch.meshgrid( | |
torch.linspace(self.xyz_min[0], self.xyz_max[0], self.world_size[0]), | |
torch.linspace(self.xyz_min[1], self.xyz_max[1], self.world_size[1]), | |
torch.linspace(self.xyz_min[2], self.xyz_max[2], self.world_size[2]), | |
), -1) | |
self_alpha = F.max_pool3d(self.activate_density(self.density.get_dense_grid()), kernel_size=3, padding=1, stride=1)[0,0] | |
self.mask_cache = grid.MaskGrid( | |
path=None, mask=self.mask_cache(self_grid_xyz) & (self_alpha>self.fast_color_thres), | |
xyz_min=self.xyz_min, xyz_max=self.xyz_max) | |
print('dcvgo: scale_volume_grid finish') | |
def update_occupancy_cache(self): | |
ori_p = self.mask_cache.mask.float().mean().item() | |
cache_grid_xyz = torch.stack(torch.meshgrid( | |
torch.linspace(self.xyz_min[0], self.xyz_max[0], self.mask_cache.mask.shape[0]), | |
torch.linspace(self.xyz_min[1], self.xyz_max[1], self.mask_cache.mask.shape[1]), | |
torch.linspace(self.xyz_min[2], self.xyz_max[2], self.mask_cache.mask.shape[2]), | |
), -1) | |
cache_grid_density = self.density(cache_grid_xyz)[None,None] | |
cache_grid_alpha = self.activate_density(cache_grid_density) | |
cache_grid_alpha = F.max_pool3d(cache_grid_alpha, kernel_size=3, padding=1, stride=1)[0,0] | |
self.mask_cache.mask &= (cache_grid_alpha > self.fast_color_thres) | |
new_p = self.mask_cache.mask.float().mean().item() | |
print(f'dcvgo: update mask_cache {ori_p:.4f} => {new_p:.4f}') | |
def update_occupancy_cache_lt_nviews(self, rays_o_tr, rays_d_tr, imsz, render_kwargs, maskout_lt_nviews): | |
print('dcvgo: update mask_cache lt_nviews start') | |
eps_time = time.time() | |
count = torch.zeros_like(self.density.get_dense_grid()).long() | |
device = count.device | |
for rays_o_, rays_d_ in zip(rays_o_tr.split(imsz), rays_d_tr.split(imsz)): | |
ones = grid.DenseGrid(1, self.world_size, self.xyz_min, self.xyz_max) | |
for rays_o, rays_d in zip(rays_o_.split(8192), rays_d_.split(8192)): | |
ray_pts, inner_mask, t = self.sample_ray( | |
ori_rays_o=rays_o.to(device), ori_rays_d=rays_d.to(device), | |
**render_kwargs) | |
ones(ray_pts).sum().backward() | |
count.data += (ones.grid.grad > 1) | |
ori_p = self.mask_cache.mask.float().mean().item() | |
self.mask_cache.mask &= (count >= maskout_lt_nviews)[0,0] | |
new_p = self.mask_cache.mask.float().mean().item() | |
print(f'dcvgo: update mask_cache {ori_p:.4f} => {new_p:.4f}') | |
eps_time = time.time() - eps_time | |
print(f'dcvgo: update mask_cache lt_nviews finish (eps time:', eps_time, 'sec)') | |
def density_total_variation_add_grad(self, weight, dense_mode): | |
w = weight * self.world_size.max() / 128 | |
self.density.total_variation_add_grad(w, w, w, dense_mode) | |
def k0_total_variation_add_grad(self, weight, dense_mode): | |
w = weight * self.world_size.max() / 128 | |
self.k0.total_variation_add_grad(w, w, w, dense_mode) | |
def activate_density(self, density, interval=None): | |
interval = interval if interval is not None else self.voxel_size_ratio | |
shape = density.shape | |
return Raw2Alpha.apply(density.flatten(), self.act_shift, interval).reshape(shape) | |
def sample_ray(self, ori_rays_o, ori_rays_d, stepsize, is_train=False, **render_kwargs): | |
'''Sample query points on rays. | |
All the output points are sorted from near to far. | |
Input: | |
rays_o, rayd_d: both in [N, 3] indicating ray configurations. | |
stepsize: the number of voxels of each sample step. | |
Output: | |
ray_pts: [M, 3] storing all the sampled points. | |
ray_id: [M] the index of the ray of each point. | |
step_id: [M] the i'th step on a ray of each point. | |
''' | |
rays_o = (ori_rays_o - self.scene_center) / self.scene_radius | |
rays_d = ori_rays_d / ori_rays_d.norm(dim=-1, keepdim=True) | |
N_inner = int(2 / (2+2*self.bg_len) * self.world_len / stepsize) + 1 | |
N_outer = N_inner | |
b_inner = torch.linspace(0, 2, N_inner+1) | |
b_outer = 2 / torch.linspace(1, 1/128, N_outer+1) | |
t = torch.cat([ | |
(b_inner[1:] + b_inner[:-1]) * 0.5, | |
(b_outer[1:] + b_outer[:-1]) * 0.5, | |
]) | |
ray_pts = rays_o[:,None,:] + rays_d[:,None,:] * t[None,:,None] | |
if self.contracted_norm == 'inf': | |
norm = ray_pts.abs().amax(dim=-1, keepdim=True) | |
elif self.contracted_norm == 'l2': | |
norm = ray_pts.norm(dim=-1, keepdim=True) | |
else: | |
raise NotImplementedError | |
inner_mask = (norm<=1) | |
ray_pts = torch.where( | |
inner_mask, | |
ray_pts, | |
ray_pts / norm * ((1+self.bg_len) - self.bg_len/norm) | |
) | |
return ray_pts, inner_mask.squeeze(-1), t | |
def forward(self, rays_o, rays_d, viewdirs, global_step=None, is_train=False, render_fct=0.0, **render_kwargs): | |
'''Volume rendering | |
@rays_o: [N, 3] the starting point of the N shooting rays. | |
@rays_d: [N, 3] the shooting direction of the N rays. | |
@viewdirs: [N, 3] viewing direction to compute positional embedding for MLP. | |
''' | |
assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format' | |
if isinstance(self._fast_color_thres, dict) and global_step in self._fast_color_thres: | |
print(f'dcvgo: update fast_color_thres {self.fast_color_thres} => {self._fast_color_thres[global_step]}') | |
self.fast_color_thres = self._fast_color_thres[global_step] | |
ret_dict = {} | |
N = len(rays_o) | |
# sample points on rays | |
ray_pts, inner_mask, t = self.sample_ray( | |
ori_rays_o=rays_o, ori_rays_d=rays_d, is_train=global_step is not None, **render_kwargs) | |
n_max = len(t) | |
interval = render_kwargs['stepsize'] * self.voxel_size_ratio | |
ray_id, step_id = create_full_step_id(ray_pts.shape[:2]) | |
# cumsum ray_pts to get distance from ray_o to any ray_pt in a ray | |
ray_distance = torch.zeros_like(ray_pts) | |
ray_distance[:, 1:] = torch.abs(ray_pts[:, 1:] - ray_pts[:, :-1]) | |
ray_distance = torch.cumsum(ray_distance, dim=1) | |
# skip oversampled points outside scene bbox | |
mask = inner_mask.clone() | |
dist_thres = (2+2*self.bg_len) / self.world_len * render_kwargs['stepsize'] * 0.95 | |
dist = (ray_pts[:,1:] - ray_pts[:,:-1]).norm(dim=-1) | |
mask[:, 1:] |= ub360_utils_cuda.cumdist_thres(dist, dist_thres) | |
ray_pts = ray_pts[mask] | |
ray_distance = ray_distance[mask] | |
inner_mask = inner_mask[mask] | |
t = t[None].repeat(N,1)[mask] | |
ray_id = ray_id[mask.flatten()] | |
step_id = step_id[mask.flatten()] | |
# skip known free space | |
mask = self.mask_cache(ray_pts) | |
ray_pts = ray_pts[mask] | |
ray_distance = ray_distance[mask] | |
inner_mask = inner_mask[mask] | |
t = t[mask] | |
ray_id = ray_id[mask] | |
step_id = step_id[mask] | |
# print(self.fast_color_thres, "self.fast_color_thres") | |
render_fct = max(render_fct, self.fast_color_thres) | |
# query for alpha w/ post-activation | |
density = self.density(ray_pts) | |
alpha = self.activate_density(density, interval) | |
if render_fct > 0: | |
mask = (alpha > render_fct) | |
ray_pts = ray_pts[mask] | |
ray_distance = ray_distance[mask] | |
inner_mask = inner_mask[mask] | |
t = t[mask] | |
ray_id = ray_id[mask] | |
step_id = step_id[mask] | |
density = density[mask] | |
alpha = alpha[mask] | |
# compute accumulated transmittance | |
weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N) | |
if render_fct > 0: | |
mask = (weights > render_fct) | |
ray_pts = ray_pts[mask] | |
ray_distance = ray_distance[mask] | |
inner_mask = inner_mask[mask] | |
t = t[mask] | |
ray_id = ray_id[mask] | |
step_id = step_id[mask] | |
density = density[mask] | |
alpha = alpha[mask] | |
weights = weights[mask] | |
# query for segmentation mask | |
# only optimize the mask volume | |
if self.seg_mask_grid.grid.requires_grad: | |
with torch.enable_grad(): | |
mask_pred = self.seg_mask_grid(ray_pts) | |
if self.mode == 'fine': | |
dual_mask_pred = self.dual_seg_mask_grid(ray_pts) | |
else: | |
mask_pred = self.seg_mask_grid(ray_pts) | |
if self.mode == 'fine': | |
dual_mask_pred = self.dual_seg_mask_grid(ray_pts) | |
# query for color | |
k0 = self.k0(ray_pts) | |
if self.rgbnet is None: | |
# no view-depend effect | |
rgb = torch.sigmoid(k0) | |
else: | |
# view-dependent color emission | |
viewdirs_emb = (viewdirs.unsqueeze(-1) * self.viewfreq).flatten(-2) | |
viewdirs_emb = torch.cat([viewdirs, viewdirs_emb.sin(), viewdirs_emb.cos()], -1) | |
viewdirs_emb = viewdirs_emb.flatten(0,-2)[ray_id] | |
rgb_feat = torch.cat([k0, viewdirs_emb], -1) | |
rgb_logit = self.rgbnet(rgb_feat) | |
rgb = torch.sigmoid(rgb_logit) | |
# Ray marching | |
rgb_marched = segment_coo( | |
src=(weights.unsqueeze(-1) * rgb), | |
index=ray_id, | |
out=torch.zeros([N, 3]), | |
reduce='sum') | |
dual_seg_mask_marched = None | |
if self.num_objects == 1: | |
if self.seg_mask_grid.grid.requires_grad: | |
with torch.enable_grad(): | |
seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1).detach().clone() * mask_pred.unsqueeze(-1)), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
if self.mode == 'fine': | |
dual_seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1).detach().clone() * dual_mask_pred.unsqueeze(-1)), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
else: | |
seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1) * mask_pred.unsqueeze(-1)), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
if self.mode == 'fine': | |
dual_seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1) * dual_mask_pred.unsqueeze(-1)), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
else: | |
if self.seg_mask_grid.grid.requires_grad: | |
with torch.enable_grad(): | |
seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1).detach().clone() * mask_pred), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
if self.mode == 'fine': | |
dual_seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1).detach().clone() * dual_mask_pred.unsqueeze(-1)), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
else: | |
seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1) * mask_pred), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
if self.mode == 'fine': | |
dual_seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1) * dual_mask_pred.unsqueeze(-1)), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
if render_kwargs.get('rand_bkgd', False) and is_train: | |
rgb_marched += (alphainv_last.unsqueeze(-1) * torch.rand_like(rgb_marched)) | |
else: | |
rgb_marched += (alphainv_last.unsqueeze(-1) * render_kwargs['bg']) | |
wsum_mid = segment_coo( | |
src=weights[inner_mask], | |
index=ray_id[inner_mask], | |
out=torch.zeros([N]), | |
reduce='sum') | |
s = 1 - 1/(1+t) # [0, inf] => [0, 1] | |
ray_distance = ray_distance.norm(dim=-1) | |
ret_dict.update({ | |
'alphainv_last': alphainv_last, | |
'weights': weights, | |
'wsum_mid': wsum_mid, | |
'rgb_marched': rgb_marched, | |
'raw_density': density, | |
'raw_alpha': alpha, | |
'raw_rgb': rgb, | |
'ray_id': ray_id, | |
'step_id': step_id, | |
'n_max': n_max, | |
't': t, | |
's': s, | |
'seg_mask_marched': seg_mask_marched, | |
'dual_seg_mask_marched': dual_seg_mask_marched, | |
'ray_distance': ray_distance | |
}) | |
if render_kwargs.get('render_depth', False): | |
with torch.no_grad(): | |
depth = segment_coo( | |
src=(weights * s), | |
index=ray_id, | |
out=torch.zeros([N]), | |
reduce='sum') | |
distance = segment_coo( | |
src=(weights * ray_distance), | |
index=ray_id, | |
out=torch.zeros([N]), | |
reduce='sum') | |
ret_dict.update({'depth': depth}) | |
ret_dict.update({'distance': distance}) | |
return ret_dict | |
def forward_mask(self, rays_o, rays_d, render_fct=0.0,**render_kwargs): | |
'''Volume rendering | |
@rays_o: [N, 3] the starting point of the N shooting rays. | |
@rays_d: [N, 3] the shooting direction of the N rays. | |
''' | |
assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format' | |
# if isinstance(self._fast_color_thres, dict) and global_step in self._fast_color_thres: | |
# print(f'dcvgo: update fast_color_thres {self.fast_color_thres} => {self._fast_color_thres[global_step]}') | |
# self.fast_color_thres = self._fast_color_thres[global_step] | |
ret_dict = {} | |
N = len(rays_o) | |
# sample points on rays | |
ray_pts, inner_mask, t = self.sample_ray( | |
ori_rays_o=rays_o, ori_rays_d=rays_d, is_train=False, **render_kwargs) | |
n_max = len(t) | |
interval = render_kwargs['stepsize'] * self.voxel_size_ratio | |
ray_id, step_id = create_full_step_id(ray_pts.shape[:2]) | |
# skip oversampled points outside scene bbox | |
mask = inner_mask.clone() | |
dist_thres = (2+2*self.bg_len) / self.world_len * render_kwargs['stepsize'] * 0.95 | |
dist = (ray_pts[:,1:] - ray_pts[:,:-1]).norm(dim=-1) | |
mask[:, 1:] |= ub360_utils_cuda.cumdist_thres(dist, dist_thres) | |
ray_pts = ray_pts[mask] | |
inner_mask = inner_mask[mask] | |
t = t[None].repeat(N,1)[mask] | |
ray_id = ray_id[mask.flatten()] | |
step_id = step_id[mask.flatten()] | |
# skip known free space | |
mask = self.mask_cache(ray_pts) | |
ray_pts = ray_pts[mask] | |
inner_mask = inner_mask[mask] | |
t = t[mask] | |
ray_id = ray_id[mask] | |
step_id = step_id[mask] | |
render_fct = max(render_fct, self.fast_color_thres) | |
# query for alpha w/ post-activation | |
density = self.density(ray_pts) | |
alpha = self.activate_density(density, interval) | |
if render_fct > 0: | |
mask = (alpha > render_fct) | |
ray_pts = ray_pts[mask] | |
inner_mask = inner_mask[mask] | |
t = t[mask] | |
ray_id = ray_id[mask] | |
step_id = step_id[mask] | |
density = density[mask] | |
alpha = alpha[mask] | |
# compute accumulated transmittance | |
weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N) | |
if render_fct > 0: | |
mask = (weights > render_fct) | |
ray_pts = ray_pts[mask] | |
inner_mask = inner_mask[mask] | |
t = t[mask] | |
ray_id = ray_id[mask] | |
step_id = step_id[mask] | |
density = density[mask] | |
alpha = alpha[mask] | |
weights = weights[mask] | |
# query for segmentation mask | |
# only optimize the mask volume | |
if self.seg_mask_grid.grid.requires_grad: | |
with torch.enable_grad(): | |
mask_pred = self.seg_mask_grid(ray_pts) | |
else: | |
mask_pred = self.seg_mask_grid(ray_pts) | |
if self.seg_mask_grid.grid.requires_grad: | |
with torch.enable_grad(): | |
seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1) * mask_pred), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
else: | |
seg_mask_marched = segment_coo( | |
src=(weights.unsqueeze(-1) * mask_pred), | |
index=ray_id, | |
out=torch.zeros([N, self.num_objects]), | |
reduce='sum') | |
ret_dict.update({ | |
'seg_mask_marched': seg_mask_marched, | |
}) | |
return ret_dict | |
class DistortionLoss(torch.autograd.Function): | |
def forward(ctx, w, s, n_max, ray_id): | |
n_rays = ray_id.max()+1 | |
interval = 1/n_max | |
w_prefix, w_total, ws_prefix, ws_total = ub360_utils_cuda.segment_cumsum(w, s, ray_id) | |
loss_uni = (1/3) * interval * w.pow(2) | |
loss_bi = 2 * w * (s * w_prefix - ws_prefix) | |
ctx.save_for_backward(w, s, w_prefix, w_total, ws_prefix, ws_total, ray_id) | |
ctx.interval = interval | |
return (loss_bi.sum() + loss_uni.sum()) / n_rays | |
def backward(ctx, grad_back): | |
w, s, w_prefix, w_total, ws_prefix, ws_total, ray_id = ctx.saved_tensors | |
interval = ctx.interval | |
grad_uni = (1/3) * interval * 2 * w | |
w_suffix = w_total[ray_id] - (w_prefix + w) | |
ws_suffix = ws_total[ray_id] - (ws_prefix + w*s) | |
grad_bi = 2 * (s * (w_prefix - w_suffix) + (ws_suffix - ws_prefix)) | |
grad = grad_back * (grad_bi + grad_uni) | |
return grad, None, None, None | |
distortion_loss = DistortionLoss.apply | |