gim-online / hloc /matchers /aspanformer.py
Vincentqyw
update: limit keypoints number
60ad158
raw
history blame
3.56 kB
import sys
import torch
from ..utils.base_model import BaseModel
from ..utils import do_system
from pathlib import Path
import subprocess
import logging
logger = logging.getLogger(__name__)
sys.path.append(str(Path(__file__).parent / "../../third_party"))
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
from ASpanFormer.src.config.default import get_cfg_defaults
from ASpanFormer.src.utils.misc import lower_config
from ASpanFormer.demo import demo_utils
aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
class ASpanFormer(BaseModel):
default_conf = {
"weights": "outdoor",
"match_threshold": 0.2,
"sinkhorn_iterations": 20,
"config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
"model_name": "weights_aspanformer.tar",
}
required_inputs = ["image0", "image1"]
proxy = "http://localhost:1080"
aspanformer_models = {
"weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
}
def _init(self, conf):
model_path = (
aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
)
# Download the model.
if not model_path.exists():
# model_path.parent.mkdir(exist_ok=True)
tar_path = aspanformer_path / conf["model_name"]
if not tar_path.exists():
link = self.aspanformer_models[conf["model_name"]]
cmd = [
"gdown",
link,
"-O",
str(tar_path),
"--proxy",
self.proxy,
]
cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
logger.info(
f"Downloading the Aspanformer model with `{cmd_wo_proxy}`."
)
try:
subprocess.run(cmd_wo_proxy, check=True)
except subprocess.CalledProcessError as e:
logger.info(
f"Downloading the Aspanformer model with `{cmd}`."
)
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logger.error(
f"Failed to download the Aspanformer model."
)
raise e
do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")
logger.info(f"Loading Aspanformer model...")
config = get_cfg_defaults()
config.merge_from_file(conf["config_path"])
_config = lower_config(config)
# update: match threshold
_config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"]
_config["aspan"]["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
self.net = _ASpanFormer(config=_config["aspan"])
weight_path = model_path
state_dict = torch.load(str(weight_path), map_location="cpu")[
"state_dict"
]
self.net.load_state_dict(state_dict, strict=False)
def _forward(self, data):
data_ = {
"image0": data["image0"],
"image1": data["image1"],
}
self.net(data_, online_resize=True)
corr0 = data_["mkpts0_f"]
corr1 = data_["mkpts1_f"]
pred = {}
pred["keypoints0"], pred["keypoints1"] = corr0, corr1
return pred