xuelunshen's picture
Print timestamp for debug
0b88271
raw
history blame
6.02 kB
import os
import torch
import subprocess
from pathlib import Path
from ..utils.base_model import BaseModel
from .. import logger
from .networks.dkm.models.model_zoo.DKMv3 import DKMv3
weight_path = Path(__file__).parent / 'networks' / 'dkm'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class GIM(BaseModel):
default_conf = {
"model_name": "gim_dkm_100h.ckpt",
"match_threshold": 0.2,
"checkpoint_dir": weight_path,
}
required_inputs = [
"image0",
"image1",
]
# Models exported using
# dkm_models = {
# "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
# "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
# }
def _init(self, conf):
model_path = weight_path / conf["model_name"]
# Download the model.
if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
link = self.dkm_models[conf["model_name"]]
cmd = ["wget", link, "-O", str(model_path)]
logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
subprocess.run(cmd, check=True)
# logger.info(f"Loading GIM model...")
# self.net = DKMv3(path_to_weights=str(model_path), device=device)
model = DKMv3(None, 672, 896, upsample_preds=True)
checkpoints_path = str(model_path)
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
if 'encoder.net.fc' in k:
state_dict.pop(k)
model.load_state_dict(state_dict)
self.net = model
def _forward(self, data):
# img0 = data["image0"].cpu().numpy().squeeze() * 255
# img1 = data["image1"].cpu().numpy().squeeze() * 255
# img0 = img0.transpose(1, 2, 0)
# img1 = img1.transpose(1, 2, 0)
# img0 = Image.fromarray(img0.astype("uint8"))
# img1 = Image.fromarray(img1.astype("uint8"))
# W_A, H_A = img0.size
# W_B, H_B = img1.size
#
# warp, certainty = self.net.match(img0, img1, device=device)
# matches, certainty = self.net.sample(warp, certainty)
# kpts1, kpts2 = self.net.to_pixel_coordinates(
# matches, H_A, W_A, H_B, W_B
# )
image0, image1 = data['image0'], data['image1']
orig_width0 = image0.shape[3]
orig_height0 = image0.shape[2]
orig_width1 = image1.shape[3]
orig_height1 = image1.shape[2]
aspect_ratio = 896 / 672
new_width0 = max(orig_width0, int(orig_height0 * aspect_ratio))
new_height0 = max(orig_height0, int(orig_width0 / aspect_ratio))
new_width1 = max(orig_width1, int(orig_height1 * aspect_ratio))
new_height1 = max(orig_height1, int(orig_width1 / aspect_ratio))
new_width = max(new_width0, new_width1)
new_height = max(new_height0, new_height1)
pad_height0 = new_height - orig_height0
pad_width0 = new_width - orig_width0
pad_height1 = new_height - orig_height1
pad_width1 = new_width - orig_width1
pad_top0 = pad_height0 // 2
pad_bottom0 = pad_height0 - pad_top0
pad_left0 = pad_width0 // 2
pad_right0 = pad_width0 - pad_left0
pad_top1 = pad_height1 // 2
pad_bottom1 = pad_height1 - pad_top1
pad_left1 = pad_width1 // 2
pad_right1 = pad_width1 - pad_left1
image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0))
image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1))
import datetime
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match start')
dense_matches, dense_certainty = self.net.match(image0, image1)
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match end')
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample start')
sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, self.conf["max_keypoints"])
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample end')
height0, width0 = image0.shape[-2:]
height1, width1 = image1.shape[-2:]
kpts0 = sparse_matches[:, :2]
kpts1 = sparse_matches[:, 2:]
kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, )
kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, )
b_ids, i_ids = torch.where(mconf[None])
# before padding
kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None]
kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None]
mask = (kpts0[:, 0] > 0) & \
(kpts0[:, 1] > 0) & \
(kpts1[:, 0] > 0) & \
(kpts1[:, 1] > 0)
mask = mask & \
(kpts0[:, 0] <= (orig_width0 - 1)) & \
(kpts1[:, 0] <= (orig_width1 - 1)) & \
(kpts0[:, 1] <= (orig_height0 - 1)) & \
(kpts1[:, 1] <= (orig_height1 - 1))
pred = {
'keypoints0': kpts0[i_ids],
'keypoints1': kpts1[i_ids],
'confidence': mconf[i_ids],
'batch_indexes': b_ids,
}
scores, b_ids = pred['confidence'], pred['batch_indexes']
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask]
pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask]
out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']}
return out