import torch from torch import nn def simple_nms(scores, nms_radius): assert(nms_radius >= 0) def max_pool(x): return torch.nn.functional.max_pool2d( x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) for _ in range(2): supp_mask = max_pool(max_mask.float()) > 0 supp_scores = torch.where(supp_mask, zeros, scores) new_max_mask = supp_scores == max_pool(supp_scores) max_mask = max_mask | (new_max_mask & (~supp_mask)) return torch.where(max_mask, scores, zeros) def remove_borders(keypoints, scores, b, h, w): mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b)) mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b)) mask = mask_h & mask_w return keypoints[mask], scores[mask] def top_k_keypoints(keypoints, scores, k): if k >= len(keypoints): return keypoints, scores scores, indices = torch.topk(scores, k, dim=0) return keypoints[indices], scores def sample_descriptors(keypoints, descriptors, s): b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], ).to(keypoints)[None] keypoints = keypoints*2 - 1 # normalize to (-1, 1) args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} descriptors = torch.nn.functional.grid_sample( descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) descriptors = torch.nn.functional.normalize( descriptors.reshape(b, c, -1), p=2, dim=1) return descriptors class SuperPoint(nn.Module): def __init__(self, config): super().__init__() self.config = {**config} self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.convDb = nn.Conv2d( c5, self.config['descriptor_dim'], kernel_size=1, stride=1, padding=0) self.load_state_dict(torch.load(config['model_path'])) mk = self.config['max_keypoints'] if mk == 0 or mk < -1: raise ValueError('\"max_keypoints\" must be positive or \"-1\"') print('Loaded SuperPoint model') def forward(self, data): # Shared Encoder x = self.relu(self.conv1a(data)) x = self.relu(self.conv1b(x)) x = self.pool(x) x = self.relu(self.conv2a(x)) x = self.relu(self.conv2b(x)) x = self.pool(x) x = self.relu(self.conv3a(x)) x = self.relu(self.conv3b(x)) x = self.pool(x) x = self.relu(self.conv4a(x)) x = self.relu(self.conv4b(x)) # Compute the dense keypoint scores cPa = self.relu(self.convPa(x)) scores = self.convPb(cPa) scores = torch.nn.functional.softmax(scores, 1)[:, :-1] b, c, h, w = scores.shape scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) scores = simple_nms(scores, self.config['nms_radius']) # Extract keypoints keypoints = [ torch.nonzero(s > self.config['detection_threshold']) for s in scores] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders keypoints, scores = list(zip(*[ remove_borders(k, s, self.config['remove_borders'], h*8, w*8) for k, s in zip(keypoints, scores)])) # Keep the k keypoints with highest score if self.config['max_keypoints'] >= 0: keypoints, scores = list(zip(*[ top_k_keypoints(k, s, self.config['max_keypoints']) for k, s in zip(keypoints, scores)])) # Convert (h, w) to (x, y) keypoints = [torch.flip(k, [1]).float() for k in keypoints] # Compute the dense descriptors cDa = self.relu(self.convDa(x)) descriptors = self.convDb(cDa) descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) # Extract descriptors descriptors = [sample_descriptors(k[None], d[None], 8)[0] for k, d in zip(keypoints, descriptors)] return { 'keypoints': keypoints, 'scores': scores, 'descriptors': descriptors, }