|
|
|
import argparse |
|
import time |
|
from collections import defaultdict |
|
from pathlib import Path |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
import torch._dynamo |
|
|
|
from lightglue import LightGlue, SuperPoint |
|
from lightglue.utils import load_image |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
def measure(matcher, data, device="cuda", r=100): |
|
timings = np.zeros((r, 1)) |
|
if device.type == "cuda": |
|
starter = torch.cuda.Event(enable_timing=True) |
|
ender = torch.cuda.Event(enable_timing=True) |
|
|
|
for _ in range(10): |
|
_ = matcher(data) |
|
|
|
with torch.no_grad(): |
|
for rep in range(r): |
|
if device.type == "cuda": |
|
starter.record() |
|
_ = matcher(data) |
|
ender.record() |
|
|
|
torch.cuda.synchronize() |
|
curr_time = starter.elapsed_time(ender) |
|
else: |
|
start = time.perf_counter() |
|
_ = matcher(data) |
|
curr_time = (time.perf_counter() - start) * 1e3 |
|
timings[rep] = curr_time |
|
mean_syn = np.sum(timings) / r |
|
std_syn = np.std(timings) |
|
return {"mean": mean_syn, "std": std_syn} |
|
|
|
|
|
def print_as_table(d, title, cnames): |
|
print() |
|
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames]) |
|
print(header) |
|
print("-" * len(header)) |
|
for k, l in d.items(): |
|
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l])) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue") |
|
parser.add_argument( |
|
"--device", |
|
choices=["auto", "cuda", "cpu", "mps"], |
|
default="auto", |
|
help="device to benchmark on", |
|
) |
|
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs") |
|
parser.add_argument( |
|
"--no_flash", action="store_true", help="disable FlashAttention" |
|
) |
|
parser.add_argument( |
|
"--no_prune_thresholds", |
|
action="store_true", |
|
help="disable pruning thresholds (i.e. always do pruning)", |
|
) |
|
parser.add_argument( |
|
"--add_superglue", |
|
action="store_true", |
|
help="add SuperGlue to the benchmark (requires hloc)", |
|
) |
|
parser.add_argument( |
|
"--measure", default="time", choices=["time", "log-time", "throughput"] |
|
) |
|
parser.add_argument( |
|
"--repeat", "--r", type=int, default=100, help="repetitions of measurements" |
|
) |
|
parser.add_argument( |
|
"--num_keypoints", |
|
nargs="+", |
|
type=int, |
|
default=[256, 512, 1024, 2048, 4096], |
|
help="number of keypoints (list separated by spaces)", |
|
) |
|
parser.add_argument( |
|
"--matmul_precision", default="highest", choices=["highest", "high", "medium"] |
|
) |
|
parser.add_argument( |
|
"--save", default=None, type=str, help="path where figure should be saved" |
|
) |
|
args = parser.parse_intermixed_args() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if args.device != "auto": |
|
device = torch.device(args.device) |
|
|
|
print("Running benchmark on device:", device) |
|
|
|
images = Path("assets") |
|
inputs = { |
|
"easy": ( |
|
load_image(images / "DSC_0411.JPG"), |
|
load_image(images / "DSC_0410.JPG"), |
|
), |
|
"difficult": ( |
|
load_image(images / "sacre_coeur1.jpg"), |
|
load_image(images / "sacre_coeur2.jpg"), |
|
), |
|
} |
|
|
|
configs = { |
|
"LightGlue-full": { |
|
"depth_confidence": -1, |
|
"width_confidence": -1, |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
"LightGlue-adaptive": {}, |
|
} |
|
|
|
if args.compile: |
|
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}} |
|
|
|
sg_configs = { |
|
|
|
"SuperGlue-fast": {"sinkhorn_iterations": 5} |
|
} |
|
|
|
torch.set_float32_matmul_precision(args.matmul_precision) |
|
|
|
results = {k: defaultdict(list) for k, v in inputs.items()} |
|
|
|
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1) |
|
extractor = extractor.eval().to(device) |
|
figsize = (len(inputs) * 4.5, 4.5) |
|
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize) |
|
axes = axes if len(inputs) > 1 else [axes] |
|
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})") |
|
|
|
for title, ax in zip(inputs.keys(), axes): |
|
ax.set_xscale("log", base=2) |
|
bases = [2**x for x in range(7, 16)] |
|
ax.set_xticks(bases, bases) |
|
ax.grid(which="major") |
|
if args.measure == "log-time": |
|
ax.set_yscale("log") |
|
yticks = [10**x for x in range(6)] |
|
ax.set_yticks(yticks, yticks) |
|
mpos = [10**x * i for x in range(6) for i in range(2, 10)] |
|
mlabel = [ |
|
10**x * i if i in [2, 5] else None |
|
for x in range(6) |
|
for i in range(2, 10) |
|
] |
|
ax.set_yticks(mpos, mlabel, minor=True) |
|
ax.grid(which="minor", linewidth=0.2) |
|
ax.set_title(title) |
|
|
|
ax.set_xlabel("# keypoints") |
|
if args.measure == "throughput": |
|
ax.set_ylabel("Throughput [pairs/s]") |
|
else: |
|
ax.set_ylabel("Latency [ms]") |
|
|
|
for name, conf in configs.items(): |
|
print("Run benchmark for:", name) |
|
torch.cuda.empty_cache() |
|
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf) |
|
if args.no_prune_thresholds: |
|
matcher.pruning_keypoint_thresholds = { |
|
k: -1 for k in matcher.pruning_keypoint_thresholds |
|
} |
|
matcher = matcher.eval().to(device) |
|
if name.endswith("compile"): |
|
import torch._dynamo |
|
|
|
torch._dynamo.reset() |
|
matcher.compile() |
|
for pair_name, ax in zip(inputs.keys(), axes): |
|
image0, image1 = [x.to(device) for x in inputs[pair_name]] |
|
runtimes = [] |
|
for num_kpts in args.num_keypoints: |
|
extractor.conf.max_num_keypoints = num_kpts |
|
feats0 = extractor.extract(image0) |
|
feats1 = extractor.extract(image1) |
|
runtime = measure( |
|
matcher, |
|
{"image0": feats0, "image1": feats1}, |
|
device=device, |
|
r=args.repeat, |
|
)["mean"] |
|
results[pair_name][name].append( |
|
1000 / runtime if args.measure == "throughput" else runtime |
|
) |
|
ax.plot( |
|
args.num_keypoints, results[pair_name][name], label=name, marker="o" |
|
) |
|
del matcher, feats0, feats1 |
|
|
|
if args.add_superglue: |
|
from hloc.matchers.superglue import SuperGlue |
|
|
|
for name, conf in sg_configs.items(): |
|
print("Run benchmark for:", name) |
|
matcher = SuperGlue(conf) |
|
matcher = matcher.eval().to(device) |
|
for pair_name, ax in zip(inputs.keys(), axes): |
|
image0, image1 = [x.to(device) for x in inputs[pair_name]] |
|
runtimes = [] |
|
for num_kpts in args.num_keypoints: |
|
extractor.conf.max_num_keypoints = num_kpts |
|
feats0 = extractor.extract(image0) |
|
feats1 = extractor.extract(image1) |
|
data = { |
|
"image0": image0[None], |
|
"image1": image1[None], |
|
**{k + "0": v for k, v in feats0.items()}, |
|
**{k + "1": v for k, v in feats1.items()}, |
|
} |
|
data["scores0"] = data["keypoint_scores0"] |
|
data["scores1"] = data["keypoint_scores1"] |
|
data["descriptors0"] = ( |
|
data["descriptors0"].transpose(-1, -2).contiguous() |
|
) |
|
data["descriptors1"] = ( |
|
data["descriptors1"].transpose(-1, -2).contiguous() |
|
) |
|
runtime = measure(matcher, data, device=device, r=args.repeat)[ |
|
"mean" |
|
] |
|
results[pair_name][name].append( |
|
1000 / runtime if args.measure == "throughput" else runtime |
|
) |
|
ax.plot( |
|
args.num_keypoints, results[pair_name][name], label=name, marker="o" |
|
) |
|
del matcher, data, image0, image1, feats0, feats1 |
|
|
|
for name, runtimes in results.items(): |
|
print_as_table(runtimes, name, args.num_keypoints) |
|
|
|
axes[0].legend() |
|
fig.tight_layout() |
|
if args.save: |
|
plt.savefig(args.save, dpi=fig.dpi) |
|
plt.show() |
|
|