Spaces:
Running
Running
import os | |
import sys | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.insert(0, ROOT_DIR) | |
from src.ASpanFormer.aspanformer import ASpanFormer | |
from src.config.default import get_cfg_defaults | |
from src.utils.misc import lower_config | |
import demo_utils | |
import cv2 | |
import torch | |
import numpy as np | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config_path", | |
type=str, | |
default="../configs/aspan/outdoor/aspan_test.py", | |
help="path for config file.", | |
) | |
parser.add_argument( | |
"--img0_path", | |
type=str, | |
default="../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg", | |
help="path for image0.", | |
) | |
parser.add_argument( | |
"--img1_path", | |
type=str, | |
default="../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg", | |
help="path for image1.", | |
) | |
parser.add_argument( | |
"--weights_path", | |
type=str, | |
default="../weights/outdoor.ckpt", | |
help="path for model weights.", | |
) | |
parser.add_argument( | |
"--long_dim0", type=int, default=1024, help="resize for longest dim of image0." | |
) | |
parser.add_argument( | |
"--long_dim1", type=int, default=1024, help="resize for longest dim of image1." | |
) | |
args = parser.parse_args() | |
if __name__ == "__main__": | |
config = get_cfg_defaults() | |
config.merge_from_file(args.config_path) | |
_config = lower_config(config) | |
matcher = ASpanFormer(config=_config["aspan"]) | |
state_dict = torch.load(args.weights_path, map_location="cpu")["state_dict"] | |
matcher.load_state_dict(state_dict, strict=False) | |
matcher.cuda(), matcher.eval() | |
img0, img1 = cv2.imread(args.img0_path), cv2.imread(args.img1_path) | |
img0_g, img1_g = cv2.imread(args.img0_path, 0), cv2.imread(args.img1_path, 0) | |
img0, img1 = demo_utils.resize(img0, args.long_dim0), demo_utils.resize( | |
img1, args.long_dim1 | |
) | |
img0_g, img1_g = demo_utils.resize(img0_g, args.long_dim0), demo_utils.resize( | |
img1_g, args.long_dim1 | |
) | |
data = { | |
"image0": torch.from_numpy(img0_g / 255.0)[None, None].cuda().float(), | |
"image1": torch.from_numpy(img1_g / 255.0)[None, None].cuda().float(), | |
} | |
with torch.no_grad(): | |
matcher(data, online_resize=True) | |
corr0, corr1 = data["mkpts0_f"].cpu().numpy(), data["mkpts1_f"].cpu().numpy() | |
F_hat, mask_F = cv2.findFundamentalMat( | |
corr0, corr1, method=cv2.FM_RANSAC, ransacReprojThreshold=1 | |
) | |
if mask_F is not None: | |
mask_F = mask_F[:, 0].astype(bool) | |
else: | |
mask_F = np.zeros_like(corr0[:, 0]).astype(bool) | |
# visualize match | |
display = demo_utils.draw_match(img0, img1, corr0, corr1) | |
display_ransac = demo_utils.draw_match(img0, img1, corr0[mask_F], corr1[mask_F]) | |
cv2.imwrite("match.png", display) | |
cv2.imwrite("match_ransac.png", display_ransac) | |
print(len(corr1), len(corr1[mask_F])) | |