import os import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from src.ASpanFormer.aspanformer import ASpanFormer from src.config.default import get_cfg_defaults from src.utils.misc import lower_config import demo_utils import cv2 import torch import numpy as np import argparse parser = argparse.ArgumentParser() parser.add_argument( "--config_path", type=str, default="../configs/aspan/outdoor/aspan_test.py", help="path for config file.", ) parser.add_argument( "--img0_path", type=str, default="../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg", help="path for image0.", ) parser.add_argument( "--img1_path", type=str, default="../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg", help="path for image1.", ) parser.add_argument( "--weights_path", type=str, default="../weights/outdoor.ckpt", help="path for model weights.", ) parser.add_argument( "--long_dim0", type=int, default=1024, help="resize for longest dim of image0." ) parser.add_argument( "--long_dim1", type=int, default=1024, help="resize for longest dim of image1." ) args = parser.parse_args() if __name__ == "__main__": config = get_cfg_defaults() config.merge_from_file(args.config_path) _config = lower_config(config) matcher = ASpanFormer(config=_config["aspan"]) state_dict = torch.load(args.weights_path, map_location="cpu")["state_dict"] matcher.load_state_dict(state_dict, strict=False) matcher.cuda(), matcher.eval() img0, img1 = cv2.imread(args.img0_path), cv2.imread(args.img1_path) img0_g, img1_g = cv2.imread(args.img0_path, 0), cv2.imread(args.img1_path, 0) img0, img1 = demo_utils.resize(img0, args.long_dim0), demo_utils.resize( img1, args.long_dim1 ) img0_g, img1_g = demo_utils.resize(img0_g, args.long_dim0), demo_utils.resize( img1_g, args.long_dim1 ) data = { "image0": torch.from_numpy(img0_g / 255.0)[None, None].cuda().float(), "image1": torch.from_numpy(img1_g / 255.0)[None, None].cuda().float(), } with torch.no_grad(): matcher(data, online_resize=True) corr0, corr1 = data["mkpts0_f"].cpu().numpy(), data["mkpts1_f"].cpu().numpy() F_hat, mask_F = cv2.findFundamentalMat( corr0, corr1, method=cv2.FM_RANSAC, ransacReprojThreshold=1 ) if mask_F is not None: mask_F = mask_F[:, 0].astype(bool) else: mask_F = np.zeros_like(corr0[:, 0]).astype(bool) # visualize match display = demo_utils.draw_match(img0, img1, corr0, corr1) display_ransac = demo_utils.draw_match(img0, img1, corr0[mask_F], corr1[mask_F]) cv2.imwrite("match.png", display) cv2.imwrite("match_ransac.png", display_ransac) print(len(corr1), len(corr1[mask_F]))