code / SparseNeuS_demo_v1 /models /rendering_network.py
Chao Xu
sparseneus and elev est
854f0d0
# the codes are partly borrowed from IBRNet
import torch
import torch.nn as nn
import torch.nn.functional as F
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
# default tensorflow initialization of linear layers
def weights_init(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias.data)
@torch.jit.script
def fused_mean_variance(x, weight):
mean = torch.sum(x * weight, dim=2, keepdim=True)
var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True)
return mean, var
class GeneralRenderingNetwork(nn.Module):
"""
This model is not sensitive to finetuning
"""
def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True):
super(GeneralRenderingNetwork, self).__init__()
self.in_geometry_feat_ch = in_geometry_feat_ch
self.in_rendering_feat_ch = in_rendering_feat_ch
self.anti_alias_pooling = anti_alias_pooling
if self.anti_alias_pooling:
self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True)
activation_func = nn.ELU(inplace=True)
self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16),
activation_func,
nn.Linear(16, in_rendering_feat_ch + 3),
activation_func)
self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64),
activation_func,
nn.Linear(64, 32),
activation_func)
self.vis_fc = nn.Sequential(nn.Linear(32, 32),
activation_func,
nn.Linear(32, 33),
activation_func,
)
self.vis_fc2 = nn.Sequential(nn.Linear(32, 32),
activation_func,
nn.Linear(32, 1),
nn.Sigmoid()
)
self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16),
activation_func,
nn.Linear(16, 8),
activation_func,
nn.Linear(8, 1))
self.base_fc.apply(weights_init)
self.vis_fc2.apply(weights_init)
self.vis_fc.apply(weights_init)
self.rgb_fc.apply(weights_init)
def forward(self, geometry_feat, rgb_feat, ray_diff, mask):
'''
:param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat]
:param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat]
:param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions,
last channel is inner product
:param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples]
:return: rgb and density output, [n_rays, n_samples, 4]
'''
rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous()
ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous()
mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous()
num_views = rgb_feat.shape[2]
geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1)
direction_feat = self.ray_dir_fc(ray_diff)
rgb_in = rgb_feat[..., :3]
rgb_feat = rgb_feat + direction_feat
if self.anti_alias_pooling:
_, dot_prod = torch.split(ray_diff, [3, 1], dim=-1)
exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1))
weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask
weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8)
else:
weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
# compute mean and variance across different views for each point
mean, var = fused_mean_variance(rgb_feat, weight) # [n_rays, n_samples, 1, n_feat]
globalfeat = torch.cat([mean, var], dim=-1) # [n_rays, n_samples, 1, 2*n_feat]
x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat],
dim=-1) # [n_rays, n_samples, n_views, 3*n_feat+n_geo_feat]
x = self.base_fc(x)
x_vis = self.vis_fc(x * weight)
x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1)
vis = F.sigmoid(vis) * mask
x = x + x_res
vis = self.vis_fc2(x * vis) * mask
# rgb computation
x = torch.cat([x, vis, ray_diff], dim=-1)
x = self.rgb_fc(x)
x = x.masked_fill(mask == 0, -1e9)
blending_weights_valid = F.softmax(x, dim=2) # color blending
rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2)
mask = mask.detach().to(rgb_out.dtype) # [n_rays, n_samples, n_views, 1]
mask = torch.sum(mask, dim=2, keepdim=False)
mask = mask >= 2 # more than 2 views see the point
mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False)
valid_mask = mask > 8 # valid rays, more than 8 valid samples
return rgb_out, valid_mask # (N_rays, n_samples, 3), (N_rays, 1)