File size: 6,721 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
from dust3r.cloud_opt_flow.base_opt import BasePCOptimizer
from dust3r.utils.geometry import geotrf
from dust3r.utils.device import to_cpu, to_numpy
from dust3r.utils.geometry import depthmap_to_pts3d
from dust3r.cloud_opt_flow.optimizer import PointCloudOptimizer, tum_to_pose_matrix, ParameterStack, xy_grid
class ModularPointCloudOptimizer (BasePCOptimizer):
""" Optimize a global scene, given a list of pairwise observations.
Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics)
Graph node: images
Graph edges: observations = (pred1, pred2)
"""
def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs):
super().__init__(*args, **kwargs)
self.has_im_poses = True # by definition of this class
self.focal_brake = focal_brake
# adding thing to optimize
self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes]
self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [
f]) for f in default_focals) # camera intrinsics
self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
self.im_pp.requires_grad_(optimize_pp)
def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
known_poses = [known_poses]
if known_poses.shape[-1] == 7: # xyz wxyz
known_poses = [tum_to_pose_matrix(pose) for pose in known_poses]
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
if self.verbose:
print(f' (setting pose #{idx} = {pose[:3,3]})')
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True))
# normalize scale if there's less than 1 known pose
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
self.norm_pw_scale = (n_known_poses <= 1)
def preset_intrinsics(self, known_intrinsics, msk=None):
if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2:
known_intrinsics = [known_intrinsics]
for K in known_intrinsics:
assert K.shape == (3, 3)
self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk)
self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk)
def preset_focal(self, known_focals, msk=None):
for idx, focal in zip(self._get_msk_indices(msk), known_focals):
if self.verbose:
print(f' (setting focal #{idx} = {focal})')
self._no_grad(self._set_focal(idx, focal, force=True))
def preset_principal_point(self, known_pp, msk=None):
for idx, pp in zip(self._get_msk_indices(msk), known_pp):
if self.verbose:
print(f' (setting principal point #{idx} = {pp})')
self._no_grad(self._set_principal_point(idx, pp, force=True))
def _no_grad(self, tensor):
return tensor.requires_grad_(False)
def _get_msk_indices(self, msk):
if msk is None:
return range(self.n_imgs)
elif isinstance(msk, int):
return [msk]
elif isinstance(msk, (tuple, list)):
return self._get_msk_indices(np.array(msk))
elif msk.dtype in (bool, torch.bool, np.bool_):
assert len(msk) == self.n_imgs
return np.where(msk)[0]
elif np.issubdtype(msk.dtype, np.integer):
return msk
else:
raise ValueError(f'bad {msk=}')
def _set_focal(self, idx, focal, force=False):
param = self.im_focals[idx]
if param.requires_grad or force: # can only init a parameter not already initialized
param.data[:] = self.focal_brake * np.log(focal)
return param
def get_focals(self):
log_focals = torch.stack(list(self.im_focals), dim=0)
return (log_focals / self.focal_brake).exp()
def _set_principal_point(self, idx, pp, force=False):
param = self.im_pp[idx]
H, W = self.imshapes[idx]
if param.requires_grad or force: # can only init a parameter not already initialized
param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
return param
def get_principal_points(self):
return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)])
def get_intrinsics(self):
K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
focals = self.get_focals().view(self.n_imgs, -1)
K[:, 0, 0] = focals[:, 0]
K[:, 1, 1] = focals[:, -1]
K[:, :2, 2] = self.get_principal_points()
K[:, 2, 2] = 1
return K
def get_im_poses(self): # cam to world
cam2world = self._get_poses(torch.stack(list(self.im_poses)))
return cam2world
def _set_depthmap(self, idx, depth, force=False):
param = self.im_depthmaps[idx]
if param.requires_grad or force: # can only init a parameter not already initialized
param.data[:] = depth.log().nan_to_num(neginf=0)
return param
def get_depthmaps(self):
return [d.exp() for d in self.im_depthmaps]
def depth_to_pts3d(self):
# Get depths and projection params if not provided
focals = self.get_focals()
pp = self.get_principal_points()
im_poses = self.get_im_poses()
depth = self.get_depthmaps()
# convert focal to (1,2,H,W) constant field
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i])
# get pointmaps in camera frame
rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])]
# project to world frame
return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)]
def get_pts3d(self):
return self.depth_to_pts3d()
|