File size: 2,579 Bytes
0a82b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
from pathlib import Path
import gdown
import torchvision.transforms as tfm
import tarfile
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
from matching.utils import to_numpy, resize_to_divisible, lower_config, add_to_path
BASE_PATH = THIRD_PARTY_DIR.joinpath("aspanformer")
add_to_path(BASE_PATH)
from src.ASpanFormer.aspanformer import ASpanFormer
from src.config.default import get_cfg_defaults as aspan_cfg_defaults
class AspanformerMatcher(BaseMatcher):
weights_src = "https://drive.google.com/file/d/1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k/view"
weights_path = WEIGHTS_DIR.joinpath("aspanformer", "weights", "outdoor.ckpt")
divisible_size = 32
def __init__(self, device="cpu", **kwargs):
super().__init__(device, **kwargs)
self.download_weights()
config = aspan_cfg_defaults()
config.merge_from_file(BASE_PATH.joinpath("configs", "aspan", "outdoor", "aspan_test.py"))
self.matcher = ASpanFormer(config=lower_config(config)["aspan"])
self.matcher.load_state_dict(
torch.load(self.weights_path, map_location=self.device)["state_dict"], strict=False
)
self.matcher = self.matcher.to(device).eval()
def download_weights(self):
if not Path(self.weights_path).is_file():
print("Downloading Aspanformer outdoor... (takes a while)")
gdown.download(
self.weights_src,
output=str(WEIGHTS_DIR.joinpath("weights_aspanformer.tar")),
fuzzy=True,
)
tar = tarfile.open(WEIGHTS_DIR.joinpath("weights_aspanformer.tar"))
weights_subdir = WEIGHTS_DIR.joinpath("aspanformer")
weights_subdir.mkdir(exist_ok=True)
tar.extractall(weights_subdir)
tar.close()
def preprocess(self, img):
_, h, w = img.shape
orig_shape = h, w
img = resize_to_divisible(img, self.divisible_size)
return tfm.Grayscale()(img).unsqueeze(0), orig_shape
def _forward(self, img0, img1):
img0, img0_orig_shape = self.preprocess(img0)
img1, img1_orig_shape = self.preprocess(img1)
batch = {"image0": img0, "image1": img1}
self.matcher(batch)
mkpts0 = to_numpy(batch["mkpts0_f"])
mkpts1 = to_numpy(batch["mkpts1_f"])
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:]
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0)
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)
return mkpts0, mkpts1, None, None, None, None
|