Realcat
add: GIM (https://github.com/xuelunshen/gim)
4d4dd90
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf
from tqdm import tqdm
from ..datasets import get_dataset
from ..models.cache_loader import CacheLoader
from ..settings import EVAL_PATH
from ..utils.export_predictions import export_predictions
from .eval_pipeline import EvalPipeline, load_eval
from .io import get_eval_parser, load_model, parse_eval_args
from .utils import aggregate_pr_results, get_tp_fp_pts
def eval_dataset(loader, pred_file, suffix=""):
results = defaultdict(list)
results["num_pos" + suffix] = 0
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
for data in tqdm(loader):
pred = cache_loader(data)
if suffix == "":
scores = pred["matching_scores0"].numpy()
sort_indices = np.argsort(scores)[::-1]
gt_matches = pred["gt_matches0"].numpy()[sort_indices]
pred_matches = pred["matches0"].numpy()[sort_indices]
else:
scores = pred["line_matching_scores0"].numpy()
sort_indices = np.argsort(scores)[::-1]
gt_matches = pred["gt_line_matches0"].numpy()[sort_indices]
pred_matches = pred["line_matches0"].numpy()[sort_indices]
scores = scores[sort_indices]
tp, fp, scores, num_pos = get_tp_fp_pts(pred_matches, gt_matches, scores)
results["tp" + suffix].append(tp)
results["fp" + suffix].append(fp)
results["scores" + suffix].append(scores)
results["num_pos" + suffix] += num_pos
# Aggregate the results
return aggregate_pr_results(results, suffix=suffix)
class ETH3DPipeline(EvalPipeline):
default_conf = {
"data": {
"name": "eth3d",
"batch_size": 1,
"train_batch_size": 1,
"val_batch_size": 1,
"test_batch_size": 1,
"num_workers": 16,
},
"model": {
"name": "gluefactory.models.two_view_pipeline",
"ground_truth": {
"name": "gluefactory.models.matchers.depth_matcher",
"use_lines": False,
},
"run_gt_in_forward": True,
},
"eval": {"plot_methods": [], "plot_line_methods": [], "eval_lines": False},
}
export_keys = [
"gt_matches0",
"matches0",
"matching_scores0",
]
optional_export_keys = [
"gt_line_matches0",
"line_matches0",
"line_matching_scores0",
]
def get_dataloader(self, data_conf=None):
data_conf = data_conf if data_conf is not None else self.default_conf["data"]
dataset = get_dataset("eth3d")(data_conf)
return dataset.get_data_loader("test")
def get_predictions(self, experiment_dir, model=None, overwrite=False):
pred_file = experiment_dir / "predictions.h5"
if not pred_file.exists() or overwrite:
if model is None:
model = load_model(self.conf.model, self.conf.checkpoint)
export_predictions(
self.get_dataloader(self.conf.data),
model,
pred_file,
keys=self.export_keys,
optional_keys=self.optional_export_keys,
)
return pred_file
def run_eval(self, loader, pred_file):
eval_conf = self.conf.eval
r = eval_dataset(loader, pred_file)
if self.conf.eval.eval_lines:
r.update(eval_dataset(loader, pred_file, conf=eval_conf, suffix="_lines"))
s = {}
return s, {}, r
def plot_pr_curve(
models_name, results, dst_file="eth3d_pr_curve.pdf", title=None, suffix=""
):
plt.figure()
f_scores = np.linspace(0.2, 0.9, num=8)
for f_score in f_scores:
x = np.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
plt.plot(x[y >= 0], y[y >= 0], color=[0, 0.5, 0], alpha=0.3)
plt.annotate(
"f={0:0.1}".format(f_score),
xy=(0.9, y[45] + 0.02),
alpha=0.4,
fontsize=14,
)
plt.rcParams.update({"font.size": 12})
# plt.rc('legend', fontsize=10)
plt.grid(True)
plt.axis([0.0, 1.0, 0.0, 1.0])
plt.xticks(np.arange(0, 1.05, step=0.1), fontsize=16)
plt.xlabel("Recall", fontsize=18)
plt.ylabel("Precision", fontsize=18)
plt.yticks(np.arange(0, 1.05, step=0.1), fontsize=16)
plt.ylim([0.3, 1.0])
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]
for m, c in zip(models_name, colors):
sAP_string = f'{m}: {results[m]["AP" + suffix]:.1f}'
plt.plot(
results[m]["curve_recall" + suffix],
results[m]["curve_precision" + suffix],
label=sAP_string,
color=c,
)
plt.legend(fontsize=16, loc="lower right")
if title:
plt.title(title)
plt.tight_layout(pad=0.5)
print(f"Saving plot to: {dst_file}")
plt.savefig(dst_file)
plt.show()
if __name__ == "__main__":
dataset_name = Path(__file__).stem
parser = get_eval_parser()
args = parser.parse_intermixed_args()
default_conf = OmegaConf.create(ETH3DPipeline.default_conf)
# mingle paths
output_dir = Path(EVAL_PATH, dataset_name)
output_dir.mkdir(exist_ok=True, parents=True)
name, conf = parse_eval_args(
dataset_name,
args,
"configs/",
default_conf,
)
experiment_dir = output_dir / name
experiment_dir.mkdir(exist_ok=True)
pipeline = ETH3DPipeline(conf)
s, f, r = pipeline.run(
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
)
# print results
for k, v in r.items():
if k.startswith("AP"):
print(f"{k}: {v:.2f}")
if args.plot:
results = {}
for m in conf.eval.plot_methods:
exp_dir = output_dir / m
results[m] = load_eval(exp_dir)[1]
plot_pr_curve(conf.eval.plot_methods, results, dst_file="eth3d_pr_curve.pdf")
if conf.eval.eval_lines:
for m in conf.eval.plot_line_methods:
exp_dir = output_dir / m
results[m] = load_eval(exp_dir)[1]
plot_pr_curve(
conf.eval.plot_line_methods,
results,
dst_file="eth3d_pr_curve_lines.pdf",
suffix="_lines",
)