import os import cv2 import glob import json import tqdm import random import numpy as np from scipy.spatial.transform import Slerp, Rotation import trimesh import torch from torch.utils.data import DataLoader from .utils import get_rays, safe_normalize def visualize_poses(poses, size=0.1): # poses: [B, 4, 4] axes = trimesh.creation.axis(axis_length=4) sphere = trimesh.creation.icosphere(radius=1) objects = [axes, sphere] for pose in poses: # a camera is visualized with 8 line segments. pos = pose[:3, 3] a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]]) segs = trimesh.load_path(segs) objects.append(segs) trimesh.Scene(objects).show() def get_view_direction(thetas, phis, overhead, front): # phis [B,]; thetas: [B,] # front = 0 [0, front) # side (left) = 1 [front, 180) # back = 2 [180, 180+front) # side (right) = 3 [180+front, 360) # top = 4 [0, overhead] # bottom = 5 [180-overhead, 180] res = torch.zeros(thetas.shape[0], dtype=torch.long) # first determine by phis res[(phis < front)] = 0 res[(phis >= front) & (phis < np.pi)] = 1 res[(phis >= np.pi) & (phis < (np.pi + front))] = 2 res[(phis >= (np.pi + front))] = 3 # override by thetas res[thetas <= overhead] = 4 res[thetas >= (np.pi - overhead)] = 5 return res def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 150], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False): ''' generate random poses from an orbit camera Args: size: batch size of generated poses. device: where to allocate the output. radius: camera radius theta_range: [min, max], should be in [0, pi] phi_range: [min, max], should be in [0, 2 * pi] Return: poses: [size, 4, 4] ''' theta_range = np.deg2rad(theta_range) phi_range = np.deg2rad(phi_range) angle_overhead = np.deg2rad(angle_overhead) angle_front = np.deg2rad(angle_front) radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] centers = torch.stack([ radius * torch.sin(thetas) * torch.sin(phis), radius * torch.cos(thetas), radius * torch.sin(thetas) * torch.cos(phis), ], dim=-1) # [B, 3] targets = 0 # jitters if jitter: centers = centers + (torch.rand_like(centers) * 0.2 - 0.1) targets = targets + torch.randn_like(centers) * 0.2 # lookat forward_vector = safe_normalize(targets - centers) up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) if jitter: up_noise = torch.randn_like(up_vector) * 0.02 else: up_noise = 0 up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) poses[:, :3, 3] = centers if return_dirs: dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) else: dirs = None return poses, dirs def circle_poses(device, radius=1.25, theta=60, phi=0, return_dirs=False, angle_overhead=30, angle_front=60): theta = np.deg2rad(theta) phi = np.deg2rad(phi) angle_overhead = np.deg2rad(angle_overhead) angle_front = np.deg2rad(angle_front) thetas = torch.FloatTensor([theta]).to(device) phis = torch.FloatTensor([phi]).to(device) centers = torch.stack([ radius * torch.sin(thetas) * torch.sin(phis), radius * torch.cos(thetas), radius * torch.sin(thetas) * torch.cos(phis), ], dim=-1) # [B, 3] # lookat forward_vector = - safe_normalize(centers) up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0) right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0) poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) poses[:, :3, 3] = centers if return_dirs: dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) else: dirs = None return poses, dirs class NeRFDataset: def __init__(self, opt, device, type='train', H=256, W=256, size=100): super().__init__() self.opt = opt self.device = device self.type = type # train, val, test self.H = H self.W = W self.radius_range = opt.radius_range self.fovy_range = opt.fovy_range self.size = size self.training = self.type in ['train', 'all'] self.cx = self.H / 2 self.cy = self.W / 2 # [debug] visualize poses # poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range) # visualize_poses(poses.detach().cpu().numpy()) def collate(self, index): B = len(index) # always 1 if self.training: # random pose on the fly poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose) # random focal fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0] focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2)) intrinsics = np.array([focal, focal, self.cx, self.cy]) else: # circle pose phi = (index[0] / self.size) * 360 poses, dirs = circle_poses(self.device, radius=self.radius_range[1] * 1.2, theta=60, phi=phi, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front) # fixed focal fov = (self.fovy_range[1] + self.fovy_range[0]) / 2 focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2)) intrinsics = np.array([focal, focal, self.cx, self.cy]) # sample a low-resolution but full image for CLIP rays = get_rays(poses, intrinsics, self.H, self.W, -1) data = { 'H': self.H, 'W': self.W, 'rays_o': rays['rays_o'], 'rays_d': rays['rays_d'], 'dir': dirs, } return data def dataloader(self): loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) loader._data = self # an ugly fix... we need to access dataset in trainer. return loader