code / SparseNeuS_demo_v1 /models /sparse_sdf_network.py
Chao Xu
sparseneus and elev est
854f0d0
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsparse.tensor import PointTensor, SparseTensor
import torchsparse.nn as spnn
from tsparse.modules import SparseCostRegNet
from tsparse.torchsparse_utils import sparse_to_dense_channel
from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d
# from .gru_fusion import GRUFusion
from ops.back_project import back_project_sparse_type
from ops.generate_grids import generate_grid
from inplace_abn import InPlaceABN
from models.embedder import Embedding
from models.featurenet import ConvBnReLU
import pdb
import random
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
@torch.jit.script
def fused_mean_variance(x, weight):
mean = torch.sum(x * weight, dim=1, keepdim=True)
var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True)
return mean, var
class LatentSDFLayer(nn.Module):
def __init__(self,
d_in=3,
d_out=129,
d_hidden=128,
n_layers=4,
skip_in=(4,),
multires=0,
bias=0.5,
geometric_init=True,
weight_norm=True,
activation='softplus',
d_conditional_feature=16):
super(LatentSDFLayer, self).__init__()
self.d_conditional_feature = d_conditional_feature
# concat latent code for ench layer input excepting the first layer and the last layer
dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden]
dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out]
self.embed_fn_fine = None
if multires > 0:
embed_fn = Embedding(in_channels=d_in, N_freqs=multires) # * include the input
self.embed_fn_fine = embed_fn
dims_in[0] = embed_fn.out_channels
self.num_layers = n_layers
self.skip_in = skip_in
for l in range(0, self.num_layers - 1):
if l in self.skip_in:
in_dim = dims_in[l] + dims_in[0]
else:
in_dim = dims_in[l]
out_dim = dims_out[l]
lin = nn.Linear(in_dim, out_dim)
if geometric_init: # - from IDR code,
if l == self.num_layers - 2:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
# the channels for latent codes are set to 0
torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0)
torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0)
elif multires > 0 and l == 0: # the first layer
torch.nn.init.constant_(lin.bias, 0.0)
# * the channels for position embeddings are set to 0
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
# * the channels for the xyz coordinate (3 channels) for initialized by normal distribution
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
# * the channels for position embeddings (and conditional_feature) are initialized to 0
torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
# the channels for latent code are initialized to 0
torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0)
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
if activation == 'softplus':
self.activation = nn.Softplus(beta=100)
else:
assert activation == 'relu'
self.activation = nn.ReLU()
def forward(self, inputs, latent):
inputs = inputs
if self.embed_fn_fine is not None:
inputs = self.embed_fn_fine(inputs)
# - only for lod1 network can use the pretrained params of lod0 network
if latent.shape[1] != self.d_conditional_feature:
latent = torch.cat([latent, latent], dim=1)
x = inputs
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
# * due to the conditional bias, different from original neus version
if l in self.skip_in:
x = torch.cat([x, inputs], 1) / np.sqrt(2)
if 0 < l < self.num_layers - 1:
x = torch.cat([x, latent], 1)
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
return x
class SparseSdfNetwork(nn.Module):
'''
Coarse-to-fine sparse cost regularization network
return sparse volume feature for extracting sdf
'''
def __init__(self, lod, ch_in, voxel_size, vol_dims,
hidden_dim=128, activation='softplus',
cost_type='variance_mean',
d_pyramid_feature_compress=16,
regnet_d_out=8, num_sdf_layers=4,
multires=6,
):
super(SparseSdfNetwork, self).__init__()
self.lod = lod # - gradually training, the current regularization lod
self.ch_in = ch_in
self.voxel_size = voxel_size # - the voxel size of the current volume
self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume
self.selected_views_num = 2 # the number of selected views for feature aggregation
self.hidden_dim = hidden_dim
self.activation = activation
self.cost_type = cost_type
self.d_pyramid_feature_compress = d_pyramid_feature_compress
self.gru_fusion = None
self.regnet_d_out = regnet_d_out
self.multires = multires
self.pos_embedder = Embedding(3, self.multires)
self.compress_layer = ConvBnReLU(
self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1,
norm_act=InPlaceABN)
sparse_ch_in = self.d_pyramid_feature_compress * 2
sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in
self.sparse_costreg_net = SparseCostRegNet(
d_in=sparse_ch_in, d_out=self.regnet_d_out)
# self.regnet_d_out = self.sparse_costreg_net.d_out
if activation == 'softplus':
self.activation = nn.Softplus(beta=100)
else:
assert activation == 'relu'
self.activation = nn.ReLU()
self.sdf_layer = LatentSDFLayer(d_in=3,
d_out=self.hidden_dim + 1,
d_hidden=self.hidden_dim,
n_layers=num_sdf_layers,
multires=multires,
geometric_init=True,
weight_norm=True,
activation=activation,
d_conditional_feature=16 # self.regnet_d_out
)
def upsample(self, pre_feat, pre_coords, interval, num=8):
'''
:param pre_feat: (Tensor), features from last level, (N, C)
:param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z)
:param interval: interval of voxels, interval = scale ** 2
:param num: 1 -> 8
:return: up_feat : (Tensor), upsampled features, (N*8, C)
:return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z)
'''
with torch.no_grad():
pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]]
n, c = pre_feat.shape
up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous()
up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous()
for i in range(num - 1):
up_coords[:, i + 1, pos_list[i]] += interval
up_feat = up_feat.view(-1, c)
up_coords = up_coords.view(-1, 4)
return up_feat, up_coords
def aggregate_multiview_features(self, multiview_features, multiview_masks):
"""
aggregate mutli-view features by compute their cost variance
:param multiview_features: (num of voxels, num_of_views, c)
:param multiview_masks: (num of voxels, num_of_views)
:return:
"""
num_pts, n_views, C = multiview_features.shape
counts = torch.sum(multiview_masks, dim=1, keepdim=False) # [num_pts]
assert torch.all(counts > 0) # the point is visible for at least 1 view
volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) # [num_pts, C]
volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False)
if volume_sum.isnan().sum() > 0:
import ipdb; ipdb.set_trace()
del multiview_features
counts = 1. / (counts + 1e-5)
costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2
costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1)
del volume_sum, volume_sq_sum, counts
return costvar_mean
def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None):
"""
convert the sparse volume into dense volume to enable trilinear sampling
to save GPU memory;
:param coords: [num_pts, 3]
:param feature: [num_pts, C]
:param vol_dims: [3] dX, dY, dZ
:param interval:
:return:
"""
# * assume batch size is 1
if device is None:
device = feature.device
coords_int = (coords / interval).to(torch.int64)
vol_dims = (vol_dims / interval).to(torch.int64)
# - if stored in CPU, too slow
dense_volume = sparse_to_dense_channel(
coords_int.to(device), feature.to(device), vol_dims.to(device),
feature.shape[1], 0, device) # [X, Y, Z, C]
valid_mask_volume = sparse_to_dense_channel(
coords_int.to(device),
torch.ones([feature.shape[0], 1]).to(feature.device),
vol_dims.to(device),
1, 0, device) # [X, Y, Z, 1]
dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z]
valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z]
return dense_volume, valid_mask_volume
def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0,
pre_coords=None, pre_feats=None,
):
"""
:param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features
:param partial_vol_origin: [B, 3] the world coordinates of the volume origin (0,0,0)
:param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size
:param sizeH: the H of original image size
:param sizeW: the W of original image size
:param pre_coords: the coordinates of sparse volume from the prior lod
:param pre_feats: the features of sparse volume from the prior lod
:return:
"""
device = proj_mats.device
bs = feature_maps.shape[0]
N_views = feature_maps.shape[1]
minimum_visible_views = np.min([1, N_views - 1])
# import ipdb; ipdb.set_trace()
outputs = {}
pts_samples = []
# ----coarse to fine----
# * use fused pyramid feature maps are very important
if self.compress_layer is not None:
feats = self.compress_layer(feature_maps[0])
else:
feats = feature_maps[0]
feats = feats[:, None, :, :, :] # [V, B, C, H, W]
KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() # [V, B, 4, 4]
interval = 1
if self.lod == 0:
# ----generate new coords----
coords = generate_grid(self.vol_dims, 1)[0]
coords = coords.view(3, -1).to(device) # [3, num_pts]
up_coords = []
for b in range(bs):
up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords]))
up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous()
# * since we only estimate the geometry of input reference image at one time;
# * mask the outside of the camera frustum
# import ipdb; ipdb.set_trace()
frustum_mask = back_project_sparse_type(
up_coords, partial_vol_origin, self.voxel_size,
feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) # [num_pts, n_views]
frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views # ! here should be large
up_coords = up_coords[frustum_mask] # [num_pts_valid, 4]
else:
# ----upsample coords----
assert pre_feats is not None
assert pre_coords is not None
up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1)
# ----back project----
# give each valid 3d grid point all valid 2D features and masks
multiview_features, multiview_masks = back_project_sparse_type(
up_coords, partial_vol_origin, self.voxel_size, feats,
KRcam, sizeH=sizeH, sizeW=sizeW) # (num of voxels, num_of_views, c), (num of voxels, num_of_views)
# num_of_views = all views
# if multiview_features.isnan().sum() > 0:
# import ipdb; ipdb.set_trace()
# import ipdb; ipdb.set_trace()
if self.lod > 0:
# ! need another invalid voxels filtering
frustum_mask = torch.sum(multiview_masks, dim=-1) > 1
up_feat = up_feat[frustum_mask]
up_coords = up_coords[frustum_mask]
multiview_features = multiview_features[frustum_mask]
multiview_masks = multiview_masks[frustum_mask]
# if multiview_features.isnan().sum() > 0:
# import ipdb; ipdb.set_trace()
volume = self.aggregate_multiview_features(multiview_features, multiview_masks) # compute variance for all images features
# import ipdb; ipdb.set_trace()
# if volume.isnan().sum() > 0:
# import ipdb; ipdb.set_trace()
del multiview_features, multiview_masks
# ----concat feature from last stage----
if self.lod != 0:
feat = torch.cat([volume, up_feat], dim=1)
else:
feat = volume
# batch index is in the last position
r_coords = up_coords[:, [1, 2, 3, 0]]
# if feat.isnan().sum() > 0:
# print('feat has nan:', feat.isnan().sum())
# import ipdb; ipdb.set_trace()
sparse_feat = SparseTensor(feat, r_coords.to(
torch.int32)) # - directly use sparse tensor to avoid point2voxel operations
# import ipdb; ipdb.set_trace()
feat = self.sparse_costreg_net(sparse_feat)
dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval,
device=None) # [1, C/1, X, Y, Z]
# if dense_volume.isnan().sum() > 0:
# import ipdb; ipdb.set_trace()
outputs['dense_volume_scale%d' % self.lod] = dense_volume # [1, 16, 96, 96, 96]
outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96]
outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96]
outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device)
# import ipdb; ipdb.set_trace()
return outputs
def sdf(self, pts, conditional_volume, lod):
num_pts = pts.shape[0]
device = pts.device
pts_ = pts.clone()
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
pts = torch.flip(pts, dims=[-1])
# import ipdb; ipdb.set_trace()
sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts]
sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device)
sdf_pts = self.sdf_layer(pts_, sampled_feature)
outputs = {}
outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1]
outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:]
outputs['sampled_latent_scale%d' % lod] = sampled_feature
return outputs
@torch.no_grad()
def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0):
num_pts = pts.shape[0]
device = pts.device
pts_ = pts.clone()
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
pts = torch.flip(pts, dims=[-1])
sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True,
padding_mode='border')
sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device)
outputs = {}
outputs['sdf_pts_scale%d' % lod] = sdf
return outputs
@torch.no_grad()
def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin):
"""
:param conditional_volume: [1,C, dX,dY,dZ]
:param mask_volume: [1,1, dX,dY,dZ]
:param coords_volume: [1,3, dX,dY,dZ]
:return:
"""
device = conditional_volume.device
chunk_size = 10240
_, C, dX, dY, dZ = conditional_volume.shape
conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous()
mask_volume = mask_volume.view(-1)
coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous()
pts = coords_volume * self.voxel_size + partial_origin # [dX*dY*dZ, 3]
sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device)
conditional_volume = conditional_volume[mask_volume > 0]
pts = pts[mask_volume > 0]
conditional_volume = conditional_volume.split(chunk_size)
pts = pts.split(chunk_size)
sdf_all = []
for pts_part, feature_part in zip(pts, conditional_volume):
sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1]
sdf_all.append(sdf_part)
sdf_all = torch.cat(sdf_all, dim=0)
sdf_volume[mask_volume > 0] = sdf_all
sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ)
return sdf_volume
def gradient(self, x, conditional_volume, lod):
"""
return the gradient of specific lod
:param x:
:param lod:
:return:
"""
x.requires_grad_(True)
# import ipdb; ipdb.set_trace()
with torch.enable_grad():
output = self.sdf(x, conditional_volume, lod)
y = output['sdf_pts_scale%d' % lod]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
# ! Distributed Data Parallel doesn’t work with torch.autograd.grad()
# ! (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters).
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)
def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None):
"""
convert the sparse volume into dense volume to enable trilinear sampling
to save GPU memory;
:param coords: [num_pts, 3]
:param feature: [num_pts, C]
:param vol_dims: [3] dX, dY, dZ
:param interval:
:return:
"""
# * assume batch size is 1
if device is None:
device = feature.device
coords_int = (coords / interval).to(torch.int64)
vol_dims = (vol_dims / interval).to(torch.int64)
# - if stored in CPU, too slow
dense_volume = sparse_to_dense_channel(
coords_int.to(device), feature.to(device), vol_dims.to(device),
feature.shape[1], 0, device) # [X, Y, Z, C]
valid_mask_volume = sparse_to_dense_channel(
coords_int.to(device),
torch.ones([feature.shape[0], 1]).to(feature.device),
vol_dims.to(device),
1, 0, device) # [X, Y, Z, 1]
dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z]
valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z]
return dense_volume, valid_mask_volume
class SdfVolume(nn.Module):
def __init__(self, volume, coords=None, type='dense'):
super(SdfVolume, self).__init__()
self.volume = torch.nn.Parameter(volume, requires_grad=True)
self.coords = coords
self.type = type
def forward(self):
return self.volume
class FinetuneOctreeSdfNetwork(nn.Module):
'''
After obtain the conditional volume from generalized network;
directly optimize the conditional volume
The conditional volume is still sparse
'''
def __init__(self, voxel_size, vol_dims,
origin=[-1., -1., -1.],
hidden_dim=128, activation='softplus',
regnet_d_out=8,
multires=6,
if_fitted_rendering=True,
num_sdf_layers=4,
):
super(FinetuneOctreeSdfNetwork, self).__init__()
self.voxel_size = voxel_size # - the voxel size of the current volume
self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume
self.origin = torch.tensor(origin).to(torch.float32)
self.hidden_dim = hidden_dim
self.activation = activation
self.regnet_d_out = regnet_d_out
self.if_fitted_rendering = if_fitted_rendering
self.multires = multires
# d_in_embedding = self.regnet_d_out if self.pos_add_type == 'latent' else 3
# self.pos_embedder = Embedding(d_in_embedding, self.multires)
# - the optimized parameters
self.sparse_volume_lod0 = None
self.sparse_coords_lod0 = None
if activation == 'softplus':
self.activation = nn.Softplus(beta=100)
else:
assert activation == 'relu'
self.activation = nn.ReLU()
self.sdf_layer = LatentSDFLayer(d_in=3,
d_out=self.hidden_dim + 1,
d_hidden=self.hidden_dim,
n_layers=num_sdf_layers,
multires=multires,
geometric_init=True,
weight_norm=True,
activation=activation,
d_conditional_feature=16 # self.regnet_d_out
)
# - add mlp rendering when finetuning
self.renderer = None
d_in_renderer = 3 + self.regnet_d_out + 3 + 3
self.renderer = BlendingRenderingNetwork(
d_feature=self.hidden_dim - 1,
mode='idr', # ! the view direction influence a lot
d_in=d_in_renderer,
d_out=50, # maximum 50 images
d_hidden=self.hidden_dim,
n_layers=3,
weight_norm=True,
multires_view=4,
squeeze_out=True,
)
def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0,
sparse_volume_lod0=None, sparse_coords_lod0=None):
"""
:param dense_volume_lod0: [1,C,dX,dY,dZ]
:param dense_volume_mask_lod0: [1,1,dX,dY,dZ]
:param dense_volume_lod1:
:param dense_volume_mask_lod1:
:return:
"""
if sparse_volume_lod0 is None:
device = dense_volume_lod0.device
_, C, dX, dY, dZ = dense_volume_lod0.shape
dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous()
mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0
self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse')
coords = generate_grid(self.vol_dims, 1)[0] # [3, dX, dY, dZ]
coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device)
self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False)
else:
self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse')
self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False)
def get_conditional_volume(self):
dense_volume, valid_mask_volume = sparse_to_dense_volume(
self.sparse_coords_lod0,
self.sparse_volume_lod0(), self.vol_dims, interval=1,
device=None) # [1, C/1, X, Y, Z]
# valid_mask_volume = self.dense_volume_mask_lod0
outputs = {}
outputs['dense_volume_scale%d' % 0] = dense_volume
outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume
return outputs
def tv_regularizer(self):
dense_volume, valid_mask_volume = sparse_to_dense_volume(
self.sparse_coords_lod0,
self.sparse_volume_lod0(), self.vol_dims, interval=1,
device=None) # [1, C/1, X, Y, Z]
dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 # [1, C/1, X-1, Y, Z]
dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 # [1, C/1, X, Y-1, Z]
dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 # [1, C/1, X, Y, Z-1]
tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] # [1, C/1, X-1, Y-1, Z-1]
mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \
valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:]
tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask
# tv = tv.mean(dim=1, keepdim=True) * mask
assert torch.all(~torch.isnan(tv))
return torch.mean(tv)
def sdf(self, pts, conditional_volume, lod):
outputs = {}
num_pts = pts.shape[0]
device = pts.device
pts_ = pts.clone()
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
pts = torch.flip(pts, dims=[-1])
sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts]
sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous()
outputs['sampled_latent_scale%d' % lod] = sampled_feature
sdf_pts = self.sdf_layer(pts_, sampled_feature)
lod = 0
outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1]
outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:]
return outputs
def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index,
pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None):
return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors,
img_index, pts_pixel_color, pts_pixel_mask,
pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask)
def gradient(self, x, conditional_volume, lod):
"""
return the gradient of specific lod
:param x:
:param lod:
:return:
"""
x.requires_grad_(True)
output = self.sdf(x, conditional_volume, lod)
y = output['sdf_pts_scale%d' % 0]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)
@torch.no_grad()
def prune_dense_mask(self, threshold=0.02):
"""
Just gradually prune the mask of dense volume to decrease the number of sdf network inference
:return:
"""
chunk_size = 10240
coords = generate_grid(self.vol_dims_lod0, 1)[0] # [3, dX, dY, dZ]
_, dX, dY, dZ = coords.shape
pts = coords.view(3, -1).permute(1,
0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] # [dX*dY*dZ, 3]
# dense_volume = self.dense_volume_lod0() # [1,C,dX,dY,dZ]
dense_volume, _ = sparse_to_dense_volume(
self.sparse_coords_lod0,
self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1,
device=None) # [1, C/1, X, Y, Z]
sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100
mask = self.dense_volume_mask_lod0.view(-1) > 0
pts_valid = pts[mask].to(dense_volume.device)
feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask]
pts_valid = pts_valid.split(chunk_size)
feature_valid = feature_valid.split(chunk_size)
sdf_list = []
for pts_part, feature_part in zip(pts_valid, feature_valid):
sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1]
sdf_list.append(sdf_part)
sdf_list = torch.cat(sdf_list, dim=0)
sdf_volume[mask] = sdf_list
occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1]
# - dilate
occupancy_mask = occupancy_mask.float()
occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
occupancy_mask = occupancy_mask > 0
self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0,
occupancy_mask).float() # (1, 1, dX, dY, dZ)
class BlendingRenderingNetwork(nn.Module):
def __init__(
self,
d_feature,
mode,
d_in,
d_out,
d_hidden,
n_layers,
weight_norm=True,
multires_view=0,
squeeze_out=True,
):
super(BlendingRenderingNetwork, self).__init__()
self.mode = mode
self.squeeze_out = squeeze_out
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
self.embedder = None
if multires_view > 0:
self.embedder = Embedding(3, multires_view)
dims[0] += (self.embedder.out_channels - 3)
self.num_layers = len(dims)
for l in range(0, self.num_layers - 1):
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.relu = nn.ReLU()
self.color_volume = None
self.softmax = nn.Softmax(dim=1)
self.type = 'blending'
def sample_pts_from_colorVolume(self, pts):
device = pts.device
num_pts = pts.shape[0]
pts_ = pts.clone()
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
pts = torch.flip(pts, dims=[-1])
sampled_color = grid_sample_3d(self.color_volume, pts) # [1, c, 1, 1, num_pts]
sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device)
return sampled_color
def forward(self, position, normals, view_dirs, feature_vectors, img_index,
pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None):
"""
:param position: can be 3d coord or interpolated volume latent
:param normals:
:param view_dirs:
:param feature_vectors:
:param img_index: [N_views], used to extract corresponding weights
:param pts_pixel_color: [N_pts, N_views, 3]
:param pts_pixel_mask: [N_pts, N_views]
:param pts_patch_color: [N_pts, N_views, Npx, 3]
:return:
"""
if self.embedder is not None:
view_dirs = self.embedder(view_dirs)
rendering_input = None
if self.mode == 'idr':
rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1)
elif self.mode == 'no_view_dir':
rendering_input = torch.cat([position, normals, feature_vectors], dim=-1)
elif self.mode == 'no_normal':
rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1)
elif self.mode == 'no_points':
rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1)
elif self.mode == 'no_points_no_view_dir':
rendering_input = torch.cat([normals, feature_vectors], dim=-1)
x = rendering_input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
x = lin(x)
if l < self.num_layers - 2:
x = self.relu(x) # [n_pts, d_out]
## extract value based on img_index
x_extracted = torch.index_select(x, 1, img_index.long())
weights_pixel = self.softmax(x_extracted) # [n_pts, N_views]
weights_pixel = weights_pixel * pts_pixel_mask
weights_pixel = weights_pixel / (
torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views]
final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1,
keepdim=False) # [N_pts, 3]
final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 # [N_pts, 1]
final_patch_color, final_patch_mask = None, None
# pts_patch_color [N_pts, N_views, Npx, 3]; pts_patch_mask [N_pts, N_views, Npx]
if pts_patch_color is not None:
N_pts, N_views, Npx, _ = pts_patch_color.shape
patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 # [N_pts, N_views]
weights_patch = self.softmax(x_extracted) # [N_pts, N_views]
weights_patch = weights_patch * patch_mask
weights_patch = weights_patch / (
torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views]
final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1,
keepdim=False) # [N_pts, Npx, 3]
final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 # [N_pts, 1] at least one image sees
return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask