|
import pytest |
|
import torch |
|
|
|
from kornia.feature import DescriptorMatcher, GFTTAffNetHardNet, LocalFeatureMatcher, SIFTFeature |
|
from kornia.geometry import resize, transform_points |
|
from kornia.testing import assert_close |
|
from kornia.tracking import HomographyTracker |
|
|
|
|
|
@pytest.fixture |
|
def data(): |
|
url = 'https://github.com/kornia/data_test/blob/main/loftr_outdoor_and_homography_data.pt?raw=true' |
|
return torch.hub.load_state_dict_from_url(url) |
|
|
|
|
|
class TestHomographyTracker: |
|
def test_smoke(self, device): |
|
tracker = HomographyTracker().to(device) |
|
assert tracker is not None |
|
|
|
def test_nomatch(self, device, dtype, data): |
|
|
|
matcher = LocalFeatureMatcher(SIFTFeature(100), DescriptorMatcher('smnn', 0.95)).to(device, dtype) |
|
tracker = HomographyTracker(matcher, matcher, minimum_inliers_num=100) |
|
for k in data.keys(): |
|
if isinstance(data[k], torch.Tensor): |
|
data[k] = data[k].to(device, dtype) |
|
tracker.set_target(data["image0"]) |
|
torch.random.manual_seed(0) |
|
_, success = tracker(torch.zeros_like(data["image0"])) |
|
assert not success |
|
|
|
def test_real(self, device, dtype, data): |
|
|
|
matcher = LocalFeatureMatcher(GFTTAffNetHardNet(1000), DescriptorMatcher('snn', 0.8)).to(device, dtype) |
|
tracker = HomographyTracker(matcher, matcher).to(device, dtype) |
|
for k in data.keys(): |
|
if isinstance(data[k], torch.Tensor): |
|
data[k] = data[k].to(device, dtype) |
|
h0, w0 = data["image0"].shape[2:] |
|
data["image0"] = resize(data["image0"], (int(h0 // 2), int(w0 // 2))) |
|
data["image1"] = resize(data["image1"], (int(h0 // 2), int(w0 // 2))) |
|
with torch.no_grad(): |
|
tracker.set_target(data["image0"]) |
|
torch.random.manual_seed(0) |
|
homography, success = tracker(data["image1"]) |
|
assert success |
|
pts_src = data['pts0'].to(device, dtype) / 2.0 |
|
pts_dst = data['pts1'].to(device, dtype) / 2.0 |
|
|
|
assert_close(transform_points(homography[None], pts_src[None]), pts_dst[None], rtol=5e-2, atol=5) |
|
|
|
with torch.no_grad(): |
|
torch.random.manual_seed(0) |
|
homography, success = tracker(data["image1"]) |
|
assert success |
|
assert_close(transform_points(homography[None], pts_src[None]), pts_dst[None], rtol=5e-2, atol=5) |
|
|