CasP / demo.py
pq-chen's picture
update inference code and online demo
f831607
import argparse
from argparse import Namespace
from typing import Any, Dict, Optional, Tuple, Type
import cv2
import torch
from numpy import ndarray
from omegaconf import OmegaConf
from torch.nn import Module
from src.data.utils import load_image
from src.models.nets import CasP
from src.models.utils import make_matching_figure
matcher_configs = {
"casp_outdoor": {
"matcher": CasP,
"name": "casp",
"ckpt_path": "weights/casp_outdoor.pth",
},
"casp_minima": {
"matcher": CasP,
"name": "casp",
"ckpt_path": "weights/casp_minima.pth",
},
}
def load_matcher(
matcher: Type[Module],
name: str,
ckpt_path: str,
threshold: Optional[float] = None,
device: str = "cpu",
) -> Module:
config = OmegaConf.load(f"configs/model/net/{name}.yaml").config
if threshold is not None:
config.threshold = threshold
matcher = matcher(config)
matcher.load_state_dict(torch.load(ckpt_path))
matcher = matcher.eval().to(device)
return matcher
def parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="CasP")
parser.add_argument("--path0", type=str, required=True)
parser.add_argument("--path1", type=str, required=True)
parser.add_argument("--save_path", type=str, required=True)
parser.add_argument("--image_size", type=int, default=1152)
parser.add_argument(
"--method",
type=str,
default="casp_outdoor",
choices=["casp_outdoor", "casp_minima"],
)
parser.add_argument("--matching_threshold", type=float)
parser.add_argument(
"--ransac", type=str, choices=["fundamental", "homography"]
)
parser.add_argument(
"--estimator",
type=str,
default="CV2_USAC_MAGSAC",
choices=["CV2_RANSAC", "CV2_USAC_MAGSAC"],
)
parser.add_argument("--inlier_threshold", type=float, default=3.0)
args = parser.parse_args()
return args
def ransac_optimize(
points0: ndarray,
points1: ndarray,
model: str,
estimator: str,
threshold: float,
) -> Tuple[ndarray, ndarray]:
if model == "fundamental":
func = cv2.findFundamentalMat
elif model == "homography":
func = cv2.findHomography
else:
raise NotImplementedError()
if estimator == "CV2_RANSAC":
method = cv2.RANSAC
elif estimator == "CV2_USAC_MAGSAC":
method = cv2.USAC_MAGSAC
else:
raise NotImplementedError()
mat, inlier_mask = func(
points0,
points1,
method=method,
ransacReprojThreshold=threshold,
confidence=0.99999,
maxIters=10000,
)
return mat, inlier_mask
def main(
args: Dict[str, Any],
) -> Tuple[ndarray, ndarray, ndarray, Optional[ndarray]]:
device = "cuda" if torch.cuda.is_available() else "cpu"
matcher = load_matcher(
**matcher_configs[args["method"]],
threshold=args["matching_threshold"],
device=device,
)
data_configs = {
"mode": matcher.data_mode,
"size": args["image_size"],
"factor": matcher.data_factor,
}
image0, mask0, scale0 = load_image(args["path0"], **data_configs)
image1, mask1, scale1 = load_image(args["path1"], **data_configs)
if matcher.data_mode == "gray":
image0, image1 = image0[None] / 255.0, image1[None] / 255.0
elif matcher.data_mode == "color":
image0 = image0.transpose(2, 0, 1) / 255.0
image1 = image1.transpose(2, 0, 1) / 255.0
else:
raise ValueError()
data = {
"image0": image0[None],
"image1": image1[None],
"scale0": scale0[None],
"scale1": scale1[None],
}
if mask0 is not None:
data["mask0"] = mask0[None]
if mask1 is not None:
data["mask1"] = mask0[None]
for key, value in data.items():
data[key] = torch.from_numpy(value).float().to(device)
with torch.no_grad():
results = matcher(data)
points0 = results["points0"].cpu().numpy()
points1 = results["points1"].cpu().numpy()
scores = results["scores"].cpu().numpy()
inlier_mask = None
if args["ransac"] is not None:
_, inlier_mask = ransac_optimize(
points0,
points1,
args["ransac"],
args["estimator"],
args["inlier_threshold"],
)
inlier_mask = inlier_mask.ravel() == 1
return points0, points1, scores, inlier_mask
if __name__ == "__main__":
args = dict(vars(parse_args()))
points0, points1, scores, inlier_mask = main(args)
if inlier_mask is not None:
points0, points1, scores = [
t[inlier_mask] for t in [points0, points1, scores]
]
errors = 1 - scores
text = [args["method"], f"#matches: {len(points0)}"]
make_matching_figure(
args["path0"],
args["path1"],
points0,
points1,
errors,
0.5,
dpi=300,
save_path=args["save_path"],
)