Spaces:
Running
Running
import sys | |
from pathlib import Path | |
from ..utils.base_model import BaseModel | |
from .. import logger | |
lightglue_path = Path(__file__).parent / "../../third_party/LightGlue" | |
sys.path.append(str(lightglue_path)) | |
from lightglue import LightGlue as LG | |
class LightGlue(BaseModel): | |
default_conf = { | |
"match_threshold": 0.2, | |
"filter_threshold": 0.2, | |
"width_confidence": 0.99, # for point pruning | |
"depth_confidence": 0.95, # for early stopping, | |
"features": "superpoint", | |
"model_name": "superpoint_lightglue.pth", | |
"flash": True, # enable FlashAttention if available. | |
"mp": False, # enable mixed precision | |
} | |
required_inputs = [ | |
"image0", | |
"keypoints0", | |
"scores0", | |
"descriptors0", | |
"image1", | |
"keypoints1", | |
"scores1", | |
"descriptors1", | |
] | |
def _init(self, conf): | |
weight_path = lightglue_path / "weights" / conf["model_name"] | |
conf["weights"] = str(weight_path) | |
conf["filter_threshold"] = conf["match_threshold"] | |
self.net = LG(**conf) | |
logger.info(f"Load lightglue model done.") | |
def _forward(self, data): | |
input = {} | |
input["image0"] = { | |
"image": data["image0"], | |
"keypoints": data["keypoints0"], | |
"descriptors": data["descriptors0"].permute(0, 2, 1), | |
} | |
input["image1"] = { | |
"image": data["image1"], | |
"keypoints": data["keypoints1"], | |
"descriptors": data["descriptors1"].permute(0, 2, 1), | |
} | |
return self.net(input) | |