# -*- coding: utf-8 -*- # @Author : xuelun import cv2 import torch import argparse import warnings import numpy as np import matplotlib.pyplot as plt import torchvision.transforms.functional as F from os.path import join from dkm.models.model_zoo.DKMv3 import DKMv3 from gluefactory.superpoint import SuperPoint from gluefactory.models.matchers.lightglue import LightGlue DEFAULT_MIN_NUM_MATCHES = 4 DEFAULT_RANSAC_MAX_ITER = 10000 DEFAULT_RANSAC_CONFIDENCE = 0.999 DEFAULT_RANSAC_REPROJ_THRESHOLD = 8 DEFAULT_RANSAC_METHOD = "USAC_MAGSAC" RANSAC_ZOO = { "RANSAC": cv2.RANSAC, "USAC_FAST": cv2.USAC_FAST, "USAC_MAGSAC": cv2.USAC_MAGSAC, "USAC_PROSAC": cv2.USAC_PROSAC, "USAC_DEFAULT": cv2.USAC_DEFAULT, "USAC_FM_8PTS": cv2.USAC_FM_8PTS, "USAC_ACCURATE": cv2.USAC_ACCURATE, "USAC_PARALLEL": cv2.USAC_PARALLEL, } def read_image(path, grayscale=False): if grayscale: mode = cv2.IMREAD_GRAYSCALE else: mode = cv2.IMREAD_COLOR image = cv2.imread(str(path), mode) if image is None: raise ValueError(f'Cannot read image {path}.') if not grayscale and len(image.shape) == 3: image = image[:, :, ::-1] # BGR to RGB return image def resize_image(image, size, interp): assert interp.startswith('cv2_') if interp.startswith('cv2_'): interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper()) h, w = image.shape[:2] if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): interp = cv2.INTER_LINEAR resized = cv2.resize(image, size, interpolation=interp) # elif interp.startswith('pil_'): # interp = getattr(PIL.Image, interp[len('pil_'):].upper()) # resized = PIL.Image.fromarray(image.astype(np.uint8)) # resized = resized.resize(size, resample=interp) # resized = np.asarray(resized, dtype=image.dtype) else: raise ValueError( f'Unknown interpolation {interp}.') return resized def fast_make_matching_figure(data, b_id): color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) kpts0 = data['mkpts0_f'].cpu().detach().numpy() kpts1 = data['mkpts1_f'].cpu().detach().numpy() mconf = data['mconf'].cpu().detach().numpy() inliers = data['inliers'] rows = 2 margin = 2 (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] h = max(h0, h1) H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 # canvas out = 255 * np.ones((H, W), np.uint8) wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] hx = lambda row: margin * row + h * (row-1) out = np.stack([out] * 3, -1) sh = hx(row=1) out[sh: sh + h0, wx[0]: wx[1]] = color0 out[sh: sh + h1, wx[2]: wx[3]] = color1 sh = hx(row=2) out[sh: sh + h0, wx[0]: wx[1]] = color0 out[sh: sh + h1, wx[2]: wx[3]] = color1 mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): c = (0, 255, 0) cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) return out def fast_make_matching_overlay(data, b_id): color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) kpts0 = data['mkpts0_f'].cpu().detach().numpy() kpts1 = data['mkpts1_f'].cpu().detach().numpy() mconf = data['mconf'].cpu().detach().numpy() inliers = data['inliers'] rows = 2 margin = 2 (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] h = max(h0, h1) H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 # canvas out = 255 * np.ones((H, W), np.uint8) wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] hx = lambda row: margin * row + h * (row-1) out = np.stack([out] * 3, -1) sh = hx(row=1) out[sh: sh + h0, wx[0]: wx[1]] = color0 out[sh: sh + h1, wx[2]: wx[3]] = color1 sh = hx(row=2) out[sh: sh + h0, wx[0]: wx[1]] = color0 out[sh: sh + h1, wx[2]: wx[3]] = color1 mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): c = (0, 255, 0) cv2.line(out, (x0, y0 + sh), (x1 + margin + w0, y1 + sh), color=c, thickness=1, lineType=cv2.LINE_AA) cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) return out def preprocess(image: np.ndarray, grayscale: bool = False, resize_max: int = None, dfactor: int = 8): image = image.astype(np.float32, copy=False) size = image.shape[:2][::-1] scale = np.array([1.0, 1.0]) if resize_max: scale = resize_max / max(size) if scale < 1.0: size_new = tuple(int(round(x*scale)) for x in size) image = resize_image(image, size_new, 'cv2_area') scale = np.array(size) / np.array(size_new) if grayscale: assert image.ndim == 2, image.shape image = image[None] else: image = image.transpose((2, 0, 1)) # HxWxC to CxHxW image = torch.from_numpy(image / 255.0).float() # assure that the size is divisible by dfactor size_new = tuple(map( lambda x: int(x // dfactor * dfactor), image.shape[-2:])) image = F.resize(image, size=size_new) scale = np.array(size) / np.array(size_new)[::-1] return image, scale def compute_geom(data, ransac_method=DEFAULT_RANSAC_METHOD, ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD, ransac_confidence=DEFAULT_RANSAC_CONFIDENCE, ransac_max_iter=DEFAULT_RANSAC_MAX_ITER, ) -> dict: mkpts0 = data["mkpts0_f"].cpu().detach().numpy() mkpts1 = data["mkpts1_f"].cpu().detach().numpy() if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES: return {} h1, w1 = data["hw0_i"] geo_info = {} F, inliers = cv2.findFundamentalMat( mkpts0, mkpts1, method=RANSAC_ZOO[ransac_method], ransacReprojThreshold=ransac_reproj_threshold, confidence=ransac_confidence, maxIters=ransac_max_iter, ) if F is not None: geo_info["Fundamental"] = F.tolist() H, _ = cv2.findHomography( mkpts1, mkpts0, method=RANSAC_ZOO[ransac_method], ransacReprojThreshold=ransac_reproj_threshold, confidence=ransac_confidence, maxIters=ransac_max_iter, ) if H is not None: geo_info["Homography"] = H.tolist() _, H1, H2 = cv2.stereoRectifyUncalibrated( mkpts0.reshape(-1, 2), mkpts1.reshape(-1, 2), F, imgSize=(w1, h1), ) geo_info["H1"] = H1.tolist() geo_info["H2"] = H2.tolist() return geo_info def wrap_images(img0, img1, geo_info, geom_type): img0 = img0[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] img1 = img1[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] h1, w1, _ = img0.shape h2, w2, _ = img1.shape rectified_image0 = img0 rectified_image1 = None H = np.array(geo_info["Homography"]) F = np.array(geo_info["Fundamental"]) title = [] if geom_type == "Homography": rectified_image1 = cv2.warpPerspective( img1, H, (img0.shape[1], img0.shape[0]) ) title = ["Image 0", "Image 1 - warped"] elif geom_type == "Fundamental": H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"]) rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1)) rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2)) title = ["Image 0 - warped", "Image 1 - warped"] else: print("Error: Unknown geometry type") fig = plot_images( [rectified_image0.squeeze(), rectified_image1.squeeze()], title, dpi=300, ) img = fig2im(fig) plt.close(fig) return img def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5): """Plot a set of images horizontally. Args: imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). titles: a list of strings, as titles for each image. cmaps: colormaps for monochrome images. dpi: size: pad: """ n = len(imgs) if not isinstance(cmaps, (list, tuple)): cmaps = [cmaps] * n figsize = (size * n, size * 6 / 5) if size is not None else None fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) if n == 1: ax = [ax] for i in range(n): ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) ax[i].get_yaxis().set_ticks([]) ax[i].get_xaxis().set_ticks([]) ax[i].set_axis_off() for spine in ax[i].spines.values(): # remove frame spine.set_visible(False) if titles: ax[i].set_title(titles[i]) fig.tight_layout(pad=pad) return fig def fig2im(fig): fig.canvas.draw() w, h = fig.canvas.get_width_height() buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1") im = buf_ndarray.reshape(h, w, 3) return im if __name__ == '__main__': model_zoo = ['gim_dkm', 'gim_lightglue'] # model parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='gim_dkm', choices=model_zoo) args = parser.parse_args() # device device = 'cuda' if torch.cuda.is_available() else 'cpu' # load model ckpt = None model = None detector = None if args.model == 'gim_dkm': ckpt = 'gim_dkm_100h.ckpt' model = DKMv3(weights=None, h=672, w=896) elif args.model == 'gim_lightglue': ckpt = 'gim_lightglue_100h.ckpt' detector = SuperPoint({ 'max_num_keypoints': 2048, 'force_num_keypoints': True, 'detection_threshold': 0.0, 'nms_radius': 3, 'trainable': False, }) model = LightGlue({ 'filter_threshold': 0.1, 'flash': False, 'checkpointed': True, }) # weights path checkpoints_path = join('weights', ckpt) # load state dict if args.model == 'gim_dkm': state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('model.'): state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) if 'encoder.net.fc' in k: state_dict.pop(k) model.load_state_dict(state_dict) elif args.model == 'gim_lightglue': state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('model.'): state_dict.pop(k) if k.startswith('superpoint.'): state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) detector.load_state_dict(state_dict) state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('superpoint.'): state_dict.pop(k) if k.startswith('model.'): state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) model.load_state_dict(state_dict) # eval mode if detector is not None: detector = detector.eval().to(device) model = model.eval().to(device) name0 = 'a1' name1 = 'a2' postfix = '.png' image_dir = join('assets', 'demo') img_path0 = join(image_dir, name0 + postfix) img_path1 = join(image_dir, name1 + postfix) image0 = read_image(img_path0) image1 = read_image(img_path1) image0, scale0 = preprocess(image0) image1, scale1 = preprocess(image1) image0 = image0.to(device)[None] image1 = image1.to(device)[None] data = dict(color0=image0, color1=image1, image0=image0, image1=image1) if args.model == 'gim_dkm': with warnings.catch_warnings(): warnings.simplefilter("ignore") dense_matches, dense_certainty = model.match(image0, image1) sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000) height0, width0 = image0.shape[-2:] height1, width1 = image1.shape[-2:] kpts0 = sparse_matches[:, :2] kpts0 = torch.stack(( width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,) kpts1 = sparse_matches[:, 2:] kpts1 = torch.stack(( width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,) b_ids = torch.where(mconf[None])[0] elif args.model == 'gim_lightglue': gray0 = read_image(img_path0, grayscale=True) gray1 = read_image(img_path1, grayscale=True) gray0 = preprocess(gray0, grayscale=True)[0] gray1 = preprocess(gray1, grayscale=True)[0] gray0 = gray0.to(device)[None] gray1 = gray1.to(device)[None] scale0 = torch.tensor(scale0).to(device)[None] scale1 = torch.tensor(scale1).to(device)[None] data.update(dict(gray0=gray0, gray1=gray1)) size0 = torch.tensor(data["gray0"].shape[-2:][::-1])[None] size1 = torch.tensor(data["gray1"].shape[-2:][::-1])[None] data.update(dict(size0=size0, size1=size1)) data.update(dict(scale0=scale0, scale1=scale1)) pred = {} pred.update({k + '0': v for k, v in detector({ "image": data["gray0"], "image_size": data["size0"], }).items()}) pred.update({k + '1': v for k, v in detector({ "image": data["gray1"], "image_size": data["size1"], }).items()}) pred.update(model({**pred, **data, **{'resize0': data['size0'], 'resize1': data['size1']}})) kpts0 = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])]) kpts1 = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])]) m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0] matches = pred['matches'] bs = data['image0'].size(0) kpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) kpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) b_ids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) mconf = torch.cat(pred['scores']) # robust fitting _, mask = cv2.findFundamentalMat(kpts0.cpu().detach().numpy(), kpts1.cpu().detach().numpy(), cv2.USAC_MAGSAC, ransacReprojThreshold=1.0, confidence=0.999999, maxIters=10000) mask = mask.ravel() > 0 data.update({ 'hw0_i': image0.shape[-2:], 'hw1_i': image1.shape[-2:], 'mkpts0_f': kpts0, 'mkpts1_f': kpts1, 'm_bids': b_ids, 'mconf': mconf, 'inliers': mask, }) # save visualization alpha = 0.5 out = fast_make_matching_figure(data, b_id=0) overlay = fast_make_matching_overlay(data, b_id=0) out = cv2.addWeighted(out, 1 - alpha, overlay, alpha, 0) cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_match.png'), out[..., ::-1]) geom_info = compute_geom(data) wrapped_images = wrap_images(image0, image1, geom_info, "Homography") cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_warp.png'), wrapped_images)