import copy import os import cv2 import glob import logging import argparse import numpy as np from tqdm import tqdm from alike import ALike, configs class ImageLoader(object): def __init__(self, filepath: str): self.N = 3000 if filepath.startswith("camera"): camera = int(filepath[6:]) self.cap = cv2.VideoCapture(camera) if not self.cap.isOpened(): raise IOError(f"Can't open camera {camera}!") logging.info(f"Opened camera {camera}") self.mode = "camera" elif os.path.exists(filepath): if os.path.isfile(filepath): self.cap = cv2.VideoCapture(filepath) if not self.cap.isOpened(): raise IOError(f"Can't open video {filepath}!") rate = self.cap.get(cv2.CAP_PROP_FPS) self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 duration = self.N / rate logging.info(f"Opened video {filepath}") logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s") self.mode = "video" else: self.images = ( glob.glob(os.path.join(filepath, "*.png")) + glob.glob(os.path.join(filepath, "*.jpg")) + glob.glob(os.path.join(filepath, "*.ppm")) ) self.images.sort() self.N = len(self.images) logging.info(f"Loading {self.N} images") self.mode = "images" else: raise IOError( "Error filepath (camerax/path of images/path of videos): ", filepath ) def __getitem__(self, item): if self.mode == "camera" or self.mode == "video": if item > self.N: return None ret, img = self.cap.read() if not ret: raise "Can't read image from camera" if self.mode == "video": self.cap.set(cv2.CAP_PROP_POS_FRAMES, item) elif self.mode == "images": filename = self.images[item] img = cv2.imread(filename) if img is None: raise Exception("Error reading image %s" % filename) return img def __len__(self): return self.N class SimpleTracker(object): def __init__(self): self.pts_prev = None self.desc_prev = None def update(self, img, pts, desc): N_matches = 0 if self.pts_prev is None: self.pts_prev = pts self.desc_prev = desc out = copy.deepcopy(img) for pt1 in pts: p1 = (int(round(pt1[0])), int(round(pt1[1]))) cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16) else: matches = self.mnn_mather(self.desc_prev, desc) mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]] N_matches = len(matches) out = copy.deepcopy(img) for pt1, pt2 in zip(mpts1, mpts2): p1 = (int(round(pt1[0])), int(round(pt1[1]))) p2 = (int(round(pt2[0])), int(round(pt2[1]))) cv2.line(out, p1, p2, (0, 255, 0), lineType=16) cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16) self.pts_prev = pts self.desc_prev = desc return out, N_matches def mnn_mather(self, desc1, desc2): sim = desc1 @ desc2.transpose() sim[sim < 0.9] = 0 nn12 = np.argmax(sim, axis=1) nn21 = np.argmax(sim, axis=0) ids1 = np.arange(0, sim.shape[0]) mask = ids1 == nn21[nn12] matches = np.stack([ids1[mask], nn12[mask]]) return matches.transpose() if __name__ == "__main__": parser = argparse.ArgumentParser(description="ALike Demo.") parser.add_argument( "input", type=str, default="", help='Image directory or movie file or "camera0" (for webcam0).', ) parser.add_argument( "--model", choices=["alike-t", "alike-s", "alike-n", "alike-l"], default="alike-t", help="The model configuration", ) parser.add_argument( "--device", type=str, default="cuda", help="Running device (default: cuda)." ) parser.add_argument( "--top_k", type=int, default=-1, help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)", ) parser.add_argument( "--scores_th", type=float, default=0.2, help="Detector score threshold (default: 0.2).", ) parser.add_argument( "--n_limit", type=int, default=5000, help="Maximum number of keypoints to be detected (default: 5000).", ) parser.add_argument( "--no_display", action="store_true", help="Do not display images to screen. Useful if running remotely (default: False).", ) parser.add_argument( "--no_sub_pixel", action="store_true", help="Do not detect sub-pixel keypoints (default: False).", ) args = parser.parse_args() logging.basicConfig(level=logging.INFO) image_loader = ImageLoader(args.input) model = ALike( **configs[args.model], device=args.device, top_k=args.top_k, scores_th=args.scores_th, n_limit=args.n_limit, ) tracker = SimpleTracker() if not args.no_display: logging.info("Press 'q' to stop!") cv2.namedWindow(args.model) runtime = [] progress_bar = tqdm(image_loader) for img in progress_bar: if img is None: break img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) pred = model(img_rgb, sub_pixel=not args.no_sub_pixel) kpts = pred["keypoints"] desc = pred["descriptors"] runtime.append(pred["time"]) out, N_matches = tracker.update(img, kpts, desc) ave_fps = (1.0 / np.stack(runtime)).mean() status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}" progress_bar.set_description(status) if not args.no_display: cv2.setWindowTitle(args.model, args.model + ": " + status) cv2.imshow(args.model, out) if cv2.waitKey(1) == ord("q"): break logging.info("Finished!") if not args.no_display: logging.info("Press any key to exit!") cv2.waitKey()