""" Inference model of SuperPoint, a feature detector and descriptor. Described in: SuperPoint: Self-Supervised Interest Point Detection and Description, Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018. Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork """ import torch from torch import nn from .. import GLUESTICK_ROOT from ..models.base_model import BaseModel def simple_nms(scores, radius): """Perform non maximum suppression on the heatmap using max-pooling. This method does not suppress contiguous points that have the same score. Args: scores: the score heatmap of size `(B, H, W)`. size: an interger scalar, the radius of the NMS window. """ def max_pool(x): return torch.nn.functional.max_pool2d( x, kernel_size=radius * 2 + 1, stride=1, padding=radius ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) for _ in range(2): supp_mask = max_pool(max_mask.float()) > 0 supp_scores = torch.where(supp_mask, zeros, scores) new_max_mask = supp_scores == max_pool(supp_scores) max_mask = max_mask | (new_max_mask & (~supp_mask)) return torch.where(max_mask, scores, zeros) def remove_borders(keypoints, scores, b, h, w): mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b)) mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b)) mask = mask_h & mask_w return keypoints[mask], scores[mask] def top_k_keypoints(keypoints, scores, k): if k >= len(keypoints): return keypoints, scores scores, indices = torch.topk(scores, k, dim=0, sorted=True) return keypoints[indices], scores def sample_descriptors(keypoints, descriptors, s): b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],).to( keypoints )[None] keypoints = keypoints * 2 - 1 # normalize to (-1, 1) args = {"align_corners": True} if torch.__version__ >= "1.3" else {} descriptors = torch.nn.functional.grid_sample( descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args ) descriptors = torch.nn.functional.normalize( descriptors.reshape(b, c, -1), p=2, dim=1 ) return descriptors class SuperPoint(BaseModel): default_conf = { "has_detector": True, "has_descriptor": True, "descriptor_dim": 256, # Inference "return_all": False, "sparse_outputs": True, "nms_radius": 4, "detection_threshold": 0.005, "max_num_keypoints": -1, "force_num_keypoints": False, "remove_borders": 4, } required_data_keys = ["image"] def _init(self, conf): self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) if conf.has_detector: self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) if conf.has_descriptor: self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0 ) path = GLUESTICK_ROOT / "resources" / "weights" / "superpoint_v1.pth" self.load_state_dict(torch.load(str(path)), strict=False) def _forward(self, data): image = data["image"] if image.shape[1] == 3: # RGB scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) image = (image * scale).sum(1, keepdim=True) # Shared Encoder x = self.relu(self.conv1a(image)) x = self.relu(self.conv1b(x)) x = self.pool(x) x = self.relu(self.conv2a(x)) x = self.relu(self.conv2b(x)) x = self.pool(x) x = self.relu(self.conv3a(x)) x = self.relu(self.conv3b(x)) x = self.pool(x) x = self.relu(self.conv4a(x)) x = self.relu(self.conv4b(x)) pred = {} if self.conf.has_detector and self.conf.max_num_keypoints != 0: # Compute the dense keypoint scores cPa = self.relu(self.convPa(x)) scores = self.convPb(cPa) scores = torch.nn.functional.softmax(scores, 1)[:, :-1] b, c, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) pred["keypoint_scores"] = dense_scores = scores if self.conf.has_descriptor: # Compute the dense descriptors cDa = self.relu(self.convDa(x)) all_desc = self.convDb(cDa) all_desc = torch.nn.functional.normalize(all_desc, p=2, dim=1) pred["descriptors"] = all_desc if self.conf.max_num_keypoints == 0: # Predict dense descriptors only b_size = len(image) device = image.device return { "keypoints": torch.empty(b_size, 0, 2, device=device), "keypoint_scores": torch.empty(b_size, 0, device=device), "descriptors": torch.empty( b_size, self.conf.descriptor_dim, 0, device=device ), "all_descriptors": all_desc, } if self.conf.sparse_outputs: assert self.conf.has_detector and self.conf.has_descriptor scores = simple_nms(scores, self.conf.nms_radius) # Extract keypoints keypoints = [ torch.nonzero(s > self.conf.detection_threshold) for s in scores ] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders keypoints, scores = list( zip( *[ remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8) for k, s in zip(keypoints, scores) ] ) ) # Keep the k keypoints with highest score if self.conf.max_num_keypoints > 0: keypoints, scores = list( zip( *[ top_k_keypoints(k, s, self.conf.max_num_keypoints) for k, s in zip(keypoints, scores) ] ) ) # Convert (h, w) to (x, y) keypoints = [torch.flip(k, [1]).float() for k in keypoints] if self.conf.force_num_keypoints: _, _, h, w = data["image"].shape assert self.conf.max_num_keypoints > 0 scores = list(scores) for i in range(len(keypoints)): k, s = keypoints[i], scores[i] missing = self.conf.max_num_keypoints - len(k) if missing > 0: new_k = torch.rand(missing, 2).to(k) new_k = new_k * k.new_tensor([[w - 1, h - 1]]) new_s = torch.zeros(missing).to(s) keypoints[i] = torch.cat([k, new_k], 0) scores[i] = torch.cat([s, new_s], 0) # Extract descriptors desc = [ sample_descriptors(k[None], d[None], 8)[0] for k, d in zip(keypoints, all_desc) ] if (len(keypoints) == 1) or self.conf.force_num_keypoints: keypoints = torch.stack(keypoints, 0) scores = torch.stack(scores, 0) desc = torch.stack(desc, 0) pred = { "keypoints": keypoints, "keypoint_scores": scores, "descriptors": desc, } if self.conf.return_all: pred["all_descriptors"] = all_desc pred["dense_score"] = dense_scores else: del all_desc torch.cuda.empty_cache() return pred def loss(self, pred, data): raise NotImplementedError def metrics(self, pred, data): raise NotImplementedError