import sys from pathlib import Path import subprocess import logging import torch from PIL import Image from collections import OrderedDict, namedtuple from ..utils.base_model import BaseModel from ..utils import do_system sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet" sys.path.append(str(sgmnet_path)) from sgmnet import matcher as SGM_Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger = logging.getLogger(__name__) class SGMNet(BaseModel): default_conf = { "name": "SGM", "model_name": "model_best.pth", "seed_top_k": [256, 256], "seed_radius_coe": 0.01, "net_channels": 128, "layer_num": 9, "head": 4, "seedlayer": [0, 6], "use_mc_seeding": True, "use_score_encoding": False, "conf_bar": [1.11, 0.1], "sink_iter": [10, 100], "detach_iter": 1000000, "match_threshold": 0.2, } required_inputs = [ "image0", "image1", ] weight_urls = { "model_best.pth": "https://drive.google.com/uc?id=1Ca0WmKSSt2G6P7m8YAOlSAHEFar_TAWb&confirm=t", } proxy = "http://localhost:1080" # Initialize the line matcher def _init(self, conf): sgmnet_weights = sgmnet_path / "weights/sgm/root" / conf["model_name"] link = self.weight_urls[conf["model_name"]] tar_path = sgmnet_path / "weights.tar.gz" # Download the model. if not sgmnet_weights.exists(): if not tar_path.exists(): cmd = [ "gdown", link, "-O", str(tar_path), "--proxy", self.proxy, ] cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)] logger.info( f"Downloading the SGMNet model with `{cmd_wo_proxy}`." ) try: subprocess.run(cmd_wo_proxy, check=True) except subprocess.CalledProcessError as e: logger.info(f"Downloading the SGMNet model with `{cmd}`.") try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: logger.error(f"Failed to download the SGMNet model.") raise e cmd = [f"cd {str(sgmnet_path)} & tar -xvf", str(tar_path)] logger.info(f"Unzip model file `{cmd}`.") do_system(f"cd {str(sgmnet_path)} & tar -xvf {str(tar_path)}") # config config = namedtuple("config", conf.keys())(*conf.values()) self.net = SGM_Model(config) checkpoint = torch.load(sgmnet_weights, map_location="cpu") # for ddp model if ( list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module" ): new_stat_dict = OrderedDict() for key, value in checkpoint["state_dict"].items(): new_stat_dict[key[7:]] = value checkpoint["state_dict"] = new_stat_dict self.net.load_state_dict(checkpoint["state_dict"]) logger.info(f"Load SGMNet model done.") def _forward(self, data): x1 = data["keypoints0"].squeeze() # N x 2 x2 = data["keypoints1"].squeeze() score1 = data["scores0"].reshape(-1, 1) # N x 1 score2 = data["scores1"].reshape(-1, 1) desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128 desc2 = data["descriptors1"].permute(0, 2, 1) size1 = torch.tensor(data["image0"].shape[2:]).flip(0) # W x H -> x & y size2 = torch.tensor(data["image1"].shape[2:]).flip(0) # W x H norm_x1 = self.normalize_size(x1, size1) norm_x2 = self.normalize_size(x2, size2) x1 = torch.cat((norm_x1, score1), dim=-1) # N x 3 x2 = torch.cat((norm_x2, score2), dim=-1) input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2} input = { k: v.to(device).float() if isinstance(v, torch.Tensor) else v for k, v in input.items() } pred = self.net(input, test_mode=True) p = pred["p"] # shape: N * M indices0 = self.match_p(p[0, :-1, :-1]) pred = { "matches0": indices0.unsqueeze(0), "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0), } return pred def match_p(self, p): score, index = torch.topk(p, k=1, dim=-1) _, index2 = torch.topk(p, k=1, dim=-2) mask_th, index, index2 = ( score[:, 0] > self.conf["match_threshold"], index[:, 0], index2.squeeze(0), ) mask_mc = index2[index] == torch.arange(len(p)).to(device) mask = mask_th & mask_mc indices0 = torch.where(mask, index, index.new_tensor(-1)) return indices0 def normalize_size(self, x, size, scale=1): norm_fac = size.max() return (x - size / 2 + 0.5) / (norm_fac * scale)