Spaces:
Running
Running
import os | |
import torch | |
import subprocess | |
from pathlib import Path | |
from ..utils.base_model import BaseModel | |
from .. import logger | |
from .networks.dkm.models.model_zoo.DKMv3 import DKMv3 | |
weight_path = Path(__file__).parent / 'networks' / 'dkm' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class GIM(BaseModel): | |
default_conf = { | |
"model_name": "gim_dkm_100h.ckpt", | |
"match_threshold": 0.2, | |
"checkpoint_dir": weight_path, | |
} | |
required_inputs = [ | |
"image0", | |
"image1", | |
] | |
# Models exported using | |
# dkm_models = { | |
# "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth", | |
# "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth", | |
# } | |
def _init(self, conf): | |
model_path = weight_path / conf["model_name"] | |
# Download the model. | |
if not model_path.exists(): | |
model_path.parent.mkdir(exist_ok=True) | |
link = self.dkm_models[conf["model_name"]] | |
cmd = ["wget", link, "-O", str(model_path)] | |
logger.info(f"Downloading the DKMv3 model with `{cmd}`.") | |
subprocess.run(cmd, check=True) | |
# logger.info(f"Loading GIM model...") | |
# self.net = DKMv3(path_to_weights=str(model_path), device=device) | |
model = DKMv3(None, 672, 896, upsample_preds=True) | |
checkpoints_path = str(model_path) | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
for k in list(state_dict.keys()): | |
if k.startswith('model.'): | |
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
if 'encoder.net.fc' in k: | |
state_dict.pop(k) | |
model.load_state_dict(state_dict) | |
self.net = model | |
def _forward(self, data): | |
# img0 = data["image0"].cpu().numpy().squeeze() * 255 | |
# img1 = data["image1"].cpu().numpy().squeeze() * 255 | |
# img0 = img0.transpose(1, 2, 0) | |
# img1 = img1.transpose(1, 2, 0) | |
# img0 = Image.fromarray(img0.astype("uint8")) | |
# img1 = Image.fromarray(img1.astype("uint8")) | |
# W_A, H_A = img0.size | |
# W_B, H_B = img1.size | |
# | |
# warp, certainty = self.net.match(img0, img1, device=device) | |
# matches, certainty = self.net.sample(warp, certainty) | |
# kpts1, kpts2 = self.net.to_pixel_coordinates( | |
# matches, H_A, W_A, H_B, W_B | |
# ) | |
image0, image1 = data['image0'], data['image1'] | |
orig_width0 = image0.shape[3] | |
orig_height0 = image0.shape[2] | |
orig_width1 = image1.shape[3] | |
orig_height1 = image1.shape[2] | |
aspect_ratio = 896 / 672 | |
new_width0 = max(orig_width0, int(orig_height0 * aspect_ratio)) | |
new_height0 = max(orig_height0, int(orig_width0 / aspect_ratio)) | |
new_width1 = max(orig_width1, int(orig_height1 * aspect_ratio)) | |
new_height1 = max(orig_height1, int(orig_width1 / aspect_ratio)) | |
new_width = max(new_width0, new_width1) | |
new_height = max(new_height0, new_height1) | |
pad_height0 = new_height - orig_height0 | |
pad_width0 = new_width - orig_width0 | |
pad_height1 = new_height - orig_height1 | |
pad_width1 = new_width - orig_width1 | |
pad_top0 = pad_height0 // 2 | |
pad_bottom0 = pad_height0 - pad_top0 | |
pad_left0 = pad_width0 // 2 | |
pad_right0 = pad_width0 - pad_left0 | |
pad_top1 = pad_height1 // 2 | |
pad_bottom1 = pad_height1 - pad_top1 | |
pad_left1 = pad_width1 // 2 | |
pad_right1 = pad_width1 - pad_left1 | |
image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0)) | |
image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1)) | |
import datetime | |
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match start') | |
dense_matches, dense_certainty = self.net.match(image0, image1) | |
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match end') | |
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample start') | |
sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, self.conf["max_keypoints"]) | |
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample end') | |
height0, width0 = image0.shape[-2:] | |
height1, width1 = image1.shape[-2:] | |
kpts0 = sparse_matches[:, :2] | |
kpts1 = sparse_matches[:, 2:] | |
kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, ) | |
kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, ) | |
b_ids, i_ids = torch.where(mconf[None]) | |
# before padding | |
kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None] | |
kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None] | |
mask = (kpts0[:, 0] > 0) & \ | |
(kpts0[:, 1] > 0) & \ | |
(kpts1[:, 0] > 0) & \ | |
(kpts1[:, 1] > 0) | |
mask = mask & \ | |
(kpts0[:, 0] <= (orig_width0 - 1)) & \ | |
(kpts1[:, 0] <= (orig_width1 - 1)) & \ | |
(kpts0[:, 1] <= (orig_height0 - 1)) & \ | |
(kpts1[:, 1] <= (orig_height1 - 1)) | |
pred = { | |
'keypoints0': kpts0[i_ids], | |
'keypoints1': kpts1[i_ids], | |
'confidence': mconf[i_ids], | |
'batch_indexes': b_ids, | |
} | |
scores, b_ids = pred['confidence'], pred['batch_indexes'] | |
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] | |
pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask] | |
pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask] | |
out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']} | |
return out | |