Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Author : xuelun | |
import cv2 | |
import torch | |
import argparse | |
import warnings | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torchvision.transforms.functional as F | |
from os.path import join | |
from dkm.models.model_zoo.DKMv3 import DKMv3 | |
from gluefactory.superpoint import SuperPoint | |
from gluefactory.models.matchers.lightglue import LightGlue | |
DEFAULT_MIN_NUM_MATCHES = 4 | |
DEFAULT_RANSAC_MAX_ITER = 10000 | |
DEFAULT_RANSAC_CONFIDENCE = 0.999 | |
DEFAULT_RANSAC_REPROJ_THRESHOLD = 8 | |
DEFAULT_RANSAC_METHOD = "USAC_MAGSAC" | |
RANSAC_ZOO = { | |
"RANSAC": cv2.RANSAC, | |
"USAC_FAST": cv2.USAC_FAST, | |
"USAC_MAGSAC": cv2.USAC_MAGSAC, | |
"USAC_PROSAC": cv2.USAC_PROSAC, | |
"USAC_DEFAULT": cv2.USAC_DEFAULT, | |
"USAC_FM_8PTS": cv2.USAC_FM_8PTS, | |
"USAC_ACCURATE": cv2.USAC_ACCURATE, | |
"USAC_PARALLEL": cv2.USAC_PARALLEL, | |
} | |
def read_image(path, grayscale=False): | |
if grayscale: | |
mode = cv2.IMREAD_GRAYSCALE | |
else: | |
mode = cv2.IMREAD_COLOR | |
image = cv2.imread(str(path), mode) | |
if image is None: | |
raise ValueError(f'Cannot read image {path}.') | |
if not grayscale and len(image.shape) == 3: | |
image = image[:, :, ::-1] # BGR to RGB | |
return image | |
def resize_image(image, size, interp): | |
assert interp.startswith('cv2_') | |
if interp.startswith('cv2_'): | |
interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper()) | |
h, w = image.shape[:2] | |
if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): | |
interp = cv2.INTER_LINEAR | |
resized = cv2.resize(image, size, interpolation=interp) | |
# elif interp.startswith('pil_'): | |
# interp = getattr(PIL.Image, interp[len('pil_'):].upper()) | |
# resized = PIL.Image.fromarray(image.astype(np.uint8)) | |
# resized = resized.resize(size, resample=interp) | |
# resized = np.asarray(resized, dtype=image.dtype) | |
else: | |
raise ValueError( | |
f'Unknown interpolation {interp}.') | |
return resized | |
def fast_make_matching_figure(data, b_id): | |
color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) | |
color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) | |
gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) | |
gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) | |
kpts0 = data['mkpts0_f'].cpu().detach().numpy() | |
kpts1 = data['mkpts1_f'].cpu().detach().numpy() | |
mconf = data['mconf'].cpu().detach().numpy() | |
inliers = data['inliers'] | |
rows = 2 | |
margin = 2 | |
(h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] | |
h = max(h0, h1) | |
H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 | |
# canvas | |
out = 255 * np.ones((H, W), np.uint8) | |
wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] | |
hx = lambda row: margin * row + h * (row-1) | |
out = np.stack([out] * 3, -1) | |
sh = hx(row=1) | |
out[sh: sh + h0, wx[0]: wx[1]] = color0 | |
out[sh: sh + h1, wx[2]: wx[3]] = color1 | |
sh = hx(row=2) | |
out[sh: sh + h0, wx[0]: wx[1]] = color0 | |
out[sh: sh + h1, wx[2]: wx[3]] = color1 | |
mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) | |
for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): | |
c = (0, 255, 0) | |
cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) | |
cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) | |
return out | |
def fast_make_matching_overlay(data, b_id): | |
color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) | |
color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) | |
gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) | |
gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) | |
kpts0 = data['mkpts0_f'].cpu().detach().numpy() | |
kpts1 = data['mkpts1_f'].cpu().detach().numpy() | |
mconf = data['mconf'].cpu().detach().numpy() | |
inliers = data['inliers'] | |
rows = 2 | |
margin = 2 | |
(h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] | |
h = max(h0, h1) | |
H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 | |
# canvas | |
out = 255 * np.ones((H, W), np.uint8) | |
wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] | |
hx = lambda row: margin * row + h * (row-1) | |
out = np.stack([out] * 3, -1) | |
sh = hx(row=1) | |
out[sh: sh + h0, wx[0]: wx[1]] = color0 | |
out[sh: sh + h1, wx[2]: wx[3]] = color1 | |
sh = hx(row=2) | |
out[sh: sh + h0, wx[0]: wx[1]] = color0 | |
out[sh: sh + h1, wx[2]: wx[3]] = color1 | |
mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) | |
for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): | |
c = (0, 255, 0) | |
cv2.line(out, (x0, y0 + sh), (x1 + margin + w0, y1 + sh), color=c, thickness=1, lineType=cv2.LINE_AA) | |
cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) | |
cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) | |
return out | |
def preprocess(image: np.ndarray, grayscale: bool = False, resize_max: int = None, | |
dfactor: int = 8): | |
image = image.astype(np.float32, copy=False) | |
size = image.shape[:2][::-1] | |
scale = np.array([1.0, 1.0]) | |
if resize_max: | |
scale = resize_max / max(size) | |
if scale < 1.0: | |
size_new = tuple(int(round(x*scale)) for x in size) | |
image = resize_image(image, size_new, 'cv2_area') | |
scale = np.array(size) / np.array(size_new) | |
if grayscale: | |
assert image.ndim == 2, image.shape | |
image = image[None] | |
else: | |
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | |
image = torch.from_numpy(image / 255.0).float() | |
# assure that the size is divisible by dfactor | |
size_new = tuple(map( | |
lambda x: int(x // dfactor * dfactor), | |
image.shape[-2:])) | |
image = F.resize(image, size=size_new) | |
scale = np.array(size) / np.array(size_new)[::-1] | |
return image, scale | |
def compute_geom(data, | |
ransac_method=DEFAULT_RANSAC_METHOD, | |
ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD, | |
ransac_confidence=DEFAULT_RANSAC_CONFIDENCE, | |
ransac_max_iter=DEFAULT_RANSAC_MAX_ITER, | |
) -> dict: | |
mkpts0 = data["mkpts0_f"].cpu().detach().numpy() | |
mkpts1 = data["mkpts1_f"].cpu().detach().numpy() | |
if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES: | |
return {} | |
h1, w1 = data["hw0_i"] | |
geo_info = {} | |
F, inliers = cv2.findFundamentalMat( | |
mkpts0, | |
mkpts1, | |
method=RANSAC_ZOO[ransac_method], | |
ransacReprojThreshold=ransac_reproj_threshold, | |
confidence=ransac_confidence, | |
maxIters=ransac_max_iter, | |
) | |
if F is not None: | |
geo_info["Fundamental"] = F.tolist() | |
H, _ = cv2.findHomography( | |
mkpts1, | |
mkpts0, | |
method=RANSAC_ZOO[ransac_method], | |
ransacReprojThreshold=ransac_reproj_threshold, | |
confidence=ransac_confidence, | |
maxIters=ransac_max_iter, | |
) | |
if H is not None: | |
geo_info["Homography"] = H.tolist() | |
_, H1, H2 = cv2.stereoRectifyUncalibrated( | |
mkpts0.reshape(-1, 2), | |
mkpts1.reshape(-1, 2), | |
F, | |
imgSize=(w1, h1), | |
) | |
geo_info["H1"] = H1.tolist() | |
geo_info["H2"] = H2.tolist() | |
return geo_info | |
def wrap_images(img0, img1, geo_info, geom_type): | |
img0 = img0[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] | |
img1 = img1[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] | |
h1, w1, _ = img0.shape | |
h2, w2, _ = img1.shape | |
rectified_image0 = img0 | |
rectified_image1 = None | |
H = np.array(geo_info["Homography"]) | |
F = np.array(geo_info["Fundamental"]) | |
title = [] | |
if geom_type == "Homography": | |
rectified_image1 = cv2.warpPerspective( | |
img1, H, (img0.shape[1], img0.shape[0]) | |
) | |
title = ["Image 0", "Image 1 - warped"] | |
elif geom_type == "Fundamental": | |
H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"]) | |
rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1)) | |
rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2)) | |
title = ["Image 0 - warped", "Image 1 - warped"] | |
else: | |
print("Error: Unknown geometry type") | |
fig = plot_images( | |
[rectified_image0.squeeze(), rectified_image1.squeeze()], | |
title, | |
dpi=300, | |
) | |
img = fig2im(fig) | |
plt.close(fig) | |
return img | |
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5): | |
"""Plot a set of images horizontally. | |
Args: | |
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). | |
titles: a list of strings, as titles for each image. | |
cmaps: colormaps for monochrome images. | |
dpi: | |
size: | |
pad: | |
""" | |
n = len(imgs) | |
if not isinstance(cmaps, (list, tuple)): | |
cmaps = [cmaps] * n | |
figsize = (size * n, size * 6 / 5) if size is not None else None | |
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) | |
if n == 1: | |
ax = [ax] | |
for i in range(n): | |
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) | |
ax[i].get_yaxis().set_ticks([]) | |
ax[i].get_xaxis().set_ticks([]) | |
ax[i].set_axis_off() | |
for spine in ax[i].spines.values(): # remove frame | |
spine.set_visible(False) | |
if titles: | |
ax[i].set_title(titles[i]) | |
fig.tight_layout(pad=pad) | |
return fig | |
def fig2im(fig): | |
fig.canvas.draw() | |
w, h = fig.canvas.get_width_height() | |
buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1") | |
im = buf_ndarray.reshape(h, w, 3) | |
return im | |
if __name__ == '__main__': | |
model_zoo = ['gim_dkm', 'gim_lightglue'] | |
# model | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model', type=str, default='gim_dkm', choices=model_zoo) | |
args = parser.parse_args() | |
# device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# load model | |
ckpt = None | |
model = None | |
detector = None | |
if args.model == 'gim_dkm': | |
ckpt = 'gim_dkm_100h.ckpt' | |
model = DKMv3(weights=None, h=672, w=896) | |
elif args.model == 'gim_lightglue': | |
ckpt = 'gim_lightglue_100h.ckpt' | |
detector = SuperPoint({ | |
'max_num_keypoints': 2048, | |
'force_num_keypoints': True, | |
'detection_threshold': 0.0, | |
'nms_radius': 3, | |
'trainable': False, | |
}) | |
model = LightGlue({ | |
'filter_threshold': 0.1, | |
'flash': False, | |
'checkpointed': True, | |
}) | |
# weights path | |
checkpoints_path = join('weights', ckpt) | |
# load state dict | |
if args.model == 'gim_dkm': | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
for k in list(state_dict.keys()): | |
if k.startswith('model.'): | |
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
if 'encoder.net.fc' in k: | |
state_dict.pop(k) | |
model.load_state_dict(state_dict) | |
elif args.model == 'gim_lightglue': | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
for k in list(state_dict.keys()): | |
if k.startswith('model.'): | |
state_dict.pop(k) | |
if k.startswith('superpoint.'): | |
state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) | |
detector.load_state_dict(state_dict) | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
for k in list(state_dict.keys()): | |
if k.startswith('superpoint.'): | |
state_dict.pop(k) | |
if k.startswith('model.'): | |
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
model.load_state_dict(state_dict) | |
# eval mode | |
if detector is not None: | |
detector = detector.eval().to(device) | |
model = model.eval().to(device) | |
name0 = 'a1' | |
name1 = 'a2' | |
postfix = '.png' | |
image_dir = join('assets', 'demo') | |
img_path0 = join(image_dir, name0 + postfix) | |
img_path1 = join(image_dir, name1 + postfix) | |
image0 = read_image(img_path0) | |
image1 = read_image(img_path1) | |
image0, scale0 = preprocess(image0) | |
image1, scale1 = preprocess(image1) | |
image0 = image0.to(device)[None] | |
image1 = image1.to(device)[None] | |
data = dict(color0=image0, color1=image1, image0=image0, image1=image1) | |
if args.model == 'gim_dkm': | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
dense_matches, dense_certainty = model.match(image0, image1) | |
sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000) | |
height0, width0 = image0.shape[-2:] | |
height1, width1 = image1.shape[-2:] | |
kpts0 = sparse_matches[:, :2] | |
kpts0 = torch.stack(( | |
width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,) | |
kpts1 = sparse_matches[:, 2:] | |
kpts1 = torch.stack(( | |
width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,) | |
b_ids = torch.where(mconf[None])[0] | |
elif args.model == 'gim_lightglue': | |
gray0 = read_image(img_path0, grayscale=True) | |
gray1 = read_image(img_path1, grayscale=True) | |
gray0 = preprocess(gray0, grayscale=True)[0] | |
gray1 = preprocess(gray1, grayscale=True)[0] | |
gray0 = gray0.to(device)[None] | |
gray1 = gray1.to(device)[None] | |
scale0 = torch.tensor(scale0).to(device)[None] | |
scale1 = torch.tensor(scale1).to(device)[None] | |
data.update(dict(gray0=gray0, gray1=gray1)) | |
size0 = torch.tensor(data["gray0"].shape[-2:][::-1])[None] | |
size1 = torch.tensor(data["gray1"].shape[-2:][::-1])[None] | |
data.update(dict(size0=size0, size1=size1)) | |
data.update(dict(scale0=scale0, scale1=scale1)) | |
pred = {} | |
pred.update({k + '0': v for k, v in detector({ | |
"image": data["gray0"], | |
"image_size": data["size0"], | |
}).items()}) | |
pred.update({k + '1': v for k, v in detector({ | |
"image": data["gray1"], | |
"image_size": data["size1"], | |
}).items()}) | |
pred.update(model({**pred, **data, | |
**{'resize0': data['size0'], 'resize1': data['size1']}})) | |
kpts0 = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])]) | |
kpts1 = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])]) | |
m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0] | |
matches = pred['matches'] | |
bs = data['image0'].size(0) | |
kpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) | |
kpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) | |
b_ids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) | |
mconf = torch.cat(pred['scores']) | |
# robust fitting | |
_, mask = cv2.findFundamentalMat(kpts0.cpu().detach().numpy(), | |
kpts1.cpu().detach().numpy(), | |
cv2.USAC_MAGSAC, ransacReprojThreshold=1.0, | |
confidence=0.999999, maxIters=10000) | |
mask = mask.ravel() > 0 | |
data.update({ | |
'hw0_i': image0.shape[-2:], | |
'hw1_i': image1.shape[-2:], | |
'mkpts0_f': kpts0, | |
'mkpts1_f': kpts1, | |
'm_bids': b_ids, | |
'mconf': mconf, | |
'inliers': mask, | |
}) | |
# save visualization | |
alpha = 0.5 | |
out = fast_make_matching_figure(data, b_id=0) | |
overlay = fast_make_matching_overlay(data, b_id=0) | |
out = cv2.addWeighted(out, 1 - alpha, overlay, alpha, 0) | |
cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_match.png'), out[..., ::-1]) | |
geom_info = compute_geom(data) | |
wrapped_images = wrap_images(image0, image1, geom_info, | |
"Homography") | |
cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_warp.png'), wrapped_images) | |