''' Manually passing scale to COTR, skip the scale difference estimation. ''' import argparse import os import time import cv2 import numpy as np import torch import imageio from scipy.spatial import distance_matrix import matplotlib.pyplot as plt from COTR.utils import utils, debug_utils from COTR.models import build_model from COTR.options.options import * from COTR.options.options_utils import * from COTR.inference.sparse_engine import SparseEngine utils.fix_randomness(0) torch.set_grad_enabled(False) def main(opt): model = build_model(opt) model = model.cuda() weights = torch.load(opt.load_weights_path)['model_state_dict'] utils.safe_load_weights(model, weights) model = model.eval() img_a = imageio.imread('./sample_data/imgs/petrzin_01.png') img_b = imageio.imread('./sample_data/imgs/petrzin_02.png') img_a_area = 1.0 img_b_area = 1.0 gt_corrs = np.loadtxt('./sample_data/petrzin_pts.txt') kp_a = gt_corrs[:, :2] kp_b = gt_corrs[:, 2:] engine = SparseEngine(model, 32, mode='tile') t0 = time.time() corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.75, 0.1, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True, areas=[img_a_area, img_b_area]) t1 = time.time() print(f'COTR spent {t1-t0} seconds.') utils.visualize_corrs(img_a, img_b, corrs) plt.imshow(img_b) plt.scatter(kp_b[:,0], kp_b[:,1]) plt.scatter(corrs[:,2], corrs[:,3]) plt.plot(np.stack([kp_b[:,0], corrs[:,2]], axis=1).T, np.stack([kp_b[:,1], corrs[:,3]], axis=1).T, color=[1,0,0]) plt.show() if __name__ == "__main__": parser = argparse.ArgumentParser() set_COTR_arguments(parser) parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') opt = parser.parse_args() opt.command = ' '.join(sys.argv) layer_2_channels = {'layer1': 256, 'layer2': 512, 'layer3': 1024, 'layer4': 2048, } opt.dim_feedforward = layer_2_channels[opt.layer] if opt.load_weights: opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') print_opt(opt) main(opt)