import torch from ..utils.base_model import BaseModel from kornia.feature.adalam import AdalamFilter from kornia.utils.helpers import get_cuda_device_if_available class AdaLAM(BaseModel): # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html. default_conf = { "area_ratio": 100, "search_expansion": 4, "ransac_iters": 128, "min_inliers": 6, "min_confidence": 200, "orientation_difference_threshold": 30, "scale_rate_threshold": 1.5, "detected_scale_rate_threshold": 5, "refit": True, "force_seed_mnn": True, "device": get_cuda_device_if_available(), } required_inputs = [ "image0", "image1", "descriptors0", "descriptors1", "keypoints0", "keypoints1", "scales0", "scales1", "oris0", "oris1", ] def _init(self, conf): self.adalam = AdalamFilter(conf) def _forward(self, data): assert data["keypoints0"].size(0) == 1 if data["keypoints0"].size(1) < 2 or data["keypoints1"].size(1) < 2: matches = torch.zeros( (0, 2), dtype=torch.int64, device=data["keypoints0"].device ) else: matches = self.adalam.match_and_filter( data["keypoints0"][0], data["keypoints1"][0], data["descriptors0"][0].T, data["descriptors1"][0].T, data["image0"].shape[2:], data["image1"].shape[2:], data["oris0"][0], data["oris1"][0], data["scales0"][0], data["scales1"][0], ) matches_new = torch.full( (data["keypoints0"].size(1),), -1, dtype=torch.int64, device=data["keypoints0"].device, ) matches_new[matches[:, 0]] = matches[:, 1] return { "matches0": matches_new.unsqueeze(0), "matching_scores0": torch.zeros(matches_new.size(0)).unsqueeze(0), }