from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt import numpy as np from omegaconf import OmegaConf from tqdm import tqdm from ..datasets import get_dataset from ..models.cache_loader import CacheLoader from ..settings import EVAL_PATH from ..utils.export_predictions import export_predictions from .eval_pipeline import EvalPipeline, load_eval from .io import get_eval_parser, load_model, parse_eval_args from .utils import aggregate_pr_results, get_tp_fp_pts def eval_dataset(loader, pred_file, suffix=""): results = defaultdict(list) results["num_pos" + suffix] = 0 cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() for data in tqdm(loader): pred = cache_loader(data) if suffix == "": scores = pred["matching_scores0"].numpy() sort_indices = np.argsort(scores)[::-1] gt_matches = pred["gt_matches0"].numpy()[sort_indices] pred_matches = pred["matches0"].numpy()[sort_indices] else: scores = pred["line_matching_scores0"].numpy() sort_indices = np.argsort(scores)[::-1] gt_matches = pred["gt_line_matches0"].numpy()[sort_indices] pred_matches = pred["line_matches0"].numpy()[sort_indices] scores = scores[sort_indices] tp, fp, scores, num_pos = get_tp_fp_pts(pred_matches, gt_matches, scores) results["tp" + suffix].append(tp) results["fp" + suffix].append(fp) results["scores" + suffix].append(scores) results["num_pos" + suffix] += num_pos # Aggregate the results return aggregate_pr_results(results, suffix=suffix) class ETH3DPipeline(EvalPipeline): default_conf = { "data": { "name": "eth3d", "batch_size": 1, "train_batch_size": 1, "val_batch_size": 1, "test_batch_size": 1, "num_workers": 16, }, "model": { "name": "gluefactory.models.two_view_pipeline", "ground_truth": { "name": "gluefactory.models.matchers.depth_matcher", "use_lines": False, }, "run_gt_in_forward": True, }, "eval": {"plot_methods": [], "plot_line_methods": [], "eval_lines": False}, } export_keys = [ "gt_matches0", "matches0", "matching_scores0", ] optional_export_keys = [ "gt_line_matches0", "line_matches0", "line_matching_scores0", ] def get_dataloader(self, data_conf=None): data_conf = data_conf if data_conf is not None else self.default_conf["data"] dataset = get_dataset("eth3d")(data_conf) return dataset.get_data_loader("test") def get_predictions(self, experiment_dir, model=None, overwrite=False): pred_file = experiment_dir / "predictions.h5" if not pred_file.exists() or overwrite: if model is None: model = load_model(self.conf.model, self.conf.checkpoint) export_predictions( self.get_dataloader(self.conf.data), model, pred_file, keys=self.export_keys, optional_keys=self.optional_export_keys, ) return pred_file def run_eval(self, loader, pred_file): eval_conf = self.conf.eval r = eval_dataset(loader, pred_file) if self.conf.eval.eval_lines: r.update(eval_dataset(loader, pred_file, conf=eval_conf, suffix="_lines")) s = {} return s, {}, r def plot_pr_curve( models_name, results, dst_file="eth3d_pr_curve.pdf", title=None, suffix="" ): plt.figure() f_scores = np.linspace(0.2, 0.9, num=8) for f_score in f_scores: x = np.linspace(0.01, 1) y = f_score * x / (2 * x - f_score) plt.plot(x[y >= 0], y[y >= 0], color=[0, 0.5, 0], alpha=0.3) plt.annotate( "f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4, fontsize=14, ) plt.rcParams.update({"font.size": 12}) # plt.rc('legend', fontsize=10) plt.grid(True) plt.axis([0.0, 1.0, 0.0, 1.0]) plt.xticks(np.arange(0, 1.05, step=0.1), fontsize=16) plt.xlabel("Recall", fontsize=18) plt.ylabel("Precision", fontsize=18) plt.yticks(np.arange(0, 1.05, step=0.1), fontsize=16) plt.ylim([0.3, 1.0]) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] for m, c in zip(models_name, colors): sAP_string = f'{m}: {results[m]["AP" + suffix]:.1f}' plt.plot( results[m]["curve_recall" + suffix], results[m]["curve_precision" + suffix], label=sAP_string, color=c, ) plt.legend(fontsize=16, loc="lower right") if title: plt.title(title) plt.tight_layout(pad=0.5) print(f"Saving plot to: {dst_file}") plt.savefig(dst_file) plt.show() if __name__ == "__main__": dataset_name = Path(__file__).stem parser = get_eval_parser() args = parser.parse_intermixed_args() default_conf = OmegaConf.create(ETH3DPipeline.default_conf) # mingle paths output_dir = Path(EVAL_PATH, dataset_name) output_dir.mkdir(exist_ok=True, parents=True) name, conf = parse_eval_args( dataset_name, args, "configs/", default_conf, ) experiment_dir = output_dir / name experiment_dir.mkdir(exist_ok=True) pipeline = ETH3DPipeline(conf) s, f, r = pipeline.run( experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval ) # print results for k, v in r.items(): if k.startswith("AP"): print(f"{k}: {v:.2f}") if args.plot: results = {} for m in conf.eval.plot_methods: exp_dir = output_dir / m results[m] = load_eval(exp_dir)[1] plot_pr_curve(conf.eval.plot_methods, results, dst_file="eth3d_pr_curve.pdf") if conf.eval.eval_lines: for m in conf.eval.plot_line_methods: exp_dir = output_dir / m results[m] = load_eval(exp_dir)[1] plot_pr_curve( conf.eval.plot_line_methods, results, dst_file="eth3d_pr_curve_lines.pdf", suffix="_lines", )