import torch import numpy as np import os from collections import OrderedDict, namedtuple import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from sgmnet import matcher as SGM_Model from superglue import matcher as SG_Model from utils import evaluation_utils class GNN_Matcher(object): def __init__(self, config, model_name): assert model_name == "SGM" or model_name == "SG" config = namedtuple("config", config.keys())(*config.values()) self.p_th = config.p_th self.model = SGM_Model(config) if model_name == "SGM" else SG_Model(config) self.model.cuda(), self.model.eval() checkpoint = torch.load(os.path.join(config.model_dir, "model_best.pth")) # for ddp model if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module": new_stat_dict = OrderedDict() for key, value in checkpoint["state_dict"].items(): new_stat_dict[key[7:]] = value checkpoint["state_dict"] = new_stat_dict self.model.load_state_dict(checkpoint["state_dict"]) def run(self, test_data): norm_x1, norm_x2 = evaluation_utils.normalize_size( test_data["x1"][:, :2], test_data["size1"] ), evaluation_utils.normalize_size(test_data["x2"][:, :2], test_data["size2"]) x1, x2 = np.concatenate( [norm_x1, test_data["x1"][:, 2, np.newaxis]], axis=-1 ), np.concatenate([norm_x2, test_data["x2"][:, 2, np.newaxis]], axis=-1) feed_data = { "x1": torch.from_numpy(x1[np.newaxis]).cuda().float(), "x2": torch.from_numpy(x2[np.newaxis]).cuda().float(), "desc1": torch.from_numpy(test_data["desc1"][np.newaxis]).cuda().float(), "desc2": torch.from_numpy(test_data["desc2"][np.newaxis]).cuda().float(), } with torch.no_grad(): res = self.model(feed_data, test_mode=True) p = res["p"] index1, index2 = self.match_p(p[0, :-1, :-1]) corr1, corr2 = ( test_data["x1"][:, :2][index1.cpu()], test_data["x2"][:, :2][index2.cpu()], ) if len(corr1.shape) == 1: corr1, corr2 = corr1[np.newaxis], corr2[np.newaxis] return corr1, corr2 def match_p(self, p): # p N*M score, index = torch.topk(p, k=1, dim=-1) _, index2 = torch.topk(p, k=1, dim=-2) mask_th, index, index2 = score[:, 0] > self.p_th, index[:, 0], index2.squeeze(0) mask_mc = index2[index] == torch.arange(len(p)).cuda() mask = mask_th & mask_mc index1, index2 = torch.nonzero(mask).squeeze(1), index[mask] return index1, index2 class NN_Matcher(object): def __init__(self, config): config = namedtuple("config", config.keys())(*config.values()) self.mutual_check = config.mutual_check self.ratio_th = config.ratio_th def run(self, test_data): desc1, desc2, x1, x2 = ( test_data["desc1"], test_data["desc2"], test_data["x1"], test_data["x2"], ) desc_mat = np.sqrt( abs( (desc1**2).sum(-1)[:, np.newaxis] + (desc2**2).sum(-1)[np.newaxis] - 2 * desc1 @ desc2.T ) ) nn_index = np.argpartition(desc_mat, kth=(1, 2), axis=-1) dis_value12 = np.take_along_axis(desc_mat, nn_index, axis=-1) ratio_score = dis_value12[:, 0] / dis_value12[:, 1] nn_index1 = nn_index[:, 0] nn_index2 = np.argmin(desc_mat, axis=0) mask_ratio, mask_mutual = ( ratio_score < self.ratio_th, np.arange(len(x1)) == nn_index2[nn_index1], ) corr1, corr2 = x1[:, :2], x2[:, :2][nn_index1] if self.mutual_check: mask = mask_ratio & mask_mutual else: mask = mask_ratio corr1, corr2 = corr1[mask], corr2[mask] return corr1, corr2