compvis / test /test_contrib.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
from unittest.mock import patch, PropertyMock
import pytest
import torch
from torch.autograd import gradcheck
import kornia
import kornia.testing as utils # test utils
from kornia.testing import assert_close
from packaging import version
class TestVisionTransformer:
@pytest.mark.parametrize("B", [1, 2])
@pytest.mark.parametrize("H", [1, 3, 8])
@pytest.mark.parametrize("D", [128, 768])
@pytest.mark.parametrize("image_size", [32, 224])
def test_smoke(self, device, dtype, B, H, D, image_size):
patch_size = 16
T = image_size ** 2 // patch_size ** 2 + 1 # tokens size
img = torch.rand(B, 3, image_size, image_size, device=device, dtype=dtype)
vit = kornia.contrib.VisionTransformer(image_size=image_size, num_heads=H, embed_dim=D).to(device, dtype)
out = vit(img)
assert isinstance(out, torch.Tensor) and out.shape == (B, T, D)
feats = vit.encoder_results
assert isinstance(feats, list) and len(feats) == 12
for f in feats:
assert f.shape == (B, T, D)
def test_backbone(self, device, dtype):
def backbone_mock(x):
return torch.ones(1, 128, 14, 14, device=device, dtype=dtype)
img = torch.rand(1, 3, 32, 32, device=device, dtype=dtype)
vit = kornia.contrib.VisionTransformer(backbone=backbone_mock).to(device, dtype)
out = vit(img)
assert out.shape == (1, 197, 128)
class TestMobileViT:
@pytest.mark.parametrize("B", [1, 2])
@pytest.mark.parametrize("image_size", [(256, 256)])
@pytest.mark.parametrize("mode", ['xxs', 'xs', 's'])
@pytest.mark.parametrize("patch_size", [(2, 2)])
def test_smoke(self, device, dtype, B, image_size, mode, patch_size):
ih, iw = image_size
channel = {'xxs': 320, 'xs': 384, 's': 640}
img = torch.rand(B, 3, ih, iw, device=device, dtype=dtype)
mvit = kornia.contrib.MobileViT(mode=mode, patch_size=patch_size).to(device, dtype)
out = mvit(img)
assert isinstance(out, torch.Tensor) and out.shape == (B, channel[mode], 8, 8)
class TestClassificationHead:
@pytest.mark.parametrize("B, D, N", [(1, 8, 10), (2, 2, 5)])
def test_smoke(self, device, dtype, B, D, N):
feat = torch.rand(B, D, D, device=device, dtype=dtype)
head = kornia.contrib.ClassificationHead(embed_size=D, num_classes=N).to(device, dtype)
logits = head(feat)
assert logits.shape == (B, N)
class TestConnectedComponents:
def test_smoke(self, device, dtype):
img = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
out = kornia.contrib.connected_components(img, num_iterations=10)
assert out.shape == (1, 1, 3, 4)
@pytest.mark.parametrize("shape", [(1, 3, 4), (2, 1, 3, 4)])
def test_cardinality(self, device, dtype, shape):
img = torch.rand(shape, device=device, dtype=dtype)
out = kornia.contrib.connected_components(img, num_iterations=10)
assert out.shape == shape
def test_exception(self, device, dtype):
img = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
with pytest.raises(TypeError):
assert kornia.contrib.connected_components(img, 1.0)
with pytest.raises(TypeError):
assert kornia.contrib.connected_components(img, 0)
with pytest.raises(ValueError):
img = torch.rand(1, 2, 3, 4, device=device, dtype=dtype)
assert kornia.contrib.connected_components(img, 2)
def test_value(self, device, dtype):
img = torch.tensor(
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 0.0],
]
]
],
device=device,
dtype=dtype,
)
expected = torch.tensor(
[
[
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 14.0, 14.0, 0.0, 0.0, 11.0],
[0.0, 14.0, 14.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 34.0, 34.0, 0.0],
[0.0, 0.0, 0.0, 34.0, 34.0, 0.0],
]
]
],
device=device,
dtype=dtype,
)
out = kornia.contrib.connected_components(img, num_iterations=10)
assert_close(out, expected)
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("1.9"), reason="Tuple cannot be used with PyTorch < v1.9"
)
def test_gradcheck(self, device, dtype):
B, C, H, W = 2, 1, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True)
assert gradcheck(kornia.contrib.connected_components, (img,), raise_exception=True)
def test_jit(self, device, dtype):
B, C, H, W = 2, 1, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=dtype)
op = kornia.contrib.connected_components
op_jit = torch.jit.script(op)
assert_close(op(img), op_jit(img))
class TestExtractTensorPatches:
def test_smoke(self, device):
img = torch.arange(16.0, device=device).view(1, 1, 4, 4)
m = kornia.contrib.ExtractTensorPatches(3)
assert m(img).shape == (1, 4, 1, 3, 3)
def test_b1_ch1_h4w4_ws3(self, device):
img = torch.arange(16.0, device=device).view(1, 1, 4, 4)
m = kornia.contrib.ExtractTensorPatches(3)
patches = m(img)
assert patches.shape == (1, 4, 1, 3, 3)
assert_close(img[0, :, :3, :3], patches[0, 0])
assert_close(img[0, :, :3, 1:], patches[0, 1])
assert_close(img[0, :, 1:, :3], patches[0, 2])
assert_close(img[0, :, 1:, 1:], patches[0, 3])
def test_b1_ch2_h4w4_ws3(self, device):
img = torch.arange(16.0, device=device).view(1, 1, 4, 4)
img = img.expand(-1, 2, -1, -1) # copy all channels
m = kornia.contrib.ExtractTensorPatches(3)
patches = m(img)
assert patches.shape == (1, 4, 2, 3, 3)
assert_close(img[0, :, :3, :3], patches[0, 0])
assert_close(img[0, :, :3, 1:], patches[0, 1])
assert_close(img[0, :, 1:, :3], patches[0, 2])
assert_close(img[0, :, 1:, 1:], patches[0, 3])
def test_b1_ch1_h4w4_ws2(self, device):
img = torch.arange(16.0, device=device).view(1, 1, 4, 4)
m = kornia.contrib.ExtractTensorPatches(2)
patches = m(img)
assert patches.shape == (1, 9, 1, 2, 2)
assert_close(img[0, :, 0:2, 1:3], patches[0, 1])
assert_close(img[0, :, 0:2, 2:4], patches[0, 2])
assert_close(img[0, :, 1:3, 1:3], patches[0, 4])
assert_close(img[0, :, 2:4, 1:3], patches[0, 7])
def test_b1_ch1_h4w4_ws2_stride2(self, device):
img = torch.arange(16.0, device=device).view(1, 1, 4, 4)
m = kornia.contrib.ExtractTensorPatches(2, stride=2)
patches = m(img)
assert patches.shape == (1, 4, 1, 2, 2)
assert_close(img[0, :, 0:2, 0:2], patches[0, 0])
assert_close(img[0, :, 0:2, 2:4], patches[0, 1])
assert_close(img[0, :, 2:4, 0:2], patches[0, 2])
assert_close(img[0, :, 2:4, 2:4], patches[0, 3])
def test_b1_ch1_h4w4_ws2_stride21(self, device):
img = torch.arange(16.0, device=device).view(1, 1, 4, 4)
m = kornia.contrib.ExtractTensorPatches(2, stride=(2, 1))
patches = m(img)
assert patches.shape == (1, 6, 1, 2, 2)
assert_close(img[0, :, 0:2, 1:3], patches[0, 1])
assert_close(img[0, :, 0:2, 2:4], patches[0, 2])
assert_close(img[0, :, 2:4, 0:2], patches[0, 3])
assert_close(img[0, :, 2:4, 2:4], patches[0, 5])
def test_b1_ch1_h3w3_ws2_stride1_padding1(self, device):
img = torch.arange(9.0).view(1, 1, 3, 3).to(device)
m = kornia.contrib.ExtractTensorPatches(2, stride=1, padding=1)
patches = m(img)
assert patches.shape == (1, 16, 1, 2, 2)
assert_close(img[0, :, 0:2, 0:2], patches[0, 5])
assert_close(img[0, :, 0:2, 1:3], patches[0, 6])
assert_close(img[0, :, 1:3, 0:2], patches[0, 9])
assert_close(img[0, :, 1:3, 1:3], patches[0, 10])
def test_b2_ch1_h3w3_ws2_stride1_padding1(self, device):
batch_size = 2
img = torch.arange(9.0).view(1, 1, 3, 3).to(device)
img = img.expand(batch_size, -1, -1, -1)
m = kornia.contrib.ExtractTensorPatches(2, stride=1, padding=1)
patches = m(img)
assert patches.shape == (batch_size, 16, 1, 2, 2)
for i in range(batch_size):
assert_close(img[i, :, 0:2, 0:2], patches[i, 5])
assert_close(img[i, :, 0:2, 1:3], patches[i, 6])
assert_close(img[i, :, 1:3, 0:2], patches[i, 9])
assert_close(img[i, :, 1:3, 1:3], patches[i, 10])
def test_b1_ch1_h3w3_ws23(self, device):
img = torch.arange(9.0).view(1, 1, 3, 3).to(device)
m = kornia.contrib.ExtractTensorPatches((2, 3))
patches = m(img)
assert patches.shape == (1, 2, 1, 2, 3)
assert_close(img[0, :, 0:2, 0:3], patches[0, 0])
assert_close(img[0, :, 1:3, 0:3], patches[0, 1])
def test_b1_ch1_h3w4_ws23(self, device):
img = torch.arange(12.0).view(1, 1, 3, 4).to(device)
m = kornia.contrib.ExtractTensorPatches((2, 3))
patches = m(img)
assert patches.shape == (1, 4, 1, 2, 3)
assert_close(img[0, :, 0:2, 0:3], patches[0, 0])
assert_close(img[0, :, 0:2, 1:4], patches[0, 1])
assert_close(img[0, :, 1:3, 0:3], patches[0, 2])
assert_close(img[0, :, 1:3, 1:4], patches[0, 3])
@pytest.mark.skip(reason="turn off all jit for a while")
def test_jit(self, device):
@torch.jit.script
def op_script(img: torch.Tensor, height: int, width: int) -> torch.Tensor:
return kornia.geometry.denormalize_pixel_coordinates(img, height, width)
height, width = 3, 4
grid = kornia.utils.create_meshgrid(height, width, normalized_coordinates=True).to(device)
actual = op_script(grid, height, width)
expected = kornia.denormalize_pixel_coordinates(grid, height, width)
assert_close(actual, expected)
def test_gradcheck(self, device):
img = torch.rand(2, 3, 4, 4).to(device)
img = utils.tensor_to_gradcheck_var(img) # to var
assert gradcheck(kornia.contrib.extract_tensor_patches, (img, 3), raise_exception=True)
class TestCombineTensorPatches:
def test_smoke(self, device, dtype):
img = torch.arange(16, device=device, dtype=dtype).view(1, 1, 4, 4)
m = kornia.contrib.CombineTensorPatches((2, 2))
patches = kornia.contrib.extract_tensor_patches(img, window_size=(2, 2), stride=(2, 2))
assert m(patches).shape == (1, 1, 4, 4)
assert (img == m(patches)).all()
def test_error(self, device, dtype):
patches = kornia.contrib.extract_tensor_patches(
torch.arange(16, device=device, dtype=dtype).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2), padding=1
)
with pytest.raises(NotImplementedError):
kornia.contrib.combine_tensor_patches(patches, window_size=(2, 2), stride=(3, 2))
def test_padding1(self, device, dtype):
img = torch.arange(16, device=device, dtype=dtype).view(1, 1, 4, 4)
patches = kornia.contrib.extract_tensor_patches(img, window_size=(2, 2), stride=(2, 2), padding=1)
m = kornia.contrib.CombineTensorPatches((2, 2), unpadding=1)
assert m(patches).shape == (1, 1, 4, 4)
assert (img == m(patches)).all()
def test_gradcheck(self, device, dtype):
patches = kornia.contrib.extract_tensor_patches(
torch.arange(16.0, device=device, dtype=dtype).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2)
)
img = utils.tensor_to_gradcheck_var(patches) # to var
assert gradcheck(kornia.contrib.combine_tensor_patches, (img, (2, 2), (2, 2)), raise_exception=True)
class TestLambdaModule:
def add_2_layer(self, tensor):
return tensor + 2
def add_x_mul_y(self, tensor, x, y=2):
return torch.mul(tensor + x, y)
def test_smoke(self, device, dtype):
B, C, H, W = 1, 3, 4, 5
img = torch.rand(B, C, H, W, device=device, dtype=dtype)
func = self.add_2_layer
if not callable(func):
raise TypeError(f"Argument lambd should be callable, got {repr(type(func).__name__)}")
assert isinstance(kornia.contrib.Lambda(func)(img), torch.Tensor)
@pytest.mark.parametrize("x", [3, 2, 5])
def test_lambda_with_arguments(self, x, device, dtype):
B, C, H, W = 2, 3, 5, 7
img = torch.rand(B, C, H, W, device=device, dtype=dtype)
func = self.add_x_mul_y
lambda_module = kornia.contrib.Lambda(func)
out = lambda_module(img, x)
assert isinstance(out, torch.Tensor)
@pytest.mark.parametrize("shape", [(1, 3, 2, 3), (2, 3, 5, 7)])
def test_lambda(self, shape, device, dtype):
B, C, H, W = shape
img = torch.rand(B, C, H, W, device=device, dtype=dtype)
func = kornia.color.bgr_to_grayscale
lambda_module = kornia.contrib.Lambda(func)
out = lambda_module(img)
assert isinstance(out, torch.Tensor)
def test_gradcheck(self, device, dtype):
B, C, H, W = 1, 3, 4, 5
img = torch.rand(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True)
func = kornia.color.bgr_to_grayscale
assert gradcheck(kornia.contrib.Lambda(func), (img,), raise_exception=True)
class TestImageStitcher:
@pytest.mark.parametrize("estimator", ['ransac', 'vanilla'])
def test_smoke(self, estimator, device, dtype):
B, C, H, W = 1, 3, 224, 224
input1 = torch.rand(B, C, H, W, device=device, dtype=dtype)
input2 = torch.rand(B, C, H, W, device=device, dtype=dtype)
return_value = {
"keypoints0": torch.rand((15, 2), device=device, dtype=dtype),
"keypoints1": torch.rand((15, 2), device=device, dtype=dtype),
"confidence": torch.rand((15,), device=device, dtype=dtype),
"batch_indexes": torch.zeros((15,), device=device, dtype=dtype),
}
with patch(
'kornia.contrib.ImageStitcher.on_matcher', new_callable=PropertyMock, return_value=lambda x: return_value
):
# NOTE: This will need to download the pretrained weights.
# To avoid that, we mock as below
matcher = kornia.feature.LoFTR(None)
stitcher = kornia.contrib.ImageStitcher(matcher, estimator=estimator).to(device=device, dtype=dtype)
out = stitcher(input1, input2)
assert out.shape[:-1] == torch.Size([1, 3, 224])
assert out.shape[-1] <= 448
def test_exception(self, device, dtype):
B, C, H, W = 1, 3, 224, 224
input1 = torch.rand(B, C, H, W, device=device, dtype=dtype)
input2 = torch.rand(B, C, H, W, device=device, dtype=dtype)
# NOTE: This will need to download the pretrained weights.
matcher = kornia.feature.LoFTR(None)
with pytest.raises(NotImplementedError):
stitcher = kornia.contrib.ImageStitcher(matcher, estimator='random').to(device=device, dtype=dtype)
stitcher = kornia.contrib.ImageStitcher(matcher).to(device=device, dtype=dtype)
with pytest.raises(RuntimeError):
stitcher(input1, input2)