""" A two-view sparse feature matching pipeline. This model contains sub-models for each step: feature extraction, feature matching, outlier filtering, pose estimation. Each step is optional, and the features or matches can be provided as input. Default: SuperPoint with nearest neighbor matching. Convention for the matches: m0[i] is the index of the keypoint in image 1 that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. """ from omegaconf import OmegaConf from . import get_model from .base_model import BaseModel to_ctr = OmegaConf.to_container # convert DictConfig to dict class TwoViewPipeline(BaseModel): default_conf = { "extractor": { "name": None, "trainable": False, }, "matcher": {"name": None}, "filter": {"name": None}, "solver": {"name": None}, "ground_truth": {"name": None}, "allow_no_extract": False, "run_gt_in_forward": False, } required_data_keys = ["view0", "view1"] strict_conf = False # need to pass new confs to children models components = [ "extractor", "matcher", "filter", "solver", "ground_truth", ] def _init(self, conf): if conf.extractor.name: self.extractor = get_model(conf.extractor.name)(to_ctr(conf.extractor)) if conf.matcher.name: self.matcher = get_model(conf.matcher.name)(to_ctr(conf.matcher)) if conf.filter.name: self.filter = get_model(conf.filter.name)(to_ctr(conf.filter)) if conf.solver.name: self.solver = get_model(conf.solver.name)(to_ctr(conf.solver)) if conf.ground_truth.name: self.ground_truth = get_model(conf.ground_truth.name)( to_ctr(conf.ground_truth) ) def extract_view(self, data, i): data_i = data[f"view{i}"] pred_i = data_i.get("cache", {}) skip_extract = len(pred_i) > 0 and self.conf.allow_no_extract if self.conf.extractor.name and not skip_extract: pred_i = {**pred_i, **self.extractor(data_i)} elif self.conf.extractor.name and not self.conf.allow_no_extract: pred_i = {**pred_i, **self.extractor({**data_i, **pred_i})} return pred_i def _forward(self, data): pred0 = self.extract_view(data, "0") pred1 = self.extract_view(data, "1") pred = { **{k + "0": v for k, v in pred0.items()}, **{k + "1": v for k, v in pred1.items()}, } if self.conf.matcher.name: pred = {**pred, **self.matcher({**data, **pred})} if self.conf.filter.name: pred = {**pred, **self.filter({**data, **pred})} if self.conf.solver.name: pred = {**pred, **self.solver({**data, **pred})} if self.conf.ground_truth.name and self.conf.run_gt_in_forward: gt_pred = self.ground_truth({**data, **pred}) pred.update({f"gt_{k}": v for k, v in gt_pred.items()}) return pred def loss(self, pred, data): losses = {} metrics = {} total = 0 # get labels if self.conf.ground_truth.name and not self.conf.run_gt_in_forward: gt_pred = self.ground_truth({**data, **pred}) pred.update({f"gt_{k}": v for k, v in gt_pred.items()}) for k in self.components: apply = True if "apply_loss" in self.conf[k].keys(): apply = self.conf[k].apply_loss if self.conf[k].name and apply: try: losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data}) except NotImplementedError: continue losses = {**losses, **losses_} metrics = {**metrics, **metrics_} total = losses_["total"] + total return {**losses, "total": total}, metrics