import os import torch import subprocess from pathlib import Path from ..utils.base_model import BaseModel from .. import logger from .networks.dkm.models.model_zoo.DKMv3 import DKMv3 weight_path = Path(__file__).parent / 'networks' / 'dkm' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class GIM(BaseModel): default_conf = { "model_name": "gim_dkm_100h.ckpt", "match_threshold": 0.2, "checkpoint_dir": weight_path, } required_inputs = [ "image0", "image1", ] # Models exported using # dkm_models = { # "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth", # "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth", # } def _init(self, conf): model_path = weight_path / conf["model_name"] # Download the model. if not model_path.exists(): model_path.parent.mkdir(exist_ok=True) link = self.dkm_models[conf["model_name"]] cmd = ["wget", link, "-O", str(model_path)] logger.info(f"Downloading the DKMv3 model with `{cmd}`.") subprocess.run(cmd, check=True) logger.info(f"Loading GIM model...") # self.net = DKMv3(path_to_weights=str(model_path), device=device) model = DKMv3(None, 672, 896, upsample_preds=True) checkpoints_path = str(model_path) state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('model.'): state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) if 'encoder.net.fc' in k: state_dict.pop(k) model.load_state_dict(state_dict) self.net = model def _forward(self, data): # img0 = data["image0"].cpu().numpy().squeeze() * 255 # img1 = data["image1"].cpu().numpy().squeeze() * 255 # img0 = img0.transpose(1, 2, 0) # img1 = img1.transpose(1, 2, 0) # img0 = Image.fromarray(img0.astype("uint8")) # img1 = Image.fromarray(img1.astype("uint8")) # W_A, H_A = img0.size # W_B, H_B = img1.size # # warp, certainty = self.net.match(img0, img1, device=device) # matches, certainty = self.net.sample(warp, certainty) # kpts1, kpts2 = self.net.to_pixel_coordinates( # matches, H_A, W_A, H_B, W_B # ) image0, image1 = data['image0'], data['image1'] orig_width0 = image0.shape[3] orig_height0 = image0.shape[2] orig_width1 = image1.shape[3] orig_height1 = image1.shape[2] aspect_ratio = 896 / 672 new_width0 = max(orig_width0, int(orig_height0 * aspect_ratio)) new_height0 = max(orig_height0, int(orig_width0 / aspect_ratio)) new_width1 = max(orig_width1, int(orig_height1 * aspect_ratio)) new_height1 = max(orig_height1, int(orig_width1 / aspect_ratio)) new_width = max(new_width0, new_width1) new_height = max(new_height0, new_height1) pad_height0 = new_height - orig_height0 pad_width0 = new_width - orig_width0 pad_height1 = new_height - orig_height1 pad_width1 = new_width - orig_width1 pad_top0 = pad_height0 // 2 pad_bottom0 = pad_height0 - pad_top0 pad_left0 = pad_width0 // 2 pad_right0 = pad_width0 - pad_left0 pad_top1 = pad_height1 // 2 pad_bottom1 = pad_height1 - pad_top1 pad_left1 = pad_width1 // 2 pad_right1 = pad_width1 - pad_left1 image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0)) image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1)) dense_matches, dense_certainty = self.net.match(image0, image1) sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, self.conf["max_keypoints"]) height0, width0 = image0.shape[-2:] height1, width1 = image1.shape[-2:] kpts0 = sparse_matches[:, :2] kpts1 = sparse_matches[:, 2:] kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, ) kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, ) b_ids, i_ids = torch.where(mconf[None]) # before padding kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None] kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None] mask = (kpts0[:, 0] > 0) & \ (kpts0[:, 1] > 0) & \ (kpts1[:, 0] > 0) & \ (kpts1[:, 1] > 0) mask = mask & \ (kpts0[:, 0] <= (orig_width0 - 1)) & \ (kpts1[:, 0] <= (orig_width1 - 1)) & \ (kpts0[:, 1] <= (orig_height0 - 1)) & \ (kpts1[:, 1] <= (orig_height1 - 1)) pred = { 'keypoints0': kpts0[i_ids], 'keypoints1': kpts1[i_ids], 'confidence': mconf[i_ids], 'batch_indexes': b_ids, } scores, b_ids = pred['confidence'], pred['batch_indexes'] kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask] pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask] out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']} return out