NeRF_ficus-pxl / rendering.py
Alesteba's picture
Update rendering.py
014a7c9
import streamlit as st
import tensorflow as tf
import numpy as np
from config import *
def encode_position(x):
"""Encodes the position into its corresponding Fourier feature.
Args:
x: The input coordinate.
Returns:
Fourier features tensors of the position.
"""
positions = [x]
for i in range(POS_ENCODE_DIMS):
for fn in [tf.sin, tf.cos]:
positions.append(fn(2.0 ** i * x))
return tf.concat(positions, axis=-1)
def get_rays(height, width, focal, pose):
"""Computes origin point and direction vector of rays.
Args:
height: Height of the image.
width: Width of the image.
focal: The focal length between the images and the camera.
pose: The pose matrix of the camera.
Returns:
Tuple of origin point and direction vector for rays.
"""
# Build a meshgrid for the rays.
i, j = tf.meshgrid(
tf.range(width, dtype=tf.float32),
tf.range(height, dtype=tf.float32),
indexing="xy",
)
# Normalize the x axis coordinates.
transformed_i = (i - width * 0.5) / focal
# Normalize the y axis coordinates.
transformed_j = (j - height * 0.5) / focal
# Create the direction unit vectors.
directions = tf.stack([transformed_i, -transformed_j, -tf.ones_like(i)], axis=-1)
# Get the camera matrix.
camera_matrix = pose[:3, :3]
height_width_focal = pose[:3, -1]
# Get origins and directions for the rays.
transformed_dirs = directions[..., None, :]
camera_dirs = transformed_dirs * camera_matrix
ray_directions = tf.reduce_sum(camera_dirs, axis=-1)
ray_origins = tf.broadcast_to(height_width_focal, tf.shape(ray_directions))
# Return the origins and directions.
return (ray_origins, ray_directions)
def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
"""Renders the rays and flattens it.
Args:
ray_origins: The origin points for rays.
ray_directions: The direction unit vectors for the rays.
near: The near bound of the volumetric scene.
far: The far bound of the volumetric scene.
num_samples: Number of sample points in a ray.
rand: Choice for randomising the sampling strategy.
Returns:
Tuple of flattened rays and sample points on each rays.
"""
# Compute 3D query points.
# Equation: r(t) = o+td -> Building the "t" here.
t_vals = tf.linspace(near, far, num_samples)
if rand:
# Inject uniform noise into sample space to make the sampling
# continuous.
shape = list(ray_origins.shape[:-1]) + [num_samples]
noise = tf.random.uniform(shape=shape) * (far - near) / num_samples
t_vals = t_vals + noise
# Equation: r(t) = o + td -> Building the "r" here.
rays = ray_origins[..., None, :] + (
ray_directions[..., None, :] * t_vals[..., None]
)
rays_flat = tf.reshape(rays, [-1, 3])
rays_flat = encode_position(rays_flat)
return (rays_flat, t_vals)
def map_fn(pose):
"""Maps individual pose to flattened rays and sample points.
Args:
pose: The pose matrix of the camera.
Returns:
Tuple of flattened rays and sample points corresponding to the
camera pose.
"""
(ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
(rays_flat, t_vals) = render_flat_rays(
ray_origins=ray_origins,
ray_directions=ray_directions,
near=2.0,
far=6.0,
num_samples=NUM_SAMPLES,
rand=True,
)
return (rays_flat, t_vals)
def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
"""Generates the RGB image and depth map from model prediction.
Args:
model: The MLP model that is trained to predict the rgb and
volume density of the volumetric scene.
rays_flat: The flattened rays that serve as the input to
the NeRF model.
t_vals: The sample points for the rays.
rand: Choice to randomise the sampling strategy.
train: Whether the model is in the training or testing phase.
Returns:
Tuple of rgb image and depth map.
"""
# Get the predictions from the nerf model and reshape it.
if train:
predictions = model(rays_flat)
else:
predictions = model.predict(rays_flat)
predictions = tf.reshape(predictions, shape=(BATCH_SIZE, H, W, NUM_SAMPLES, 4))
# Slice the predictions into rgb and sigma.
rgb = tf.sigmoid(predictions[..., :-1])
sigma_a = tf.nn.relu(predictions[..., -1])
# Get the distance of adjacent intervals.
delta = t_vals[..., 1:] - t_vals[..., :-1]
# delta shape = (num_samples)
if rand:
delta = tf.concat(
[delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
)
alpha = 1.0 - tf.exp(-sigma_a * delta)
else:
delta = tf.concat(
[delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, 1))], axis=-1
)
alpha = 1.0 - tf.exp(-sigma_a * delta[:, None, None, :])
# Get transmittance.
exp_term = 1.0 - alpha
epsilon = 1e-10
transmittance = tf.math.cumprod(exp_term + epsilon, axis=-1, exclusive=True)
weights = alpha * transmittance
rgb = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
if rand:
depth_map = tf.reduce_sum(weights * t_vals, axis=-1)
else:
depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
return (rgb, depth_map)