import shutil import tempfile from pathlib import Path from typing import Any, Dict, List import pycolmap from hloc import ( extract_features, logger, match_features, pairs_from_retrieval, reconstruction, visualization, ) from .viz import fig2im class SfmEngine: def __init__(self, cfg: Dict[str, Any] = None): self.cfg = cfg if "outputs" in cfg and Path(cfg["outputs"]): outputs = Path(cfg["outputs"]) outputs.mkdir(parents=True, exist_ok=True) else: outputs = tempfile.mkdtemp() self.outputs = Path(outputs) def call( self, key: str, images: Path, camera_model: str, camera_params: List[float], max_keypoints: int, keypoint_threshold: float, match_threshold: float, ransac_threshold: int, ransac_confidence: float, ransac_max_iter: int, scene_graph: bool, global_feature: str, top_k: int = 10, mapper_refine_focal_length: bool = False, mapper_refine_principle_points: bool = False, mapper_refine_extra_params: bool = False, ): """ Call a list of functions to perform feature extraction, matching, and reconstruction. Args: key (str): The key to retrieve the matcher and feature models. images (Path): The directory containing the images. outputs (Path): The directory to store the outputs. camera_model (str): The camera model. camera_params (List[float]): The camera parameters. max_keypoints (int): The maximum number of features. match_threshold (float): The match threshold. ransac_threshold (int): The RANSAC threshold. ransac_confidence (float): The RANSAC confidence. ransac_max_iter (int): The maximum number of RANSAC iterations. scene_graph (bool): Whether to compute the scene graph. global_feature (str): Whether to compute the global feature. top_k (int): The number of image-pair to use. mapper_refine_focal_length (bool): Whether to refine the focal length. mapper_refine_principle_points (bool): Whether to refine the principle points. mapper_refine_extra_params (bool): Whether to refine the extra parameters. Returns: Path: The directory containing the SfM results. """ if len(images) == 0: logger.error(f"{images} does not exist.") temp_images = Path(tempfile.mkdtemp()) # copy images logger.info(f"Copying images to {temp_images}.") for image in images: shutil.copy(image, temp_images) matcher_zoo = self.cfg["matcher_zoo"] model = matcher_zoo[key] match_conf = model["matcher"] match_conf["model"]["max_keypoints"] = max_keypoints match_conf["model"]["match_threshold"] = match_threshold feature_conf = model["feature"] feature_conf["model"]["max_keypoints"] = max_keypoints feature_conf["model"]["keypoint_threshold"] = keypoint_threshold # retrieval retrieval_name = self.cfg.get("retrieval_name", "netvlad") retrieval_conf = extract_features.confs[retrieval_name] mapper_options = { "ba_refine_extra_params": mapper_refine_extra_params, "ba_refine_focal_length": mapper_refine_focal_length, "ba_refine_principal_point": mapper_refine_principle_points, "ba_local_max_num_iterations": 40, "ba_local_max_refinements": 3, "ba_global_max_num_iterations": 100, # below 3 options are for individual/video data, for internet photos, they should be left # default "min_focal_length_ratio": 0.1, "max_focal_length_ratio": 10, "max_extra_param": 1e15, } sfm_dir = self.outputs / "sfm_{}".format(key) sfm_pairs = self.outputs / "pairs-sfm.txt" sfm_dir.mkdir(exist_ok=True, parents=True) # extract features retrieval_path = extract_features.main( retrieval_conf, temp_images, self.outputs ) pairs_from_retrieval.main(retrieval_path, sfm_pairs, num_matched=top_k) feature_path = extract_features.main( feature_conf, temp_images, self.outputs ) # match features match_path = match_features.main( match_conf, sfm_pairs, feature_conf["output"], self.outputs ) # reconstruction already_sfm = False if sfm_dir.exists(): try: model = pycolmap.Reconstruction(str(sfm_dir)) already_sfm = True except ValueError: logger.info(f"sfm_dir not exists model: {sfm_dir}") if not already_sfm: model = reconstruction.main( sfm_dir, temp_images, sfm_pairs, feature_path, match_path, mapper_options=mapper_options, ) vertices = [] for point3D_id, point3D in model.points3D.items(): vertices.append([point3D.xyz, point3D.color]) model_3d = sfm_dir / "points3D.obj" with open(model_3d, "w") as f: for p, c in vertices: # Write vertex position f.write("v {} {} {}\n".format(p[0], p[1], p[2])) # Write vertex normal (color) f.write( "vn {} {} {}\n".format( c[0] / 255.0, c[1] / 255.0, c[2] / 255.0 ) ) viz_2d = visualization.visualize_sfm_2d( model, temp_images, color_by="visibility", n=2, dpi=300 ) return model_3d, fig2im(viz_2d) / 255.0