Spaces:
Running
Running
import sys | |
from pathlib import Path | |
import torchvision.transforms as tvf | |
from ..utils.base_model import BaseModel | |
base_path = Path(__file__).parent / "../../third_party" | |
sys.path.append(str(base_path)) | |
r2d2_path = Path(__file__).parent / "../../third_party/r2d2" | |
from r2d2.extract import load_network, NonMaxSuppression, extract_multiscale | |
class R2D2(BaseModel): | |
default_conf = { | |
"model_name": "r2d2_WASF_N16.pt", | |
"max_keypoints": 5000, | |
"scale_factor": 2**0.25, | |
"min_size": 256, | |
"max_size": 1024, | |
"min_scale": 0, | |
"max_scale": 1, | |
"reliability_threshold": 0.7, | |
"repetability_threshold": 0.7, | |
} | |
required_inputs = ["image"] | |
def _init(self, conf): | |
model_fn = r2d2_path / "models" / conf["model_name"] | |
self.norm_rgb = tvf.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
) | |
self.net = load_network(model_fn) | |
self.detector = NonMaxSuppression( | |
rel_thr=conf["reliability_threshold"], | |
rep_thr=conf["repetability_threshold"], | |
) | |
def _forward(self, data): | |
img = data["image"] | |
img = self.norm_rgb(img) | |
xys, desc, scores = extract_multiscale( | |
self.net, | |
img, | |
self.detector, | |
scale_f=self.conf["scale_factor"], | |
min_size=self.conf["min_size"], | |
max_size=self.conf["max_size"], | |
min_scale=self.conf["min_scale"], | |
max_scale=self.conf["max_scale"], | |
) | |
idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] | |
xy = xys[idxs, :2] | |
desc = desc[idxs].t() | |
scores = scores[idxs] | |
pred = { | |
"keypoints": xy[None], | |
"descriptors": desc[None], | |
"scores": scores[None], | |
} | |
return pred | |