|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import roma |
|
from copy import deepcopy |
|
import tqdm |
|
|
|
from dust3r.utils.geometry import inv, geotrf |
|
from dust3r.utils.device import to_numpy |
|
from dust3r.utils.image import rgb |
|
from dust3r.viz import SceneViz, segment_sky, auto_cam_size |
|
from dust3r.optim_factory import adjust_learning_rate_by_lr |
|
|
|
from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, |
|
cosine_schedule, linear_schedule, get_conf_trf) |
|
import dust3r.cloud_opt.init_im_poses as init_fun |
|
from scipy.spatial.transform import Rotation |
|
from dust3r.utils.vo_eval import save_trajectory_tum_format |
|
import os |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
|
|
def c2w_to_tumpose(c2w): |
|
""" |
|
Convert a camera-to-world matrix to a tuple of translation and rotation |
|
|
|
input: c2w: 4x4 matrix |
|
output: tuple of translation and rotation (x y z qw qx qy qz) |
|
""" |
|
|
|
c2w = to_numpy(c2w) |
|
xyz = c2w[:3, -1] |
|
rot = Rotation.from_matrix(c2w[:3, :3]) |
|
qx, qy, qz, qw = rot.as_quat() |
|
tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]]) |
|
return tum_pose |
|
|
|
class BasePCOptimizer (nn.Module): |
|
""" Optimize a global scene, given a list of pairwise observations. |
|
Graph node: images |
|
Graph edges: observations = (pred1, pred2) |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
if len(args) == 1 and len(kwargs) == 0: |
|
other = deepcopy(args[0]) |
|
attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes |
|
min_conf_thr conf_thr conf_i conf_j im_conf |
|
base_scale norm_pw_scale POSE_DIM pw_poses |
|
pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() |
|
self.__dict__.update({k: other[k] for k in attrs}) |
|
else: |
|
self._init_from_views(*args, **kwargs) |
|
|
|
def _init_from_views(self, view1, view2, pred1, pred2, |
|
if_use_mono, |
|
mono_depths, |
|
dist='l1', |
|
conf='log', |
|
min_conf_thr=3, |
|
base_scale=0.5, |
|
allow_pw_adaptors=False, |
|
pw_break=20, |
|
rand_pose=torch.randn, |
|
iterationsCount=None, |
|
verbose=True): |
|
super().__init__() |
|
if not isinstance(view1['idx'], list): |
|
view1['idx'] = view1['idx'].tolist() |
|
if not isinstance(view2['idx'], list): |
|
view2['idx'] = view2['idx'].tolist() |
|
self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] |
|
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} |
|
self.dist = ALL_DISTS[dist] |
|
self.verbose = verbose |
|
|
|
self.if_use_mono = if_use_mono |
|
self.mono_depths = mono_depths |
|
for i in range(len(self.mono_depths)): |
|
self.mono_depths[i].requires_grad = False |
|
|
|
self.n_imgs = self._check_edges() |
|
|
|
|
|
pred1_pts = pred1['pts3d'] |
|
pred2_pts = pred2['pts3d_in_other_view'] |
|
self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) |
|
self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) |
|
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) |
|
|
|
|
|
pred1_conf = pred1['conf'] |
|
pred2_conf = pred2['conf'] |
|
self.min_conf_thr = min_conf_thr |
|
self.conf_trf = get_conf_trf(conf) |
|
|
|
self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}) |
|
self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}) |
|
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) |
|
for i in range(len(self.im_conf)): |
|
self.im_conf[i].requires_grad = False |
|
|
|
|
|
self.base_scale = base_scale |
|
self.norm_pw_scale = True |
|
self.pw_break = pw_break |
|
self.POSE_DIM = 7 |
|
self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) |
|
self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) |
|
self.pw_adaptors.requires_grad_(allow_pw_adaptors) |
|
self.has_im_poses = False |
|
self.rand_pose = rand_pose |
|
|
|
|
|
self.imgs = None |
|
if 'img' in view1 and 'img' in view2: |
|
imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] |
|
for v in range(len(self.edges)): |
|
idx = view1['idx'][v] |
|
imgs[idx] = view1['img'][v] |
|
idx = view2['idx'][v] |
|
imgs[idx] = view2['img'][v] |
|
self.imgs = rgb(imgs) |
|
self.camera_poses = None |
|
if 'camera_pose' in view1 and 'camera_pose' in view2: |
|
camera_poses = [torch.zeros((4, 4)) for _ in range(self.n_imgs)] |
|
for v in range(len(self.edges)): |
|
idx = view1['idx'][v] |
|
camera_poses[idx] = view1['camera_pose'][v] |
|
idx = view2['idx'][v] |
|
camera_poses[idx] = view2['camera_pose'][v] |
|
self.camera_poses = camera_poses |
|
@property |
|
def n_edges(self): |
|
return len(self.edges) |
|
|
|
@property |
|
def str_edges(self): |
|
return [edge_str(i, j) for i, j in self.edges] |
|
|
|
@property |
|
def imsizes(self): |
|
return [(w, h) for h, w in self.imshapes] |
|
|
|
@property |
|
def device(self): |
|
return next(iter(self.parameters())).device |
|
|
|
def state_dict(self, trainable=True): |
|
all_params = super().state_dict() |
|
return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} |
|
|
|
def load_state_dict(self, data): |
|
return super().load_state_dict(self.state_dict(trainable=False) | data) |
|
|
|
def _check_edges(self): |
|
indices = sorted({i for edge in self.edges for i in edge}) |
|
assert indices == list(range(len(indices))), 'bad pair indices: missing values ' |
|
return len(indices) |
|
|
|
@torch.no_grad() |
|
def _compute_img_conf(self, pred1_conf, pred2_conf): |
|
im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) |
|
for e, (i, j) in enumerate(self.edges): |
|
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) |
|
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) |
|
return im_conf |
|
|
|
def get_adaptors(self): |
|
adapt = self.pw_adaptors |
|
adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) |
|
if self.norm_pw_scale: |
|
adapt = adapt - adapt.mean(dim=1, keepdim=True) |
|
return (adapt / self.pw_break).exp() |
|
|
|
def _get_poses(self, poses): |
|
|
|
Q = poses[:, :4] |
|
T = signed_expm1(poses[:, 4:7]) |
|
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() |
|
return RT |
|
|
|
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): |
|
|
|
pose = poses[idx] |
|
if not (pose.requires_grad or force): |
|
return pose |
|
|
|
if R.shape == (4, 4): |
|
assert T is None |
|
T = R[:3, 3] |
|
R = R[:3, :3] |
|
|
|
if R is not None: |
|
pose.data[0:4] = roma.rotmat_to_unitquat(R) |
|
if T is not None: |
|
pose.data[4:7] = signed_log1p(T / (scale or 1)) |
|
|
|
if scale is not None: |
|
assert poses.shape[-1] in (8, 13) |
|
pose.data[-1] = np.log(float(scale)) |
|
return pose |
|
|
|
def get_pw_norm_scale_factor(self): |
|
if self.norm_pw_scale: |
|
|
|
|
|
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() |
|
else: |
|
return 1 |
|
|
|
def get_pw_scale(self): |
|
scale = self.pw_poses[:, -1].exp() |
|
scale = scale * self.get_pw_norm_scale_factor() |
|
return scale |
|
|
|
def get_pw_poses(self): |
|
RT = self._get_poses(self.pw_poses) |
|
scaled_RT = RT.clone() |
|
scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) |
|
return scaled_RT |
|
|
|
def get_masks(self): |
|
return [(conf > self.min_conf_thr) for conf in self.im_conf] |
|
|
|
def depth_to_pts3d(self): |
|
raise NotImplementedError() |
|
|
|
def get_pts3d(self, raw=False): |
|
res = self.depth_to_pts3d() |
|
if not raw: |
|
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] |
|
return res |
|
|
|
def _set_focal(self, idx, focal, force=False): |
|
raise NotImplementedError() |
|
|
|
def get_focals(self): |
|
raise NotImplementedError() |
|
|
|
def get_known_focal_mask(self): |
|
raise NotImplementedError() |
|
|
|
def get_principal_points(self): |
|
raise NotImplementedError() |
|
|
|
def get_conf(self, mode=None): |
|
trf = self.conf_trf if mode is None else get_conf_trf(mode) |
|
return [trf(c) for c in self.im_conf] |
|
|
|
def get_im_poses(self): |
|
raise NotImplementedError() |
|
|
|
def _set_depthmap(self, idx, depth, force=False): |
|
raise NotImplementedError() |
|
|
|
def get_depthmaps(self, raw=False): |
|
raise NotImplementedError() |
|
|
|
def clean_pointcloud(self, **kw): |
|
cams = inv(self.get_im_poses()) |
|
K = self.get_intrinsics() |
|
depthmaps = self.get_depthmaps() |
|
all_pts3d = self.get_pts3d() |
|
|
|
new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw) |
|
|
|
for i, new_conf in enumerate(new_im_confs): |
|
self.im_conf[i].data[:] = new_conf |
|
return self |
|
def get_tum_poses(self): |
|
poses = self.get_im_poses() |
|
tt = np.arange(len(poses)).astype(float) |
|
tum_poses = [c2w_to_tumpose(p) for p in poses] |
|
tum_poses = np.stack(tum_poses, 0) |
|
return [tum_poses, tt] |
|
|
|
def save_tum_poses(self, path): |
|
traj = self.get_tum_poses() |
|
save_trajectory_tum_format(traj, path) |
|
return traj[0] |
|
|
|
def save_focals(self, path): |
|
|
|
focals = self.get_focals() |
|
np.savetxt(path, focals.detach().cpu().numpy(), fmt='%.6f') |
|
return focals |
|
|
|
def save_intrinsics(self, path): |
|
K_raw = self.get_intrinsics() |
|
K = K_raw.reshape(-1, 9) |
|
np.savetxt(path, K.detach().cpu().numpy(), fmt='%.6f') |
|
return K_raw |
|
|
|
def save_conf_maps(self, path): |
|
conf = self.get_conf() |
|
for i, c in enumerate(conf): |
|
np.save(f'{path}/conf_{i}.npy', c.detach().cpu().numpy()) |
|
return conf |
|
|
|
def save_init_conf_maps(self, path): |
|
conf = self.get_init_conf() |
|
for i, c in enumerate(conf): |
|
np.save(f'{path}/init_conf_{i}.npy', c.detach().cpu().numpy()) |
|
return conf |
|
|
|
def save_rgb_imgs(self, path): |
|
imgs = self.imgs |
|
for i, img in enumerate(imgs): |
|
|
|
img = img[..., ::-1] |
|
cv2.imwrite(f'{path}/frame_{i:04d}.png', img*255) |
|
return imgs |
|
|
|
def save_dynamic_masks(self, path): |
|
dynamic_masks = self.dynamic_masks if getattr(self, 'sam2_dynamic_masks', None) is None else self.sam2_dynamic_masks |
|
for i, dynamic_mask in enumerate(dynamic_masks): |
|
cv2.imwrite(f'{path}/dynamic_mask_{i}.png', (dynamic_mask * 255).detach().cpu().numpy().astype(np.uint8)) |
|
return dynamic_masks |
|
|
|
def save_depth_maps(self, path): |
|
depth_maps = self.get_depthmaps() |
|
images = [] |
|
|
|
for i, depth_map in enumerate(depth_maps): |
|
|
|
depth_map_colored = cv2.applyColorMap((depth_map * 255).detach().cpu().numpy().astype(np.uint8), cv2.COLORMAP_JET) |
|
img_path = f'{path}/frame_{(i):04d}.png' |
|
cv2.imwrite(img_path, depth_map_colored) |
|
images.append(Image.open(img_path)) |
|
np.save(f'{path}/frame_{(i):04d}.npy', depth_map.detach().cpu().numpy()) |
|
|
|
images[0].save(f'{path}/_depth_maps.gif', save_all=True, append_images=images[1:], duration=100, loop=0) |
|
|
|
return depth_maps |
|
def forward(self, ret_details=False): |
|
pw_poses = self.get_pw_poses() |
|
pw_adapt = self.get_adaptors() |
|
proj_pts3d = self.get_pts3d() |
|
|
|
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} |
|
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} |
|
|
|
loss = 0 |
|
if ret_details: |
|
details = -torch.ones((self.n_imgs, self.n_imgs)) |
|
|
|
for e, (i, j) in enumerate(self.edges): |
|
i_j = edge_str(i, j) |
|
|
|
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) |
|
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) |
|
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() |
|
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() |
|
loss = loss + li + lj |
|
|
|
if ret_details: |
|
details[i, j] = li + lj |
|
loss /= self.n_edges |
|
|
|
if ret_details: |
|
return loss, details |
|
return loss |
|
|
|
@torch.cuda.amp.autocast(enabled=False) |
|
def compute_global_alignment(self, init=None, init_priors=None, niter_PnP=10, **kw): |
|
if init is None: |
|
pass |
|
elif init == 'msp' or init == 'mst': |
|
init_fun.init_minimum_spanning_tree(self, init_priors=init_priors, niter_PnP=niter_PnP) |
|
elif init == 'known_poses': |
|
init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, |
|
niter_PnP=niter_PnP) |
|
else: |
|
raise ValueError(f'bad value for {init=}') |
|
|
|
return global_alignment_loop(self, **kw) |
|
|
|
@torch.no_grad() |
|
def mask_sky(self): |
|
res = deepcopy(self) |
|
for i in range(self.n_imgs): |
|
sky = segment_sky(self.imgs[i]) |
|
res.im_conf[i][sky] = 0 |
|
return res |
|
|
|
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): |
|
viz = SceneViz() |
|
if self.imgs is None: |
|
colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) |
|
colors = list(map(tuple, colors.tolist())) |
|
for n in range(self.n_imgs): |
|
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) |
|
else: |
|
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) |
|
colors = np.random.randint(256, size=(self.n_imgs, 3)) |
|
|
|
|
|
im_poses = to_numpy(self.get_im_poses()) |
|
if cam_size is None: |
|
cam_size = auto_cam_size(im_poses) |
|
viz.add_cameras(im_poses, self.get_focals(), colors=colors, |
|
images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) |
|
if show_pw_cams: |
|
pw_poses = self.get_pw_poses() |
|
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) |
|
|
|
if show_pw_pts3d: |
|
pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] |
|
viz.add_pointcloud(pts, (128, 0, 128)) |
|
|
|
viz.show(**kw) |
|
return viz |
|
|
|
|
|
def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6): |
|
params = [p for p in net.parameters() if p.requires_grad] |
|
if not params: |
|
return net |
|
|
|
verbose = net.verbose |
|
if verbose: |
|
print('Global alignement - optimizing for:') |
|
print([name for name, value in net.named_parameters() if value.requires_grad]) |
|
|
|
lr_base = lr |
|
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) |
|
|
|
loss = float('inf') |
|
if verbose: |
|
with tqdm.tqdm(total=niter) as bar: |
|
while bar.n < bar.total: |
|
loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule) |
|
bar.set_postfix_str(f'{lr=:g} loss={loss:g}') |
|
bar.update() |
|
else: |
|
for n in range(niter): |
|
loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule) |
|
return loss |
|
|
|
|
|
def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): |
|
t = cur_iter / niter |
|
if schedule == 'cosine': |
|
lr = cosine_schedule(t, lr_base, lr_min) |
|
elif schedule == 'linear': |
|
lr = linear_schedule(t, lr_base, lr_min) |
|
else: |
|
raise ValueError(f'bad lr {schedule=}') |
|
adjust_learning_rate_by_lr(optimizer, lr) |
|
optimizer.zero_grad() |
|
loss = net() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
return float(loss), lr |
|
|
|
|
|
@torch.no_grad() |
|
def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, |
|
tol=0.001, bad_conf=0, dbg=()): |
|
""" Method: |
|
1) express all 3d points in each camera coordinate frame |
|
2) if they're in front of a depthmap --> then lower their confidence |
|
""" |
|
assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) |
|
assert 0 <= tol < 1 |
|
res = [c.clone() for c in im_confs] |
|
|
|
|
|
all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)] |
|
depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)] |
|
|
|
for i, pts3d in enumerate(all_pts3d): |
|
for j in range(len(all_pts3d)): |
|
if i == j: continue |
|
|
|
|
|
proj = geotrf(cams[j], pts3d) |
|
proj_depth = proj[:,:,2] |
|
u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) |
|
|
|
|
|
H, W = im_confs[j].shape |
|
msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) |
|
msk_j = v[msk_i], u[msk_i] |
|
|
|
|
|
bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j]) |
|
|
|
bad_msk_i = msk_i.clone() |
|
bad_msk_i[msk_i] = bad_points |
|
res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) |
|
|
|
return res |
|
|