# Copyright 2020 Toyota Research Institute. All rights reserved. import numpy as np import torch import torchvision.transforms as transforms from tqdm import tqdm from evaluation.descriptor_evaluation import (compute_homography, compute_matching_score) from evaluation.detector_evaluation import compute_repeatability def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), top_k=300): """Keypoint net evaluation script. Parameters ---------- data_loader: torch.utils.data.DataLoader Dataset loader. keypoint_net: torch.nn.module Keypoint network. output_shape: tuple Original image shape. top_k: int Number of keypoints to use to compute metrics, selected based on probability. use_color: bool Use color or grayscale images. """ keypoint_net.eval() keypoint_net.training = False conf_threshold = 0.0 localization_err, repeatability = [], [] correctness1, correctness3, correctness5, MScore = [], [], [], [] with torch.no_grad(): for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"): image = sample['image'].cuda() warped_image = sample['warped_image'].cuda() score_1, coord_1, desc1 = keypoint_net(image) score_2, coord_2, desc2 = keypoint_net(warped_image) B, _, Hc, Wc = desc1.shape # Scores & Descriptors score_1 = torch.cat([coord_1, score_1], dim=1).view(3, -1).t().cpu().numpy() score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy() desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() # Filter based on confidence threshold desc1 = desc1[score_1[:, 2] > conf_threshold, :] desc2 = desc2[score_2[:, 2] > conf_threshold, :] score_1 = score_1[score_1[:, 2] > conf_threshold, :] score_2 = score_2[score_2[:, 2] > conf_threshold, :] # Prepare data for eval data = {'image': sample['image'].numpy().squeeze(), 'image_shape' : output_shape[::-1], 'warped_image': sample['warped_image'].numpy().squeeze(), 'homography': sample['homography'].squeeze().numpy(), 'prob': score_1, 'warped_prob': score_2, 'desc': desc1, 'warped_desc': desc2} # Compute repeatabilty and localization error _, _, rep, loc_err = compute_repeatability(data, keep_k_points=top_k, distance_thresh=3) repeatability.append(rep) localization_err.append(loc_err) # Compute correctness c1, c2, c3 = compute_homography(data, keep_k_points=top_k) correctness1.append(c1) correctness3.append(c2) correctness5.append(c3) # Compute matching score mscore = compute_matching_score(data, keep_k_points=top_k) MScore.append(mscore) return np.mean(repeatability), np.mean(localization_err), \ np.mean(correctness1), np.mean(correctness3), np.mean(correctness5), np.mean(MScore)