|
|
|
|
|
import numpy as np |
|
import torch |
|
import torchvision.transforms as transforms |
|
from tqdm import tqdm |
|
|
|
from evaluation.descriptor_evaluation import (compute_homography, |
|
compute_matching_score) |
|
from evaluation.detector_evaluation import compute_repeatability |
|
|
|
|
|
def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), top_k=300): |
|
"""Keypoint net evaluation script. |
|
|
|
Parameters |
|
---------- |
|
data_loader: torch.utils.data.DataLoader |
|
Dataset loader. |
|
keypoint_net: torch.nn.module |
|
Keypoint network. |
|
output_shape: tuple |
|
Original image shape. |
|
top_k: int |
|
Number of keypoints to use to compute metrics, selected based on probability. |
|
use_color: bool |
|
Use color or grayscale images. |
|
""" |
|
keypoint_net.eval() |
|
keypoint_net.training = False |
|
|
|
conf_threshold = 0.0 |
|
localization_err, repeatability = [], [] |
|
correctness1, correctness3, correctness5, MScore = [], [], [], [] |
|
|
|
with torch.no_grad(): |
|
for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"): |
|
|
|
image = sample['image'].cuda() |
|
warped_image = sample['warped_image'].cuda() |
|
|
|
score_1, coord_1, desc1 = keypoint_net(image) |
|
score_2, coord_2, desc2 = keypoint_net(warped_image) |
|
B, _, Hc, Wc = desc1.shape |
|
|
|
|
|
score_1 = torch.cat([coord_1, score_1], dim=1).view(3, -1).t().cpu().numpy() |
|
score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy() |
|
desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() |
|
desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() |
|
|
|
|
|
desc1 = desc1[score_1[:, 2] > conf_threshold, :] |
|
desc2 = desc2[score_2[:, 2] > conf_threshold, :] |
|
score_1 = score_1[score_1[:, 2] > conf_threshold, :] |
|
score_2 = score_2[score_2[:, 2] > conf_threshold, :] |
|
|
|
|
|
data = {'image': sample['image'].numpy().squeeze(), |
|
'image_shape' : output_shape[::-1], |
|
'warped_image': sample['warped_image'].numpy().squeeze(), |
|
'homography': sample['homography'].squeeze().numpy(), |
|
'prob': score_1, |
|
'warped_prob': score_2, |
|
'desc': desc1, |
|
'warped_desc': desc2} |
|
|
|
|
|
_, _, rep, loc_err = compute_repeatability(data, keep_k_points=top_k, distance_thresh=3) |
|
repeatability.append(rep) |
|
localization_err.append(loc_err) |
|
|
|
|
|
c1, c2, c3 = compute_homography(data, keep_k_points=top_k) |
|
correctness1.append(c1) |
|
correctness3.append(c2) |
|
correctness5.append(c3) |
|
|
|
|
|
mscore = compute_matching_score(data, keep_k_points=top_k) |
|
MScore.append(mscore) |
|
|
|
return np.mean(repeatability), np.mean(localization_err), \ |
|
np.mean(correctness1), np.mean(correctness3), np.mean(correctness5), np.mean(MScore) |
|
|