Spaces:
Build error
Build error
File size: 2,329 Bytes
0d35ba8 19677a1 1a30119 19677a1 b5cab38 19677a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
from functools import partial
import jax
from jax import random
import numpy as np
from PIL import Image
from jaxnerf.nerf import clip_utils
from jaxnerf.nerf import utils
from demo.src.config import NerfConfig
from demo.src.models import init_model
model, _ = init_model()
def render_predict_from_pose(state, theta, phi, radius):
rng = random.PRNGKey(0)
partial_render_fn = partial(render_pfn, state.optimizer.target)
rays = _render_rays_from_pose(theta, phi, radius)
pred_color, pred_disp, _ = utils.render_image(
partial_render_fn, rays,
rng, False, chunk=NerfConfig.CHUNK)
return pred_color, pred_disp
def predict_to_image(pred_out) -> Image:
image_arr = np.array(np.clip(pred_out, 0., 1.) * 255.).astype(np.uint8)
return Image.fromarray(image_arr)
def _render_rays_from_pose(theta, phi, radius):
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
rays = _camtoworld_matrix_to_rays(camtoworld)
return rays
def _camtoworld_matrix_to_rays(camtoworld):
""" render one instance of rays given a camera to world matrix (4, 4) """
pixel_center = 0.
w, h = NerfConfig.W, NerfConfig.H
focal, downsample = NerfConfig.FOCAL, NerfConfig.DOWNSAMPLE
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
np.arange(0, w, downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
np.arange(0, h, downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
indexing="xy")
camera_dirs = np.stack([(x - w * 0.5) / focal,
-(y - h * 0.5) / focal,
-np.ones_like(x)],
axis=-1)
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
def _render_fn(variables, key_0, key_1, rays):
return jax.lax.all_gather(model.apply(
variables, key_0, key_1, rays, False),
axis_name="batch")
render_pfn = jax.pmap(_render_fn, in_axes=(None, None, None, 0),
donate_argnums=3, axis_name="batch")
|