import torch from ...geometry.gt_generation import ( gt_line_matches_from_pose_depth, gt_matches_from_pose_depth, ) from ..base_model import BaseModel class DepthMatcher(BaseModel): default_conf = { # GT parameters for points "use_points": True, "th_positive": 3.0, "th_negative": 5.0, "th_epi": None, # add some more epi outliers "th_consistency": None, # check for projection consistency in px # GT parameters for lines "use_lines": False, "n_line_sampled_pts": 50, "line_perp_dist_th": 5, "overlap_th": 0.2, "min_visibility_th": 0.5, } required_data_keys = ["view0", "view1", "T_0to1", "T_1to0"] def _init(self, conf): # TODO (iago): Is this just boilerplate code? if self.conf.use_points: self.required_data_keys += ["keypoints0", "keypoints1"] if self.conf.use_lines: self.required_data_keys += [ "lines0", "lines1", "valid_lines0", "valid_lines1", ] @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def _forward(self, data): result = {} if self.conf.use_points: if "depth_keypoints0" in data: keys = [ "depth_keypoints0", "valid_depth_keypoints0", "depth_keypoints1", "valid_depth_keypoints1", ] kw = {k: data[k] for k in keys} else: kw = {} result = gt_matches_from_pose_depth( data["keypoints0"], data["keypoints1"], data, pos_th=self.conf.th_positive, neg_th=self.conf.th_negative, epi_th=self.conf.th_epi, cc_th=self.conf.th_consistency, **kw, ) if self.conf.use_lines: line_assignment, line_m0, line_m1 = gt_line_matches_from_pose_depth( data["lines0"], data["lines1"], data["valid_lines0"], data["valid_lines1"], data, self.conf.n_line_sampled_pts, self.conf.line_perp_dist_th, self.conf.overlap_th, self.conf.min_visibility_th, ) result["line_matches0"] = line_m0 result["line_matches1"] = line_m1 result["line_assignment"] = line_assignment return result def loss(self, pred, data): raise NotImplementedError