Spaces:
Build error
Build error
# coding=utf-8 | |
# Copyright 2021 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Lint as: python3 | |
"""Different model implementation plus a general port for all the models.""" | |
from typing import Any, Callable | |
from flax import linen as nn | |
from jax import random | |
import jax.numpy as jnp | |
from jaxnerf.nerf import model_utils | |
from jaxnerf.nerf import utils | |
def get_model(key, example_batch, args): | |
"""A helper function that wraps around a 'model zoo'.""" | |
model_dict = {"nerf": construct_nerf} | |
return model_dict[args.model](key, example_batch, args) | |
class NerfModel(nn.Module): | |
"""Nerf NN Model with both coarse and fine MLPs.""" | |
num_coarse_samples: int # The number of samples for the coarse nerf. | |
num_fine_samples: int # The number of samples for the fine nerf. | |
use_viewdirs: bool # If True, use viewdirs as an input. | |
near: float # The distance to the near plane | |
far: float # The distance to the far plane | |
noise_std: float # The std dev of noise added to raw sigma. | |
net_depth: int # The depth of the first part of MLP. | |
net_width: int # The width of the first part of MLP. | |
net_depth_condition: int # The depth of the second part of MLP. | |
net_width_condition: int # The width of the second part of MLP. | |
net_activation: Callable[..., Any] # MLP activation | |
skip_layer: int # How often to add skip connections. | |
num_rgb_channels: int # The number of RGB channels. | |
num_sigma_channels: int # The number of density channels. | |
white_bkgd: bool # If True, use a white background. | |
min_deg_point: int # The minimum degree of positional encoding for positions. | |
max_deg_point: int # The maximum degree of positional encoding for positions. | |
deg_view: int # The degree of positional encoding for viewdirs. | |
lindisp: bool # If True, sample linearly in disparity rather than in depth. | |
rgb_activation: Callable[..., Any] # Output RGB activation. | |
sigma_activation: Callable[..., Any] # Output sigma activation. | |
legacy_posenc_order: bool # Keep the same ordering as the original tf code. | |
def __call__(self, rng_0, rng_1, rays, randomized): | |
"""Nerf Model. | |
Args: | |
rng_0: jnp.ndarray, random number generator for coarse model sampling. | |
rng_1: jnp.ndarray, random number generator for fine model sampling. | |
rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs. | |
randomized: bool, use randomized stratified sampling. | |
Returns: | |
ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)] | |
""" | |
# Stratified sampling along rays | |
key, rng_0 = random.split(rng_0) | |
z_vals, samples = model_utils.sample_along_rays( | |
key, | |
rays.origins, | |
rays.directions, | |
self.num_coarse_samples, | |
self.near, | |
self.far, | |
randomized, | |
self.lindisp, | |
) | |
samples_enc = model_utils.posenc( | |
samples, | |
self.min_deg_point, | |
self.max_deg_point, | |
self.legacy_posenc_order, | |
) | |
# Construct the "coarse" MLP. | |
coarse_mlp = model_utils.MLP( | |
net_depth=self.net_depth, | |
net_width=self.net_width, | |
net_depth_condition=self.net_depth_condition, | |
net_width_condition=self.net_width_condition, | |
net_activation=self.net_activation, | |
skip_layer=self.skip_layer, | |
num_rgb_channels=self.num_rgb_channels, | |
num_sigma_channels=self.num_sigma_channels) | |
# Point attribute predictions | |
if self.use_viewdirs: | |
viewdirs_enc = model_utils.posenc( | |
rays.viewdirs, | |
0, | |
self.deg_view, | |
self.legacy_posenc_order, | |
) | |
raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc) | |
else: | |
viewdirs_enc = None | |
raw_rgb, raw_sigma = coarse_mlp(samples_enc) | |
# Add noises to regularize the density predictions if needed | |
key, rng_0 = random.split(rng_0) | |
raw_sigma = model_utils.add_gaussian_noise( | |
key, | |
raw_sigma, | |
self.noise_std, | |
randomized, | |
) | |
rgb = self.rgb_activation(raw_rgb) | |
sigma = self.sigma_activation(raw_sigma) | |
# Volumetric rendering. | |
comp_rgb, disp, acc, weights = model_utils.volumetric_rendering( | |
rgb, | |
sigma, | |
z_vals, | |
rays.directions, | |
white_bkgd=self.white_bkgd, | |
) | |
ret = [ | |
(comp_rgb, disp, acc), | |
] | |
# Hierarchical sampling based on coarse predictions | |
if self.num_fine_samples > 0: | |
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) | |
key, rng_1 = random.split(rng_1) | |
z_vals, samples = model_utils.sample_pdf( | |
key, | |
z_vals_mid, | |
weights[..., 1:-1], | |
rays.origins, | |
rays.directions, | |
z_vals, | |
self.num_fine_samples, | |
randomized, | |
) | |
samples_enc = model_utils.posenc( | |
samples, | |
self.min_deg_point, | |
self.max_deg_point, | |
self.legacy_posenc_order, | |
) | |
# Construct the "fine" MLP. | |
fine_mlp = model_utils.MLP( | |
net_depth=self.net_depth, | |
net_width=self.net_width, | |
net_depth_condition=self.net_depth_condition, | |
net_width_condition=self.net_width_condition, | |
net_activation=self.net_activation, | |
skip_layer=self.skip_layer, | |
num_rgb_channels=self.num_rgb_channels, | |
num_sigma_channels=self.num_sigma_channels) | |
if self.use_viewdirs: | |
raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc) | |
else: | |
raw_rgb, raw_sigma = fine_mlp(samples_enc) | |
key, rng_1 = random.split(rng_1) | |
raw_sigma = model_utils.add_gaussian_noise( | |
key, | |
raw_sigma, | |
self.noise_std, | |
randomized, | |
) | |
rgb = self.rgb_activation(raw_rgb) | |
sigma = self.sigma_activation(raw_sigma) | |
comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering( | |
rgb, | |
sigma, | |
z_vals, | |
rays.directions, | |
white_bkgd=self.white_bkgd, | |
) | |
ret.append((comp_rgb, disp, acc)) | |
return ret | |
def construct_nerf(key, example_batch, args): | |
"""Construct a Neural Radiance Field. | |
Args: | |
key: jnp.ndarray. Random number generator. | |
example_batch: dict, an example of a batch of data. | |
args: FLAGS class. Hyperparameters of nerf. | |
Returns: | |
model: nn.Model. Nerf model with parameters. | |
state: flax.Module.state. Nerf model state for stateful parameters. | |
""" | |
net_activation = getattr(nn, str(args.net_activation)) | |
rgb_activation = getattr(nn, str(args.rgb_activation)) | |
sigma_activation = getattr(nn, str(args.sigma_activation)) | |
# Assert that rgb_activation always produces outputs in [0, 1], and | |
# sigma_activation always produce non-negative outputs. | |
x = jnp.exp(jnp.linspace(-90, 90, 1024)) | |
x = jnp.concatenate([-x[::-1], x], 0) | |
rgb = rgb_activation(x) | |
if jnp.any(rgb < 0) or jnp.any(rgb > 1): | |
raise NotImplementedError( | |
"Choice of rgb_activation `{}` produces colors outside of [0, 1]" | |
.format(args.rgb_activation)) | |
sigma = sigma_activation(x) | |
if jnp.any(sigma < 0): | |
raise NotImplementedError( | |
"Choice of sigma_activation `{}` produces negative densities".format( | |
args.sigma_activation)) | |
model = NerfModel( | |
min_deg_point=args.min_deg_point, | |
max_deg_point=args.max_deg_point, | |
deg_view=args.deg_view, | |
num_coarse_samples=args.num_coarse_samples, | |
num_fine_samples=args.num_fine_samples, | |
use_viewdirs=args.use_viewdirs, | |
near=args.near, | |
far=args.far, | |
noise_std=args.noise_std, | |
white_bkgd=args.white_bkgd, | |
net_depth=args.net_depth, | |
net_width=args.net_width, | |
net_depth_condition=args.net_depth_condition, | |
net_width_condition=args.net_width_condition, | |
skip_layer=args.skip_layer, | |
num_rgb_channels=args.num_rgb_channels, | |
num_sigma_channels=args.num_sigma_channels, | |
lindisp=args.lindisp, | |
net_activation=net_activation, | |
rgb_activation=rgb_activation, | |
sigma_activation=sigma_activation, | |
legacy_posenc_order=args.legacy_posenc_order) | |
rays = example_batch["rays"] | |
key1, key2, key3 = random.split(key, num=3) | |
init_variables = model.init( | |
key1, | |
rng_0=key2, | |
rng_1=key3, | |
rays=utils.namedtuple_map(lambda x: x[0], rays), | |
randomized=args.randomized) | |
return model, init_variables | |