Spaces:
Running
Running
import argparse | |
from typing import Union, Optional, Dict, List, Tuple | |
from pathlib import Path | |
import pprint | |
from queue import Queue | |
from threading import Thread | |
from functools import partial | |
from tqdm import tqdm | |
import h5py | |
import torch | |
from . import matchers, logger | |
from .utils.base_model import dynamic_load | |
from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval | |
import numpy as np | |
""" | |
A set of standard configurations that can be directly selected from the command | |
line using their name. Each is a dictionary with the following entries: | |
- output: the name of the match file that will be generated. | |
- model: the model configuration, as passed to a feature matcher. | |
""" | |
confs = { | |
"superglue": { | |
"output": "matches-superglue", | |
"model": { | |
"name": "superglue", | |
"weights": "outdoor", | |
"sinkhorn_iterations": 50, | |
"match_threshold": 0.2, | |
}, | |
"preprocessing": { | |
"grayscale": True, | |
"resize_max": 1024, | |
"dfactor": 8, | |
"force_resize": False, | |
}, | |
}, | |
"superglue-fast": { | |
"output": "matches-superglue-it5", | |
"model": { | |
"name": "superglue", | |
"weights": "outdoor", | |
"sinkhorn_iterations": 5, | |
"match_threshold": 0.2, | |
}, | |
}, | |
"superpoint-lightglue": { | |
"output": "matches-lightglue", | |
"model": { | |
"name": "lightglue", | |
"match_threshold": 0.2, | |
"width_confidence": 0.99, # for point pruning | |
"depth_confidence": 0.95, # for early stopping, | |
"features": "superpoint", | |
"model_name": "superpoint_lightglue.pth", | |
}, | |
"preprocessing": { | |
"grayscale": True, | |
"resize_max": 1024, | |
"dfactor": 8, | |
"force_resize": False, | |
}, | |
}, | |
"disk-lightglue": { | |
"output": "matches-lightglue", | |
"model": { | |
"name": "lightglue", | |
"match_threshold": 0.2, | |
"width_confidence": 0.99, # for point pruning | |
"depth_confidence": 0.95, # for early stopping, | |
"features": "disk", | |
"model_name": "disk_lightglue.pth", | |
}, | |
"preprocessing": { | |
"grayscale": True, | |
"resize_max": 1024, | |
"dfactor": 8, | |
"force_resize": False, | |
}, | |
}, | |
"sgmnet": { | |
"output": "matches-sgmnet", | |
"model": { | |
"name": "sgmnet", | |
"seed_top_k": [256, 256], | |
"seed_radius_coe": 0.01, | |
"net_channels": 128, | |
"layer_num": 9, | |
"head": 4, | |
"seedlayer": [0, 6], | |
"use_mc_seeding": True, | |
"use_score_encoding": False, | |
"conf_bar": [1.11, 0.1], | |
"sink_iter": [10, 100], | |
"detach_iter": 1000000, | |
"match_threshold": 0.2, | |
}, | |
"preprocessing": { | |
"grayscale": True, | |
"resize_max": 1024, | |
"dfactor": 8, | |
"force_resize": False, | |
}, | |
}, | |
"NN-superpoint": { | |
"output": "matches-NN-mutual-dist.7", | |
"model": { | |
"name": "nearest_neighbor", | |
"do_mutual_check": True, | |
"distance_threshold": 0.7, | |
"match_threshold": 0.2, | |
}, | |
}, | |
"NN-ratio": { | |
"output": "matches-NN-mutual-ratio.8", | |
"model": { | |
"name": "nearest_neighbor", | |
"do_mutual_check": True, | |
"ratio_threshold": 0.8, | |
"match_threshold": 0.2, | |
}, | |
}, | |
"NN-mutual": { | |
"output": "matches-NN-mutual", | |
"model": { | |
"name": "nearest_neighbor", | |
"do_mutual_check": True, | |
"match_threshold": 0.2, | |
}, | |
}, | |
"Dual-Softmax": { | |
"output": "matches-Dual-Softmax", | |
"model": { | |
"name": "dual_softmax", | |
"do_mutual_check": True, | |
"match_threshold": 0.2, # TODO | |
}, | |
}, | |
"adalam": { | |
"output": "matches-adalam", | |
"model": { | |
"name": "adalam", | |
"match_threshold": 0.2, | |
}, | |
}, | |
} | |
class WorkQueue: | |
def __init__(self, work_fn, num_threads=1): | |
self.queue = Queue(num_threads) | |
self.threads = [ | |
Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads) | |
] | |
for thread in self.threads: | |
thread.start() | |
def join(self): | |
for thread in self.threads: | |
self.queue.put(None) | |
for thread in self.threads: | |
thread.join() | |
def thread_fn(self, work_fn): | |
item = self.queue.get() | |
while item is not None: | |
work_fn(item) | |
item = self.queue.get() | |
def put(self, data): | |
self.queue.put(data) | |
class FeaturePairsDataset(torch.utils.data.Dataset): | |
def __init__(self, pairs, feature_path_q, feature_path_r): | |
self.pairs = pairs | |
self.feature_path_q = feature_path_q | |
self.feature_path_r = feature_path_r | |
def __getitem__(self, idx): | |
name0, name1 = self.pairs[idx] | |
data = {} | |
with h5py.File(self.feature_path_q, "r") as fd: | |
grp = fd[name0] | |
for k, v in grp.items(): | |
data[k + "0"] = torch.from_numpy(v.__array__()).float() | |
# some matchers might expect an image but only use its size | |
data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) | |
with h5py.File(self.feature_path_r, "r") as fd: | |
grp = fd[name1] | |
for k, v in grp.items(): | |
data[k + "1"] = torch.from_numpy(v.__array__()).float() | |
data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) | |
return data | |
def __len__(self): | |
return len(self.pairs) | |
def writer_fn(inp, match_path): | |
pair, pred = inp | |
with h5py.File(str(match_path), "a", libver="latest") as fd: | |
if pair in fd: | |
del fd[pair] | |
grp = fd.create_group(pair) | |
matches = pred["matches0"][0].cpu().short().numpy() | |
grp.create_dataset("matches0", data=matches) | |
if "matching_scores0" in pred: | |
scores = pred["matching_scores0"][0].cpu().half().numpy() | |
grp.create_dataset("matching_scores0", data=scores) | |
def main( | |
conf: Dict, | |
pairs: Path, | |
features: Union[Path, str], | |
export_dir: Optional[Path] = None, | |
matches: Optional[Path] = None, | |
features_ref: Optional[Path] = None, | |
overwrite: bool = False, | |
) -> Path: | |
if isinstance(features, Path) or Path(features).exists(): | |
features_q = features | |
if matches is None: | |
raise ValueError( | |
"Either provide both features and matches as Path" " or both as names." | |
) | |
else: | |
if export_dir is None: | |
raise ValueError( | |
"Provide an export_dir if features is not" f" a file path: {features}." | |
) | |
features_q = Path(export_dir, features + ".h5") | |
if matches is None: | |
matches = Path(export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5') | |
if features_ref is None: | |
features_ref = features_q | |
match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite) | |
return matches | |
def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): | |
"""Avoid to recompute duplicates to save time.""" | |
pairs = set() | |
for i, j in pairs_all: | |
if (j, i) not in pairs: | |
pairs.add((i, j)) | |
pairs = list(pairs) | |
if match_path is not None and match_path.exists(): | |
with h5py.File(str(match_path), "r", libver="latest") as fd: | |
pairs_filtered = [] | |
for i, j in pairs: | |
if ( | |
names_to_pair(i, j) in fd | |
or names_to_pair(j, i) in fd | |
or names_to_pair_old(i, j) in fd | |
or names_to_pair_old(j, i) in fd | |
): | |
continue | |
pairs_filtered.append((i, j)) | |
return pairs_filtered | |
return pairs | |
def match_from_paths( | |
conf: Dict, | |
pairs_path: Path, | |
match_path: Path, | |
feature_path_q: Path, | |
feature_path_ref: Path, | |
overwrite: bool = False, | |
) -> Path: | |
logger.info( | |
"Matching local features with configuration:" f"\n{pprint.pformat(conf)}" | |
) | |
if not feature_path_q.exists(): | |
raise FileNotFoundError(f"Query feature file {feature_path_q}.") | |
if not feature_path_ref.exists(): | |
raise FileNotFoundError(f"Reference feature file {feature_path_ref}.") | |
match_path.parent.mkdir(exist_ok=True, parents=True) | |
assert pairs_path.exists(), pairs_path | |
pairs = parse_retrieval(pairs_path) | |
pairs = [(q, r) for q, rs in pairs.items() for r in rs] | |
pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) | |
if len(pairs) == 0: | |
logger.info("Skipping the matching.") | |
return | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
Model = dynamic_load(matchers, conf["model"]["name"]) | |
model = Model(conf["model"]).eval().to(device) | |
dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) | |
loader = torch.utils.data.DataLoader( | |
dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True | |
) | |
writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) | |
for idx, data in enumerate(tqdm(loader, smoothing=0.1)): | |
data = { | |
k: v if k.startswith("image") else v.to(device, non_blocking=True) | |
for k, v in data.items() | |
} | |
pred = model(data) | |
pair = names_to_pair(*pairs[idx]) | |
writer_queue.put((pair, pred)) | |
writer_queue.join() | |
logger.info("Finished exporting matches.") | |
def scale_keypoints(kpts, scale): | |
if np.any(scale != 1.0): | |
kpts *= kpts.new_tensor(scale) | |
return kpts | |
def match_images(model, feat0, feat1): | |
# forward pass to match keypoints | |
desc0 = feat0["descriptors"][0] | |
desc1 = feat1["descriptors"][0] | |
if len(desc0.shape) == 2: | |
desc0 = desc0.unsqueeze(0) | |
if len(desc1.shape) == 2: | |
desc1 = desc1.unsqueeze(0) | |
pred = model( | |
{ | |
"image0": feat0["image"], | |
"keypoints0": feat0["keypoints"][0], | |
"scores0": feat0["scores"][0].unsqueeze(0), | |
"descriptors0": desc0, | |
"image1": feat1["image"], | |
"keypoints1": feat1["keypoints"][0], | |
"scores1": feat1["scores"][0].unsqueeze(0), | |
"descriptors1": desc1, | |
} | |
) | |
pred = { | |
k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v | |
for k, v in pred.items() | |
} | |
kpts0, kpts1 = ( | |
feat0["keypoints"][0].cpu().numpy(), | |
feat1["keypoints"][0].cpu().numpy(), | |
) | |
matches, confid = pred["matches0"], pred["matching_scores0"] | |
# Keep the matching keypoints. | |
valid = matches > -1 | |
mkpts0 = kpts0[valid] | |
mkpts1 = kpts1[matches[valid]] | |
mconfid = confid[valid] | |
# rescale the keypoints to their original size | |
s0 = feat0["original_size"] / feat0["size"] | |
s1 = feat1["original_size"] / feat1["size"] | |
kpts0_origin = scale_keypoints(torch.from_numpy(mkpts0 + 0.5), s0) - 0.5 | |
kpts1_origin = scale_keypoints(torch.from_numpy(mkpts1 + 0.5), s1) - 0.5 | |
ret = { | |
"image0_orig": feat0["image_orig"], | |
"image1_orig": feat1["image_orig"], | |
"keypoints0": kpts0, | |
"keypoints1": kpts1, | |
"keypoints0_orig": kpts0_origin.numpy(), | |
"keypoints1_orig": kpts1_origin.numpy(), | |
"mconf": mconfid, | |
} | |
del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin | |
torch.cuda.empty_cache() | |
return ret | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pairs", type=Path, required=True) | |
parser.add_argument("--export_dir", type=Path) | |
parser.add_argument("--features", type=str, default="feats-superpoint-n4096-r1024") | |
parser.add_argument("--matches", type=Path) | |
parser.add_argument( | |
"--conf", type=str, default="superglue", choices=list(confs.keys()) | |
) | |
args = parser.parse_args() | |
main(confs[args.conf], args.pairs, args.features, args.export_dir) | |