alexlau's picture
first deploy demo
19677a1
raw
history blame
23.3 kB
# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Different datasets implementation plus a general port for all the datasets."""
INTERNAL = False # pylint: disable=g-statement-before-imports
import json
import os
from os import path
import queue
import threading
if not INTERNAL:
import cv2 # pylint: disable=g-import-not-at-top
import jax
import numpy as np
from PIL import Image
from jaxnerf.nerf import utils
from jaxnerf.nerf import clip_utils
def get_dataset(split, args, clip_model = None):
return dataset_dict[args.dataset](split, args, clip_model)
def convert_to_ndc(origins, directions, focal, w, h, near=1.):
"""Convert a set of rays to NDC coordinates."""
# Shift ray origins to near plane
t = -(near + origins[..., 2]) / directions[..., 2]
origins = origins + t[..., None] * directions
dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))
# Projection
o0 = -((2 * focal) / w) * (ox / oz)
o1 = -((2 * focal) / h) * (oy / oz)
o2 = 1 + 2 * near / oz
d0 = -((2 * focal) / w) * (dx / dz - ox / oz)
d1 = -((2 * focal) / h) * (dy / dz - oy / oz)
d2 = -2 * near / oz
origins = np.stack([o0, o1, o2], -1)
directions = np.stack([d0, d1, d2], -1)
return origins, directions
class Dataset(threading.Thread):
"""Dataset Base Class."""
def __init__(self, split, flags, clip_model):
super(Dataset, self).__init__()
self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches.
self.daemon = True
self.use_pixel_centers = flags.use_pixel_centers
self.split = split
if split == "train":
self._train_init(flags, clip_model)
elif split == "test":
self._test_init(flags)
else:
raise ValueError(
"the split argument should be either \"train\" or \"test\", set"
"to {} here.".format(split))
self.batch_size = flags.batch_size // jax.process_count()
self.batching = flags.batching
self.render_path = flags.render_path
self.far = flags.far
self.near = flags.near
self.max_steps = flags.max_steps
self.sc_loss_factor = flags.sc_loss_factor
self.start()
def __iter__(self):
return self
def __next__(self):
"""Get the next training batch or test example.
Returns:
batch: dict, has "pixels" and "rays".
"""
x = self.queue.get()
if self.split == "train":
return utils.shard(x)
else:
return utils.to_device(x)
def peek(self):
"""Peek at the next training batch or test example without dequeuing it.
Returns:
batch: dict, has "pixels" and "rays".
"""
x = self.queue.queue[0].copy() # Make a copy of the front of the queue.
if self.split == "train":
return utils.shard(x)
else:
return utils.to_device(x)
def run(self):
if self.split == "train":
next_func = self._next_train
else:
next_func = self._next_test
while True:
self.queue.put(next_func())
@property
def size(self):
return self.n_examples
def _train_init(self, flags, clip_model):
"""Initialize training."""
self._load_renderings(flags, clip_model)
self._generate_rays()
if flags.batching == "all_images":
# flatten the ray and image dimension together.
self.images = self.images.reshape([-1, 3])
self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]),
self.rays)
elif flags.batching == "single_image":
self.images = self.images.reshape([-1, self.resolution, 3])
self.rays = utils.namedtuple_map(
lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays)
else:
raise NotImplementedError(
f"{flags.batching} batching strategy is not implemented.")
def _test_init(self, flags):
self._load_renderings(flags, clip_model = None)
self._generate_rays()
self.it = 0
def _next_train(self):
"""Sample next training batch."""
if self.batching == "all_images":
ray_indices = np.random.randint(0, self.rays[0].shape[0],
(self.batch_size,))
batch_pixels = self.images[ray_indices]
batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays)
raise NotImplementedError("image_index not implemented for batching=all_images")
elif self.batching == "single_image":
image_index = np.random.randint(0, self.n_examples, ())
ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
(self.batch_size,))
batch_pixels = self.images[image_index][ray_indices]
batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices],
self.rays)
else:
raise NotImplementedError(
f"{self.batching} batching strategy is not implemented.")
return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index}
def _next_test(self):
"""Sample next test example."""
idx = self.it
self.it = (self.it + 1) % self.n_examples
if self.render_path:
return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)}
else:
return {"pixels": self.images[idx],
"rays": utils.namedtuple_map(lambda r: r[idx], self.rays),
"image_index": idx}
# TODO(bydeng): Swap this function with a more flexible camera model.
def _generate_rays(self):
"""Generating rays for all images."""
pixel_center = 0.5 if self.use_pixel_centers else 0.0
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
np.arange(self.w, dtype=np.float32) + pixel_center, # X-Axis (columns)
np.arange(self.h, dtype=np.float32) + pixel_center, # Y-Axis (rows)
indexing="xy")
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
axis=-1)
directions = ((camera_dirs[None, ..., None, :] *
self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1],
directions.shape)
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
self.rays = utils.Rays(
origins=origins, directions=directions, viewdirs=viewdirs)
def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1):
""" render one instance of rays given a camera to world matrix (4, 4) """
pixel_center = 0.5 if self.use_pixel_centers else 0.0
# TODO @Alex: apply mesh downsampling here
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
indexing="xy")
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
axis=-1)
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
class Blender(Dataset):
"""Blender Dataset."""
def _load_renderings(self, flags, clip_model = None):
"""Load images from disk."""
if flags.render_path:
raise ValueError("render_path cannot be used for the blender dataset.")
cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot)
# load in CLIP precomputed image features
self.images = np.stack(images, axis=0)
if flags.white_bkgd:
self.images = (self.images[..., :3] * self.images[..., -1:] +
(1. - self.images[..., -1:]))
else:
self.images = self.images[..., :3]
self.h, self.w = self.images.shape[1:3]
self.resolution = self.h * self.w
self.camtoworlds = np.stack(cams, axis=0)
camera_angle_x = float(meta["camera_angle_x"])
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
self.n_examples = self.images.shape[0]
if flags.use_semantic_loss and clip_model is not None:
embs = []
for img in self.images:
img = np.expand_dims(np.transpose(img,[2,0,1]), 0)
embs.append(clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img)))
self.embeddings = np.concatenate(embs, 0)
self.image_idx = np.arange(self.images.shape[0])
np.random.shuffle(self.image_idx)
self.image_idx = self.image_idx.tolist()
# self.embeddings = utils.read_pickle(flags.precompute_pkl_path)
# self.precompute_pkl_path = flags.precompute_pkl_path
@staticmethod
def load_files(data_dir, split, factor, few_shot):
with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
meta = json.load(fp)
images = []
cams = []
frames = np.arange(len(meta["frames"]))
if few_shot > 0 and split == 'train':
np.random.shuffle(frames)
frames = frames[:few_shot]
for i in frames:
frame = meta["frames"][i]
fname = os.path.join(data_dir, frame["file_path"] + ".png")
with utils.open_file(fname, "rb") as imgin:
image = np.array(Image.open(imgin)).astype(np.float32) / 255.
if factor == 2:
[halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]]
image = cv2.resize(image, (halfres_w, halfres_h),
interpolation=cv2.INTER_AREA)
elif factor == 4:
[halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]]
image = cv2.resize(image, (halfres_w, halfres_h),
interpolation=cv2.INTER_AREA)
elif factor > 0:
raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} "
"set.".format(factor))
cams.append(np.array(frame["transform_matrix"], dtype=np.float32))
images.append(image)
return cams, images, meta
def _next_train(self):
batch_dict = super(Blender, self)._next_train()
if self.batching == "single_image":
image_index = batch_dict.pop("image_index")
# target image for CLIP
'''
batch_dict["embedding"] = self.embeddings[image_index]
# source rays for CLIP (for constructing source image later)
src_seed = int(np.random.randint(0, self.max_steps, ()))
src_rng = jax.random.PRNGKey(src_seed)
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16)
random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3]))
batch_dict["random_rays"] = random_rays
'''
else:
raise NotImplementedError
return batch_dict
def get_clip_data(self):
if len(self.image_idx) == 0:
self.image_idx = np.arange(self.images.shape[0])
np.random.shuffle(self.image_idx)
self.image_idx = self.image_idx.tolist()
image_index = self.image_idx.pop()
batch_dict = {}
batch_dict["embedding"] = self.embeddings[image_index]
# source rays for CLIP (for constructing source image later)
src_seed = int(np.random.randint(0, self.max_steps, ()))
src_rng = jax.random.PRNGKey(src_seed)
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16)
random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3]))
batch_dict["random_rays"] = random_rays
return batch_dict
class LLFF(Dataset):
"""LLFF Dataset."""
def _load_renderings(self, flags):
"""Load images from disk."""
# Load images.
imgdir_suffix = ""
if flags.factor > 0:
imgdir_suffix = "_{}".format(flags.factor)
factor = flags.factor
else:
factor = 1
imgdir = path.join(flags.data_dir, "images" + imgdir_suffix)
if not utils.file_exists(imgdir):
raise ValueError("Image folder {} doesn't exist.".format(imgdir))
imgfiles = [
path.join(imgdir, f)
for f in sorted(utils.listdir(imgdir))
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
]
images = []
for imgfile in imgfiles:
with utils.open_file(imgfile, "rb") as imgin:
image = np.array(Image.open(imgin), dtype=np.float32) / 255.
images.append(image)
images = np.stack(images, axis=-1)
# Load poses and bds.
with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"),
"rb") as fp:
poses_arr = np.load(fp)
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
bds = poses_arr[:, -2:].transpose([1, 0])
if poses.shape[-1] != images.shape[-1]:
raise RuntimeError("Mismatch between imgs {} and poses {}".format(
images.shape[-1], poses.shape[-1]))
# Update poses according to downsampling.
poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
poses[2, 4, :] = poses[2, 4, :] * 1. / factor
# Correct rotation matrix ordering and move variable dim to axis 0.
poses = np.concatenate(
[poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
images = np.moveaxis(images, -1, 0)
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
# Rescale according to a default bd factor.
scale = 1. / (bds.min() * .75)
poses[:, :3, 3] *= scale
bds *= scale
# Recenter poses.
poses = self._recenter_poses(poses)
# Generate a spiral/spherical ray path for rendering videos.
if flags.spherify:
poses = self._generate_spherical_poses(poses, bds)
self.spherify = True
else:
self.spherify = False
if not flags.spherify and self.split == "test":
self._generate_spiral_poses(poses, bds)
# Select the split.
i_test = np.arange(images.shape[0])[::flags.llffhold]
i_train = np.array(
[i for i in np.arange(int(images.shape[0])) if i not in i_test])
if self.split == "train":
indices = i_train
else:
indices = i_test
images = images[indices]
poses = poses[indices]
self.images = images
self.camtoworlds = poses[:, :3, :4]
self.focal = poses[0, -1, -1]
self.h, self.w = images.shape[1:3]
self.resolution = self.h * self.w
if flags.render_path:
self.n_examples = self.render_poses.shape[0]
else:
self.n_examples = images.shape[0]
def _generate_rays(self):
"""Generate normalized device coordinate rays for llff."""
if self.split == "test":
n_render_poses = self.render_poses.shape[0]
self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds],
axis=0)
super()._generate_rays()
if not self.spherify:
ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins,
self.rays.directions,
self.focal, self.w, self.h)
self.rays = utils.Rays(
origins=ndc_origins,
directions=ndc_directions,
viewdirs=self.rays.viewdirs)
# Split poses from the dataset and generated poses
if self.split == "test":
self.camtoworlds = self.camtoworlds[n_render_poses:]
split = [np.split(r, [n_render_poses], 0) for r in self.rays]
split0, split1 = zip(*split)
self.render_rays = utils.Rays(*split0)
self.rays = utils.Rays(*split1)
def _recenter_poses(self, poses):
"""Recenter poses according to the original NeRF code."""
poses_ = poses.copy()
bottom = np.reshape([0, 0, 0, 1.], [1, 4])
c2w = self._poses_avg(poses)
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
poses = np.linalg.inv(c2w) @ poses
poses_[:, :3, :4] = poses[:, :3, :4]
poses = poses_
return poses
def _poses_avg(self, poses):
"""Average poses according to the original NeRF code."""
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
vec2 = self._normalize(poses[:, :3, 2].sum(0))
up = poses[:, :3, 1].sum(0)
c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1)
return c2w
def _viewmatrix(self, z, up, pos):
"""Construct lookat view matrix."""
vec2 = self._normalize(z)
vec1_avg = up
vec0 = self._normalize(np.cross(vec1_avg, vec2))
vec1 = self._normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def _normalize(self, x):
"""Normalization helper function."""
return x / np.linalg.norm(x)
def _generate_spiral_poses(self, poses, bds):
"""Generate a spiral path for rendering."""
c2w = self._poses_avg(poses)
# Get average pose.
up = self._normalize(poses[:, :3, 1].sum(0))
# Find a reasonable "focus depth" for this dataset.
close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
dt = .75
mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
focal = mean_dz
# Get radii for spiral path.
tt = poses[:, :3, 3]
rads = np.percentile(np.abs(tt), 90, 0)
c2w_path = c2w
n_views = 120
n_rots = 2
# Generate poses for spiral path.
render_poses = []
rads = np.array(list(rads) + [1.])
hwf = c2w_path[:, 4:5]
zrate = .5
for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]:
c = np.dot(c2w[:3, :4], (np.array(
[np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads))
z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1))
self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4]
def _generate_spherical_poses(self, poses, bds):
"""Generate a 360 degree spherical path for rendering."""
# pylint: disable=g-long-lambda
p34_to_44 = lambda p: np.concatenate([
p,
np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])
], 1)
rays_d = poses[:, :3, 2:3]
rays_o = poses[:, :3, 3:4]
def min_line_dist(rays_o, rays_d):
a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
b_i = -a_i @ rays_o
pt_mindist = np.squeeze(-np.linalg.inv(
(np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0))
return pt_mindist
pt_mindist = min_line_dist(rays_o, rays_d)
center = pt_mindist
up = (poses[:, :3, 3] - center).mean(0)
vec0 = self._normalize(up)
vec1 = self._normalize(np.cross([.1, .2, .3], vec0))
vec2 = self._normalize(np.cross(vec0, vec1))
pos = center
c2w = np.stack([vec1, vec2, vec0, pos], 1)
poses_reset = (
np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]))
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
sc = 1. / rad
poses_reset[:, :3, 3] *= sc
bds *= sc
rad *= sc
centroid = np.mean(poses_reset[:, :3, 3], 0)
zh = centroid[2]
radcircle = np.sqrt(rad ** 2 - zh ** 2)
new_poses = []
for th in np.linspace(0., 2. * np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0, 0, -1.])
vec2 = self._normalize(camorigin)
vec0 = self._normalize(np.cross(vec2, up))
vec1 = self._normalize(np.cross(vec2, vec0))
pos = camorigin
p = np.stack([vec0, vec1, vec2, pos], 1)
new_poses.append(p)
new_poses = np.stack(new_poses, 0)
new_poses = np.concatenate([
new_poses,
np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)
], -1)
poses_reset = np.concatenate([
poses_reset[:, :3, :4],
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)
], -1)
if self.split == "test":
self.render_poses = new_poses[:, :3, :4]
return poses_reset
dataset_dict = {"blender": Blender,
"llff": LLFF}