Spaces:
Build error
Build error
import os | |
from functools import partial | |
import jax | |
from jax import random | |
import numpy as np | |
from PIL import Image | |
from jaxnerf.nerf import clip_utils | |
from jaxnerf.nerf import utils | |
from demo.src.config import NerfConfig | |
from demo.src.models import init_model | |
model, _ = init_model() | |
def render_predict_from_pose(state, theta, phi, radius): | |
rng = random.PRNGKey(0) | |
partial_render_fn = partial(render_pfn, state.optimizer.target) | |
rays = _render_rays_from_pose(theta, phi, radius) | |
pred_color, pred_disp, _ = utils.render_image( | |
partial_render_fn, rays, | |
rng, False, chunk=NerfConfig.CHUNK) | |
return pred_color, pred_disp | |
def predict_to_image(pred_out) -> Image: | |
image_arr = np.array(np.clip(pred_out, 0., 1.) * 255.).astype(np.uint8) | |
return Image.fromarray(image_arr) | |
def _render_rays_from_pose(theta, phi, radius): | |
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi)) | |
rays = _camtoworld_matrix_to_rays(camtoworld) | |
return rays | |
def _camtoworld_matrix_to_rays(camtoworld): | |
""" render one instance of rays given a camera to world matrix (4, 4) """ | |
pixel_center = 0. | |
w, h = NerfConfig.W, NerfConfig.H | |
focal, downsample = NerfConfig.FOCAL, NerfConfig.DOWNSAMPLE | |
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking | |
np.arange(0, w, downsample, dtype=np.float32) + pixel_center, # X-Axis (columns) | |
np.arange(0, h, downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows) | |
indexing="xy") | |
camera_dirs = np.stack([(x - w * 0.5) / focal, | |
-(y - h * 0.5) / focal, | |
-np.ones_like(x)], | |
axis=-1) | |
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1) | |
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape) | |
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) | |
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs) | |
def _render_fn(variables, key_0, key_1, rays): | |
return jax.lax.all_gather(model.apply( | |
variables, key_0, key_1, rays, False), | |
axis_name="batch") | |
render_pfn = jax.pmap(_render_fn, in_axes=(None, None, None, 0), | |
donate_argnums=3, axis_name="batch") | |