File size: 1,751 Bytes
c705408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import kornia
import torch
from .utils import Extractor
class DISK(Extractor):
default_conf = {
"weights": "depth",
"max_num_keypoints": None,
"desc_dim": 128,
"nms_window_size": 5,
"detection_threshold": 0.0,
"pad_if_not_divisible": True,
}
preprocess_conf = {
"resize": 1024,
"grayscale": False,
}
required_data_keys = ["image"]
def __init__(self, **conf) -> None:
super().__init__(**conf) # Update with default configuration.
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
def forward(self, data: dict) -> dict:
"""Compute keypoints, scores, descriptors for image"""
for key in self.required_data_keys:
assert key in data, f"Missing key {key} in data"
image = data["image"]
if image.shape[1] == 1:
image = kornia.color.grayscale_to_rgb(image)
features = self.model(
image,
n=self.conf.max_num_keypoints,
window_size=self.conf.nms_window_size,
score_threshold=self.conf.detection_threshold,
pad_if_not_divisible=self.conf.pad_if_not_divisible,
)
keypoints = [f.keypoints for f in features]
scores = [f.detection_scores for f in features]
descriptors = [f.descriptors for f in features]
del features
keypoints = torch.stack(keypoints, 0)
scores = torch.stack(scores, 0)
descriptors = torch.stack(descriptors, 0)
return {
"keypoints": keypoints.to(image).contiguous(),
"keypoint_scores": scores.to(image).contiguous(),
"descriptors": descriptors.to(image).contiguous(),
}
|