# ORIGINAL LICENSE # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # Modified by Jiale Xu # The modifications are subject to the same license as the original. import itertools import torch import torch.nn as nn from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes class OSGDecoder(nn.Module): """ Triplane decoder that gives RGB and sigma values from sampled features. Using ReLU here instead of Softplus in the original implementation. Reference: EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 """ def __init__(self, n_features: int, hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): super().__init__() self.net_sdf = nn.Sequential( nn.Linear(3 * n_features, hidden_dim), activation(), *itertools.chain(*[[ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2)]), nn.Linear(hidden_dim, 1), ) self.net_rgb = nn.Sequential( nn.Linear(3 * n_features, hidden_dim), activation(), *itertools.chain(*[[ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2)]), nn.Linear(hidden_dim, 3), ) self.net_deformation = nn.Sequential( nn.Linear(3 * n_features, hidden_dim), activation(), *itertools.chain(*[[ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2)]), nn.Linear(hidden_dim, 3), ) self.net_weight = nn.Sequential( nn.Linear(8 * 3 * n_features, hidden_dim), activation(), *itertools.chain(*[[ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2)]), nn.Linear(hidden_dim, 21), ) # init all bias to zero for m in self.modules(): if isinstance(m, nn.Linear): nn.init.zeros_(m.bias) def get_geometry_prediction(self, sampled_features, flexicubes_indices): _N, n_planes, _M, _C = sampled_features.shape sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) sdf = self.net_sdf(sampled_features) deformation = self.net_deformation(sampled_features) grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) grid_features = grid_features.reshape( sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) weight = self.net_weight(grid_features) * 0.1 return sdf, deformation, weight def get_texture_prediction(self, sampled_features): _N, n_planes, _M, _C = sampled_features.shape sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) rgb = self.net_rgb(sampled_features) rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF return rgb class TriplaneSynthesizer(nn.Module): """ Synthesizer that renders a triplane volume with planes and a camera. Reference: EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 """ DEFAULT_RENDERING_KWARGS = { 'ray_start': 'auto', 'ray_end': 'auto', 'box_warp': 2., 'white_back': True, 'disparity_space_sampling': False, 'clamp_mode': 'softplus', 'sampler_bbox_min': -1., 'sampler_bbox_max': 1., } def __init__(self, triplane_dim: int, samples_per_ray: int): super().__init__() # attributes self.triplane_dim = triplane_dim self.rendering_kwargs = { **self.DEFAULT_RENDERING_KWARGS, 'depth_resolution': samples_per_ray // 2, 'depth_resolution_importance': samples_per_ray // 2, } # modules self.plane_axes = generate_planes() self.decoder = OSGDecoder(n_features=triplane_dim) def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): plane_axes = self.plane_axes.to(planes.device) sampled_features = sample_from_planes( plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) return sdf, deformation, weight def get_texture_prediction(self, planes, sample_coordinates): plane_axes = self.plane_axes.to(planes.device) sampled_features = sample_from_planes( plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) rgb = self.decoder.get_texture_prediction(sampled_features) return rgb