Spaces:
Running
Running
File size: 8,572 Bytes
4c88343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import argparse
import pickle
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Union
import numpy as np
import pycolmap
from tqdm import tqdm
from . import logger
from .utils.io import get_keypoints, get_matches
from .utils.parsers import parse_image_lists, parse_retrieval
def do_covisibility_clustering(
frame_ids: List[int], reconstruction: pycolmap.Reconstruction
):
clusters = []
visited = set()
for frame_id in frame_ids:
# Check if already labeled
if frame_id in visited:
continue
# New component
clusters.append([])
queue = {frame_id}
while len(queue):
exploration_frame = queue.pop()
# Already part of the component
if exploration_frame in visited:
continue
visited.add(exploration_frame)
clusters[-1].append(exploration_frame)
observed = reconstruction.images[exploration_frame].points2D
connected_frames = {
obs.image_id
for p2D in observed
if p2D.has_point3D()
for obs in reconstruction.points3D[p2D.point3D_id].track.elements
}
connected_frames &= set(frame_ids)
connected_frames -= visited
queue |= connected_frames
clusters = sorted(clusters, key=len, reverse=True)
return clusters
class QueryLocalizer:
def __init__(self, reconstruction, config=None):
self.reconstruction = reconstruction
self.config = config or {}
def localize(self, points2D_all, points2D_idxs, points3D_id, query_camera):
points2D = points2D_all[points2D_idxs]
points3D = [self.reconstruction.points3D[j].xyz for j in points3D_id]
ret = pycolmap.absolute_pose_estimation(
points2D,
points3D,
query_camera,
estimation_options=self.config.get("estimation", {}),
refinement_options=self.config.get("refinement", {}),
)
return ret
def pose_from_cluster(
localizer: QueryLocalizer,
qname: str,
query_camera: pycolmap.Camera,
db_ids: List[int],
features_path: Path,
matches_path: Path,
**kwargs,
):
kpq = get_keypoints(features_path, qname)
kpq += 0.5 # COLMAP coordinates
kp_idx_to_3D = defaultdict(list)
kp_idx_to_3D_to_db = defaultdict(lambda: defaultdict(list))
num_matches = 0
for i, db_id in enumerate(db_ids):
image = localizer.reconstruction.images[db_id]
if image.num_points3D == 0:
logger.debug(f"No 3D points found for {image.name}.")
continue
points3D_ids = np.array(
[p.point3D_id if p.has_point3D() else -1 for p in image.points2D]
)
matches, _ = get_matches(matches_path, qname, image.name)
matches = matches[points3D_ids[matches[:, 1]] != -1]
num_matches += len(matches)
for idx, m in matches:
id_3D = points3D_ids[m]
kp_idx_to_3D_to_db[idx][id_3D].append(i)
# avoid duplicate observations
if id_3D not in kp_idx_to_3D[idx]:
kp_idx_to_3D[idx].append(id_3D)
idxs = list(kp_idx_to_3D.keys())
mkp_idxs = [i for i in idxs for _ in kp_idx_to_3D[i]]
mp3d_ids = [j for i in idxs for j in kp_idx_to_3D[i]]
ret = localizer.localize(kpq, mkp_idxs, mp3d_ids, query_camera, **kwargs)
if ret is not None:
ret["camera"] = query_camera
# mostly for logging and post-processing
mkp_to_3D_to_db = [
(j, kp_idx_to_3D_to_db[i][j]) for i in idxs for j in kp_idx_to_3D[i]
]
log = {
"db": db_ids,
"PnP_ret": ret,
"keypoints_query": kpq[mkp_idxs],
"points3D_ids": mp3d_ids,
"points3D_xyz": None, # we don't log xyz anymore because of file size
"num_matches": num_matches,
"keypoint_index_to_db": (mkp_idxs, mkp_to_3D_to_db),
}
return ret, log
def main(
reference_sfm: Union[Path, pycolmap.Reconstruction],
queries: Path,
retrieval: Path,
features: Path,
matches: Path,
results: Path,
ransac_thresh: int = 12,
covisibility_clustering: bool = False,
prepend_camera_name: bool = False,
config: Dict = None,
):
assert retrieval.exists(), retrieval
assert features.exists(), features
assert matches.exists(), matches
queries = parse_image_lists(queries, with_intrinsics=True)
retrieval_dict = parse_retrieval(retrieval)
logger.info("Reading the 3D model...")
if not isinstance(reference_sfm, pycolmap.Reconstruction):
reference_sfm = pycolmap.Reconstruction(reference_sfm)
db_name_to_id = {img.name: i for i, img in reference_sfm.images.items()}
config = {"estimation": {"ransac": {"max_error": ransac_thresh}}, **(config or {})}
localizer = QueryLocalizer(reference_sfm, config)
cam_from_world = {}
logs = {
"features": features,
"matches": matches,
"retrieval": retrieval,
"loc": {},
}
logger.info("Starting localization...")
for qname, qcam in tqdm(queries):
if qname not in retrieval_dict:
logger.warning(f"No images retrieved for query image {qname}. Skipping...")
continue
db_names = retrieval_dict[qname]
db_ids = []
for n in db_names:
if n not in db_name_to_id:
logger.warning(f"Image {n} was retrieved but not in database")
continue
db_ids.append(db_name_to_id[n])
if covisibility_clustering:
clusters = do_covisibility_clustering(db_ids, reference_sfm)
best_inliers = 0
best_cluster = None
logs_clusters = []
for i, cluster_ids in enumerate(clusters):
ret, log = pose_from_cluster(
localizer, qname, qcam, cluster_ids, features, matches
)
if ret is not None and ret["num_inliers"] > best_inliers:
best_cluster = i
best_inliers = ret["num_inliers"]
logs_clusters.append(log)
if best_cluster is not None:
ret = logs_clusters[best_cluster]["PnP_ret"]
cam_from_world[qname] = ret["cam_from_world"]
logs["loc"][qname] = {
"db": db_ids,
"best_cluster": best_cluster,
"log_clusters": logs_clusters,
"covisibility_clustering": covisibility_clustering,
}
else:
ret, log = pose_from_cluster(
localizer, qname, qcam, db_ids, features, matches
)
if ret is not None:
cam_from_world[qname] = ret["cam_from_world"]
else:
closest = reference_sfm.images[db_ids[0]]
cam_from_world[qname] = closest.cam_from_world
log["covisibility_clustering"] = covisibility_clustering
logs["loc"][qname] = log
logger.info(f"Localized {len(cam_from_world)} / {len(queries)} images.")
logger.info(f"Writing poses to {results}...")
with open(results, "w") as f:
for query, t in cam_from_world.items():
qvec = " ".join(map(str, t.rotation.quat[[3, 0, 1, 2]]))
tvec = " ".join(map(str, t.translation))
name = query.split("/")[-1]
if prepend_camera_name:
name = query.split("/")[-2] + "/" + name
f.write(f"{name} {qvec} {tvec}\n")
logs_path = f"{results}_logs.pkl"
logger.info(f"Writing logs to {logs_path}...")
# TODO: Resolve pickling issue with pycolmap objects.
with open(logs_path, "wb") as f:
pickle.dump(logs, f)
logger.info("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--reference_sfm", type=Path, required=True)
parser.add_argument("--queries", type=Path, required=True)
parser.add_argument("--features", type=Path, required=True)
parser.add_argument("--matches", type=Path, required=True)
parser.add_argument("--retrieval", type=Path, required=True)
parser.add_argument("--results", type=Path, required=True)
parser.add_argument("--ransac_thresh", type=float, default=12.0)
parser.add_argument("--covisibility_clustering", action="store_true")
parser.add_argument("--prepend_camera_name", action="store_true")
args = parser.parse_args()
main(**args.__dict__)
|