File size: 5,510 Bytes
854f0d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# the codes are partly borrowed from IBRNet
import torch
import torch.nn as nn
import torch.nn.functional as F
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
# default tensorflow initialization of linear layers
def weights_init(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias.data)
@torch.jit.script
def fused_mean_variance(x, weight):
mean = torch.sum(x * weight, dim=2, keepdim=True)
var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True)
return mean, var
class GeneralRenderingNetwork(nn.Module):
"""
This model is not sensitive to finetuning
"""
def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True):
super(GeneralRenderingNetwork, self).__init__()
self.in_geometry_feat_ch = in_geometry_feat_ch
self.in_rendering_feat_ch = in_rendering_feat_ch
self.anti_alias_pooling = anti_alias_pooling
if self.anti_alias_pooling:
self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True)
activation_func = nn.ELU(inplace=True)
self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16),
activation_func,
nn.Linear(16, in_rendering_feat_ch + 3),
activation_func)
self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64),
activation_func,
nn.Linear(64, 32),
activation_func)
self.vis_fc = nn.Sequential(nn.Linear(32, 32),
activation_func,
nn.Linear(32, 33),
activation_func,
)
self.vis_fc2 = nn.Sequential(nn.Linear(32, 32),
activation_func,
nn.Linear(32, 1),
nn.Sigmoid()
)
self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16),
activation_func,
nn.Linear(16, 8),
activation_func,
nn.Linear(8, 1))
self.base_fc.apply(weights_init)
self.vis_fc2.apply(weights_init)
self.vis_fc.apply(weights_init)
self.rgb_fc.apply(weights_init)
def forward(self, geometry_feat, rgb_feat, ray_diff, mask):
'''
:param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat]
:param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat]
:param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions,
last channel is inner product
:param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples]
:return: rgb and density output, [n_rays, n_samples, 4]
'''
rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous()
ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous()
mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous()
num_views = rgb_feat.shape[2]
geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1)
direction_feat = self.ray_dir_fc(ray_diff)
rgb_in = rgb_feat[..., :3]
rgb_feat = rgb_feat + direction_feat
if self.anti_alias_pooling:
_, dot_prod = torch.split(ray_diff, [3, 1], dim=-1)
exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1))
weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask
weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8)
else:
weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
# compute mean and variance across different views for each point
mean, var = fused_mean_variance(rgb_feat, weight) # [n_rays, n_samples, 1, n_feat]
globalfeat = torch.cat([mean, var], dim=-1) # [n_rays, n_samples, 1, 2*n_feat]
x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat],
dim=-1) # [n_rays, n_samples, n_views, 3*n_feat+n_geo_feat]
x = self.base_fc(x)
x_vis = self.vis_fc(x * weight)
x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1)
vis = F.sigmoid(vis) * mask
x = x + x_res
vis = self.vis_fc2(x * vis) * mask
# rgb computation
x = torch.cat([x, vis, ray_diff], dim=-1)
x = self.rgb_fc(x)
x = x.masked_fill(mask == 0, -1e9)
blending_weights_valid = F.softmax(x, dim=2) # color blending
rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2)
mask = mask.detach().to(rgb_out.dtype) # [n_rays, n_samples, n_views, 1]
mask = torch.sum(mask, dim=2, keepdim=False)
mask = mask >= 2 # more than 2 views see the point
mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False)
valid_mask = mask > 8 # valid rays, more than 8 valid samples
return rgb_out, valid_mask # (N_rays, n_samples, 3), (N_rays, 1)
|