|
import os |
|
import math |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader, IterableDataset |
|
import torchvision.transforms.functional as TF |
|
|
|
import pytorch_lightning as pl |
|
|
|
import datasets |
|
from datasets.colmap_utils import \ |
|
read_cameras_binary, read_images_binary, read_points3d_binary |
|
from models.ray_utils import get_ray_directions |
|
from utils.misc import get_rank |
|
|
|
|
|
def get_center(pts): |
|
center = pts.mean(0) |
|
dis = (pts - center[None,:]).norm(p=2, dim=-1) |
|
mean, std = dis.mean(), dis.std() |
|
q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75) |
|
valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5) |
|
center = pts[valid].mean(0) |
|
return center |
|
|
|
def normalize_poses(poses, pts, up_est_method, center_est_method): |
|
if center_est_method == 'camera': |
|
|
|
center = poses[...,3].mean(0) |
|
elif center_est_method == 'lookat': |
|
|
|
cams_ori = poses[...,3] |
|
cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.]) |
|
cams_dir = F.normalize(cams_dir, dim=-1) |
|
A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1) |
|
b = -cams_ori + cams_ori.roll(1,0) |
|
t = torch.linalg.lstsq(A, b).solution |
|
center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2)) |
|
elif center_est_method == 'point': |
|
|
|
|
|
center = poses[...,3].mean(0) |
|
else: |
|
raise NotImplementedError(f'Unknown center estimation method: {center_est_method}') |
|
|
|
if up_est_method == 'ground': |
|
|
|
|
|
import pyransac3d as pyrsc |
|
ground = pyrsc.Plane() |
|
plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) |
|
plane_eq = torch.as_tensor(plane_eq) |
|
z = F.normalize(plane_eq[:3], dim=-1) |
|
signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1) |
|
if signed_distance.mean() < 0: |
|
z = -z |
|
elif up_est_method == 'camera': |
|
|
|
z = F.normalize((poses[...,3] - center).mean(0), dim=0) |
|
else: |
|
raise NotImplementedError(f'Unknown up estimation method: {up_est_method}') |
|
|
|
|
|
y_ = torch.as_tensor([z[1], -z[0], 0.]) |
|
x = F.normalize(y_.cross(z), dim=0) |
|
y = z.cross(x) |
|
|
|
if center_est_method == 'point': |
|
|
|
Rc = torch.stack([x, y, z], dim=1) |
|
R = Rc.T |
|
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) |
|
inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) |
|
poses_norm = (inv_trans @ poses_homo)[:,:3] |
|
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] |
|
|
|
|
|
poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0] |
|
pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])] |
|
center = get_center(pts_fg) |
|
tc = center.reshape(3, 1) |
|
t = -tc |
|
poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1) |
|
inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) |
|
poses_norm = (inv_trans @ poses_homo)[:,:3] |
|
scale = poses_norm[...,3].norm(p=2, dim=-1).min() |
|
poses_norm[...,3] /= scale |
|
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] |
|
pts = pts / scale |
|
else: |
|
|
|
Rc = torch.stack([x, y, z], dim=1) |
|
tc = center.reshape(3, 1) |
|
R, t = Rc.T, -Rc.T @ tc |
|
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) |
|
inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) |
|
poses_norm = (inv_trans @ poses_homo)[:,:3] |
|
|
|
|
|
scale = poses_norm[...,3].norm(p=2, dim=-1).min() |
|
poses_norm[...,3] /= scale |
|
|
|
|
|
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] |
|
pts = pts / scale |
|
|
|
return poses_norm, pts |
|
|
|
def create_spheric_poses(cameras, n_steps=120): |
|
center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) |
|
mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean() |
|
mean_h = cameras[:,2].mean() |
|
r = (mean_d**2 - mean_h**2).sqrt() |
|
up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device) |
|
|
|
all_c2w = [] |
|
for theta in torch.linspace(0, 2 * math.pi, n_steps): |
|
cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h]) |
|
l = F.normalize(center - cam_pos, p=2, dim=0) |
|
s = F.normalize(l.cross(up), p=2, dim=0) |
|
u = F.normalize(s.cross(l), p=2, dim=0) |
|
c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) |
|
all_c2w.append(c2w) |
|
|
|
all_c2w = torch.stack(all_c2w, dim=0) |
|
|
|
return all_c2w |
|
|
|
class ColmapDatasetBase(): |
|
|
|
initialized = False |
|
properties = {} |
|
|
|
def setup(self, config, split): |
|
self.config = config |
|
self.split = split |
|
self.rank = get_rank() |
|
|
|
if not ColmapDatasetBase.initialized: |
|
camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin')) |
|
|
|
H = int(camdata[1].height) |
|
W = int(camdata[1].width) |
|
|
|
if 'img_wh' in self.config: |
|
w, h = self.config.img_wh |
|
assert round(W / w * h) == H |
|
elif 'img_downscale' in self.config: |
|
w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) |
|
else: |
|
raise KeyError("Either img_wh or img_downscale should be specified.") |
|
|
|
img_wh = (w, h) |
|
factor = w / W |
|
|
|
if camdata[1].model == 'SIMPLE_RADIAL': |
|
fx = fy = camdata[1].params[0] * factor |
|
cx = camdata[1].params[1] * factor |
|
cy = camdata[1].params[2] * factor |
|
elif camdata[1].model in ['PINHOLE', 'OPENCV']: |
|
fx = camdata[1].params[0] * factor |
|
fy = camdata[1].params[1] * factor |
|
cx = camdata[1].params[2] * factor |
|
cy = camdata[1].params[3] * factor |
|
else: |
|
raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") |
|
|
|
directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank) |
|
|
|
imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin')) |
|
|
|
mask_dir = os.path.join(self.config.root_dir, 'masks') |
|
has_mask = os.path.exists(mask_dir) |
|
apply_mask = has_mask and self.config.apply_mask |
|
|
|
all_c2w, all_images, all_fg_masks = [], [], [] |
|
|
|
for i, d in enumerate(imdata.values()): |
|
R = d.qvec2rotmat() |
|
t = d.tvec.reshape(3, 1) |
|
c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float() |
|
c2w[:,1:3] *= -1. |
|
all_c2w.append(c2w) |
|
if self.split in ['train', 'val']: |
|
img_path = os.path.join(self.config.root_dir, 'images', d.name) |
|
img = Image.open(img_path) |
|
img = img.resize(img_wh, Image.BICUBIC) |
|
img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] |
|
img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu() |
|
if has_mask: |
|
mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])] |
|
mask_paths = list(filter(os.path.exists, mask_paths)) |
|
assert len(mask_paths) == 1 |
|
mask = Image.open(mask_paths[0]).convert('L') |
|
mask = mask.resize(img_wh, Image.BICUBIC) |
|
mask = TF.to_tensor(mask)[0] |
|
else: |
|
mask = torch.ones_like(img[...,0], device=img.device) |
|
all_fg_masks.append(mask) |
|
all_images.append(img) |
|
|
|
all_c2w = torch.stack(all_c2w, dim=0) |
|
|
|
pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin')) |
|
pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float() |
|
all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method) |
|
|
|
ColmapDatasetBase.properties = { |
|
'w': w, |
|
'h': h, |
|
'img_wh': img_wh, |
|
'factor': factor, |
|
'has_mask': has_mask, |
|
'apply_mask': apply_mask, |
|
'directions': directions, |
|
'pts3d': pts3d, |
|
'all_c2w': all_c2w, |
|
'all_images': all_images, |
|
'all_fg_masks': all_fg_masks |
|
} |
|
|
|
ColmapDatasetBase.initialized = True |
|
|
|
for k, v in ColmapDatasetBase.properties.items(): |
|
setattr(self, k, v) |
|
|
|
if self.split == 'test': |
|
self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) |
|
self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) |
|
self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) |
|
else: |
|
self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float() |
|
|
|
""" |
|
# for debug use |
|
from models.ray_utils import get_rays |
|
rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True) |
|
pts_out = [] |
|
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()])) |
|
|
|
t_vals = torch.linspace(0, 1, 8) |
|
z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals |
|
|
|
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :]) |
|
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()])) |
|
|
|
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :]) |
|
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) |
|
|
|
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :]) |
|
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) |
|
|
|
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :]) |
|
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) |
|
|
|
open('cameras.txt', 'w').write('\n'.join(pts_out)) |
|
open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()])) |
|
|
|
exit(1) |
|
""" |
|
|
|
self.all_c2w = self.all_c2w.float().to(self.rank) |
|
if self.config.load_data_on_gpu: |
|
self.all_images = self.all_images.to(self.rank) |
|
self.all_fg_masks = self.all_fg_masks.to(self.rank) |
|
|
|
|
|
class ColmapDataset(Dataset, ColmapDatasetBase): |
|
def __init__(self, config, split): |
|
self.setup(config, split) |
|
|
|
def __len__(self): |
|
return len(self.all_images) |
|
|
|
def __getitem__(self, index): |
|
return { |
|
'index': index |
|
} |
|
|
|
|
|
class ColmapIterableDataset(IterableDataset, ColmapDatasetBase): |
|
def __init__(self, config, split): |
|
self.setup(config, split) |
|
|
|
def __iter__(self): |
|
while True: |
|
yield {} |
|
|
|
|
|
@datasets.register('colmap') |
|
class ColmapDataModule(pl.LightningDataModule): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
def setup(self, stage=None): |
|
if stage in [None, 'fit']: |
|
self.train_dataset = ColmapIterableDataset(self.config, 'train') |
|
if stage in [None, 'fit', 'validate']: |
|
self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train')) |
|
if stage in [None, 'test']: |
|
self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test')) |
|
if stage in [None, 'predict']: |
|
self.predict_dataset = ColmapDataset(self.config, 'train') |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def general_loader(self, dataset, batch_size): |
|
sampler = None |
|
return DataLoader( |
|
dataset, |
|
num_workers=os.cpu_count(), |
|
batch_size=batch_size, |
|
pin_memory=True, |
|
sampler=sampler |
|
) |
|
|
|
def train_dataloader(self): |
|
return self.general_loader(self.train_dataset, batch_size=1) |
|
|
|
def val_dataloader(self): |
|
return self.general_loader(self.val_dataset, batch_size=1) |
|
|
|
def test_dataloader(self): |
|
return self.general_loader(self.test_dataset, batch_size=1) |
|
|
|
def predict_dataloader(self): |
|
return self.general_loader(self.predict_dataset, batch_size=1) |
|
|