Vincentqyw
update: limit keypoints number
e15a186
raw
history blame
4.4 kB
from pathlib import Path
import argparse
from ... import extract_features, match_features, triangulation, logger
from ... import pairs_from_covisibility, pairs_from_retrieval, localize_sfm
TEST_SLICES = [2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18, 19, 20, 21]
def generate_query_list(dataset, path, slice_):
cameras = {}
with open(dataset / "intrinsics.txt", "r") as f:
for line in f.readlines():
if line[0] == "#" or line == "\n":
continue
data = line.split()
cameras[data[0]] = data[1:]
assert len(cameras) == 2
queries = dataset / f"{slice_}/test-images-{slice_}.txt"
with open(queries, "r") as f:
queries = [q.rstrip("\n") for q in f.readlines()]
out = [[q] + cameras[q.split("_")[2]] for q in queries]
with open(path, "w") as f:
f.write("\n".join(map(" ".join, out)))
def run_slice(slice_, root, outputs, num_covis, num_loc):
dataset = root / slice_
ref_images = dataset / "database"
query_images = dataset / "query"
sift_sfm = dataset / "sparse"
outputs = outputs / slice_
outputs.mkdir(exist_ok=True, parents=True)
query_list = dataset / "queries_with_intrinsics.txt"
sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt"
loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt"
ref_sfm = outputs / "sfm_superpoint+superglue"
results = outputs / f"CMU_hloc_superpoint+superglue_netvlad{num_loc}.txt"
# pick one of the configurations for extraction and matching
retrieval_conf = extract_features.confs["netvlad"]
feature_conf = extract_features.confs["superpoint_aachen"]
matcher_conf = match_features.confs["superglue"]
pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=num_covis)
features = extract_features.main(
feature_conf, ref_images, outputs, as_half=True
)
sfm_matches = match_features.main(
matcher_conf, sfm_pairs, feature_conf["output"], outputs
)
triangulation.main(
ref_sfm, sift_sfm, ref_images, sfm_pairs, features, sfm_matches
)
generate_query_list(root, query_list, slice_)
global_descriptors = extract_features.main(
retrieval_conf, ref_images, outputs
)
global_descriptors = extract_features.main(
retrieval_conf, query_images, outputs
)
pairs_from_retrieval.main(
global_descriptors,
loc_pairs,
num_loc,
query_list=query_list,
db_model=ref_sfm,
)
features = extract_features.main(
feature_conf, query_images, outputs, as_half=True
)
loc_matches = match_features.main(
matcher_conf, loc_pairs, feature_conf["output"], outputs
)
localize_sfm.main(
ref_sfm,
dataset / "queries/*_time_queries_with_intrinsics.txt",
loc_pairs,
features,
loc_matches,
results,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--slices",
type=str,
default="*",
help="a single number, an interval (e.g. 2-6), "
"or a Python-style list or int (e.g. [2, 3, 4]",
)
parser.add_argument(
"--dataset",
type=Path,
default="datasets/cmu_extended",
help="Path to the dataset, default: %(default)s",
)
parser.add_argument(
"--outputs",
type=Path,
default="outputs/aachen_extended",
help="Path to the output directory, default: %(default)s",
)
parser.add_argument(
"--num_covis",
type=int,
default=20,
help="Number of image pairs for SfM, default: %(default)s",
)
parser.add_argument(
"--num_loc",
type=int,
default=10,
help="Number of image pairs for loc, default: %(default)s",
)
args = parser.parse_args()
if args.slice == "*":
slices = TEST_SLICES
if "-" in args.slices:
min_, max_ = args.slices.split("-")
slices = list(range(int(min_), int(max_) + 1))
else:
slices = eval(args.slices)
if isinstance(slices, int):
slices = [slices]
for slice_ in slices:
logger.info("Working on slice %s.", slice_)
run_slice(
f"slice{slice_}",
args.dataset,
args.outputs,
args.num_covis,
args.num_loc,
)