|
import sys |
|
|
|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
from torch.autograd import gradcheck |
|
|
|
import kornia |
|
import kornia.testing as utils |
|
from kornia.feature import ( |
|
DescriptorMatcher, |
|
extract_patches_from_pyramid, |
|
get_laf_descriptors, |
|
GFTTAffNetHardNet, |
|
LAFDescriptor, |
|
LocalFeature, |
|
ScaleSpaceDetector, |
|
SIFTDescriptor, |
|
SIFTFeature, |
|
) |
|
from kornia.feature.integrated import LocalFeatureMatcher |
|
from kornia.geometry import RANSAC, resize, transform_points |
|
from kornia.testing import assert_close |
|
|
|
|
|
class TestGetLAFDescriptors: |
|
def test_same(self, device, dtype): |
|
B, C, H, W = 1, 3, 64, 64 |
|
PS = 16 |
|
img = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
img_gray = kornia.color.rgb_to_grayscale(img) |
|
centers = torch.tensor([[H / 3.0, W / 3.0], [2.0 * H / 3.0, W / 2.0]], device=device, dtype=dtype).view(1, 2, 2) |
|
scales = torch.tensor([(H + W) / 4.0, (H + W) / 8.0], device=device, dtype=dtype).view(1, 2, 1, 1) |
|
ori = torch.tensor([0.0, 30.0], device=device, dtype=dtype).view(1, 2, 1) |
|
lafs = kornia.feature.laf_from_center_scale_ori(centers, scales, ori) |
|
sift = SIFTDescriptor(PS).to(device, dtype) |
|
descs_test_from_rgb = get_laf_descriptors(img, lafs, sift, PS, True) |
|
descs_test_from_gray = get_laf_descriptors(img_gray, lafs, sift, PS, True) |
|
|
|
patches = extract_patches_from_pyramid(img_gray, lafs, PS) |
|
B1, N1, CH1, H1, W1 = patches.size() |
|
|
|
|
|
descs_reference = sift(patches.view(B1 * N1, CH1, H1, W1)).view(B1, N1, -1) |
|
assert_close(descs_test_from_rgb, descs_reference) |
|
assert_close(descs_test_from_gray, descs_reference) |
|
|
|
def test_gradcheck(self, device, dtype=torch.float64): |
|
B, C, H, W = 1, 1, 32, 32 |
|
PS = 16 |
|
img = torch.rand(B, C, H, W, device=device) |
|
centers = torch.tensor([[H / 3.0, W / 3.0], [2.0 * H / 3.0, W / 2.0]], device=device, dtype=dtype).view( |
|
1, 2, 2 |
|
) |
|
scales = torch.tensor([(H + W) / 4.0, (H + W) / 8.0], device=device, dtype=dtype).view(1, 2, 1, 1) |
|
ori = torch.tensor([0.0, 30.0], device=device, dtype=dtype).view(1, 2, 1) |
|
lafs = kornia.feature.laf_from_center_scale_ori(centers, scales, ori) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
lafs = utils.tensor_to_gradcheck_var(lafs) |
|
|
|
class _MeanPatch(nn.Module): |
|
def forward(self, inputs): |
|
return inputs.mean(dim=(2, 3)) |
|
|
|
desc = _MeanPatch() |
|
img = utils.tensor_to_gradcheck_var(img) |
|
assert gradcheck(get_laf_descriptors, (img, lafs, desc, PS, True), |
|
eps=1e-3, atol=1e-3, raise_exception=True, nondet_tol=1e-3) |
|
|
|
|
|
class TestLAFDescriptor: |
|
def test_same(self, device, dtype): |
|
B, C, H, W = 1, 3, 64, 64 |
|
PS = 16 |
|
img = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
img_gray = kornia.color.rgb_to_grayscale(img) |
|
centers = torch.tensor([[H / 3.0, W / 3.0], [2.0 * H / 3.0, W / 2.0]], device=device, dtype=dtype).view(1, 2, 2) |
|
scales = torch.tensor([(H + W) / 4.0, (H + W) / 8.0], device=device, dtype=dtype).view(1, 2, 1, 1) |
|
ori = torch.tensor([0.0, 30.0], device=device, dtype=dtype).view(1, 2, 1) |
|
lafs = kornia.feature.laf_from_center_scale_ori(centers, scales, ori) |
|
sift = SIFTDescriptor(PS).to(device, dtype) |
|
lafsift = LAFDescriptor(sift, PS) |
|
descs_test = lafsift(img, lafs) |
|
patches = extract_patches_from_pyramid(img_gray, lafs, PS) |
|
B1, N1, CH1, H1, W1 = patches.size() |
|
|
|
|
|
descs_reference = sift(patches.view(B1 * N1, CH1, H1, W1)).view(B1, N1, -1) |
|
assert_close(descs_test, descs_reference) |
|
|
|
def test_gradcheck(self, device, dtype=torch.float64): |
|
B, C, H, W = 1, 1, 32, 32 |
|
PS = 16 |
|
img = torch.rand(B, C, H, W, device=device) |
|
centers = torch.tensor([[H / 3.0, W / 3.0], [2.0 * H / 3.0, W / 2.0]], device=device, dtype=dtype).view(1, 2, 2) |
|
scales = torch.tensor([(H + W) / 4.0, (H + W) / 8.0], device=device, dtype=dtype).view(1, 2, 1, 1) |
|
ori = torch.tensor([0.0, 30.0], device=device, dtype=dtype).view(1, 2, 1) |
|
lafs = kornia.feature.laf_from_center_scale_ori(centers, scales, ori) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
lafs = utils.tensor_to_gradcheck_var(lafs) |
|
|
|
class _MeanPatch(nn.Module): |
|
def forward(self, inputs): |
|
return inputs.mean(dim=(2, 3)) |
|
|
|
lafdesc = LAFDescriptor(_MeanPatch(), PS) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
assert gradcheck(lafdesc, (img, lafs), eps=1e-3, atol=1e-3, raise_exception=True, nondet_tol=1e-3) |
|
|
|
|
|
class TestLocalFeature: |
|
def test_smoke(self, device, dtype): |
|
det = ScaleSpaceDetector(10) |
|
desc = SIFTDescriptor(32) |
|
local_feature = LocalFeature(det, desc).to(device, dtype) |
|
assert local_feature is not None |
|
|
|
def test_same(self, device, dtype): |
|
B, C, H, W = 1, 1, 64, 64 |
|
PS = 16 |
|
img = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
det = ScaleSpaceDetector(10) |
|
desc = SIFTDescriptor(PS) |
|
local_feature = LocalFeature(det, LAFDescriptor(desc, PS)).to(device, dtype) |
|
lafs, responses, descs = local_feature(img) |
|
lafs1, responses1 = det(img) |
|
assert_close(lafs, lafs1) |
|
assert_close(responses, responses1) |
|
patches = extract_patches_from_pyramid(img, lafs1, PS) |
|
B1, N1, CH1, H1, W1 = patches.size() |
|
|
|
|
|
descs1 = desc(patches.view(B1 * N1, CH1, H1, W1)).view(B1, N1, -1) |
|
assert_close(descs, descs1) |
|
|
|
@pytest.mark.skip("Takes too long time (but works)") |
|
def test_gradcheck(self, device): |
|
B, C, H, W = 1, 1, 32, 32 |
|
PS = 16 |
|
img = torch.rand(B, C, H, W, device=device) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
local_feature = LocalFeature(ScaleSpaceDetector(2), LAFDescriptor(SIFTDescriptor(PS), PS)).to(device, img.dtype) |
|
assert gradcheck(local_feature, img, eps=1e-4, atol=1e-4, raise_exception=True) |
|
|
|
|
|
class TestSIFTFeature: |
|
|
|
def test_smoke(self, device, dtype): |
|
sift = SIFTFeature() |
|
assert sift is not None |
|
|
|
@pytest.mark.skip("jacobian not well computed") |
|
def test_gradcheck(self, device): |
|
B, C, H, W = 1, 1, 32, 32 |
|
img = torch.rand(B, C, H, W, device=device) |
|
local_feature = SIFTFeature(2, True).to(device).to(device) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
assert gradcheck(local_feature, img, eps=1e-4, atol=1e-4, raise_exception=True) |
|
|
|
|
|
class TestGFTTAffNetHardNet: |
|
|
|
def test_smoke(self, device, dtype): |
|
feat = GFTTAffNetHardNet().to(device, dtype) |
|
assert feat is not None |
|
|
|
@pytest.mark.skip("jacobian not well computed") |
|
def test_gradcheck(self, device): |
|
B, C, H, W = 1, 1, 32, 32 |
|
img = torch.rand(B, C, H, W, device=device) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
local_feature = GFTTAffNetHardNet(2, True).to(device, img.dtype) |
|
assert gradcheck(local_feature, img, eps=1e-4, atol=1e-4, raise_exception=True) |
|
|
|
|
|
class TestLocalFeatureMatcher: |
|
def test_smoke(self, device): |
|
matcher = LocalFeatureMatcher(SIFTFeature(5), DescriptorMatcher('snn', 0.8)).to(device) |
|
assert matcher is not None |
|
|
|
@pytest.mark.parametrize("data", ["loftr_homo"], indirect=True) |
|
def test_nomatch(self, device, dtype, data): |
|
matcher = LocalFeatureMatcher(GFTTAffNetHardNet(100), DescriptorMatcher('snn', 0.8)).to(device, dtype) |
|
data_dev = utils.dict_to(data, device, dtype) |
|
with torch.no_grad(): |
|
out = matcher({"image0": data_dev["image0"], "image1": 0 * data_dev["image0"]}) |
|
assert len(out['keypoints0']) == 0 |
|
|
|
@pytest.mark.skip("Takes too long time (but works)") |
|
def test_gradcheck(self, device): |
|
matcher = LocalFeatureMatcher(SIFTFeature(5), DescriptorMatcher('nn', 1.0)).to(device) |
|
patches = torch.rand(1, 1, 32, 32, device=device) |
|
patches05 = resize(patches, (48, 48)) |
|
patches = utils.tensor_to_gradcheck_var(patches) |
|
patches05 = utils.tensor_to_gradcheck_var(patches05) |
|
|
|
def proxy_forward(x, y): |
|
return matcher({"image0": x, "image1": y})["keypoints0"] |
|
|
|
assert gradcheck(proxy_forward, (patches, patches05), eps=1e-4, atol=1e-4, raise_exception=True) |
|
|
|
@pytest.mark.parametrize("data", ["loftr_homo"], indirect=True) |
|
def test_real_sift(self, device, dtype, data): |
|
torch.random.manual_seed(0) |
|
|
|
matcher = LocalFeatureMatcher(SIFTFeature(2000), DescriptorMatcher('snn', 0.8)).to(device, dtype) |
|
ransac = RANSAC('homography', 1.0, 2048, 10).to(device, dtype) |
|
data_dev = utils.dict_to(data, device, dtype) |
|
pts_src = data_dev['pts0'] |
|
pts_dst = data_dev['pts1'] |
|
with torch.no_grad(): |
|
out = matcher(data_dev) |
|
homography, inliers = ransac(out['keypoints0'], out['keypoints1']) |
|
assert inliers.sum().item() > 50 |
|
|
|
assert_close(transform_points(homography[None], pts_src[None]), pts_dst[None], rtol=5e-2, atol=5) |
|
|
|
@pytest.mark.parametrize("data", ["loftr_homo"], indirect=True) |
|
def test_real_sift_preextract(self, device, dtype, data): |
|
torch.random.manual_seed(0) |
|
|
|
feat = SIFTFeature(2000) |
|
matcher = LocalFeatureMatcher(feat, DescriptorMatcher('snn', 0.8)).to(device) |
|
ransac = RANSAC('homography', 1.0, 2048, 10).to(device, dtype) |
|
data_dev = utils.dict_to(data, device, dtype) |
|
pts_src = data_dev['pts0'] |
|
pts_dst = data_dev['pts1'] |
|
|
|
lafs, _, descs = feat(data_dev["image0"]) |
|
data_dev["lafs0"] = lafs |
|
data_dev["descriptors0"] = descs |
|
|
|
lafs2, _, descs2 = feat(data_dev["image1"]) |
|
data_dev["lafs1"] = lafs2 |
|
data_dev["descriptors1"] = descs2 |
|
|
|
with torch.no_grad(): |
|
out = matcher(data_dev) |
|
homography, inliers = ransac(out['keypoints0'], out['keypoints1']) |
|
assert inliers.sum().item() > 50 |
|
|
|
assert_close(transform_points(homography[None], pts_src[None]), pts_dst[None], rtol=5e-2, atol=5) |
|
|
|
@pytest.mark.skipif(sys.platform == "win32", |
|
reason="this test takes so much memory in the CI with Windows") |
|
@pytest.mark.parametrize("data", ["loftr_homo"], indirect=True) |
|
def test_real_gftt(self, device, dtype, data): |
|
torch.random.manual_seed(0) |
|
|
|
matcher = LocalFeatureMatcher(GFTTAffNetHardNet(2000), DescriptorMatcher('snn', 0.8)).to(device, dtype) |
|
ransac = RANSAC('homography', 1.0, 2048, 10).to(device, dtype) |
|
data_dev = utils.dict_to(data, device, dtype) |
|
pts_src = data_dev['pts0'] |
|
pts_dst = data_dev['pts1'] |
|
with torch.no_grad(): |
|
out = matcher(data_dev) |
|
homography, inliers = ransac(out['keypoints0'], out['keypoints1']) |
|
assert inliers.sum().item() > 50 |
|
|
|
assert_close(transform_points(homography[None], pts_src[None]), pts_dst[None], rtol=5e-2, atol=5) |
|
|
|
@pytest.mark.skip("ScaleSpaceDetector now is not jittable") |
|
def test_jit(self, device, dtype): |
|
B, C, H, W = 1, 1, 32, 32 |
|
patches = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
patches2x = resize(patches, (48, 48)) |
|
inputs = {"image0": patches, "image1": patches2x} |
|
model = LocalFeatureMatcher(SIFTDescriptor(32), DescriptorMatcher('snn', 0.8)).to(device).eval() |
|
model_jit = torch.jit.script(model) |
|
|
|
out = model(inputs) |
|
out_jit = model_jit(inputs) |
|
for k, v in out.items(): |
|
assert_close(v, out_jit[k]) |
|
|