Vincentqyw
fix: roma
8b973ee
import copy
import os
import cv2
import glob
import logging
import argparse
import numpy as np
from tqdm import tqdm
from alike import ALike, configs
class ImageLoader(object):
def __init__(self, filepath: str):
self.N = 3000
if filepath.startswith("camera"):
camera = int(filepath[6:])
self.cap = cv2.VideoCapture(camera)
if not self.cap.isOpened():
raise IOError(f"Can't open camera {camera}!")
logging.info(f"Opened camera {camera}")
self.mode = "camera"
elif os.path.exists(filepath):
if os.path.isfile(filepath):
self.cap = cv2.VideoCapture(filepath)
if not self.cap.isOpened():
raise IOError(f"Can't open video {filepath}!")
rate = self.cap.get(cv2.CAP_PROP_FPS)
self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
duration = self.N / rate
logging.info(f"Opened video {filepath}")
logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s")
self.mode = "video"
else:
self.images = (
glob.glob(os.path.join(filepath, "*.png"))
+ glob.glob(os.path.join(filepath, "*.jpg"))
+ glob.glob(os.path.join(filepath, "*.ppm"))
)
self.images.sort()
self.N = len(self.images)
logging.info(f"Loading {self.N} images")
self.mode = "images"
else:
raise IOError(
"Error filepath (camerax/path of images/path of videos): ", filepath
)
def __getitem__(self, item):
if self.mode == "camera" or self.mode == "video":
if item > self.N:
return None
ret, img = self.cap.read()
if not ret:
raise "Can't read image from camera"
if self.mode == "video":
self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
elif self.mode == "images":
filename = self.images[item]
img = cv2.imread(filename)
if img is None:
raise Exception("Error reading image %s" % filename)
return img
def __len__(self):
return self.N
class SimpleTracker(object):
def __init__(self):
self.pts_prev = None
self.desc_prev = None
def update(self, img, pts, desc):
N_matches = 0
if self.pts_prev is None:
self.pts_prev = pts
self.desc_prev = desc
out = copy.deepcopy(img)
for pt1 in pts:
p1 = (int(round(pt1[0])), int(round(pt1[1])))
cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16)
else:
matches = self.mnn_mather(self.desc_prev, desc)
mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]]
N_matches = len(matches)
out = copy.deepcopy(img)
for pt1, pt2 in zip(mpts1, mpts2):
p1 = (int(round(pt1[0])), int(round(pt1[1])))
p2 = (int(round(pt2[0])), int(round(pt2[1])))
cv2.line(out, p1, p2, (0, 255, 0), lineType=16)
cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16)
self.pts_prev = pts
self.desc_prev = desc
return out, N_matches
def mnn_mather(self, desc1, desc2):
sim = desc1 @ desc2.transpose()
sim[sim < 0.9] = 0
nn12 = np.argmax(sim, axis=1)
nn21 = np.argmax(sim, axis=0)
ids1 = np.arange(0, sim.shape[0])
mask = ids1 == nn21[nn12]
matches = np.stack([ids1[mask], nn12[mask]])
return matches.transpose()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ALike Demo.")
parser.add_argument(
"input",
type=str,
default="",
help='Image directory or movie file or "camera0" (for webcam0).',
)
parser.add_argument(
"--model",
choices=["alike-t", "alike-s", "alike-n", "alike-l"],
default="alike-t",
help="The model configuration",
)
parser.add_argument(
"--device", type=str, default="cuda", help="Running device (default: cuda)."
)
parser.add_argument(
"--top_k",
type=int,
default=-1,
help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)",
)
parser.add_argument(
"--scores_th",
type=float,
default=0.2,
help="Detector score threshold (default: 0.2).",
)
parser.add_argument(
"--n_limit",
type=int,
default=5000,
help="Maximum number of keypoints to be detected (default: 5000).",
)
parser.add_argument(
"--no_display",
action="store_true",
help="Do not display images to screen. Useful if running remotely (default: False).",
)
parser.add_argument(
"--no_sub_pixel",
action="store_true",
help="Do not detect sub-pixel keypoints (default: False).",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
image_loader = ImageLoader(args.input)
model = ALike(
**configs[args.model],
device=args.device,
top_k=args.top_k,
scores_th=args.scores_th,
n_limit=args.n_limit,
)
tracker = SimpleTracker()
if not args.no_display:
logging.info("Press 'q' to stop!")
cv2.namedWindow(args.model)
runtime = []
progress_bar = tqdm(image_loader)
for img in progress_bar:
if img is None:
break
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
kpts = pred["keypoints"]
desc = pred["descriptors"]
runtime.append(pred["time"])
out, N_matches = tracker.update(img, kpts, desc)
ave_fps = (1.0 / np.stack(runtime)).mean()
status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
progress_bar.set_description(status)
if not args.no_display:
cv2.setWindowTitle(args.model, args.model + ": " + status)
cv2.imshow(args.model, out)
if cv2.waitKey(1) == ord("q"):
break
logging.info("Finished!")
if not args.no_display:
logging.info("Press any key to exit!")
cv2.waitKey()