Spaces:
Running
Running
from collections import defaultdict | |
from collections.abc import Iterable | |
from pathlib import Path | |
from pprint import pprint | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
from ..datasets import get_dataset | |
from ..models.cache_loader import CacheLoader | |
from ..settings import EVAL_PATH | |
from ..utils.export_predictions import export_predictions | |
from ..utils.tensor import map_tensor | |
from ..utils.tools import AUCMetric | |
from ..visualization.viz2d import plot_cumulative | |
from .eval_pipeline import EvalPipeline | |
from .io import get_eval_parser, load_model, parse_eval_args | |
from .utils import ( | |
eval_homography_dlt, | |
eval_homography_robust, | |
eval_matches_homography, | |
eval_poses, | |
) | |
class HPatchesPipeline(EvalPipeline): | |
default_conf = { | |
"data": { | |
"batch_size": 1, | |
"name": "hpatches", | |
"num_workers": 16, | |
"preprocessing": { | |
"resize": 480, # we also resize during eval to have comparable metrics | |
"side": "short", | |
}, | |
}, | |
"model": { | |
"ground_truth": { | |
"name": None, # remove gt matches | |
} | |
}, | |
"eval": { | |
"estimator": "poselib", | |
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best | |
}, | |
} | |
export_keys = [ | |
"keypoints0", | |
"keypoints1", | |
"keypoint_scores0", | |
"keypoint_scores1", | |
"matches0", | |
"matches1", | |
"matching_scores0", | |
"matching_scores1", | |
] | |
optional_export_keys = [ | |
"lines0", | |
"lines1", | |
"orig_lines0", | |
"orig_lines1", | |
"line_matches0", | |
"line_matches1", | |
"line_matching_scores0", | |
"line_matching_scores1", | |
] | |
def _init(self, conf): | |
pass | |
def get_dataloader(self, data_conf=None): | |
data_conf = data_conf if data_conf else self.default_conf["data"] | |
dataset = get_dataset("hpatches")(data_conf) | |
return dataset.get_data_loader("test") | |
def get_predictions(self, experiment_dir, model=None, overwrite=False): | |
pred_file = experiment_dir / "predictions.h5" | |
if not pred_file.exists() or overwrite: | |
if model is None: | |
model = load_model(self.conf.model, self.conf.checkpoint) | |
export_predictions( | |
self.get_dataloader(self.conf.data), | |
model, | |
pred_file, | |
keys=self.export_keys, | |
optional_keys=self.optional_export_keys, | |
) | |
return pred_file | |
def run_eval(self, loader, pred_file): | |
assert pred_file.exists() | |
results = defaultdict(list) | |
conf = self.conf.eval | |
test_thresholds = ( | |
([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) | |
if not isinstance(conf.ransac_th, Iterable) | |
else conf.ransac_th | |
) | |
pose_results = defaultdict(lambda: defaultdict(list)) | |
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() | |
for i, data in enumerate(tqdm(loader)): | |
pred = cache_loader(data) | |
# Remove batch dimension | |
data = map_tensor(data, lambda t: torch.squeeze(t, dim=0)) | |
# add custom evaluations here | |
if "keypoints0" in pred: | |
results_i = eval_matches_homography(data, pred) | |
results_i = {**results_i, **eval_homography_dlt(data, pred)} | |
else: | |
results_i = {} | |
for th in test_thresholds: | |
pose_results_i = eval_homography_robust( | |
data, | |
pred, | |
{"estimator": conf.estimator, "ransac_th": th}, | |
) | |
[pose_results[th][k].append(v) for k, v in pose_results_i.items()] | |
# we also store the names for later reference | |
results_i["names"] = data["name"][0] | |
results_i["scenes"] = data["scene"][0] | |
for k, v in results_i.items(): | |
results[k].append(v) | |
# summarize results as a dict[str, float] | |
# you can also add your custom evaluations here | |
summaries = {} | |
for k, v in results.items(): | |
arr = np.array(v) | |
if not np.issubdtype(np.array(v).dtype, np.number): | |
continue | |
summaries[f"m{k}"] = round(np.median(arr), 3) | |
auc_ths = [1, 3, 5] | |
best_pose_results, best_th = eval_poses( | |
pose_results, auc_ths=auc_ths, key="H_error_ransac", unit="px" | |
) | |
if "H_error_dlt" in results.keys(): | |
dlt_aucs = AUCMetric(auc_ths, results["H_error_dlt"]).compute() | |
for i, ath in enumerate(auc_ths): | |
summaries[f"H_error_dlt@{ath}px"] = dlt_aucs[i] | |
results = {**results, **pose_results[best_th]} | |
summaries = { | |
**summaries, | |
**best_pose_results, | |
} | |
figures = { | |
"homography_recall": plot_cumulative( | |
{ | |
"DLT": results["H_error_dlt"], | |
self.conf.eval.estimator: results["H_error_ransac"], | |
}, | |
[0, 10], | |
unit="px", | |
title="Homography ", | |
) | |
} | |
return summaries, figures, results | |
if __name__ == "__main__": | |
dataset_name = Path(__file__).stem | |
parser = get_eval_parser() | |
args = parser.parse_intermixed_args() | |
default_conf = OmegaConf.create(HPatchesPipeline.default_conf) | |
# mingle paths | |
output_dir = Path(EVAL_PATH, dataset_name) | |
output_dir.mkdir(exist_ok=True, parents=True) | |
name, conf = parse_eval_args( | |
dataset_name, | |
args, | |
"configs/", | |
default_conf, | |
) | |
experiment_dir = output_dir / name | |
experiment_dir.mkdir(exist_ok=True) | |
pipeline = HPatchesPipeline(conf) | |
s, f, r = pipeline.run( | |
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval | |
) | |
# print results | |
pprint(s) | |
if args.plot: | |
for name, fig in f.items(): | |
fig.canvas.manager.set_window_title(name) | |
plt.show() | |