# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # MASt3R to colmap export functions # -------------------------------------------------------- import os import torch import copy import numpy as np import torchvision import numpy as np from tqdm import tqdm from scipy.cluster.hierarchy import DisjointSet from scipy.spatial.transform import Rotation as R from mast3r.utils.misc import hash_md5 from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns import mast3r.utils.path_to_dust3r # noqa from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf # noqa def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz): if viz: from matplotlib import pyplot as pl image_mean = torch.as_tensor( [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) image_std = torch.as_tensor( [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) rgb0 = img0['img'] * image_std + image_mean rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0]) rgb0 = np.array(rgb0) rgb1 = img1['img'] * image_std + image_mean rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0]) rgb1 = np.array(rgb1) imgs = [rgb0, rgb1] # visualize a few matches n_viz = 100 num_matches = matches_im0.shape[0] match_idx_to_viz = np.round(np.linspace( 0, num_matches - 1, n_viz)).astype(int) viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) img = np.concatenate((rgb0, rgb1), axis=1) pl.figure() pl.imshow(img) cmap = pl.get_cmap('jet') for ii in range(n_viz): (x0, y0), (x1, y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii / (n_viz - 1)), scalex=False, scaley=False) pl.show(block=True) matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)] imgs = [img0, img1] imidx0 = img0['idx'] imidx1 = img1['idx'] ravel_matches = [] for j in range(2): H, W = imgs[j]['true_shape'][0] with np.errstate(invalid='ignore'): qx, qy = matches[j].round().astype(np.int32).T ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy) ravel_matches.append(ravel_matches_j) imidxj = imgs[j]['idx'] for m in ravel_matches_j: if m not in im_keypoints[imidxj]: im_keypoints[imidxj][m] = 0 im_keypoints[imidxj][m] += 1 imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid']) imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid']) if imid0 > imid1: colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1) imid0, imid1 = imid1, imid0 imidx0, imidx1 = imidx1, imidx0 else: colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1) colmap_matches = np.unique(colmap_matches, axis=0) return imidx0, imidx1, colmap_matches def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr, is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'): im_matches = {} for i in range(len(pred1['pts3d'])): imidx0 = pairs[i][0]['idx'] imidx1 = pairs[i][1]['idx'] if 'desc' in pred1: # mast3r descs = [pred1['desc'][i], pred2['desc'][i]] confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]] desc_dim = descs[0].shape[-1] if is_sparse: corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1], device=device, subsample=subsample, pixel_tol=pixel_tol) conf = corres[2] mask = conf >= conf_thr matches_im0 = corres[0][mask].cpu().numpy() matches_im1 = corres[1][mask].cpu().numpy() else: confidence_masks = [confidences[0] >= conf_thr, confidences[1] >= conf_thr] pts2d_list, desc_list = [], [] for j in range(2): conf_j = confidence_masks[j].cpu().numpy().flatten() true_shape_j = pairs[i][j]['true_shape'][0] pts2d_j = xy_grid( true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] desc_j = descs[j].detach().cpu( ).numpy().reshape(-1, desc_dim)[conf_j] pts2d_list.append(pts2d_j) desc_list.append(desc_j) if len(desc_list[0]) == 0 or len(desc_list[1]) == 0: continue nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1], device=device, dist='dot', block_size=2**13) reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0))) matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0] matches_im0 = pts2d_list[0][reciprocal_in_P0] else: pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]] confidences = [pred1['conf'][i], pred2['conf'][i]] if is_sparse: corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1], device=device, subsample=subsample, pixel_tol=pixel_tol, ptmap_key='3d') conf = corres[2] mask = conf >= conf_thr matches_im0 = corres[0][mask].cpu().numpy() matches_im1 = corres[1][mask].cpu().numpy() else: confidence_masks = [confidences[0] >= conf_thr, confidences[1] >= conf_thr] # find 2D-2D matches between the two images pts2d_list, pts3d_list = [], [] for j in range(2): conf_j = confidence_masks[j].cpu().numpy().flatten() true_shape_j = pairs[i][j]['true_shape'][0] pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j] pts2d_list.append(pts2d_j) pts3d_list.append(pts3d_j) PQ, PM = pts3d_list[0], pts3d_list[1] if len(PQ) == 0 or len(PM) == 0: continue reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches( PQ, PM) matches_im1 = pts2d_list[1][reciprocal_in_PM] matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] if len(matches_im0) == 0: continue imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], image_to_colmap, im_keypoints, matches_im0, matches_im1, viz) im_matches[(imidx0, imidx1)] = colmap_matches return im_matches def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample, image_to_colmap, im_keypoints, conf_thr, viz=False, device='cuda'): im_matches = {} for i in range(len(pairs)): imidx0 = pairs[i][0]['idx'] imidx1 = pairs[i][1]['idx'] corres_idx1 = hash_md5(pairs[i][0]['instance']) corres_idx2 = hash_md5(pairs[i][1]['instance']) path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth' if os.path.isfile(path_corres): score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device) else: path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth' score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device) mask = confs >= conf_thr matches_im0 = xy1[mask].cpu().numpy() matches_im1 = xy2[mask].cpu().numpy() if len(matches_im0) == 0: continue imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], image_to_colmap, im_keypoints, matches_im0, matches_im1, viz) im_matches[(imidx0, imidx1)] = colmap_matches return im_matches def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model): # add cameras/images to the db # with the output of ga as prior image_to_colmap = {} im_keypoints = {} for idx in range(len(image_paths)): im_keypoints[idx] = {} H, W = images[idx]["orig_shape"] if focals is None: focal_x = focal_y = 1.2 * max(W, H) prior_focal_length = False cx = W / 2.0 cy = H / 2.0 elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2: # intrinsics focal_x = focals[idx][0, 0] focal_y = focals[idx][1, 1] cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0] cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1] prior_focal_length = True else: focal_x = focal_y = float(focals[idx]) prior_focal_length = True cx = W / 2.0 cy = H / 2.0 focal_x = focal_x * images[idx]["to_orig"][0, 0] focal_y = focal_y * images[idx]["to_orig"][1, 1] if camera_model == "SIMPLE_PINHOLE": model_id = 0 focal = (focal_x + focal_y) / 2.0 params = np.asarray([focal, cx, cy], np.float64) elif camera_model == "PINHOLE": model_id = 1 params = np.asarray([focal_x, focal_y, cx, cy], np.float64) elif camera_model == "SIMPLE_RADIAL": model_id = 2 focal = (focal_x + focal_y) / 2.0 params = np.asarray([focal, cx, cy, 0.0], np.float64) elif camera_model == "OPENCV": model_id = 4 params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64) else: raise ValueError(f"invalid camera model {camera_model}") H, W = int(H), int(W) # OPENCV camera model camid = db.add_camera( model_id, W, H, params, prior_focal_length=prior_focal_length) if ga_world_to_cam is None: prior_t = np.zeros(3) prior_q = np.zeros(4) else: q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat() prior_t = ga_world_to_cam[idx][:3, 3] prior_q = np.array([q[-1], q[0], q[1], q[2]]) imid = db.add_image( image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t) image_to_colmap[idx] = { 'colmap_imid': imid, 'colmap_camid': camid } return image_to_colmap, im_keypoints def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification): colmap_image_pairs = [] # 2D-2D are quite dense # we want to remove the very small tracks # and export only kpt for which we have values # build tracks print("building tracks") keypoints_to_track_id = {} track_id_to_kpt_list = [] to_merge = [] for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()): if imidx0 not in keypoints_to_track_id: keypoints_to_track_id[imidx0] = {} if imidx1 not in keypoints_to_track_id: keypoints_to_track_id[imidx1] = {} for m in colmap_matches: if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]: # new pair of kpts never seen before track_idx = len(track_id_to_kpt_list) keypoints_to_track_id[imidx0][m[0]] = track_idx keypoints_to_track_id[imidx1][m[1]] = track_idx track_id_to_kpt_list.append( [(imidx0, m[0]), (imidx1, m[1])]) elif m[1] not in keypoints_to_track_id[imidx1]: # 0 has a track, not 1 track_idx = keypoints_to_track_id[imidx0][m[0]] keypoints_to_track_id[imidx1][m[1]] = track_idx track_id_to_kpt_list[track_idx].append((imidx1, m[1])) elif m[0] not in keypoints_to_track_id[imidx0]: # 1 has a track, not 0 track_idx = keypoints_to_track_id[imidx1][m[1]] keypoints_to_track_id[imidx0][m[0]] = track_idx track_id_to_kpt_list[track_idx].append((imidx0, m[0])) else: # both have tracks, merge them track_idx0 = keypoints_to_track_id[imidx0][m[0]] track_idx1 = keypoints_to_track_id[imidx1][m[1]] if track_idx0 != track_idx1: # let's deal with them later to_merge.append((track_idx0, track_idx1)) # regroup merge targets print("merging tracks") unique = np.unique(to_merge) tree = DisjointSet(unique) for track_idx0, track_idx1 in tqdm(to_merge): tree.merge(track_idx0, track_idx1) subsets = tree.subsets() print("applying merge") for setvals in tqdm(subsets): new_trackid = len(track_id_to_kpt_list) kpt_list = [] for track_idx in setvals: kpt_list.extend(track_id_to_kpt_list[track_idx]) for imidx, kpid in track_id_to_kpt_list[track_idx]: keypoints_to_track_id[imidx][kpid] = new_trackid track_id_to_kpt_list.append(kpt_list) # binc = np.bincount([len(v) for v in track_id_to_kpt_list]) # nonzero = np.nonzero(binc) # nonzerobinc = binc[nonzero[0]] # print(nonzero[0].tolist()) # print(nonzerobinc) num_valid_tracks = sum( [1 for v in track_id_to_kpt_list if len(v) >= min_len_track]) keypoints_to_idx = {} print(f"squashing keypoints - {num_valid_tracks} valid tracks") for imidx, keypoints_imid in tqdm(im_keypoints.items()): imid = image_to_colmap[imidx]['colmap_imid'] keypoints_kept = [] keypoints_to_idx[imidx] = {} for kp in keypoints_imid.keys(): if kp not in keypoints_to_track_id[imidx]: continue track_idx = keypoints_to_track_id[imidx][kp] track_length = len(track_id_to_kpt_list[track_idx]) if track_length < min_len_track: continue keypoints_to_idx[imidx][kp] = len(keypoints_kept) keypoints_kept.append(kp) if len(keypoints_kept) == 0: continue keypoints_kept = np.array(keypoints_kept) keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[ 0].base[:, ::-1].copy().astype(np.float32) # rescale coordinates keypoints_kept[:, 0] += 0.5 keypoints_kept[:, 1] += 0.5 keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True) H, W = images[imidx]['orig_shape'] keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01) keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01) db.add_keypoints(imid, keypoints_kept) print("exporting im_matches") for (imidx0, imidx1), colmap_matches in im_matches.items(): imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid'] assert imid0 < imid1 final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]] for m in colmap_matches if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]]) if len(final_matches) > 0: colmap_image_pairs.append( (images[imidx0]['instance'], images[imidx1]['instance'])) db.add_matches(imid0, imid1, final_matches) if skip_geometric_verification: db.add_two_view_geometry(imid0, imid1, final_matches) return colmap_image_pairs