heheyas
init
cfb7702
raw history blame
No virus
14.7 kB
import os
import math
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
import datasets
from datasets.colmap_utils import \
read_cameras_binary, read_images_binary, read_points3d_binary
from models.ray_utils import get_ray_directions
from utils.misc import get_rank
def get_center(pts):
center = pts.mean(0)
dis = (pts - center[None,:]).norm(p=2, dim=-1)
mean, std = dis.mean(), dis.std()
q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75)
valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5)
center = pts[valid].mean(0)
return center
def normalize_poses(poses, pts, up_est_method, center_est_method):
if center_est_method == 'camera':
# estimation scene center as the average of all camera positions
center = poses[...,3].mean(0)
elif center_est_method == 'lookat':
# estimation scene center as the average of the intersection of selected pairs of camera rays
cams_ori = poses[...,3]
cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.])
cams_dir = F.normalize(cams_dir, dim=-1)
A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1)
b = -cams_ori + cams_ori.roll(1,0)
t = torch.linalg.lstsq(A, b).solution
center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2))
elif center_est_method == 'point':
# first estimation scene center as the average of all camera positions
# later we'll use the center of all points bounded by the cameras as the final scene center
center = poses[...,3].mean(0)
else:
raise NotImplementedError(f'Unknown center estimation method: {center_est_method}')
if up_est_method == 'ground':
# estimate up direction as the normal of the estimated ground plane
# use RANSAC to estimate the ground plane in the point cloud
import pyransac3d as pyrsc
ground = pyrsc.Plane()
plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale
plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0
z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction
signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1)
if signed_distance.mean() < 0:
z = -z # flip the direction if points lie under the plane
elif up_est_method == 'camera':
# estimate up direction as the average of all camera up directions
z = F.normalize((poses[...,3] - center).mean(0), dim=0)
else:
raise NotImplementedError(f'Unknown up estimation method: {up_est_method}')
# new axis
y_ = torch.as_tensor([z[1], -z[0], 0.])
x = F.normalize(y_.cross(z), dim=0)
y = z.cross(x)
if center_est_method == 'point':
# rotation
Rc = torch.stack([x, y, z], dim=1)
R = Rc.T
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
poses_norm = (inv_trans @ poses_homo)[:,:3]
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
# translation and scaling
poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0]
pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])]
center = get_center(pts_fg)
tc = center.reshape(3, 1)
t = -tc
poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1)
inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
poses_norm = (inv_trans @ poses_homo)[:,:3]
scale = poses_norm[...,3].norm(p=2, dim=-1).min()
poses_norm[...,3] /= scale
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
pts = pts / scale
else:
# rotation and translation
Rc = torch.stack([x, y, z], dim=1)
tc = center.reshape(3, 1)
R, t = Rc.T, -Rc.T @ tc
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4)
# scaling
scale = poses_norm[...,3].norm(p=2, dim=-1).min()
poses_norm[...,3] /= scale
# apply the transformation to the point cloud
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
pts = pts / scale
return poses_norm, pts
def create_spheric_poses(cameras, n_steps=120):
center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean()
mean_h = cameras[:,2].mean()
r = (mean_d**2 - mean_h**2).sqrt()
up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device)
all_c2w = []
for theta in torch.linspace(0, 2 * math.pi, n_steps):
cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h])
l = F.normalize(center - cam_pos, p=2, dim=0)
s = F.normalize(l.cross(up), p=2, dim=0)
u = F.normalize(s.cross(l), p=2, dim=0)
c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
all_c2w.append(c2w)
all_c2w = torch.stack(all_c2w, dim=0)
return all_c2w
class ColmapDatasetBase():
# the data only has to be processed once
initialized = False
properties = {}
def setup(self, config, split):
self.config = config
self.split = split
self.rank = get_rank()
if not ColmapDatasetBase.initialized:
camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin'))
H = int(camdata[1].height)
W = int(camdata[1].width)
if 'img_wh' in self.config:
w, h = self.config.img_wh
assert round(W / w * h) == H
elif 'img_downscale' in self.config:
w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
else:
raise KeyError("Either img_wh or img_downscale should be specified.")
img_wh = (w, h)
factor = w / W
if camdata[1].model == 'SIMPLE_RADIAL':
fx = fy = camdata[1].params[0] * factor
cx = camdata[1].params[1] * factor
cy = camdata[1].params[2] * factor
elif camdata[1].model in ['PINHOLE', 'OPENCV']:
fx = camdata[1].params[0] * factor
fy = camdata[1].params[1] * factor
cx = camdata[1].params[2] * factor
cy = camdata[1].params[3] * factor
else:
raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!")
directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank)
imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin'))
mask_dir = os.path.join(self.config.root_dir, 'masks')
has_mask = os.path.exists(mask_dir) # TODO: support partial masks
apply_mask = has_mask and self.config.apply_mask
all_c2w, all_images, all_fg_masks = [], [], []
for i, d in enumerate(imdata.values()):
R = d.qvec2rotmat()
t = d.tvec.reshape(3, 1)
c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float()
c2w[:,1:3] *= -1. # COLMAP => OpenGL
all_c2w.append(c2w)
if self.split in ['train', 'val']:
img_path = os.path.join(self.config.root_dir, 'images', d.name)
img = Image.open(img_path)
img = img.resize(img_wh, Image.BICUBIC)
img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu()
if has_mask:
mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])]
mask_paths = list(filter(os.path.exists, mask_paths))
assert len(mask_paths) == 1
mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1)
mask = mask.resize(img_wh, Image.BICUBIC)
mask = TF.to_tensor(mask)[0]
else:
mask = torch.ones_like(img[...,0], device=img.device)
all_fg_masks.append(mask) # (h, w)
all_images.append(img)
all_c2w = torch.stack(all_c2w, dim=0)
pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin'))
pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float()
all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method)
ColmapDatasetBase.properties = {
'w': w,
'h': h,
'img_wh': img_wh,
'factor': factor,
'has_mask': has_mask,
'apply_mask': apply_mask,
'directions': directions,
'pts3d': pts3d,
'all_c2w': all_c2w,
'all_images': all_images,
'all_fg_masks': all_fg_masks
}
ColmapDatasetBase.initialized = True
for k, v in ColmapDatasetBase.properties.items():
setattr(self, k, v)
if self.split == 'test':
self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
else:
self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float()
"""
# for debug use
from models.ray_utils import get_rays
rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True)
pts_out = []
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()]))
t_vals = torch.linspace(0, 1, 8)
z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :])
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()]))
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :])
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :])
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :])
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
open('cameras.txt', 'w').write('\n'.join(pts_out))
open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()]))
exit(1)
"""
self.all_c2w = self.all_c2w.float().to(self.rank)
if self.config.load_data_on_gpu:
self.all_images = self.all_images.to(self.rank)
self.all_fg_masks = self.all_fg_masks.to(self.rank)
class ColmapDataset(Dataset, ColmapDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __len__(self):
return len(self.all_images)
def __getitem__(self, index):
return {
'index': index
}
class ColmapIterableDataset(IterableDataset, ColmapDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __iter__(self):
while True:
yield {}
@datasets.register('colmap')
class ColmapDataModule(pl.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
def setup(self, stage=None):
if stage in [None, 'fit']:
self.train_dataset = ColmapIterableDataset(self.config, 'train')
if stage in [None, 'fit', 'validate']:
self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train'))
if stage in [None, 'test']:
self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test'))
if stage in [None, 'predict']:
self.predict_dataset = ColmapDataset(self.config, 'train')
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size):
sampler = None
return DataLoader(
dataset,
num_workers=os.cpu_count(),
batch_size=batch_size,
pin_memory=True,
sampler=sampler
)
def train_dataloader(self):
return self.general_loader(self.train_dataset, batch_size=1)
def val_dataloader(self):
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self):
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self):
return self.general_loader(self.predict_dataset, batch_size=1)