|
""" |
|
A two-view sparse feature matching pipeline. |
|
|
|
This model contains sub-models for each step: |
|
feature extraction, feature matching, outlier filtering, pose estimation. |
|
Each step is optional, and the features or matches can be provided as input. |
|
Default: SuperPoint with nearest neighbor matching. |
|
|
|
Convention for the matches: m0[i] is the index of the keypoint in image 1 |
|
that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .. import get_model |
|
from .base_model import BaseModel |
|
|
|
|
|
def keep_quadrant_kp_subset(keypoints, scores, descs, h, w): |
|
"""Keep only keypoints in one of the four quadrant of the image.""" |
|
h2, w2 = h // 2, w // 2 |
|
w_x = np.random.choice([0, w2]) |
|
w_y = np.random.choice([0, h2]) |
|
valid_mask = ( |
|
(keypoints[..., 0] >= w_x) |
|
& (keypoints[..., 0] < w_x + w2) |
|
& (keypoints[..., 1] >= w_y) |
|
& (keypoints[..., 1] < w_y + h2) |
|
) |
|
keypoints = keypoints[valid_mask][None] |
|
scores = scores[valid_mask][None] |
|
descs = descs.permute(0, 2, 1)[valid_mask].t()[None] |
|
return keypoints, scores, descs |
|
|
|
|
|
def keep_random_kp_subset(keypoints, scores, descs, num_selected): |
|
"""Keep a random subset of keypoints.""" |
|
num_kp = keypoints.shape[1] |
|
selected_kp = torch.randperm(num_kp)[:num_selected] |
|
keypoints = keypoints[:, selected_kp] |
|
scores = scores[:, selected_kp] |
|
descs = descs[:, :, selected_kp] |
|
return keypoints, scores, descs |
|
|
|
|
|
def keep_best_kp_subset(keypoints, scores, descs, num_selected): |
|
"""Keep the top num_selected best keypoints.""" |
|
sorted_indices = torch.sort(scores, dim=1)[1] |
|
selected_kp = sorted_indices[:, -num_selected:] |
|
keypoints = torch.gather(keypoints, 1, selected_kp[:, :, None].repeat(1, 1, 2)) |
|
scores = torch.gather(scores, 1, selected_kp) |
|
descs = torch.gather(descs, 2, selected_kp[:, None].repeat(1, descs.shape[1], 1)) |
|
return keypoints, scores, descs |
|
|
|
|
|
class TwoViewPipeline(BaseModel): |
|
default_conf = { |
|
"extractor": { |
|
"name": "superpoint", |
|
"trainable": False, |
|
}, |
|
"use_lines": False, |
|
"use_points": True, |
|
"randomize_num_kp": False, |
|
"detector": {"name": None}, |
|
"descriptor": {"name": None}, |
|
"matcher": {"name": "nearest_neighbor_matcher"}, |
|
"filter": {"name": None}, |
|
"solver": {"name": None}, |
|
"ground_truth": { |
|
"from_pose_depth": False, |
|
"from_homography": False, |
|
"th_positive": 3, |
|
"th_negative": 5, |
|
"reward_positive": 1, |
|
"reward_negative": -0.25, |
|
"is_likelihood_soft": True, |
|
"p_random_occluders": 0, |
|
"n_line_sampled_pts": 50, |
|
"line_perp_dist_th": 5, |
|
"overlap_th": 0.2, |
|
"min_visibility_th": 0.5, |
|
}, |
|
} |
|
required_data_keys = ["image0", "image1"] |
|
strict_conf = False |
|
components = ["extractor", "detector", "descriptor", "matcher", "filter", "solver"] |
|
|
|
def _init(self, conf): |
|
if conf.extractor.name: |
|
self.extractor = get_model(conf.extractor.name)(conf.extractor) |
|
else: |
|
if self.conf.detector.name: |
|
self.detector = get_model(conf.detector.name)(conf.detector) |
|
else: |
|
self.required_data_keys += ["keypoints0", "keypoints1"] |
|
if self.conf.descriptor.name: |
|
self.descriptor = get_model(conf.descriptor.name)(conf.descriptor) |
|
else: |
|
self.required_data_keys += ["descriptors0", "descriptors1"] |
|
|
|
if conf.matcher.name: |
|
self.matcher = get_model(conf.matcher.name)(conf.matcher) |
|
else: |
|
self.required_data_keys += ["matches0"] |
|
|
|
if conf.filter.name: |
|
self.filter = get_model(conf.filter.name)(conf.filter) |
|
|
|
if conf.solver.name: |
|
self.solver = get_model(conf.solver.name)(conf.solver) |
|
|
|
def _forward(self, data): |
|
def process_siamese(data, i): |
|
data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} |
|
if self.conf.extractor.name: |
|
pred_i = self.extractor(data_i) |
|
else: |
|
pred_i = {} |
|
if self.conf.detector.name: |
|
pred_i = self.detector(data_i) |
|
else: |
|
for k in [ |
|
"keypoints", |
|
"keypoint_scores", |
|
"descriptors", |
|
"lines", |
|
"line_scores", |
|
"line_descriptors", |
|
"valid_lines", |
|
]: |
|
if k in data_i: |
|
pred_i[k] = data_i[k] |
|
if self.conf.descriptor.name: |
|
pred_i = {**pred_i, **self.descriptor({**data_i, **pred_i})} |
|
return pred_i |
|
|
|
pred0 = process_siamese(data, "0") |
|
pred1 = process_siamese(data, "1") |
|
|
|
pred = { |
|
**{k + "0": v for k, v in pred0.items()}, |
|
**{k + "1": v for k, v in pred1.items()}, |
|
} |
|
|
|
if self.conf.matcher.name: |
|
pred = {**pred, **self.matcher({**data, **pred})} |
|
|
|
if self.conf.filter.name: |
|
pred = {**pred, **self.filter({**data, **pred})} |
|
|
|
if self.conf.solver.name: |
|
pred = {**pred, **self.solver({**data, **pred})} |
|
|
|
return pred |
|
|
|
def loss(self, pred, data): |
|
losses = {} |
|
total = 0 |
|
for k in self.components: |
|
if self.conf[k].name: |
|
try: |
|
losses_ = getattr(self, k).loss(pred, {**pred, **data}) |
|
except NotImplementedError: |
|
continue |
|
losses = {**losses, **losses_} |
|
total = losses_["total"] + total |
|
return {**losses, "total": total} |
|
|
|
def metrics(self, pred, data): |
|
metrics = {} |
|
for k in self.components: |
|
if self.conf[k].name: |
|
try: |
|
metrics_ = getattr(self, k).metrics(pred, {**pred, **data}) |
|
except NotImplementedError: |
|
continue |
|
metrics = {**metrics, **metrics_} |
|
return metrics |
|
|