# Benchmark script for LightGlue on real images import argparse import time from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch import torch._dynamo from lightglue import LightGlue, SuperPoint from lightglue.utils import load_image torch.set_grad_enabled(False) def measure(matcher, data, device="cuda", r=100): timings = np.zeros((r, 1)) if device.type == "cuda": starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) # warmup for _ in range(10): _ = matcher(data) # measurements with torch.no_grad(): for rep in range(r): if device.type == "cuda": starter.record() _ = matcher(data) ender.record() # sync gpu torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) else: start = time.perf_counter() _ = matcher(data) curr_time = (time.perf_counter() - start) * 1e3 timings[rep] = curr_time mean_syn = np.sum(timings) / r std_syn = np.std(timings) return {"mean": mean_syn, "std": std_syn} def print_as_table(d, title, cnames): print() header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames]) print(header) print("-" * len(header)) for k, l in d.items(): print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l])) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark script for LightGlue") parser.add_argument( "--device", choices=["auto", "cuda", "cpu", "mps"], default="auto", help="device to benchmark on", ) parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs") parser.add_argument( "--no_flash", action="store_true", help="disable FlashAttention" ) parser.add_argument( "--no_prune_thresholds", action="store_true", help="disable pruning thresholds (i.e. always do pruning)", ) parser.add_argument( "--add_superglue", action="store_true", help="add SuperGlue to the benchmark (requires hloc)", ) parser.add_argument( "--measure", default="time", choices=["time", "log-time", "throughput"] ) parser.add_argument( "--repeat", "--r", type=int, default=100, help="repetitions of measurements" ) parser.add_argument( "--num_keypoints", nargs="+", type=int, default=[256, 512, 1024, 2048, 4096], help="number of keypoints (list separated by spaces)", ) parser.add_argument( "--matmul_precision", default="highest", choices=["highest", "high", "medium"] ) parser.add_argument( "--save", default=None, type=str, help="path where figure should be saved" ) args = parser.parse_intermixed_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.device != "auto": device = torch.device(args.device) print("Running benchmark on device:", device) images = Path("assets") inputs = { "easy": ( load_image(images / "DSC_0411.JPG"), load_image(images / "DSC_0410.JPG"), ), "difficult": ( load_image(images / "sacre_coeur1.jpg"), load_image(images / "sacre_coeur2.jpg"), ), } configs = { "LightGlue-full": { "depth_confidence": -1, "width_confidence": -1, }, # 'LG-prune': { # 'width_confidence': -1, # }, # 'LG-depth': { # 'depth_confidence': -1, # }, "LightGlue-adaptive": {}, } if args.compile: configs = {**configs, **{k + "-compile": v for k, v in configs.items()}} sg_configs = { # 'SuperGlue': {}, "SuperGlue-fast": {"sinkhorn_iterations": 5} } torch.set_float32_matmul_precision(args.matmul_precision) results = {k: defaultdict(list) for k, v in inputs.items()} extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1) extractor = extractor.eval().to(device) figsize = (len(inputs) * 4.5, 4.5) fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize) axes = axes if len(inputs) > 1 else [axes] fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})") for title, ax in zip(inputs.keys(), axes): ax.set_xscale("log", base=2) bases = [2**x for x in range(7, 16)] ax.set_xticks(bases, bases) ax.grid(which="major") if args.measure == "log-time": ax.set_yscale("log") yticks = [10**x for x in range(6)] ax.set_yticks(yticks, yticks) mpos = [10**x * i for x in range(6) for i in range(2, 10)] mlabel = [ 10**x * i if i in [2, 5] else None for x in range(6) for i in range(2, 10) ] ax.set_yticks(mpos, mlabel, minor=True) ax.grid(which="minor", linewidth=0.2) ax.set_title(title) ax.set_xlabel("# keypoints") if args.measure == "throughput": ax.set_ylabel("Throughput [pairs/s]") else: ax.set_ylabel("Latency [ms]") for name, conf in configs.items(): print("Run benchmark for:", name) torch.cuda.empty_cache() matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf) if args.no_prune_thresholds: matcher.pruning_keypoint_thresholds = { k: -1 for k in matcher.pruning_keypoint_thresholds } matcher = matcher.eval().to(device) if name.endswith("compile"): import torch._dynamo torch._dynamo.reset() # avoid buffer overflow matcher.compile() for pair_name, ax in zip(inputs.keys(), axes): image0, image1 = [x.to(device) for x in inputs[pair_name]] runtimes = [] for num_kpts in args.num_keypoints: extractor.conf.max_num_keypoints = num_kpts feats0 = extractor.extract(image0) feats1 = extractor.extract(image1) runtime = measure( matcher, {"image0": feats0, "image1": feats1}, device=device, r=args.repeat, )["mean"] results[pair_name][name].append( 1000 / runtime if args.measure == "throughput" else runtime ) ax.plot( args.num_keypoints, results[pair_name][name], label=name, marker="o" ) del matcher, feats0, feats1 if args.add_superglue: from hloc.matchers.superglue import SuperGlue for name, conf in sg_configs.items(): print("Run benchmark for:", name) matcher = SuperGlue(conf) matcher = matcher.eval().to(device) for pair_name, ax in zip(inputs.keys(), axes): image0, image1 = [x.to(device) for x in inputs[pair_name]] runtimes = [] for num_kpts in args.num_keypoints: extractor.conf.max_num_keypoints = num_kpts feats0 = extractor.extract(image0) feats1 = extractor.extract(image1) data = { "image0": image0[None], "image1": image1[None], **{k + "0": v for k, v in feats0.items()}, **{k + "1": v for k, v in feats1.items()}, } data["scores0"] = data["keypoint_scores0"] data["scores1"] = data["keypoint_scores1"] data["descriptors0"] = ( data["descriptors0"].transpose(-1, -2).contiguous() ) data["descriptors1"] = ( data["descriptors1"].transpose(-1, -2).contiguous() ) runtime = measure(matcher, data, device=device, r=args.repeat)[ "mean" ] results[pair_name][name].append( 1000 / runtime if args.measure == "throughput" else runtime ) ax.plot( args.num_keypoints, results[pair_name][name], label=name, marker="o" ) del matcher, data, image0, image1, feats0, feats1 for name, runtimes in results.items(): print_as_table(runtimes, name, args.num_keypoints) axes[0].legend() fig.tight_layout() if args.save: plt.savefig(args.save, dpi=fig.dpi) plt.show()