Vincentqyw
fix: roma
c74a070
raw
history blame
No virus
3.39 kB
# Copyright 2020 Toyota Research Institute. All rights reserved.
import numpy as np
import torch
import torchvision.transforms as transforms
from tqdm import tqdm
from evaluation.descriptor_evaluation import compute_homography, compute_matching_score
from evaluation.detector_evaluation import compute_repeatability
def evaluate_keypoint_net(
data_loader, keypoint_net, output_shape=(320, 240), top_k=300
):
"""Keypoint net evaluation script.
Parameters
----------
data_loader: torch.utils.data.DataLoader
Dataset loader.
keypoint_net: torch.nn.module
Keypoint network.
output_shape: tuple
Original image shape.
top_k: int
Number of keypoints to use to compute metrics, selected based on probability.
use_color: bool
Use color or grayscale images.
"""
keypoint_net.eval()
keypoint_net.training = False
conf_threshold = 0.0
localization_err, repeatability = [], []
correctness1, correctness3, correctness5, MScore = [], [], [], []
with torch.no_grad():
for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"):
image = sample["image"].cuda()
warped_image = sample["warped_image"].cuda()
score_1, coord_1, desc1 = keypoint_net(image)
score_2, coord_2, desc2 = keypoint_net(warped_image)
B, _, Hc, Wc = desc1.shape
# Scores & Descriptors
score_1 = torch.cat([coord_1, score_1], dim=1).view(3, -1).t().cpu().numpy()
score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy()
desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy()
desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy()
# Filter based on confidence threshold
desc1 = desc1[score_1[:, 2] > conf_threshold, :]
desc2 = desc2[score_2[:, 2] > conf_threshold, :]
score_1 = score_1[score_1[:, 2] > conf_threshold, :]
score_2 = score_2[score_2[:, 2] > conf_threshold, :]
# Prepare data for eval
data = {
"image": sample["image"].numpy().squeeze(),
"image_shape": output_shape[::-1],
"warped_image": sample["warped_image"].numpy().squeeze(),
"homography": sample["homography"].squeeze().numpy(),
"prob": score_1,
"warped_prob": score_2,
"desc": desc1,
"warped_desc": desc2,
}
# Compute repeatabilty and localization error
_, _, rep, loc_err = compute_repeatability(
data, keep_k_points=top_k, distance_thresh=3
)
repeatability.append(rep)
localization_err.append(loc_err)
# Compute correctness
c1, c2, c3 = compute_homography(data, keep_k_points=top_k)
correctness1.append(c1)
correctness3.append(c2)
correctness5.append(c3)
# Compute matching score
mscore = compute_matching_score(data, keep_k_points=top_k)
MScore.append(mscore)
return (
np.mean(repeatability),
np.mean(localization_err),
np.mean(correctness1),
np.mean(correctness3),
np.mean(correctness5),
np.mean(MScore),
)