Vincentqyw
fix: roma
c74a070
raw
history blame
No virus
6.35 kB
"""
A two-view sparse feature matching pipeline.
This model contains sub-models for each step:
feature extraction, feature matching, outlier filtering, pose estimation.
Each step is optional, and the features or matches can be provided as input.
Default: SuperPoint with nearest neighbor matching.
Convention for the matches: m0[i] is the index of the keypoint in image 1
that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
"""
import numpy as np
import torch
from .. import get_model
from .base_model import BaseModel
def keep_quadrant_kp_subset(keypoints, scores, descs, h, w):
"""Keep only keypoints in one of the four quadrant of the image."""
h2, w2 = h // 2, w // 2
w_x = np.random.choice([0, w2])
w_y = np.random.choice([0, h2])
valid_mask = (
(keypoints[..., 0] >= w_x)
& (keypoints[..., 0] < w_x + w2)
& (keypoints[..., 1] >= w_y)
& (keypoints[..., 1] < w_y + h2)
)
keypoints = keypoints[valid_mask][None]
scores = scores[valid_mask][None]
descs = descs.permute(0, 2, 1)[valid_mask].t()[None]
return keypoints, scores, descs
def keep_random_kp_subset(keypoints, scores, descs, num_selected):
"""Keep a random subset of keypoints."""
num_kp = keypoints.shape[1]
selected_kp = torch.randperm(num_kp)[:num_selected]
keypoints = keypoints[:, selected_kp]
scores = scores[:, selected_kp]
descs = descs[:, :, selected_kp]
return keypoints, scores, descs
def keep_best_kp_subset(keypoints, scores, descs, num_selected):
"""Keep the top num_selected best keypoints."""
sorted_indices = torch.sort(scores, dim=1)[1]
selected_kp = sorted_indices[:, -num_selected:]
keypoints = torch.gather(keypoints, 1, selected_kp[:, :, None].repeat(1, 1, 2))
scores = torch.gather(scores, 1, selected_kp)
descs = torch.gather(descs, 2, selected_kp[:, None].repeat(1, descs.shape[1], 1))
return keypoints, scores, descs
class TwoViewPipeline(BaseModel):
default_conf = {
"extractor": {
"name": "superpoint",
"trainable": False,
},
"use_lines": False,
"use_points": True,
"randomize_num_kp": False,
"detector": {"name": None},
"descriptor": {"name": None},
"matcher": {"name": "nearest_neighbor_matcher"},
"filter": {"name": None},
"solver": {"name": None},
"ground_truth": {
"from_pose_depth": False,
"from_homography": False,
"th_positive": 3,
"th_negative": 5,
"reward_positive": 1,
"reward_negative": -0.25,
"is_likelihood_soft": True,
"p_random_occluders": 0,
"n_line_sampled_pts": 50,
"line_perp_dist_th": 5,
"overlap_th": 0.2,
"min_visibility_th": 0.5,
},
}
required_data_keys = ["image0", "image1"]
strict_conf = False # need to pass new confs to children models
components = ["extractor", "detector", "descriptor", "matcher", "filter", "solver"]
def _init(self, conf):
if conf.extractor.name:
self.extractor = get_model(conf.extractor.name)(conf.extractor)
else:
if self.conf.detector.name:
self.detector = get_model(conf.detector.name)(conf.detector)
else:
self.required_data_keys += ["keypoints0", "keypoints1"]
if self.conf.descriptor.name:
self.descriptor = get_model(conf.descriptor.name)(conf.descriptor)
else:
self.required_data_keys += ["descriptors0", "descriptors1"]
if conf.matcher.name:
self.matcher = get_model(conf.matcher.name)(conf.matcher)
else:
self.required_data_keys += ["matches0"]
if conf.filter.name:
self.filter = get_model(conf.filter.name)(conf.filter)
if conf.solver.name:
self.solver = get_model(conf.solver.name)(conf.solver)
def _forward(self, data):
def process_siamese(data, i):
data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i}
if self.conf.extractor.name:
pred_i = self.extractor(data_i)
else:
pred_i = {}
if self.conf.detector.name:
pred_i = self.detector(data_i)
else:
for k in [
"keypoints",
"keypoint_scores",
"descriptors",
"lines",
"line_scores",
"line_descriptors",
"valid_lines",
]:
if k in data_i:
pred_i[k] = data_i[k]
if self.conf.descriptor.name:
pred_i = {**pred_i, **self.descriptor({**data_i, **pred_i})}
return pred_i
pred0 = process_siamese(data, "0")
pred1 = process_siamese(data, "1")
pred = {
**{k + "0": v for k, v in pred0.items()},
**{k + "1": v for k, v in pred1.items()},
}
if self.conf.matcher.name:
pred = {**pred, **self.matcher({**data, **pred})}
if self.conf.filter.name:
pred = {**pred, **self.filter({**data, **pred})}
if self.conf.solver.name:
pred = {**pred, **self.solver({**data, **pred})}
return pred
def loss(self, pred, data):
losses = {}
total = 0
for k in self.components:
if self.conf[k].name:
try:
losses_ = getattr(self, k).loss(pred, {**pred, **data})
except NotImplementedError:
continue
losses = {**losses, **losses_}
total = losses_["total"] + total
return {**losses, "total": total}
def metrics(self, pred, data):
metrics = {}
for k in self.components:
if self.conf[k].name:
try:
metrics_ = getattr(self, k).metrics(pred, {**pred, **data})
except NotImplementedError:
continue
metrics = {**metrics, **metrics_}
return metrics