""" This file implements the evaluation metrics. """ import torch import torch.nn.functional as F import numpy as np from torchvision.ops.boxes import batched_nms from ..misc.geometry_utils import keypoints_to_grid class Metrics(object): """Metric evaluation calculator.""" def __init__( self, detection_thresh, prob_thresh, grid_size, junc_metric_lst=None, heatmap_metric_lst=None, pr_metric_lst=None, desc_metric_lst=None, ): # List supported metrics self.supported_junc_metrics = [ "junc_precision", "junc_precision_nms", "junc_recall", "junc_recall_nms", ] self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"] self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] self.supported_desc_metrics = ["matching_score"] # If metric_lst is None, default to use all metrics if junc_metric_lst is None: self.junc_metric_lst = self.supported_junc_metrics else: self.junc_metric_lst = junc_metric_lst if heatmap_metric_lst is None: self.heatmap_metric_lst = self.supported_heatmap_metrics else: self.heatmap_metric_lst = heatmap_metric_lst if pr_metric_lst is None: self.pr_metric_lst = self.supported_pr_metrics else: self.pr_metric_lst = pr_metric_lst # For the descriptors, the default None assumes no desc metric at all if desc_metric_lst is None: self.desc_metric_lst = [] elif desc_metric_lst == "all": self.desc_metric_lst = self.supported_desc_metrics else: self.desc_metric_lst = desc_metric_lst if not self._check_metrics(): raise ValueError("[Error] Some elements in the metric_lst are invalid.") # Metric mapping table self.metric_table = { "junc_precision": junction_precision(detection_thresh), "junc_precision_nms": junction_precision(detection_thresh), "junc_recall": junction_recall(detection_thresh), "junc_recall_nms": junction_recall(detection_thresh), "heatmap_precision": heatmap_precision(prob_thresh), "heatmap_recall": heatmap_recall(prob_thresh), "junc_pr": junction_pr(), "junc_nms_pr": junction_pr(), "matching_score": matching_score(grid_size), } # Initialize the results self.metric_results = {} for key in self.metric_table.keys(): self.metric_results[key] = 0.0 def evaluate( self, junc_pred, junc_pred_nms, junc_gt, heatmap_pred, heatmap_gt, valid_mask, line_points1=None, line_points2=None, desc_pred1=None, desc_pred2=None, valid_points=None, ): """Perform evaluation.""" for metric in self.junc_metric_lst: # If nms metrics then use nms to compute it. if "nms" in metric: junc_pred_input = junc_pred_nms # Use normal inputs instead. else: junc_pred_input = junc_pred self.metric_results[metric] = self.metric_table[metric]( junc_pred_input, junc_gt, valid_mask ) for metric in self.heatmap_metric_lst: self.metric_results[metric] = self.metric_table[metric]( heatmap_pred, heatmap_gt, valid_mask ) for metric in self.pr_metric_lst: if "nms" in metric: self.metric_results[metric] = self.metric_table[metric]( junc_pred_nms, junc_gt, valid_mask ) else: self.metric_results[metric] = self.metric_table[metric]( junc_pred, junc_gt, valid_mask ) for metric in self.desc_metric_lst: self.metric_results[metric] = self.metric_table[metric]( line_points1, line_points2, desc_pred1, desc_pred2, valid_points ) def _check_metrics(self): """Check if all input metrics are valid.""" flag = True for metric in self.junc_metric_lst: if not metric in self.supported_junc_metrics: flag = False break for metric in self.heatmap_metric_lst: if not metric in self.supported_heatmap_metrics: flag = False break for metric in self.desc_metric_lst: if not metric in self.supported_desc_metrics: flag = False break return flag class AverageMeter(object): def __init__( self, junc_metric_lst=None, heatmap_metric_lst=None, is_training=True, desc_metric_lst=None, ): # List supported metrics self.supported_junc_metrics = [ "junc_precision", "junc_precision_nms", "junc_recall", "junc_recall_nms", ] self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"] self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] self.supported_desc_metrics = ["matching_score"] # Record loss in training mode # if is_training: self.supported_loss = [ "junc_loss", "heatmap_loss", "descriptor_loss", "total_loss", ] self.is_training = is_training # If metric_lst is None, default to use all metrics if junc_metric_lst is None: self.junc_metric_lst = self.supported_junc_metrics else: self.junc_metric_lst = junc_metric_lst if heatmap_metric_lst is None: self.heatmap_metric_lst = self.supported_heatmap_metrics else: self.heatmap_metric_lst = heatmap_metric_lst # For the descriptors, the default None assumes no desc metric at all if desc_metric_lst is None: self.desc_metric_lst = [] elif desc_metric_lst == "all": self.desc_metric_lst = self.supported_desc_metrics else: self.desc_metric_lst = desc_metric_lst if not self._check_metrics(): raise ValueError("[Error] Some elements in the metric_lst are invalid.") # Initialize the results self.metric_results = {} for key in ( self.supported_junc_metrics + self.supported_heatmap_metrics + self.supported_loss + self.supported_desc_metrics ): self.metric_results[key] = 0.0 for key in self.supported_pr_metrics: zero_lst = [0 for _ in range(50)] self.metric_results[key] = { "tp": zero_lst, "tn": zero_lst, "fp": zero_lst, "fn": zero_lst, "precision": zero_lst, "recall": zero_lst, } # Initialize total count self.count = 0 def update(self, metrics, loss_dict=None, num_samples=1): # loss should be given in the training mode if self.is_training and (loss_dict is None): raise ValueError("[Error] loss info should be given in the training mode.") # update total counts self.count += num_samples # update all the metrics for met in ( self.supported_junc_metrics + self.supported_heatmap_metrics + self.supported_desc_metrics ): self.metric_results[met] += num_samples * metrics.metric_results[met] # Update all the losses for loss in loss_dict.keys(): self.metric_results[loss] += num_samples * loss_dict[loss] # Update all pr counts for pr_met in self.supported_pr_metrics: # Update all tp, tn, fp, fn, precision, and recall. for key in metrics.metric_results[pr_met].keys(): # Update each interval for idx in range(len(self.metric_results[pr_met][key])): self.metric_results[pr_met][key][idx] += ( num_samples * metrics.metric_results[pr_met][key][idx] ) def average(self): results = {} for met in self.metric_results.keys(): # Skip pr curve metrics if not met in self.supported_pr_metrics: results[met] = self.metric_results[met] / self.count # Only update precision and recall in pr metrics else: met_results = { "tp": self.metric_results[met]["tp"], "tn": self.metric_results[met]["tn"], "fp": self.metric_results[met]["fp"], "fn": self.metric_results[met]["fn"], "precision": [], "recall": [], } for idx in range(len(self.metric_results[met]["precision"])): met_results["precision"].append( self.metric_results[met]["precision"][idx] / self.count ) met_results["recall"].append( self.metric_results[met]["recall"][idx] / self.count ) results[met] = met_results return results def _check_metrics(self): """Check if all input metrics are valid.""" flag = True for metric in self.junc_metric_lst: if not metric in self.supported_junc_metrics: flag = False break for metric in self.heatmap_metric_lst: if not metric in self.supported_heatmap_metrics: flag = False break for metric in self.desc_metric_lst: if not metric in self.supported_desc_metrics: flag = False break return flag class junction_precision(object): """Junction precision.""" def __init__(self, detection_thresh): self.detection_thresh = detection_thresh # Compute the evaluation result def __call__(self, junc_pred, junc_gt, valid_mask): # Convert prediction to discrete detection junc_pred = (junc_pred >= self.detection_thresh).astype(np.int) junc_pred = junc_pred * valid_mask.squeeze() # Deal with the corner case of the prediction if np.sum(junc_pred) > 0: precision = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_pred) else: precision = 0 return float(precision) class junction_recall(object): """Junction recall.""" def __init__(self, detection_thresh): self.detection_thresh = detection_thresh # Compute the evaluation result def __call__(self, junc_pred, junc_gt, valid_mask): # Convert prediction to discrete detection junc_pred = (junc_pred >= self.detection_thresh).astype(np.int) junc_pred = junc_pred * valid_mask.squeeze() # Deal with the corner case of the recall. if np.sum(junc_gt): recall = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_gt) else: recall = 0 return float(recall) class junction_pr(object): """Junction precision-recall info.""" def __init__(self, num_threshold=50): self.max = 0.4 step = self.max / num_threshold self.min = step self.intervals = np.flip(np.arange(self.min, self.max + step, step)) def __call__(self, junc_pred_raw, junc_gt, valid_mask): tp_lst = [] fp_lst = [] tn_lst = [] fn_lst = [] precision_lst = [] recall_lst = [] valid_mask = valid_mask.squeeze() # Iterate through all the thresholds for thresh in list(self.intervals): # Convert prediction to discrete detection junc_pred = (junc_pred_raw >= thresh).astype(np.int) junc_pred = junc_pred * valid_mask # Compute tp, fp, tn, fn junc_gt = junc_gt.squeeze() tp = np.sum(junc_pred * junc_gt) tn = np.sum( (junc_pred == 0).astype(np.float) * (junc_gt == 0).astype(np.float) * valid_mask ) fp = np.sum( (junc_pred == 1).astype(np.float) * (junc_gt == 0).astype(np.float) * valid_mask ) fn = np.sum( (junc_pred == 0).astype(np.float) * (junc_gt == 1).astype(np.float) * valid_mask ) tp_lst.append(tp) tn_lst.append(tn) fp_lst.append(fp) fn_lst.append(fn) precision_lst.append(tp / (tp + fp)) recall_lst.append(tp / (tp + fn)) return { "tp": np.array(tp_lst), "tn": np.array(tn_lst), "fp": np.array(fp_lst), "fn": np.array(fn_lst), "precision": np.array(precision_lst), "recall": np.array(recall_lst), } class heatmap_precision(object): """Heatmap precision.""" def __init__(self, prob_thresh): self.prob_thresh = prob_thresh def __call__(self, heatmap_pred, heatmap_gt, valid_mask): # Assume NHWC (Handle L1 and L2 cases) NxHxWx1 heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh) heatmap_pred = heatmap_pred * valid_mask.squeeze() # Deal with the corner case of the prediction if np.sum(heatmap_pred) > 0: precision = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum( heatmap_pred ) else: precision = 0.0 return precision class heatmap_recall(object): """Heatmap recall.""" def __init__(self, prob_thresh): self.prob_thresh = prob_thresh def __call__(self, heatmap_pred, heatmap_gt, valid_mask): # Assume NHWC (Handle L1 and L2 cases) NxHxWx1 heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh) heatmap_pred = heatmap_pred * valid_mask.squeeze() # Deal with the corner case of the ground truth if np.sum(heatmap_gt) > 0: recall = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum(heatmap_gt) else: recall = 0.0 return recall class matching_score(object): """Descriptors matching score.""" def __init__(self, grid_size): self.grid_size = grid_size def __call__(self, points1, points2, desc_pred1, desc_pred2, line_indices): b_size, _, Hc, Wc = desc_pred1.size() img_size = (Hc * self.grid_size, Wc * self.grid_size) device = desc_pred1.device # Extract valid keypoints n_points = line_indices.size()[1] valid_points = line_indices.bool().flatten() n_correct_points = torch.sum(valid_points).item() if n_correct_points == 0: return torch.tensor(0.0, dtype=torch.float, device=device) # Convert the keypoints to a grid suitable for interpolation grid1 = keypoints_to_grid(points1, img_size) grid2 = keypoints_to_grid(points2, img_size) # Extract the descriptors desc1 = ( F.grid_sample(desc_pred1, grid1) .permute(0, 2, 3, 1) .reshape(b_size * n_points, -1)[valid_points] ) desc1 = F.normalize(desc1, dim=1) desc2 = ( F.grid_sample(desc_pred2, grid2) .permute(0, 2, 3, 1) .reshape(b_size * n_points, -1)[valid_points] ) desc2 = F.normalize(desc2, dim=1) desc_dists = 2 - 2 * (desc1 @ desc2.t()) # Compute percentage of correct matches matches0 = torch.min(desc_dists, dim=1)[1] matches1 = torch.min(desc_dists, dim=0)[1] matching_score = matches1[matches0] == torch.arange(len(matches0)).to(device) matching_score = matching_score.float().mean() return matching_score def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0): """Non-maximum suppression adapted from SuperPoint.""" # Iterate through batch dimension im_h = prob_predictions.shape[1] im_w = prob_predictions.shape[2] output_lst = [] for i in range(prob_predictions.shape[0]): # print(i) prob_pred = prob_predictions[i, ...] # Filter the points using prob_thresh coord = np.where(prob_pred >= prob_thresh) # HW format points = np.concatenate( (coord[0][..., None], coord[1][..., None]), axis=1 ) # HW format # Get the probability score prob_score = prob_pred[points[:, 0], points[:, 1]] # Perform super nms # Modify the in_points to xy format (instead of HW format) in_points = np.concatenate( (coord[1][..., None], coord[0][..., None], prob_score), axis=1 ).T keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh) # Remember to flip outputs back to HW format keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T) keep_score = keep_points_[-1, :].T # Whether we only keep the topk value if (top_k > 0) or (top_k is None): k = min([keep_points.shape[0], top_k]) keep_points = keep_points[:k, :] keep_score = keep_score[:k] # Re-compose the probability map output_map = np.zeros([im_h, im_w]) output_map[ keep_points[:, 0].astype(np.int), keep_points[:, 1].astype(np.int) ] = keep_score.squeeze() output_lst.append(output_map[None, ...]) return np.concatenate(output_lst, axis=0) def nms_fast(in_corners, H, W, dist_thresh): """ Run a faster approximate Non-Max-Suppression on numpy corners shaped: 3xN [x_i,y_i,conf_i]^T Algo summary: Create a grid sized HxW. Assign each corner location a 1, rest are zeros. Iterate through all the 1's and convert them to -1 or 0. Suppress points by setting nearby values to 0. Grid Value Legend: -1 : Kept. 0 : Empty or suppressed. 1 : To be processed (converted to either kept or supressed). NOTE: The NMS first rounds points to integers, so NMS distance might not be exactly dist_thresh. It also assumes points are within image boundary. Inputs in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T. H - Image height. W - Image width. dist_thresh - Distance to suppress, measured as an infinite distance. Returns nmsed_corners - 3xN numpy matrix with surviving corners. nmsed_inds - N length numpy vector with surviving corner indices. """ grid = np.zeros((H, W)).astype(int) # Track NMS data. inds = np.zeros((H, W)).astype(int) # Store indices of points. # Sort by confidence and round to nearest int. inds1 = np.argsort(-in_corners[2, :]) corners = in_corners[:, inds1] rcorners = corners[:2, :].round().astype(int) # Rounded corners. # Check for edge case of 0 or 1 corners. if rcorners.shape[1] == 0: return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int) if rcorners.shape[1] == 1: out = np.vstack((rcorners, in_corners[2])).reshape(3, 1) return out, np.zeros((1)).astype(int) # Initialize the grid. for i, rc in enumerate(rcorners.T): grid[rcorners[1, i], rcorners[0, i]] = 1 inds[rcorners[1, i], rcorners[0, i]] = i # Pad the border of the grid, so that we can NMS points near the border. pad = dist_thresh grid = np.pad(grid, ((pad, pad), (pad, pad)), mode="constant") # Iterate through points, highest to lowest conf, suppress neighborhood. count = 0 for i, rc in enumerate(rcorners.T): # Account for top and left padding. pt = (rc[0] + pad, rc[1] + pad) if grid[pt[1], pt[0]] == 1: # If not yet suppressed. grid[pt[1] - pad : pt[1] + pad + 1, pt[0] - pad : pt[0] + pad + 1] = 0 grid[pt[1], pt[0]] = -1 count += 1 # Get all surviving -1's and return sorted array of remaining corners. keepy, keepx = np.where(grid == -1) keepy, keepx = keepy - pad, keepx - pad inds_keep = inds[keepy, keepx] out = corners[:, inds_keep] values = out[-1, :] inds2 = np.argsort(-values) out = out[:, inds2] out_inds = inds1[inds_keep[inds2]] return out, out_inds