alexlau's picture
set DOWNSAMPLE=4 and then resize a bigger image
1a30119
raw
history blame
2.33 kB
import os
from functools import partial
import jax
from jax import random
import numpy as np
from PIL import Image
from jaxnerf.nerf import clip_utils
from jaxnerf.nerf import utils
from demo.src.config import NerfConfig
from demo.src.models import init_model
model, _ = init_model()
def render_predict_from_pose(state, theta, phi, radius):
rng = random.PRNGKey(0)
partial_render_fn = partial(render_pfn, state.optimizer.target)
rays = _render_rays_from_pose(theta, phi, radius)
pred_color, pred_disp, _ = utils.render_image(
partial_render_fn, rays,
rng, False, chunk=NerfConfig.CHUNK)
return pred_color, pred_disp
def predict_to_image(pred_out) -> Image:
image_arr = np.array(np.clip(pred_out, 0., 1.) * 255.).astype(np.uint8)
return Image.fromarray(image_arr)
def _render_rays_from_pose(theta, phi, radius):
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
rays = _camtoworld_matrix_to_rays(camtoworld)
return rays
def _camtoworld_matrix_to_rays(camtoworld):
""" render one instance of rays given a camera to world matrix (4, 4) """
pixel_center = 0.
w, h = NerfConfig.W, NerfConfig.H
focal, downsample = NerfConfig.FOCAL, NerfConfig.DOWNSAMPLE
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
np.arange(0, w, downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
np.arange(0, h, downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
indexing="xy")
camera_dirs = np.stack([(x - w * 0.5) / focal,
-(y - h * 0.5) / focal,
-np.ones_like(x)],
axis=-1)
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
def _render_fn(variables, key_0, key_1, rays):
return jax.lax.all_gather(model.apply(
variables, key_0, key_1, rays, False),
axis_name="batch")
render_pfn = jax.pmap(_render_fn, in_axes=(None, None, None, 0),
donate_argnums=3, axis_name="batch")