File size: 2,015 Bytes
9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import sys
from pathlib import Path
import subprocess
import logging
from ..utils.base_model import BaseModel
logger = logging.getLogger(__name__)
darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
sys.path.append(str(darkfeat_path))
from darkfeat import DarkFeat as DarkFeat_
class DarkFeat(BaseModel):
default_conf = {
"model_name": "DarkFeat.pth",
"max_keypoints": 1000,
"detection_threshold": 0.5,
"sub_pixel": False,
}
weight_urls = {
"DarkFeat.pth": "https://drive.google.com/uc?id=1Thl6m8NcmQ7zSAF-1_xaFs3F4H8UU6HX&confirm=t",
}
proxy = "http://localhost:1080"
required_inputs = ["image"]
def _init(self, conf):
model_path = darkfeat_path / "checkpoints" / conf["model_name"]
link = self.weight_urls[conf["model_name"]]
if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
logger.info(f"Downloading the DarkFeat model with `{cmd_wo_proxy}`.")
try:
subprocess.run(cmd_wo_proxy, check=True)
except subprocess.CalledProcessError as e:
logger.info(f"Downloading the DarkFeat model with `{cmd}`.")
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Failed to download the DarkFeat model.")
raise e
self.net = DarkFeat_(model_path)
def _forward(self, data):
pred = self.net({"image": data["image"]})
keypoints = pred["keypoints"]
descriptors = pred["descriptors"]
scores = pred["scores"]
return {
"keypoints": keypoints[None], # 1 x N x 2
"scores": scores[None], # 1 x N
"descriptors": descriptors[None], # 1 x 128 x N
}
|