|
from unittest.mock import patch, PropertyMock |
|
|
|
import pytest |
|
import torch |
|
from torch.autograd import gradcheck |
|
|
|
import kornia |
|
import kornia.testing as utils |
|
from kornia.testing import assert_close |
|
from packaging import version |
|
|
|
|
|
class TestVisionTransformer: |
|
@pytest.mark.parametrize("B", [1, 2]) |
|
@pytest.mark.parametrize("H", [1, 3, 8]) |
|
@pytest.mark.parametrize("D", [128, 768]) |
|
@pytest.mark.parametrize("image_size", [32, 224]) |
|
def test_smoke(self, device, dtype, B, H, D, image_size): |
|
patch_size = 16 |
|
T = image_size ** 2 // patch_size ** 2 + 1 |
|
|
|
img = torch.rand(B, 3, image_size, image_size, device=device, dtype=dtype) |
|
vit = kornia.contrib.VisionTransformer(image_size=image_size, num_heads=H, embed_dim=D).to(device, dtype) |
|
|
|
out = vit(img) |
|
assert isinstance(out, torch.Tensor) and out.shape == (B, T, D) |
|
|
|
feats = vit.encoder_results |
|
assert isinstance(feats, list) and len(feats) == 12 |
|
for f in feats: |
|
assert f.shape == (B, T, D) |
|
|
|
def test_backbone(self, device, dtype): |
|
def backbone_mock(x): |
|
return torch.ones(1, 128, 14, 14, device=device, dtype=dtype) |
|
|
|
img = torch.rand(1, 3, 32, 32, device=device, dtype=dtype) |
|
vit = kornia.contrib.VisionTransformer(backbone=backbone_mock).to(device, dtype) |
|
out = vit(img) |
|
assert out.shape == (1, 197, 128) |
|
|
|
|
|
class TestMobileViT: |
|
@pytest.mark.parametrize("B", [1, 2]) |
|
@pytest.mark.parametrize("image_size", [(256, 256)]) |
|
@pytest.mark.parametrize("mode", ['xxs', 'xs', 's']) |
|
@pytest.mark.parametrize("patch_size", [(2, 2)]) |
|
def test_smoke(self, device, dtype, B, image_size, mode, patch_size): |
|
ih, iw = image_size |
|
channel = {'xxs': 320, 'xs': 384, 's': 640} |
|
|
|
img = torch.rand(B, 3, ih, iw, device=device, dtype=dtype) |
|
mvit = kornia.contrib.MobileViT(mode=mode, patch_size=patch_size).to(device, dtype) |
|
|
|
out = mvit(img) |
|
assert isinstance(out, torch.Tensor) and out.shape == (B, channel[mode], 8, 8) |
|
|
|
|
|
class TestClassificationHead: |
|
@pytest.mark.parametrize("B, D, N", [(1, 8, 10), (2, 2, 5)]) |
|
def test_smoke(self, device, dtype, B, D, N): |
|
feat = torch.rand(B, D, D, device=device, dtype=dtype) |
|
head = kornia.contrib.ClassificationHead(embed_size=D, num_classes=N).to(device, dtype) |
|
logits = head(feat) |
|
assert logits.shape == (B, N) |
|
|
|
|
|
class TestConnectedComponents: |
|
def test_smoke(self, device, dtype): |
|
img = torch.rand(1, 1, 3, 4, device=device, dtype=dtype) |
|
out = kornia.contrib.connected_components(img, num_iterations=10) |
|
assert out.shape == (1, 1, 3, 4) |
|
|
|
@pytest.mark.parametrize("shape", [(1, 3, 4), (2, 1, 3, 4)]) |
|
def test_cardinality(self, device, dtype, shape): |
|
img = torch.rand(shape, device=device, dtype=dtype) |
|
out = kornia.contrib.connected_components(img, num_iterations=10) |
|
assert out.shape == shape |
|
|
|
def test_exception(self, device, dtype): |
|
img = torch.rand(1, 1, 3, 4, device=device, dtype=dtype) |
|
|
|
with pytest.raises(TypeError): |
|
assert kornia.contrib.connected_components(img, 1.0) |
|
|
|
with pytest.raises(TypeError): |
|
assert kornia.contrib.connected_components(img, 0) |
|
|
|
with pytest.raises(ValueError): |
|
img = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
assert kornia.contrib.connected_components(img, 2) |
|
|
|
def test_value(self, device, dtype): |
|
img = torch.tensor( |
|
[ |
|
[ |
|
[ |
|
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
[0.0, 1.0, 1.0, 0.0, 0.0, 1.0], |
|
[0.0, 1.0, 1.0, 0.0, 0.0, 0.0], |
|
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
[0.0, 0.0, 0.0, 1.0, 1.0, 0.0], |
|
[0.0, 0.0, 0.0, 1.0, 1.0, 0.0], |
|
] |
|
] |
|
], |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
expected = torch.tensor( |
|
[ |
|
[ |
|
[ |
|
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
[0.0, 14.0, 14.0, 0.0, 0.0, 11.0], |
|
[0.0, 14.0, 14.0, 0.0, 0.0, 0.0], |
|
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
[0.0, 0.0, 0.0, 34.0, 34.0, 0.0], |
|
[0.0, 0.0, 0.0, 34.0, 34.0, 0.0], |
|
] |
|
] |
|
], |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
out = kornia.contrib.connected_components(img, num_iterations=10) |
|
assert_close(out, expected) |
|
|
|
@pytest.mark.skipif( |
|
version.parse(torch.__version__) < version.parse("1.9"), reason="Tuple cannot be used with PyTorch < v1.9" |
|
) |
|
def test_gradcheck(self, device, dtype): |
|
B, C, H, W = 2, 1, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True) |
|
assert gradcheck(kornia.contrib.connected_components, (img,), raise_exception=True) |
|
|
|
def test_jit(self, device, dtype): |
|
B, C, H, W = 2, 1, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=dtype) |
|
op = kornia.contrib.connected_components |
|
op_jit = torch.jit.script(op) |
|
assert_close(op(img), op_jit(img)) |
|
|
|
|
|
class TestExtractTensorPatches: |
|
def test_smoke(self, device): |
|
img = torch.arange(16.0, device=device).view(1, 1, 4, 4) |
|
m = kornia.contrib.ExtractTensorPatches(3) |
|
assert m(img).shape == (1, 4, 1, 3, 3) |
|
|
|
def test_b1_ch1_h4w4_ws3(self, device): |
|
img = torch.arange(16.0, device=device).view(1, 1, 4, 4) |
|
m = kornia.contrib.ExtractTensorPatches(3) |
|
patches = m(img) |
|
assert patches.shape == (1, 4, 1, 3, 3) |
|
assert_close(img[0, :, :3, :3], patches[0, 0]) |
|
assert_close(img[0, :, :3, 1:], patches[0, 1]) |
|
assert_close(img[0, :, 1:, :3], patches[0, 2]) |
|
assert_close(img[0, :, 1:, 1:], patches[0, 3]) |
|
|
|
def test_b1_ch2_h4w4_ws3(self, device): |
|
img = torch.arange(16.0, device=device).view(1, 1, 4, 4) |
|
img = img.expand(-1, 2, -1, -1) |
|
m = kornia.contrib.ExtractTensorPatches(3) |
|
patches = m(img) |
|
assert patches.shape == (1, 4, 2, 3, 3) |
|
assert_close(img[0, :, :3, :3], patches[0, 0]) |
|
assert_close(img[0, :, :3, 1:], patches[0, 1]) |
|
assert_close(img[0, :, 1:, :3], patches[0, 2]) |
|
assert_close(img[0, :, 1:, 1:], patches[0, 3]) |
|
|
|
def test_b1_ch1_h4w4_ws2(self, device): |
|
img = torch.arange(16.0, device=device).view(1, 1, 4, 4) |
|
m = kornia.contrib.ExtractTensorPatches(2) |
|
patches = m(img) |
|
assert patches.shape == (1, 9, 1, 2, 2) |
|
assert_close(img[0, :, 0:2, 1:3], patches[0, 1]) |
|
assert_close(img[0, :, 0:2, 2:4], patches[0, 2]) |
|
assert_close(img[0, :, 1:3, 1:3], patches[0, 4]) |
|
assert_close(img[0, :, 2:4, 1:3], patches[0, 7]) |
|
|
|
def test_b1_ch1_h4w4_ws2_stride2(self, device): |
|
img = torch.arange(16.0, device=device).view(1, 1, 4, 4) |
|
m = kornia.contrib.ExtractTensorPatches(2, stride=2) |
|
patches = m(img) |
|
assert patches.shape == (1, 4, 1, 2, 2) |
|
assert_close(img[0, :, 0:2, 0:2], patches[0, 0]) |
|
assert_close(img[0, :, 0:2, 2:4], patches[0, 1]) |
|
assert_close(img[0, :, 2:4, 0:2], patches[0, 2]) |
|
assert_close(img[0, :, 2:4, 2:4], patches[0, 3]) |
|
|
|
def test_b1_ch1_h4w4_ws2_stride21(self, device): |
|
img = torch.arange(16.0, device=device).view(1, 1, 4, 4) |
|
m = kornia.contrib.ExtractTensorPatches(2, stride=(2, 1)) |
|
patches = m(img) |
|
assert patches.shape == (1, 6, 1, 2, 2) |
|
assert_close(img[0, :, 0:2, 1:3], patches[0, 1]) |
|
assert_close(img[0, :, 0:2, 2:4], patches[0, 2]) |
|
assert_close(img[0, :, 2:4, 0:2], patches[0, 3]) |
|
assert_close(img[0, :, 2:4, 2:4], patches[0, 5]) |
|
|
|
def test_b1_ch1_h3w3_ws2_stride1_padding1(self, device): |
|
img = torch.arange(9.0).view(1, 1, 3, 3).to(device) |
|
m = kornia.contrib.ExtractTensorPatches(2, stride=1, padding=1) |
|
patches = m(img) |
|
assert patches.shape == (1, 16, 1, 2, 2) |
|
assert_close(img[0, :, 0:2, 0:2], patches[0, 5]) |
|
assert_close(img[0, :, 0:2, 1:3], patches[0, 6]) |
|
assert_close(img[0, :, 1:3, 0:2], patches[0, 9]) |
|
assert_close(img[0, :, 1:3, 1:3], patches[0, 10]) |
|
|
|
def test_b2_ch1_h3w3_ws2_stride1_padding1(self, device): |
|
batch_size = 2 |
|
img = torch.arange(9.0).view(1, 1, 3, 3).to(device) |
|
img = img.expand(batch_size, -1, -1, -1) |
|
m = kornia.contrib.ExtractTensorPatches(2, stride=1, padding=1) |
|
patches = m(img) |
|
assert patches.shape == (batch_size, 16, 1, 2, 2) |
|
for i in range(batch_size): |
|
assert_close(img[i, :, 0:2, 0:2], patches[i, 5]) |
|
assert_close(img[i, :, 0:2, 1:3], patches[i, 6]) |
|
assert_close(img[i, :, 1:3, 0:2], patches[i, 9]) |
|
assert_close(img[i, :, 1:3, 1:3], patches[i, 10]) |
|
|
|
def test_b1_ch1_h3w3_ws23(self, device): |
|
img = torch.arange(9.0).view(1, 1, 3, 3).to(device) |
|
m = kornia.contrib.ExtractTensorPatches((2, 3)) |
|
patches = m(img) |
|
assert patches.shape == (1, 2, 1, 2, 3) |
|
assert_close(img[0, :, 0:2, 0:3], patches[0, 0]) |
|
assert_close(img[0, :, 1:3, 0:3], patches[0, 1]) |
|
|
|
def test_b1_ch1_h3w4_ws23(self, device): |
|
img = torch.arange(12.0).view(1, 1, 3, 4).to(device) |
|
m = kornia.contrib.ExtractTensorPatches((2, 3)) |
|
patches = m(img) |
|
assert patches.shape == (1, 4, 1, 2, 3) |
|
assert_close(img[0, :, 0:2, 0:3], patches[0, 0]) |
|
assert_close(img[0, :, 0:2, 1:4], patches[0, 1]) |
|
assert_close(img[0, :, 1:3, 0:3], patches[0, 2]) |
|
assert_close(img[0, :, 1:3, 1:4], patches[0, 3]) |
|
|
|
@pytest.mark.skip(reason="turn off all jit for a while") |
|
def test_jit(self, device): |
|
@torch.jit.script |
|
def op_script(img: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
return kornia.geometry.denormalize_pixel_coordinates(img, height, width) |
|
|
|
height, width = 3, 4 |
|
grid = kornia.utils.create_meshgrid(height, width, normalized_coordinates=True).to(device) |
|
|
|
actual = op_script(grid, height, width) |
|
expected = kornia.denormalize_pixel_coordinates(grid, height, width) |
|
|
|
assert_close(actual, expected) |
|
|
|
def test_gradcheck(self, device): |
|
img = torch.rand(2, 3, 4, 4).to(device) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
assert gradcheck(kornia.contrib.extract_tensor_patches, (img, 3), raise_exception=True) |
|
|
|
|
|
class TestCombineTensorPatches: |
|
def test_smoke(self, device, dtype): |
|
img = torch.arange(16, device=device, dtype=dtype).view(1, 1, 4, 4) |
|
m = kornia.contrib.CombineTensorPatches((2, 2)) |
|
patches = kornia.contrib.extract_tensor_patches(img, window_size=(2, 2), stride=(2, 2)) |
|
assert m(patches).shape == (1, 1, 4, 4) |
|
assert (img == m(patches)).all() |
|
|
|
def test_error(self, device, dtype): |
|
patches = kornia.contrib.extract_tensor_patches( |
|
torch.arange(16, device=device, dtype=dtype).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2), padding=1 |
|
) |
|
with pytest.raises(NotImplementedError): |
|
kornia.contrib.combine_tensor_patches(patches, window_size=(2, 2), stride=(3, 2)) |
|
|
|
def test_padding1(self, device, dtype): |
|
img = torch.arange(16, device=device, dtype=dtype).view(1, 1, 4, 4) |
|
patches = kornia.contrib.extract_tensor_patches(img, window_size=(2, 2), stride=(2, 2), padding=1) |
|
m = kornia.contrib.CombineTensorPatches((2, 2), unpadding=1) |
|
assert m(patches).shape == (1, 1, 4, 4) |
|
assert (img == m(patches)).all() |
|
|
|
def test_gradcheck(self, device, dtype): |
|
patches = kornia.contrib.extract_tensor_patches( |
|
torch.arange(16.0, device=device, dtype=dtype).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2) |
|
) |
|
img = utils.tensor_to_gradcheck_var(patches) |
|
assert gradcheck(kornia.contrib.combine_tensor_patches, (img, (2, 2), (2, 2)), raise_exception=True) |
|
|
|
|
|
class TestLambdaModule: |
|
def add_2_layer(self, tensor): |
|
return tensor + 2 |
|
|
|
def add_x_mul_y(self, tensor, x, y=2): |
|
return torch.mul(tensor + x, y) |
|
|
|
def test_smoke(self, device, dtype): |
|
B, C, H, W = 1, 3, 4, 5 |
|
img = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
func = self.add_2_layer |
|
if not callable(func): |
|
raise TypeError(f"Argument lambd should be callable, got {repr(type(func).__name__)}") |
|
assert isinstance(kornia.contrib.Lambda(func)(img), torch.Tensor) |
|
|
|
@pytest.mark.parametrize("x", [3, 2, 5]) |
|
def test_lambda_with_arguments(self, x, device, dtype): |
|
B, C, H, W = 2, 3, 5, 7 |
|
img = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
func = self.add_x_mul_y |
|
lambda_module = kornia.contrib.Lambda(func) |
|
out = lambda_module(img, x) |
|
assert isinstance(out, torch.Tensor) |
|
|
|
@pytest.mark.parametrize("shape", [(1, 3, 2, 3), (2, 3, 5, 7)]) |
|
def test_lambda(self, shape, device, dtype): |
|
B, C, H, W = shape |
|
img = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
func = kornia.color.bgr_to_grayscale |
|
lambda_module = kornia.contrib.Lambda(func) |
|
out = lambda_module(img) |
|
assert isinstance(out, torch.Tensor) |
|
|
|
def test_gradcheck(self, device, dtype): |
|
B, C, H, W = 1, 3, 4, 5 |
|
img = torch.rand(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True) |
|
func = kornia.color.bgr_to_grayscale |
|
assert gradcheck(kornia.contrib.Lambda(func), (img,), raise_exception=True) |
|
|
|
|
|
class TestImageStitcher: |
|
@pytest.mark.parametrize("estimator", ['ransac', 'vanilla']) |
|
def test_smoke(self, estimator, device, dtype): |
|
B, C, H, W = 1, 3, 224, 224 |
|
input1 = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
input2 = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
return_value = { |
|
"keypoints0": torch.rand((15, 2), device=device, dtype=dtype), |
|
"keypoints1": torch.rand((15, 2), device=device, dtype=dtype), |
|
"confidence": torch.rand((15,), device=device, dtype=dtype), |
|
"batch_indexes": torch.zeros((15,), device=device, dtype=dtype), |
|
} |
|
with patch( |
|
'kornia.contrib.ImageStitcher.on_matcher', new_callable=PropertyMock, return_value=lambda x: return_value |
|
): |
|
|
|
|
|
matcher = kornia.feature.LoFTR(None) |
|
stitcher = kornia.contrib.ImageStitcher(matcher, estimator=estimator).to(device=device, dtype=dtype) |
|
out = stitcher(input1, input2) |
|
assert out.shape[:-1] == torch.Size([1, 3, 224]) |
|
assert out.shape[-1] <= 448 |
|
|
|
def test_exception(self, device, dtype): |
|
B, C, H, W = 1, 3, 224, 224 |
|
input1 = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
input2 = torch.rand(B, C, H, W, device=device, dtype=dtype) |
|
|
|
matcher = kornia.feature.LoFTR(None) |
|
|
|
with pytest.raises(NotImplementedError): |
|
stitcher = kornia.contrib.ImageStitcher(matcher, estimator='random').to(device=device, dtype=dtype) |
|
|
|
stitcher = kornia.contrib.ImageStitcher(matcher).to(device=device, dtype=dtype) |
|
with pytest.raises(RuntimeError): |
|
stitcher(input1, input2) |
|
|