import argparse from argparse import Namespace from typing import Any, Dict, Optional, Tuple, Type import cv2 import torch from numpy import ndarray from omegaconf import OmegaConf from torch.nn import Module from src.data.utils import load_image from src.models.nets import CasP from src.models.utils import make_matching_figure matcher_configs = { "casp_outdoor": { "matcher": CasP, "name": "casp", "ckpt_path": "weights/casp_outdoor.pth", }, "casp_minima": { "matcher": CasP, "name": "casp", "ckpt_path": "weights/casp_minima.pth", }, } def load_matcher( matcher: Type[Module], name: str, ckpt_path: str, threshold: Optional[float] = None, device: str = "cpu", ) -> Module: config = OmegaConf.load(f"configs/model/net/{name}.yaml").config if threshold is not None: config.threshold = threshold matcher = matcher(config) matcher.load_state_dict(torch.load(ckpt_path)) matcher = matcher.eval().to(device) return matcher def parse_args() -> Namespace: parser = argparse.ArgumentParser(description="CasP") parser.add_argument("--path0", type=str, required=True) parser.add_argument("--path1", type=str, required=True) parser.add_argument("--save_path", type=str, required=True) parser.add_argument("--image_size", type=int, default=1152) parser.add_argument( "--method", type=str, default="casp_outdoor", choices=["casp_outdoor", "casp_minima"], ) parser.add_argument("--matching_threshold", type=float) parser.add_argument( "--ransac", type=str, choices=["fundamental", "homography"] ) parser.add_argument( "--estimator", type=str, default="CV2_USAC_MAGSAC", choices=["CV2_RANSAC", "CV2_USAC_MAGSAC"], ) parser.add_argument("--inlier_threshold", type=float, default=3.0) args = parser.parse_args() return args def ransac_optimize( points0: ndarray, points1: ndarray, model: str, estimator: str, threshold: float, ) -> Tuple[ndarray, ndarray]: if model == "fundamental": func = cv2.findFundamentalMat elif model == "homography": func = cv2.findHomography else: raise NotImplementedError() if estimator == "CV2_RANSAC": method = cv2.RANSAC elif estimator == "CV2_USAC_MAGSAC": method = cv2.USAC_MAGSAC else: raise NotImplementedError() mat, inlier_mask = func( points0, points1, method=method, ransacReprojThreshold=threshold, confidence=0.99999, maxIters=10000, ) return mat, inlier_mask def main( args: Dict[str, Any], ) -> Tuple[ndarray, ndarray, ndarray, Optional[ndarray]]: device = "cuda" if torch.cuda.is_available() else "cpu" matcher = load_matcher( **matcher_configs[args["method"]], threshold=args["matching_threshold"], device=device, ) data_configs = { "mode": matcher.data_mode, "size": args["image_size"], "factor": matcher.data_factor, } image0, mask0, scale0 = load_image(args["path0"], **data_configs) image1, mask1, scale1 = load_image(args["path1"], **data_configs) if matcher.data_mode == "gray": image0, image1 = image0[None] / 255.0, image1[None] / 255.0 elif matcher.data_mode == "color": image0 = image0.transpose(2, 0, 1) / 255.0 image1 = image1.transpose(2, 0, 1) / 255.0 else: raise ValueError() data = { "image0": image0[None], "image1": image1[None], "scale0": scale0[None], "scale1": scale1[None], } if mask0 is not None: data["mask0"] = mask0[None] if mask1 is not None: data["mask1"] = mask0[None] for key, value in data.items(): data[key] = torch.from_numpy(value).float().to(device) with torch.no_grad(): results = matcher(data) points0 = results["points0"].cpu().numpy() points1 = results["points1"].cpu().numpy() scores = results["scores"].cpu().numpy() inlier_mask = None if args["ransac"] is not None: _, inlier_mask = ransac_optimize( points0, points1, args["ransac"], args["estimator"], args["inlier_threshold"], ) inlier_mask = inlier_mask.ravel() == 1 return points0, points1, scores, inlier_mask if __name__ == "__main__": args = dict(vars(parse_args())) points0, points1, scores, inlier_mask = main(args) if inlier_mask is not None: points0, points1, scores = [ t[inlier_mask] for t in [points0, points1, scores] ] errors = 1 - scores text = [args["method"], f"#matches: {len(points0)}"] make_matching_figure( args["path0"], args["path1"], points0, points1, errors, 0.5, dpi=300, save_path=args["save_path"], )