# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Initialization functions for global alignment # -------------------------------------------------------- from functools import cache import numpy as np import scipy.sparse as sp import torch import cv2 import roma from tqdm import tqdm from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses from dust3r.post_process import estimate_focal_knowing_depth from dust3r.viz import to_numpy from dust3r.cloud_opt_flow.commons import edge_str, i_j_ij, compute_edge_scores import matplotlib.pyplot as plt import seaborn as sns def draw_edge_scores_map(edge_scores, save_path, n_imgs=None): # Determine the size of the heatmap if n_imgs is None: n_imgs = max(max(edge) for edge in edge_scores) + 1 # Create a matrix to hold the scores heatmap_matrix = np.full((n_imgs, n_imgs), np.nan) # Populate the matrix with the edge scores for (i, j), score in edge_scores.items(): heatmap_matrix[i, j] = score # Plotting the heatmap plt.figure(figsize=(int(5.5*np.log(n_imgs)-2), int((5.5*np.log(n_imgs)-2) * 3 / 4))) sns.heatmap(heatmap_matrix, annot=True, fmt=".1f", cmap="viridis", cbar=True, annot_kws={"fontsize": int(-4.2*np.log(n_imgs)+22.4)}) plt.title("Heatmap of Edge Scores") plt.xlabel("Node") plt.ylabel("Node") plt.savefig(save_path) @torch.no_grad() def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3): device = self.device # indices of known poses nkp, known_poses_msk, known_poses = get_known_poses(self) # assert nkp == self.n_imgs, 'not all poses are known' # get all focals nkf, _, im_focals = get_known_focals(self) # assert nkf == self.n_imgs im_pp = self.get_principal_points() best_depthmaps = {} # init all pairwise poses for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)): i_j = edge_str(i, j) # find relative pose for this pair P1 = torch.eye(4, device=device) msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1) _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()), pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP) # align the two predicted camera with the two gt cameras s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]]) # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1 # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3]) self._set_pose(self.pw_poses, e, R, T, scale=s) # remember if this is a good depthmap score = float(self.conf_i[i_j].mean()) if score > best_depthmaps.get(i, (0,))[0]: best_depthmaps[i] = score, i_j, s # init all image poses for n in range(self.n_imgs): # assert known_poses_msk[n] if n in best_depthmaps: _, i_j, scale = best_depthmaps[n] depth = self.pred_i[i_j][:, :, 2] self._set_depthmap(n, depth * scale) @torch.no_grad() def init_minimum_spanning_tree(self, save_score_path=None, save_score_only=False,init_priors=None, **kw): """ Init all camera poses (image-wise and pairwise poses) given an initial set of pairwise estimations. """ device = self.device if save_score_only: eadge_and_scores = compute_edge_scores(map(i_j_ij, self.edges), self.conf_i, self.conf_j) draw_edge_scores_map(eadge_and_scores, save_score_path) return pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges, self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr, device, has_im_poses=self.has_im_poses, verbose=self.verbose, init_priors = init_priors, save_score_path=save_score_path, **kw) return init_from_pts3d(self, pts3d, im_focals, im_poses) def init_from_pts3d(self, pts3d, im_focals, im_poses): # init poses nkp, known_poses_msk, known_poses = get_known_poses(self) if nkp == 1: raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose") elif nkp > 1: # global rigid SE3 alignment s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk]) trf = sRT_to_4x4(s, R, T, device=known_poses.device) # rotate everything im_poses = trf @ im_poses im_poses[:, :3, :3] /= s # undo scaling on the rotation part for img_pts3d in pts3d: img_pts3d[:] = geotrf(trf, img_pts3d) else: pass # no known poses # set all pairwise poses for e, (i, j) in enumerate(self.edges): i_j = edge_str(i, j) # compute transform that goes from cam to world s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) self._set_pose(self.pw_poses, e, R, T, scale=s) # take into account the scale normalization s_factor = self.get_pw_norm_scale_factor() im_poses[:, :3, 3] *= s_factor # apply downscaling factor for img_pts3d in pts3d: img_pts3d *= s_factor # init all image poses if self.has_im_poses: for i in range(self.n_imgs): cam2world = im_poses[i] depth = geotrf(inv(cam2world), pts3d[i])[..., 2] self._set_depthmap(i, depth) self._set_pose(self.im_poses, i, cam2world) if im_focals[i] is not None: if not self.shared_focal: self._set_focal(i, im_focals[i]) if self.shared_focal: self._set_focal(0, sum(im_focals) / self.n_imgs) if self.n_imgs > 2: self._set_init_depthmap() if self.verbose: with torch.no_grad(): print(' init loss =', float(self())) def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr, device, init_priors=None, has_im_poses=True, niter_PnP=10, verbose=True, save_score_path=None): n_imgs = len(imshapes) eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j) sparse_graph = -dict_to_sparse_graph(eadge_and_scores) msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() # temp variable to store 3d points pts3d = [None] * len(imshapes) todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges im_poses = [None] * n_imgs im_focals = [None] * n_imgs # init with specific edge score, i, j = None, None, None if init_priors is None: score, i, j = todo.pop() else: while todo: score, i, j = todo.pop() if i == 0 or j == 0: break else: todo.insert(0, (score, i, j)) if verbose: print(f' init edge ({i}*,{j}*) {score=}') if save_score_path is not None: draw_edge_scores_map(eadge_and_scores, save_score_path, n_imgs=n_imgs) save_tree_path = save_score_path.replace(".png", "_tree.txt") with open(save_tree_path, "w") as f: f.write(f'init edge ({i}*,{j}*) {score=}\n') i_j = edge_str(i, j) pts3d[i] = pred_i[i_j].clone() # the first one is set to be world coordinate pts3d[j] = pred_j[i_j].clone() done = {i, j} if has_im_poses: if init_priors is None: im_poses[i] = torch.eye(4, device=device) im_focals[i] = estimate_focal(pred_i[i_j]) else: init_keypose = np.array(init_priors[0]).astype(np.float32) init_keyfocal = init_priors[2][0] if i == 0: im_poses[i] = torch.from_numpy(init_keypose).to(device) im_focals[i] = float(init_keyfocal) pts3d[i] = geotrf(im_poses[i], pts3d[i]) pts3d[j] = geotrf(im_poses[i], pts3d[j]) elif j == 0: im_poses[j] = torch.from_numpy(init_keypose).to(device) im_focals[j] = float(init_keyfocal) j_i = edge_str(j, i) pts3d[i] = geotrf(im_poses[j], pred_j[j_i].clone()) pts3d[j] = geotrf(im_poses[j], pred_i[j_i].clone()) # set initial pointcloud based on pairwise graph msp_edges = [(i, j)] while todo: # each time, predict the next one score, i, j = todo.pop() if im_focals[i] is None: im_focals[i] = estimate_focal(pred_i[i_j]) if i in done: # the first frame is already set, align the second frame with the first frame if verbose: print(f' init edge ({i},{j}*) {score=}') if save_score_path is not None: with open(save_tree_path, "a") as f: f.write(f'init edge ({i},{j}*) {score=}\n') assert j not in done # align pred[i] with pts3d[i], and then set j accordingly i_j = edge_str(i, j) s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) trf = sRT_to_4x4(s, R, T, device) pts3d[j] = geotrf(trf, pred_j[i_j]) done.add(j) msp_edges.append((i, j)) if has_im_poses and im_poses[i] is None: im_poses[i] = sRT_to_4x4(1, R, T, device) elif j in done: # the second frame is already set, align the first frame with the second frame if verbose: print(f' init edge ({i}*,{j}) {score=}') if save_score_path is not None: with open(save_tree_path, "a") as f: f.write(f'init edge ({i}*,{j}) {score=}\n') assert i not in done i_j = edge_str(i, j) s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) trf = sRT_to_4x4(s, R, T, device) pts3d[i] = geotrf(trf, pred_i[i_j]) done.add(i) msp_edges.append((i, j)) if has_im_poses and im_poses[i] is None: im_poses[i] = sRT_to_4x4(1, R, T, device) else: # let's try again later todo.insert(0, (score, i, j)) if has_im_poses: # complete all missing informations pair_scores = list(sparse_graph.values()) # already negative scores: less is best edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)] for i, j in edges_from_best_to_worse.tolist(): if im_focals[i] is None: im_focals[i] = estimate_focal(pred_i[edge_str(i, j)]) for i in range(n_imgs): if im_poses[i] is None: msk = im_conf[i] > min_conf_thr res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP) if res: im_focals[i], im_poses[i] = res if im_poses[i] is None: im_poses[i] = torch.eye(4, device=device) im_poses = torch.stack(im_poses) else: im_poses = im_focals = None return pts3d, msp_edges, im_focals, im_poses def dict_to_sparse_graph(dic): n_imgs = max(max(e) for e in dic) + 1 res = sp.dok_array((n_imgs, n_imgs)) for edge, value in dic.items(): res[edge] = value return res def rigid_points_registration(pts1, pts2, conf): R, T, s = roma.rigid_points_registration( pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True) return s, R, T # return un-scaled (R, T) def sRT_to_4x4(scale, R, T, device): trf = torch.eye(4, device=device) trf[:3, :3] = R * scale trf[:3, 3] = T.ravel() # doesn't need scaling return trf def estimate_focal(pts3d_i, pp=None): if pp is None: H, W, THREE = pts3d_i.shape assert THREE == 3 pp = torch.tensor((W/2, H/2), device=pts3d_i.device) focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel() return float(focal) @cache def pixel_grid(H, W): return np.mgrid[:W, :H].T.astype(np.float32) def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): # extract camera poses and focals with RANSAC-PnP if msk.sum() < 4: return None # we need at least 4 points for PnP pts3d, msk = map(to_numpy, (pts3d, msk)) H, W, THREE = pts3d.shape assert THREE == 3 pixels = pixel_grid(H, W) if focal is None: S = max(W, H) tentative_focals = np.geomspace(S/2, S*3, 21) else: tentative_focals = [focal] if pp is None: pp = (W/2, H/2) else: pp = to_numpy(pp) best = 0, for focal in tentative_focals: K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) if not success: continue score = len(inliers) if success and score > best[0]: best = score, R, T, focal if not best[0]: return None _, R, T, best_focal = best R = cv2.Rodrigues(R)[0] # world to cam R, T = map(torch.from_numpy, (R, T)) return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world def get_known_poses(self): if self.has_im_poses: known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses]) known_poses = self.get_im_poses() return known_poses_msk.sum(), known_poses_msk, known_poses else: return 0, None, None def get_known_focals(self): if self.has_im_poses: known_focal_msk = self.get_known_focal_mask() known_focals = self.get_focals() return known_focal_msk.sum(), known_focal_msk, known_focals else: return 0, None, None def align_multiple_poses(src_poses, target_poses): N = len(src_poses) assert src_poses.shape == target_poses.shape == (N, 4, 4) def center_and_z(poses): eps = get_med_dist_between_poses(poses) / 100 return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2])) R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True) return s, R, T