Spaces:
Running
Running
import argparse | |
import torch | |
from pathlib import Path | |
import h5py | |
import logging | |
from tqdm import tqdm | |
import pprint | |
from queue import Queue | |
from threading import Thread | |
from functools import partial | |
from typing import Dict, List, Optional, Tuple, Union | |
import localization.matchers as matchers | |
from localization.base_model import dynamic_load | |
from colmap_utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval | |
confs = { | |
'gm': { | |
'output': 'gm', | |
'model': { | |
'name': 'gm', | |
'weight_path': 'weights/imp_gm.900.pth', | |
'sinkhorn_iterations': 20, | |
}, | |
}, | |
'gml': { | |
'output': 'gml', | |
'model': { | |
'name': 'gml', | |
'weight_path': 'weights/imp_gml.920.pth', | |
'sinkhorn_iterations': 20, | |
}, | |
}, | |
'adagml': { | |
'output': 'adagml', | |
'model': { | |
'name': 'adagml', | |
'weight_path': 'weights/imp_adagml.80.pth', | |
'sinkhorn_iterations': 20, | |
}, | |
}, | |
'superglue': { | |
'output': 'superglue', | |
'model': { | |
'name': 'superglue', | |
'weights': 'outdoor', | |
'sinkhorn_iterations': 20, | |
'weight_path': 'weights/superglue_outdoor.pth', | |
}, | |
}, | |
'NNM': { | |
'output': 'NNM', | |
'model': { | |
'name': 'nearest_neighbor', | |
'do_mutual_check': True, | |
'distance_threshold': None, | |
}, | |
}, | |
} | |
class WorkQueue: | |
def __init__(self, work_fn, num_threads=1): | |
self.queue = Queue(num_threads) | |
self.threads = [ | |
Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads) | |
] | |
for thread in self.threads: | |
thread.start() | |
def join(self): | |
for thread in self.threads: | |
self.queue.put(None) | |
for thread in self.threads: | |
thread.join() | |
def thread_fn(self, work_fn): | |
item = self.queue.get() | |
while item is not None: | |
work_fn(item) | |
item = self.queue.get() | |
def put(self, data): | |
self.queue.put(data) | |
class FeaturePairsDataset(torch.utils.data.Dataset): | |
def __init__(self, pairs, feature_path_q, feature_path_r): | |
self.pairs = pairs | |
self.feature_path_q = feature_path_q | |
self.feature_path_r = feature_path_r | |
def __getitem__(self, idx): | |
name0, name1 = self.pairs[idx] | |
data = {} | |
with h5py.File(self.feature_path_q, "r") as fd: | |
grp = fd[name0] | |
for k, v in grp.items(): | |
data[k + "0"] = torch.from_numpy(v.__array__()).float() | |
if k == 'descriptors': | |
data[k + '0'] = data[k + '0'].t() | |
# some matchers might expect an image but only use its size | |
data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) | |
with h5py.File(self.feature_path_r, "r") as fd: | |
grp = fd[name1] | |
for k, v in grp.items(): | |
data[k + "1"] = torch.from_numpy(v.__array__()).float() | |
if k == 'descriptors': | |
data[k + '1'] = data[k + '1'].t() | |
data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) | |
return data | |
def __len__(self): | |
return len(self.pairs) | |
def writer_fn(inp, match_path): | |
pair, pred = inp | |
with h5py.File(str(match_path), "a", libver="latest") as fd: | |
if pair in fd: | |
del fd[pair] | |
grp = fd.create_group(pair) | |
matches = pred["matches0"][0].cpu().short().numpy() | |
grp.create_dataset("matches0", data=matches) | |
if "matching_scores0" in pred: | |
scores = pred["matching_scores0"][0].cpu().half().numpy() | |
grp.create_dataset("matching_scores0", data=scores) | |
def main( | |
conf: Dict, | |
pairs: Path, | |
features: Union[Path, str], | |
export_dir: Optional[Path] = None, | |
matches: Optional[Path] = None, | |
features_ref: Optional[Path] = None, | |
overwrite: bool = False, | |
) -> Path: | |
if isinstance(features, Path) or Path(features).exists(): | |
features_q = features | |
if matches is None: | |
raise ValueError( | |
"Either provide both features and matches as Path" " or both as names." | |
) | |
else: | |
if export_dir is None: | |
raise ValueError( | |
"Provide an export_dir if features is not" f" a file path: {features}." | |
) | |
features_q = Path(export_dir, features + ".h5") | |
if matches is None: | |
matches = Path(export_dir, f'{features}-{conf["output"]}-{pairs.stem}.h5') | |
if features_ref is None: | |
features_ref = features_q | |
match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite) | |
return matches | |
def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): | |
"""Avoid to recompute duplicates to save time.""" | |
pairs = set() | |
for i, j in pairs_all: | |
if (j, i) not in pairs: | |
pairs.add((i, j)) | |
pairs = list(pairs) | |
if match_path is not None and match_path.exists(): | |
with h5py.File(str(match_path), "r", libver="latest") as fd: | |
pairs_filtered = [] | |
for i, j in pairs: | |
if ( | |
names_to_pair(i, j) in fd | |
or names_to_pair(j, i) in fd | |
or names_to_pair_old(i, j) in fd | |
or names_to_pair_old(j, i) in fd | |
): | |
continue | |
pairs_filtered.append((i, j)) | |
return pairs_filtered | |
return pairs | |
def match_from_paths( | |
conf: Dict, | |
pairs_path: Path, | |
match_path: Path, | |
feature_path_q: Path, | |
feature_path_ref: Path, | |
overwrite: bool = False, | |
) -> Path: | |
logging.info( | |
"Matching local features with configuration:" f"\n{pprint.pformat(conf)}" | |
) | |
if not feature_path_q.exists(): | |
raise FileNotFoundError(f"Query feature file {feature_path_q}.") | |
if not feature_path_ref.exists(): | |
raise FileNotFoundError(f"Reference feature file {feature_path_ref}.") | |
match_path.parent.mkdir(exist_ok=True, parents=True) | |
assert pairs_path.exists(), pairs_path | |
pairs = parse_retrieval(pairs_path) | |
pairs = [(q, r) for q, rs in pairs.items() for r in rs] | |
pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) | |
if len(pairs) == 0: | |
logging.info("Skipping the matching.") | |
return | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
Model = dynamic_load(matchers, conf["model"]["name"]) | |
model = Model(conf["model"]).eval().to(device) | |
dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) | |
loader = torch.utils.data.DataLoader( | |
dataset, num_workers=4, batch_size=1, shuffle=False, pin_memory=True | |
) | |
writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) | |
for idx, data in enumerate(tqdm(loader, smoothing=0.1)): | |
data = { | |
k: v if k.startswith("image") else v.to(device, non_blocking=True) | |
for k, v in data.items() | |
} | |
pred = model(data) | |
pair = names_to_pair(*pairs[idx]) | |
writer_queue.put((pair, pred)) | |
writer_queue.join() | |
logging.info("Finished exporting matches.") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--export_dir', type=Path, required=True) | |
parser.add_argument('--features', type=str, required=True) | |
parser.add_argument('--pairs', type=Path, required=True) | |
parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) | |
args = parser.parse_args() | |
main(confs[args.conf], args.pairs, args.features, args.export_dir) | |