|
import torch |
|
from pathlib import Path |
|
import gdown, py3_wget |
|
|
|
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher |
|
from matching.utils import to_numpy, resize_to_divisible, add_to_path |
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("path/to/submodule")) |
|
|
|
from submodule import model, other_components |
|
|
|
|
|
class NewMatcher(BaseMatcher): |
|
weights_src = "url-to-weights-hosted-online" |
|
model_path = WEIGHTS_DIR.joinpath("my_weights.ckpt") |
|
divisible_size = 32 |
|
|
|
def __init__(self, device="cpu", *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
|
|
self.download_weights() |
|
|
|
self.matcher = model() |
|
|
|
self.matcher.load_state_dict(torch.load(self.model_path, map_location=torch.device("cpu"))["state_dict"]) |
|
|
|
def download_weights(self): |
|
|
|
if not Path(NewMatcher.model_path).is_file(): |
|
print("Downloading model... (takes a while)") |
|
|
|
|
|
gdown.download( |
|
NewMatcher.weights_src, |
|
output=str(NewMatcher.model_path), |
|
fuzzy=True, |
|
) |
|
|
|
|
|
py3_wget.download_file(NewMatcher.model_path, NewMatcher.model_path) |
|
|
|
def preprocess(self, img): |
|
_, h, w = img.shape |
|
orig_shape = h, w |
|
|
|
img = resize_to_divisible(img, self.divisible_size) |
|
|
|
|
|
img = img.unsqueeze(0) |
|
return img, orig_shape |
|
|
|
def _forward(self, img0, img1): |
|
img0, img0_orig_shape = self.preprocess(img0) |
|
img1, img1_orig_shape = self.preprocess(img1) |
|
|
|
batch = {"image0": img0, "image1": img1} |
|
|
|
|
|
output = self.matcher(batch) |
|
|
|
|
|
|
|
|
|
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:] |
|
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) |
|
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) |
|
|
|
return mkpts0, mkpts1, keypoints_0, keypoints_1, desc0, desc1 |
|
|