|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
torch._C._jit_set_profiling_executor(False) |
|
torch._C._jit_set_profiling_mode(False) |
|
|
|
|
|
|
|
def weights_init(m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.kaiming_normal_(m.weight.data) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias.data) |
|
|
|
|
|
@torch.jit.script |
|
def fused_mean_variance(x, weight): |
|
mean = torch.sum(x * weight, dim=2, keepdim=True) |
|
var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True) |
|
return mean, var |
|
|
|
|
|
class GeneralRenderingNetwork(nn.Module): |
|
""" |
|
This model is not sensitive to finetuning |
|
""" |
|
|
|
def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True): |
|
super(GeneralRenderingNetwork, self).__init__() |
|
|
|
self.in_geometry_feat_ch = in_geometry_feat_ch |
|
self.in_rendering_feat_ch = in_rendering_feat_ch |
|
self.anti_alias_pooling = anti_alias_pooling |
|
|
|
if self.anti_alias_pooling: |
|
self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True) |
|
activation_func = nn.ELU(inplace=True) |
|
|
|
self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16), |
|
activation_func, |
|
nn.Linear(16, in_rendering_feat_ch + 3), |
|
activation_func) |
|
|
|
self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64), |
|
activation_func, |
|
nn.Linear(64, 32), |
|
activation_func) |
|
|
|
self.vis_fc = nn.Sequential(nn.Linear(32, 32), |
|
activation_func, |
|
nn.Linear(32, 33), |
|
activation_func, |
|
) |
|
|
|
self.vis_fc2 = nn.Sequential(nn.Linear(32, 32), |
|
activation_func, |
|
nn.Linear(32, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16), |
|
activation_func, |
|
nn.Linear(16, 8), |
|
activation_func, |
|
nn.Linear(8, 1)) |
|
|
|
self.base_fc.apply(weights_init) |
|
self.vis_fc2.apply(weights_init) |
|
self.vis_fc.apply(weights_init) |
|
self.rgb_fc.apply(weights_init) |
|
|
|
def forward(self, geometry_feat, rgb_feat, ray_diff, mask): |
|
''' |
|
:param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat] |
|
:param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat] |
|
:param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions, |
|
last channel is inner product |
|
:param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples] |
|
:return: rgb and density output, [n_rays, n_samples, 4] |
|
''' |
|
|
|
rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous() |
|
ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous() |
|
mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous() |
|
num_views = rgb_feat.shape[2] |
|
geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1) |
|
|
|
direction_feat = self.ray_dir_fc(ray_diff) |
|
rgb_in = rgb_feat[..., :3] |
|
rgb_feat = rgb_feat + direction_feat |
|
|
|
if self.anti_alias_pooling: |
|
_, dot_prod = torch.split(ray_diff, [3, 1], dim=-1) |
|
exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1)) |
|
weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask |
|
weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8) |
|
else: |
|
weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) |
|
|
|
|
|
mean, var = fused_mean_variance(rgb_feat, weight) |
|
globalfeat = torch.cat([mean, var], dim=-1) |
|
|
|
x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat], |
|
dim=-1) |
|
x = self.base_fc(x) |
|
|
|
x_vis = self.vis_fc(x * weight) |
|
x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1) |
|
vis = F.sigmoid(vis) * mask |
|
x = x + x_res |
|
vis = self.vis_fc2(x * vis) * mask |
|
|
|
|
|
x = torch.cat([x, vis, ray_diff], dim=-1) |
|
x = self.rgb_fc(x) |
|
x = x.masked_fill(mask == 0, -1e9) |
|
blending_weights_valid = F.softmax(x, dim=2) |
|
rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2) |
|
|
|
mask = mask.detach().to(rgb_out.dtype) |
|
mask = torch.sum(mask, dim=2, keepdim=False) |
|
mask = mask >= 2 |
|
mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False) |
|
valid_mask = mask > 8 |
|
return rgb_out, valid_mask |
|
|