""" A two-view sparse feature matching pipeline. This model contains sub-models for each step: feature extraction, feature matching, outlier filtering, pose estimation. Each step is optional, and the features or matches can be provided as input. Default: SuperPoint with nearest neighbor matching. Convention for the matches: m0[i] is the index of the keypoint in image 1 that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. """ import numpy as np import torch from .. import get_model from .base_model import BaseModel def keep_quadrant_kp_subset(keypoints, scores, descs, h, w): """Keep only keypoints in one of the four quadrant of the image.""" h2, w2 = h // 2, w // 2 w_x = np.random.choice([0, w2]) w_y = np.random.choice([0, h2]) valid_mask = ((keypoints[..., 0] >= w_x) & (keypoints[..., 0] < w_x + w2) & (keypoints[..., 1] >= w_y) & (keypoints[..., 1] < w_y + h2)) keypoints = keypoints[valid_mask][None] scores = scores[valid_mask][None] descs = descs.permute(0, 2, 1)[valid_mask].t()[None] return keypoints, scores, descs def keep_random_kp_subset(keypoints, scores, descs, num_selected): """Keep a random subset of keypoints.""" num_kp = keypoints.shape[1] selected_kp = torch.randperm(num_kp)[:num_selected] keypoints = keypoints[:, selected_kp] scores = scores[:, selected_kp] descs = descs[:, :, selected_kp] return keypoints, scores, descs def keep_best_kp_subset(keypoints, scores, descs, num_selected): """Keep the top num_selected best keypoints.""" sorted_indices = torch.sort(scores, dim=1)[1] selected_kp = sorted_indices[:, -num_selected:] keypoints = torch.gather(keypoints, 1, selected_kp[:, :, None].repeat(1, 1, 2)) scores = torch.gather(scores, 1, selected_kp) descs = torch.gather(descs, 2, selected_kp[:, None].repeat(1, descs.shape[1], 1)) return keypoints, scores, descs class TwoViewPipeline(BaseModel): default_conf = { 'extractor': { 'name': 'superpoint', 'trainable': False, }, 'use_lines': False, 'use_points': True, 'randomize_num_kp': False, 'detector': {'name': None}, 'descriptor': {'name': None}, 'matcher': {'name': 'nearest_neighbor_matcher'}, 'filter': {'name': None}, 'solver': {'name': None}, 'ground_truth': { 'from_pose_depth': False, 'from_homography': False, 'th_positive': 3, 'th_negative': 5, 'reward_positive': 1, 'reward_negative': -0.25, 'is_likelihood_soft': True, 'p_random_occluders': 0, 'n_line_sampled_pts': 50, 'line_perp_dist_th': 5, 'overlap_th': 0.2, 'min_visibility_th': 0.5 }, } required_data_keys = ['image0', 'image1'] strict_conf = False # need to pass new confs to children models components = [ 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver'] def _init(self, conf): if conf.extractor.name: self.extractor = get_model(conf.extractor.name)(conf.extractor) else: if self.conf.detector.name: self.detector = get_model(conf.detector.name)(conf.detector) else: self.required_data_keys += ['keypoints0', 'keypoints1'] if self.conf.descriptor.name: self.descriptor = get_model(conf.descriptor.name)( conf.descriptor) else: self.required_data_keys += ['descriptors0', 'descriptors1'] if conf.matcher.name: self.matcher = get_model(conf.matcher.name)(conf.matcher) else: self.required_data_keys += ['matches0'] if conf.filter.name: self.filter = get_model(conf.filter.name)(conf.filter) if conf.solver.name: self.solver = get_model(conf.solver.name)(conf.solver) def _forward(self, data): def process_siamese(data, i): data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} if self.conf.extractor.name: pred_i = self.extractor(data_i) else: pred_i = {} if self.conf.detector.name: pred_i = self.detector(data_i) else: for k in ['keypoints', 'keypoint_scores', 'descriptors', 'lines', 'line_scores', 'line_descriptors', 'valid_lines']: if k in data_i: pred_i[k] = data_i[k] if self.conf.descriptor.name: pred_i = { **pred_i, **self.descriptor({**data_i, **pred_i})} return pred_i pred0 = process_siamese(data, '0') pred1 = process_siamese(data, '1') pred = {**{k + '0': v for k, v in pred0.items()}, **{k + '1': v for k, v in pred1.items()}} if self.conf.matcher.name: pred = {**pred, **self.matcher({**data, **pred})} if self.conf.filter.name: pred = {**pred, **self.filter({**data, **pred})} if self.conf.solver.name: pred = {**pred, **self.solver({**data, **pred})} return pred def loss(self, pred, data): losses = {} total = 0 for k in self.components: if self.conf[k].name: try: losses_ = getattr(self, k).loss(pred, {**pred, **data}) except NotImplementedError: continue losses = {**losses, **losses_} total = losses_['total'] + total return {**losses, 'total': total} def metrics(self, pred, data): metrics = {} for k in self.components: if self.conf[k].name: try: metrics_ = getattr(self, k).metrics(pred, {**pred, **data}) except NotImplementedError: continue metrics = {**metrics, **metrics_} return metrics