Spaces:
Build error
Build error
# coding=utf-8 | |
# Copyright 2021 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Lint as: python3 | |
"""Different datasets implementation plus a general port for all the datasets.""" | |
INTERNAL = False # pylint: disable=g-statement-before-imports | |
import json | |
import os | |
from os import path | |
import queue | |
import threading | |
# if not INTERNAL: | |
# import cv2 # pylint: disable=g-import-not-at-top | |
import jax | |
import numpy as np | |
from PIL import Image | |
from jaxnerf.nerf import utils | |
from jaxnerf.nerf import clip_utils | |
def get_dataset(split, args, clip_model = None): | |
return dataset_dict[args.dataset](split, args, clip_model) | |
def convert_to_ndc(origins, directions, focal, w, h, near=1.): | |
"""Convert a set of rays to NDC coordinates.""" | |
# Shift ray origins to near plane | |
t = -(near + origins[..., 2]) / directions[..., 2] | |
origins = origins + t[..., None] * directions | |
dx, dy, dz = tuple(np.moveaxis(directions, -1, 0)) | |
ox, oy, oz = tuple(np.moveaxis(origins, -1, 0)) | |
# Projection | |
o0 = -((2 * focal) / w) * (ox / oz) | |
o1 = -((2 * focal) / h) * (oy / oz) | |
o2 = 1 + 2 * near / oz | |
d0 = -((2 * focal) / w) * (dx / dz - ox / oz) | |
d1 = -((2 * focal) / h) * (dy / dz - oy / oz) | |
d2 = -2 * near / oz | |
origins = np.stack([o0, o1, o2], -1) | |
directions = np.stack([d0, d1, d2], -1) | |
return origins, directions | |
class Dataset(threading.Thread): | |
"""Dataset Base Class.""" | |
def __init__(self, split, flags, clip_model): | |
super(Dataset, self).__init__() | |
self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches. | |
self.daemon = True | |
self.use_pixel_centers = flags.use_pixel_centers | |
self.split = split | |
if split == "train": | |
self._train_init(flags, clip_model) | |
elif split == "test": | |
self._test_init(flags) | |
else: | |
raise ValueError( | |
"the split argument should be either \"train\" or \"test\", set" | |
"to {} here.".format(split)) | |
self.batch_size = flags.batch_size // jax.process_count() | |
self.batching = flags.batching | |
self.render_path = flags.render_path | |
self.far = flags.far | |
self.near = flags.near | |
self.max_steps = flags.max_steps | |
self.sc_loss_factor = flags.sc_loss_factor | |
self.start() | |
def __iter__(self): | |
return self | |
def __next__(self): | |
"""Get the next training batch or test example. | |
Returns: | |
batch: dict, has "pixels" and "rays". | |
""" | |
x = self.queue.get() | |
if self.split == "train": | |
return utils.shard(x) | |
else: | |
return utils.to_device(x) | |
def peek(self): | |
"""Peek at the next training batch or test example without dequeuing it. | |
Returns: | |
batch: dict, has "pixels" and "rays". | |
""" | |
x = self.queue.queue[0].copy() # Make a copy of the front of the queue. | |
if self.split == "train": | |
return utils.shard(x) | |
else: | |
return utils.to_device(x) | |
def run(self): | |
if self.split == "train": | |
next_func = self._next_train | |
else: | |
next_func = self._next_test | |
while True: | |
self.queue.put(next_func()) | |
def size(self): | |
return self.n_examples | |
def _train_init(self, flags, clip_model): | |
"""Initialize training.""" | |
self._load_renderings(flags, clip_model) | |
self._generate_rays() | |
if flags.batching == "all_images": | |
# flatten the ray and image dimension together. | |
self.images = self.images.reshape([-1, 3]) | |
self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]), | |
self.rays) | |
elif flags.batching == "single_image": | |
self.images = self.images.reshape([-1, self.resolution, 3]) | |
self.rays = utils.namedtuple_map( | |
lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays) | |
else: | |
raise NotImplementedError( | |
f"{flags.batching} batching strategy is not implemented.") | |
def _test_init(self, flags): | |
self._load_renderings(flags, clip_model = None) | |
self._generate_rays() | |
self.it = 0 | |
def _next_train(self): | |
"""Sample next training batch.""" | |
if self.batching == "all_images": | |
ray_indices = np.random.randint(0, self.rays[0].shape[0], | |
(self.batch_size,)) | |
batch_pixels = self.images[ray_indices] | |
batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays) | |
raise NotImplementedError("image_index not implemented for batching=all_images") | |
elif self.batching == "single_image": | |
image_index = np.random.randint(0, self.n_examples, ()) | |
ray_indices = np.random.randint(0, self.rays[0][0].shape[0], | |
(self.batch_size,)) | |
batch_pixels = self.images[image_index][ray_indices] | |
batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices], | |
self.rays) | |
else: | |
raise NotImplementedError( | |
f"{self.batching} batching strategy is not implemented.") | |
return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index} | |
def _next_test(self): | |
"""Sample next test example.""" | |
idx = self.it | |
self.it = (self.it + 1) % self.n_examples | |
if self.render_path: | |
return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)} | |
else: | |
return {"pixels": self.images[idx], | |
"rays": utils.namedtuple_map(lambda r: r[idx], self.rays), | |
"image_index": idx} | |
# TODO(bydeng): Swap this function with a more flexible camera model. | |
def _generate_rays(self): | |
"""Generating rays for all images.""" | |
pixel_center = 0.5 if self.use_pixel_centers else 0.0 | |
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking | |
np.arange(self.w, dtype=np.float32) + pixel_center, # X-Axis (columns) | |
np.arange(self.h, dtype=np.float32) + pixel_center, # Y-Axis (rows) | |
indexing="xy") | |
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal, | |
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)], | |
axis=-1) | |
directions = ((camera_dirs[None, ..., None, :] * | |
self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1)) | |
origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1], | |
directions.shape) | |
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) | |
self.rays = utils.Rays( | |
origins=origins, directions=directions, viewdirs=viewdirs) | |
def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1): | |
""" render one instance of rays given a camera to world matrix (4, 4) """ | |
pixel_center = 0.5 if self.use_pixel_centers else 0.0 | |
# TODO @Alex: apply mesh downsampling here | |
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking | |
np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, # X-Axis (columns) | |
np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows) | |
indexing="xy") | |
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal, | |
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)], | |
axis=-1) | |
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1) | |
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape) | |
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) | |
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs) | |
class Blender(Dataset): | |
"""Blender Dataset.""" | |
def _load_renderings(self, flags, clip_model = None): | |
"""Load images from disk.""" | |
if flags.render_path: | |
raise ValueError("render_path cannot be used for the blender dataset.") | |
cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot) | |
# load in CLIP precomputed image features | |
self.images = np.stack(images, axis=0) | |
if flags.white_bkgd: | |
self.images = (self.images[..., :3] * self.images[..., -1:] + | |
(1. - self.images[..., -1:])) | |
else: | |
self.images = self.images[..., :3] | |
self.h, self.w = self.images.shape[1:3] | |
self.resolution = self.h * self.w | |
self.camtoworlds = np.stack(cams, axis=0) | |
camera_angle_x = float(meta["camera_angle_x"]) | |
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x) | |
self.n_examples = self.images.shape[0] | |
if flags.use_semantic_loss and clip_model is not None: | |
embs = [] | |
for img in self.images: | |
img = np.expand_dims(np.transpose(img,[2,0,1]), 0) | |
embs.append(clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img))) | |
self.embeddings = np.concatenate(embs, 0) | |
self.image_idx = np.arange(self.images.shape[0]) | |
np.random.shuffle(self.image_idx) | |
self.image_idx = self.image_idx.tolist() | |
# self.embeddings = utils.read_pickle(flags.precompute_pkl_path) | |
# self.precompute_pkl_path = flags.precompute_pkl_path | |
def load_files(data_dir, split, factor, few_shot): | |
with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: | |
meta = json.load(fp) | |
images = [] | |
cams = [] | |
frames = np.arange(len(meta["frames"])) | |
if few_shot > 0 and split == 'train': | |
np.random.shuffle(frames) | |
frames = frames[:few_shot] | |
for i in frames: | |
frame = meta["frames"][i] | |
fname = os.path.join(data_dir, frame["file_path"] + ".png") | |
with utils.open_file(fname, "rb") as imgin: | |
image = np.array(Image.open(imgin)).astype(np.float32) / 255. | |
if factor == 2: | |
[halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]] | |
image = cv2.resize(image, (halfres_w, halfres_h), | |
interpolation=cv2.INTER_AREA) | |
elif factor == 4: | |
[halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]] | |
image = cv2.resize(image, (halfres_w, halfres_h), | |
interpolation=cv2.INTER_AREA) | |
elif factor > 0: | |
raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} " | |
"set.".format(factor)) | |
cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) | |
images.append(image) | |
return cams, images, meta | |
def _next_train(self): | |
batch_dict = super(Blender, self)._next_train() | |
if self.batching == "single_image": | |
image_index = batch_dict.pop("image_index") | |
# target image for CLIP | |
''' | |
batch_dict["embedding"] = self.embeddings[image_index] | |
# source rays for CLIP (for constructing source image later) | |
src_seed = int(np.random.randint(0, self.max_steps, ())) | |
src_rng = jax.random.PRNGKey(src_seed) | |
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far))) | |
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16) | |
random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3])) | |
batch_dict["random_rays"] = random_rays | |
''' | |
else: | |
raise NotImplementedError | |
return batch_dict | |
def get_clip_data(self): | |
if len(self.image_idx) == 0: | |
self.image_idx = np.arange(self.images.shape[0]) | |
np.random.shuffle(self.image_idx) | |
self.image_idx = self.image_idx.tolist() | |
image_index = self.image_idx.pop() | |
batch_dict = {} | |
batch_dict["embedding"] = self.embeddings[image_index] | |
# source rays for CLIP (for constructing source image later) | |
src_seed = int(np.random.randint(0, self.max_steps, ())) | |
src_rng = jax.random.PRNGKey(src_seed) | |
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far))) | |
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16) | |
random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3])) | |
batch_dict["random_rays"] = random_rays | |
return batch_dict | |
class LLFF(Dataset): | |
"""LLFF Dataset.""" | |
def _load_renderings(self, flags): | |
"""Load images from disk.""" | |
# Load images. | |
imgdir_suffix = "" | |
if flags.factor > 0: | |
imgdir_suffix = "_{}".format(flags.factor) | |
factor = flags.factor | |
else: | |
factor = 1 | |
imgdir = path.join(flags.data_dir, "images" + imgdir_suffix) | |
if not utils.file_exists(imgdir): | |
raise ValueError("Image folder {} doesn't exist.".format(imgdir)) | |
imgfiles = [ | |
path.join(imgdir, f) | |
for f in sorted(utils.listdir(imgdir)) | |
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") | |
] | |
images = [] | |
for imgfile in imgfiles: | |
with utils.open_file(imgfile, "rb") as imgin: | |
image = np.array(Image.open(imgin), dtype=np.float32) / 255. | |
images.append(image) | |
images = np.stack(images, axis=-1) | |
# Load poses and bds. | |
with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"), | |
"rb") as fp: | |
poses_arr = np.load(fp) | |
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) | |
bds = poses_arr[:, -2:].transpose([1, 0]) | |
if poses.shape[-1] != images.shape[-1]: | |
raise RuntimeError("Mismatch between imgs {} and poses {}".format( | |
images.shape[-1], poses.shape[-1])) | |
# Update poses according to downsampling. | |
poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) | |
poses[2, 4, :] = poses[2, 4, :] * 1. / factor | |
# Correct rotation matrix ordering and move variable dim to axis 0. | |
poses = np.concatenate( | |
[poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) | |
poses = np.moveaxis(poses, -1, 0).astype(np.float32) | |
images = np.moveaxis(images, -1, 0) | |
bds = np.moveaxis(bds, -1, 0).astype(np.float32) | |
# Rescale according to a default bd factor. | |
scale = 1. / (bds.min() * .75) | |
poses[:, :3, 3] *= scale | |
bds *= scale | |
# Recenter poses. | |
poses = self._recenter_poses(poses) | |
# Generate a spiral/spherical ray path for rendering videos. | |
if flags.spherify: | |
poses = self._generate_spherical_poses(poses, bds) | |
self.spherify = True | |
else: | |
self.spherify = False | |
if not flags.spherify and self.split == "test": | |
self._generate_spiral_poses(poses, bds) | |
# Select the split. | |
i_test = np.arange(images.shape[0])[::flags.llffhold] | |
i_train = np.array( | |
[i for i in np.arange(int(images.shape[0])) if i not in i_test]) | |
if self.split == "train": | |
indices = i_train | |
else: | |
indices = i_test | |
images = images[indices] | |
poses = poses[indices] | |
self.images = images | |
self.camtoworlds = poses[:, :3, :4] | |
self.focal = poses[0, -1, -1] | |
self.h, self.w = images.shape[1:3] | |
self.resolution = self.h * self.w | |
if flags.render_path: | |
self.n_examples = self.render_poses.shape[0] | |
else: | |
self.n_examples = images.shape[0] | |
def _generate_rays(self): | |
"""Generate normalized device coordinate rays for llff.""" | |
if self.split == "test": | |
n_render_poses = self.render_poses.shape[0] | |
self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds], | |
axis=0) | |
super()._generate_rays() | |
if not self.spherify: | |
ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins, | |
self.rays.directions, | |
self.focal, self.w, self.h) | |
self.rays = utils.Rays( | |
origins=ndc_origins, | |
directions=ndc_directions, | |
viewdirs=self.rays.viewdirs) | |
# Split poses from the dataset and generated poses | |
if self.split == "test": | |
self.camtoworlds = self.camtoworlds[n_render_poses:] | |
split = [np.split(r, [n_render_poses], 0) for r in self.rays] | |
split0, split1 = zip(*split) | |
self.render_rays = utils.Rays(*split0) | |
self.rays = utils.Rays(*split1) | |
def _recenter_poses(self, poses): | |
"""Recenter poses according to the original NeRF code.""" | |
poses_ = poses.copy() | |
bottom = np.reshape([0, 0, 0, 1.], [1, 4]) | |
c2w = self._poses_avg(poses) | |
c2w = np.concatenate([c2w[:3, :4], bottom], -2) | |
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) | |
poses = np.concatenate([poses[:, :3, :4], bottom], -2) | |
poses = np.linalg.inv(c2w) @ poses | |
poses_[:, :3, :4] = poses[:, :3, :4] | |
poses = poses_ | |
return poses | |
def _poses_avg(self, poses): | |
"""Average poses according to the original NeRF code.""" | |
hwf = poses[0, :3, -1:] | |
center = poses[:, :3, 3].mean(0) | |
vec2 = self._normalize(poses[:, :3, 2].sum(0)) | |
up = poses[:, :3, 1].sum(0) | |
c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1) | |
return c2w | |
def _viewmatrix(self, z, up, pos): | |
"""Construct lookat view matrix.""" | |
vec2 = self._normalize(z) | |
vec1_avg = up | |
vec0 = self._normalize(np.cross(vec1_avg, vec2)) | |
vec1 = self._normalize(np.cross(vec2, vec0)) | |
m = np.stack([vec0, vec1, vec2, pos], 1) | |
return m | |
def _normalize(self, x): | |
"""Normalization helper function.""" | |
return x / np.linalg.norm(x) | |
def _generate_spiral_poses(self, poses, bds): | |
"""Generate a spiral path for rendering.""" | |
c2w = self._poses_avg(poses) | |
# Get average pose. | |
up = self._normalize(poses[:, :3, 1].sum(0)) | |
# Find a reasonable "focus depth" for this dataset. | |
close_depth, inf_depth = bds.min() * .9, bds.max() * 5. | |
dt = .75 | |
mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) | |
focal = mean_dz | |
# Get radii for spiral path. | |
tt = poses[:, :3, 3] | |
rads = np.percentile(np.abs(tt), 90, 0) | |
c2w_path = c2w | |
n_views = 120 | |
n_rots = 2 | |
# Generate poses for spiral path. | |
render_poses = [] | |
rads = np.array(list(rads) + [1.]) | |
hwf = c2w_path[:, 4:5] | |
zrate = .5 | |
for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]: | |
c = np.dot(c2w[:3, :4], (np.array( | |
[np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)) | |
z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) | |
render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1)) | |
self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4] | |
def _generate_spherical_poses(self, poses, bds): | |
"""Generate a 360 degree spherical path for rendering.""" | |
# pylint: disable=g-long-lambda | |
p34_to_44 = lambda p: np.concatenate([ | |
p, | |
np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1]) | |
], 1) | |
rays_d = poses[:, :3, 2:3] | |
rays_o = poses[:, :3, 3:4] | |
def min_line_dist(rays_o, rays_d): | |
a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) | |
b_i = -a_i @ rays_o | |
pt_mindist = np.squeeze(-np.linalg.inv( | |
(np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0)) | |
return pt_mindist | |
pt_mindist = min_line_dist(rays_o, rays_d) | |
center = pt_mindist | |
up = (poses[:, :3, 3] - center).mean(0) | |
vec0 = self._normalize(up) | |
vec1 = self._normalize(np.cross([.1, .2, .3], vec0)) | |
vec2 = self._normalize(np.cross(vec0, vec1)) | |
pos = center | |
c2w = np.stack([vec1, vec2, vec0, pos], 1) | |
poses_reset = ( | |
np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])) | |
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) | |
sc = 1. / rad | |
poses_reset[:, :3, 3] *= sc | |
bds *= sc | |
rad *= sc | |
centroid = np.mean(poses_reset[:, :3, 3], 0) | |
zh = centroid[2] | |
radcircle = np.sqrt(rad ** 2 - zh ** 2) | |
new_poses = [] | |
for th in np.linspace(0., 2. * np.pi, 120): | |
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) | |
up = np.array([0, 0, -1.]) | |
vec2 = self._normalize(camorigin) | |
vec0 = self._normalize(np.cross(vec2, up)) | |
vec1 = self._normalize(np.cross(vec2, vec0)) | |
pos = camorigin | |
p = np.stack([vec0, vec1, vec2, pos], 1) | |
new_poses.append(p) | |
new_poses = np.stack(new_poses, 0) | |
new_poses = np.concatenate([ | |
new_poses, | |
np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape) | |
], -1) | |
poses_reset = np.concatenate([ | |
poses_reset[:, :3, :4], | |
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape) | |
], -1) | |
if self.split == "test": | |
self.render_poses = new_poses[:, :3, :4] | |
return poses_reset | |
dataset_dict = {"blender": Blender, | |
"llff": LLFF} | |