Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import Dataset | |
import glob | |
import numpy as np | |
import os | |
from PIL import Image | |
from torchvision import transforms as T | |
from .ray_utils import * | |
def normalize(v): | |
"""Normalize a vector.""" | |
return v / np.linalg.norm(v) | |
def average_poses(poses): | |
""" | |
Calculate the average pose, which is then used to center all poses | |
using @center_poses. Its computation is as follows: | |
1. Compute the center: the average of pose centers. | |
2. Compute the z axis: the normalized average z axis. | |
3. Compute axis y': the average y axis. | |
4. Compute x' = y' cross product z, then normalize it as the x axis. | |
5. Compute the y axis: z cross product x. | |
Note that at step 3, we cannot directly use y' as y axis since it's | |
not necessarily orthogonal to z axis. We need to pass from x to y. | |
Inputs: | |
poses: (N_images, 3, 4) | |
Outputs: | |
pose_avg: (3, 4) the average pose | |
""" | |
# 1. Compute the center | |
center = poses[..., 3].mean(0) # (3) | |
# 2. Compute the z axis | |
z = normalize(poses[..., 2].mean(0)) # (3) | |
# 3. Compute axis y' (no need to normalize as it's not the final output) | |
y_ = poses[..., 1].mean(0) # (3) | |
# 4. Compute the x axis | |
x = normalize(np.cross(z, y_)) # (3) | |
# 5. Compute the y axis (as z and x are normalized, y is already of norm 1) | |
y = np.cross(x, z) # (3) | |
pose_avg = np.stack([x, y, z, center], 1) # (3, 4) | |
return pose_avg | |
def center_poses(poses, blender2opencv): | |
""" | |
Center the poses so that we can use NDC. | |
See https://github.com/bmild/nerf/issues/34 | |
Inputs: | |
poses: (N_images, 3, 4) | |
Outputs: | |
poses_centered: (N_images, 3, 4) the centered poses | |
pose_avg: (3, 4) the average pose | |
""" | |
poses = poses @ blender2opencv | |
pose_avg = average_poses(poses) # (3, 4) | |
pose_avg_homo = np.eye(4) | |
pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation | |
pose_avg_homo = pose_avg_homo | |
# by simply adding 0, 0, 0, 1 as the last row | |
last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) | |
poses_homo = \ | |
np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate | |
poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) | |
# poses_centered = poses_centered @ blender2opencv | |
poses_centered = poses_centered[:, :3] # (N_images, 3, 4) | |
return poses_centered, pose_avg_homo | |
def viewmatrix(z, up, pos): | |
vec2 = normalize(z) | |
vec1_avg = up | |
vec0 = normalize(np.cross(vec1_avg, vec2)) | |
vec1 = normalize(np.cross(vec2, vec0)) | |
m = np.eye(4) | |
m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) | |
return m | |
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): | |
render_poses = [] | |
rads = np.array(list(rads) + [1.]) | |
for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]: | |
c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) | |
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) | |
render_poses.append(viewmatrix(z, up, c)) | |
return render_poses | |
def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120): | |
# center pose | |
c2w = average_poses(c2ws_all) | |
# Get average pose | |
up = normalize(c2ws_all[:, :3, 1].sum(0)) | |
# Find a reasonable "focus depth" for this dataset | |
dt = 0.75 | |
close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 | |
focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) | |
# Get radii for spiral path | |
zdelta = near_fars.min() * .2 | |
tt = c2ws_all[:, :3, 3] | |
rads = np.percentile(np.abs(tt), 90, 0) * rads_scale | |
render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) | |
return np.stack(render_poses) | |
def get_interpolation_path(c2ws_all, steps=30): | |
# flower | |
# idx0 = 1 | |
# idx1 = 10 | |
# trex | |
# idx0 = 8 | |
# idx1 = 53 | |
# horns | |
idx0 = 18 | |
idx1 = 47 | |
v = np.linspace(0, 1, num=steps) | |
c2w0 = c2ws_all[idx0] | |
c2w1 = c2ws_all[idx1] | |
c2w_ = [] | |
for i in range(steps): | |
c2w_.append(c2w0 * v[i] + c2w1 * (1 - v[i])) | |
return np.stack(c2w_) | |
class LLFFDataset(Dataset): | |
def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8, N_vis=-1): | |
self.root_dir = datadir | |
self.split = split | |
self.hold_every = hold_every | |
self.is_stack = is_stack | |
self.downsample = downsample | |
self.define_transforms() | |
self.blender2opencv = np.eye(4) # np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) | |
self.read_meta() | |
self.white_bg = False | |
# self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])] | |
self.near_far = [0.0, 1.0] | |
self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]) | |
# self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]]) | |
self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3) | |
self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3) | |
def read_meta(self): | |
poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17) | |
self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*'))) | |
# load full resolution image then resize | |
if self.split in ['train', 'test']: | |
assert len(poses_bounds) == len(self.image_paths), \ | |
'Mismatch between number of images and number of poses! Please rerun COLMAP!' | |
poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) | |
self.near_fars = poses_bounds[:, -2:] # (N_images, 2) | |
hwf = poses[:, :, -1] | |
# Step 1: rescale focal length according to training resolution | |
H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images | |
self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)]) | |
self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H] | |
# Step 2: correct poses | |
# Original poses has rotation in form "down right back", change to "right up back" | |
# See https://github.com/bmild/nerf/issues/34 | |
poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) | |
# (N_images, 3, 4) exclude H, W, focal | |
self.poses, self.pose_avg = center_poses(poses, self.blender2opencv) | |
# Step 3: correct scale so that the nearest depth is at a little more than 1.0 | |
# See https://github.com/bmild/nerf/issues/34 | |
near_original = self.near_fars.min() | |
scale_factor = near_original * 0.75 # 0.75 is the default parameter | |
# the nearest depth is at 1/0.75=1.33 | |
self.near_fars /= scale_factor | |
self.poses[..., 3] /= scale_factor | |
# build rendering path | |
N_views, N_rots = 120, 2 | |
tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T | |
up = normalize(self.poses[:, :3, 1].sum(0)) | |
rads = np.percentile(np.abs(tt), 90, 0) | |
self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views) | |
# self.render_path = get_interpolation_path(self.poses) | |
# distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) | |
# val_idx = np.argmin(distances_from_center) # choose val image as the closest to | |
# center image | |
# ray directions for all pixels, same for all images (same H, W, focal) | |
W, H = self.img_wh | |
self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3) | |
average_pose = average_poses(self.poses) | |
dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1) | |
i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)] | |
img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test)) | |
# use first N_images-1 to train, the LAST is val | |
self.all_rays = [] | |
self.all_rgbs = [] | |
for i in img_list: | |
image_path = self.image_paths[i] | |
c2w = torch.FloatTensor(self.poses[i]) | |
img = Image.open(image_path).convert('RGB') | |
if self.downsample != 1.0: | |
img = img.resize(self.img_wh, Image.LANCZOS) | |
img = self.transform(img) # (3, h, w) | |
img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB | |
self.all_rgbs += [img] | |
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) | |
rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) | |
# viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) | |
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) | |
all_rays = self.all_rays | |
all_rgbs = self.all_rgbs | |
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6) | |
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3) | |
if self.is_stack: | |
self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1, *self.img_wh[::-1], | |
6) # (len(self.meta['frames]),h,w,6) | |
avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True) | |
self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0, 3, 1, 2)).permute(0, 2, 3, | |
1) # (len(self.meta['frames]),h/4,w/4,6) | |
self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1, *self.img_wh[::-1], | |
3) # (len(self.meta['frames]),h,w,3) | |
def prepare_feature_data(self, encoder, chunk=8): | |
''' | |
Prepare feature maps as training data. | |
''' | |
assert self.is_stack, 'Dataset should contain original stacked taining data!' | |
print('====> prepare_feature_data ...') | |
frames_num, h, w, _ = self.all_rgbs_stack.size() | |
features = [] | |
for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)): | |
rgbs_chunk = self.all_rgbs_stack[chunk_idx * chunk: (chunk_idx + 1) * chunk].cuda() | |
features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0, 3, 1, 2))).relu3_1 | |
# resize to the size of rgb map so that rays can match | |
features_chunk = T.functional.resize(features_chunk, size=(h, w), | |
interpolation=T.InterpolationMode.BILINEAR) | |
features.append(features_chunk.detach().cpu().requires_grad_(False)) | |
self.all_features_stack = torch.cat(features).permute(0, 2, 3, 1) # (len(self.meta['frames]),h,w,256) | |
self.all_features = self.all_features_stack.reshape(-1, 256) | |
print('prepare_feature_data Done!') | |
def define_transforms(self): | |
self.transform = T.ToTensor() | |
def __len__(self): | |
return len(self.all_rgbs) | |
def __getitem__(self, idx): | |
sample = {'rays': self.all_rays[idx], | |
'rgbs': self.all_rgbs[idx]} | |
return sample |