FNeVR_demo / modules /nerf_verts_util.py
PascalLiu's picture
complete image and video
d51072c
raw
history blame
No virus
8.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
import einops
from modules.util import UpBlock2d, DownBlock2d
def make_coordinate_grid(spatial_size, type):
d, h, w = spatial_size
x = torch.arange(w).type(type)
y = torch.arange(h).type(type)
z = torch.arange(d).type(type)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
z = (2 * (z / (d - 1)) - 1)
yy = y.view(1, -1, 1).repeat(d, 1, w)
xx = x.view(1, 1, -1).repeat(d, h, 1)
zz = z.view(-1, 1, 1).repeat(1, h, w)
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
return meshed
def kp2gaussian_3d(kp, spatial_size, kp_variance):
"""
Transform a keypoint into gaussian like representation
"""
# mean = kp['value']
mean = kp
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
coordinate_grid = coordinate_grid.view(*shape)
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
coordinate_grid = coordinate_grid.repeat(*repeats)
# Preprocess kp shape
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
mean = mean.view(*shape)
mean_sub = (coordinate_grid - mean)
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
return out
class ResBlock3d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock3d, self).__init__()
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.norm1 = BatchNorm3d(in_features, affine=True)
self.norm2 = BatchNorm3d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class rgb_predictor(nn.Module):
def __init__(self, in_channels, simpled_channel=128, floor_num=8):
super(rgb_predictor, self).__init__()
self.floor_num = floor_num
self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)
def forward(self, feature):
"""
Args:
feature: warp feature: bs * c * h * w
Returns:
rgb: bs * h * w * floor_num * e
"""
feature = self.down_conv(feature)
feature = einops.rearrange(feature, 'b (c f) h w -> b c f h w', f=self.floor_num)
feature = einops.rearrange(feature, 'b c f h w -> b h w f c')
return feature
class sigma_predictor(nn.Module):
def __init__(self, in_channels, simpled_channel=128, floor_num=8):
super(sigma_predictor, self).__init__()
self.floor_num = floor_num
self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)
self.res_conv3d = nn.Sequential(
ResBlock3d(16, 3, 1),
nn.BatchNorm3d(16),
ResBlock3d(16, 3, 1),
nn.BatchNorm3d(16),
ResBlock3d(16, 3, 1),
nn.BatchNorm3d(16)
)
def forward(self, feature):
"""
Args:
feature: bs * h * w * floor * c, the output of rgb predictor
Returns:
sigma: bs * h * w * floor * encode
point: bs * 5023 * 3
"""
heatmap = self.down_conv(feature)
heatmap = einops.rearrange(heatmap, "b (c f) h w -> b c f h w", f=self.floor_num)
heatmap = self.res_conv3d(heatmap)
sigma = einops.rearrange(heatmap, "b c f h w -> b h w f c")
point_dict = {'sigma_map': heatmap}
# point_pred = einops.rearrange(point_pred, 'b p n -> b n p')
return sigma, point_dict
class MultiHeadNeRFModel(torch.nn.Module):
def __init__(self, hidden_size=128, num_encoding_rgb=16, num_encoding_sigma=16):
super(MultiHeadNeRFModel, self).__init__()
# self.xyz_encoding_dims = 1 + 1 * 2 * num_encoding_functions + num_encoding_rgb
self.xyz_encoding_dims = num_encoding_sigma
self.viewdir_encoding_dims = num_encoding_rgb
# Input layer (default: 16 -> 128)
self.layer1 = torch.nn.Linear(self.xyz_encoding_dims, hidden_size)
# Layer 2 (default: 128 -> 128)
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
# Layer 3_1 (default: 128 -> 1): Predicts radiance ("sigma")
self.layer3_1 = torch.nn.Linear(hidden_size, 1)
# Layer 3_2 (default: 128 -> 32): Predicts a feature vector (used for color)
self.layer3_2 = torch.nn.Linear(hidden_size, hidden_size // 4)
self.layer3_3 = torch.nn.Linear(self.viewdir_encoding_dims, hidden_size)
# Layer 4 (default: 32 + 128 -> 128)
self.layer4 = torch.nn.Linear(
hidden_size // 4 + hidden_size, hidden_size
)
# Layer 5 (default: 128 -> 128)
self.layer5 = torch.nn.Linear(hidden_size, hidden_size)
# Layer 6 (default: 128 -> 256): Predicts RGB color
self.layer6 = torch.nn.Linear(hidden_size, 256)
# Short hand for torch.nn.functional.relu
self.relu = torch.nn.functional.relu
def forward(self, rgb_in, sigma_in):
"""
Args:
x: rgb pred result of Perdict3D
view: result of LightPredict
Returns:
"""
bs, h, w, floor_num, _ = rgb_in.size()
# x = torch.cat((x, point3D), dim=-1)
out = self.relu(self.layer1(sigma_in))
out = self.relu(self.layer2(out))
sigma = self.layer3_1(out)
feat_sigma = self.relu(self.layer3_2(out))
feat_rgb = self.relu(self.layer3_3(rgb_in))
x = torch.cat((feat_sigma, feat_rgb), dim=-1)
x = self.relu(self.layer4(x))
x = self.relu(self.layer5(x))
x = self.layer6(x)
return x, sigma
def volume_render(rgb_pred, sigma_pred):
"""
Args:
rgb_pred: result of Nerf, [bs, h, w, floor, rgb_channel]
sigma_pred: result of Nerf, [bs, h, w, floor, sigma_channel]
Returns:
"""
_, _, _, floor, _ = sigma_pred.size()
c = 0
T = 0
for i in range(floor):
sigma_mid = torch.nn.functional.relu(sigma_pred[:, :, :, i, :])
T = T + (-sigma_mid)
c = c + torch.exp(T) * (1 - torch.exp(-sigma_mid)) * rgb_pred[:, :, :, i, :]
c = einops.rearrange(c, 'b h w c -> b c h w')
return c
class RenderModel(nn.Module):
def __init__(self, in_channels, simpled_channel_rgb, simpled_channel_sigma, floor_num, hidden_size):
super(RenderModel, self).__init__()
self.rgb_predict = rgb_predictor(in_channels=in_channels, simpled_channel=simpled_channel_rgb,
floor_num=floor_num)
self.sigma_predict = sigma_predictor(in_channels=in_channels, simpled_channel=simpled_channel_sigma,
floor_num=floor_num)
num_encoding_rgb, num_encoding_sigma = simpled_channel_rgb // floor_num, simpled_channel_sigma // floor_num
self.nerf_module = MultiHeadNeRFModel(hidden_size=hidden_size, num_encoding_rgb=num_encoding_rgb,
num_encoding_sigma=num_encoding_sigma)
self.mini_decoder = nn.Sequential(
UpBlock2d(256, 64, kernel_size=3, padding=1),
nn.ReLU(),
UpBlock2d(64, 3, kernel_size=3, padding=1),
nn.Sigmoid()
)
def forward(self, feature):
rgb_in = self.rgb_predict(feature)
# sigma_in, point_dict = self.sigma_predict(feature.detach())
sigma_in, point_dict = self.sigma_predict(feature)
rgb_out, sigma_out = self.nerf_module(rgb_in, sigma_in)
render_result = volume_render(rgb_out, sigma_out)
render_result = torch.sigmoid(render_result)
mini_pred = self.mini_decoder(render_result)
out_dict = {'render': render_result, 'mini_pred': mini_pred, 'point_pred': point_dict}
return out_dict