compvis / test /feature /test_mkd.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
from math import pi
import pytest
import torch
from torch.autograd import gradcheck
import kornia.testing as utils # test utils
from kornia.feature.mkd import (
COEFFS,
EmbedGradients,
ExplicitSpacialEncoding,
get_grid_dict,
get_kron_order,
MKDDescriptor,
MKDGradients,
SimpleKD,
spatial_kernel_embedding,
VonMisesKernel,
Whitening,
)
from kornia.testing import assert_close
@pytest.mark.parametrize("ps", [5, 13, 25])
def test_get_grid_dict(ps):
grid_dict = get_grid_dict(ps)
param_keys = ['x', 'y', 'phi', 'rho']
assert set(grid_dict.keys()) == set(param_keys)
for k in param_keys:
assert grid_dict[k].shape == (ps, ps)
@pytest.mark.parametrize("d1,d2", [(1, 1), (1, 2), (2, 1), (5, 6)])
def test_get_kron_order(d1, d2):
out = get_kron_order(d1, d2)
assert out.shape == (d1 * d2, 2)
class TestMKDGradients:
@pytest.mark.parametrize("ps", [5, 13, 25])
def test_shape(self, ps, device):
inp = torch.ones(1, 1, ps, ps).to(device)
gradients = MKDGradients().to(device)
out = gradients(inp)
assert out.shape == (1, 2, ps, ps)
@pytest.mark.parametrize("bs", [1, 5, 13])
def test_batch_shape(self, bs, device):
inp = torch.ones(bs, 1, 15, 15).to(device)
gradients = MKDGradients().to(device)
out = gradients(inp)
assert out.shape == (bs, 2, 15, 15)
def test_print(self, device):
gradients = MKDGradients().to(device)
gradients.__repr__()
def test_toy(self, device):
patch = torch.ones(1, 1, 6, 6).to(device).float()
patch[0, 0, :, 3:] = 0
gradients = MKDGradients().to(device)
out = gradients(patch)
expected_mags_1 = torch.Tensor([0, 0, 1.0, 1.0, 0, 0]).to(device)
expected_mags = expected_mags_1.unsqueeze(0).repeat(6, 1)
expected_oris_1 = torch.Tensor([-pi, -pi, 0, 0, -pi, -pi]).to(device)
expected_oris = expected_oris_1.unsqueeze(0).repeat(6, 1)
assert_close(out[0, 0, :, :], expected_mags, atol=1e-3, rtol=1e-3)
assert_close(out[0, 1, :, :], expected_oris, atol=1e-3, rtol=1e-3)
def test_gradcheck(self, device):
batch_size, channels, height, width = 1, 1, 13, 13
patches = torch.rand(batch_size, channels, height, width).to(device)
patches = utils.tensor_to_gradcheck_var(patches) # to var
def grad_describe(patches):
mkd_grads = MKDGradients()
mkd_grads.to(device)
return mkd_grads(patches)
assert gradcheck(grad_describe, (patches), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.jit
def test_jit(self, device, dtype):
B, C, H, W = 2, 1, 13, 13
patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
model = MKDGradients().to(patches.device, patches.dtype).eval()
model_jit = torch.jit.script(MKDGradients().to(patches.device, patches.dtype).eval())
assert_close(model(patches), model_jit(patches))
class TestVonMisesKernel:
@pytest.mark.parametrize("ps", [5, 13, 25])
def test_shape(self, ps, device):
inp = torch.ones(1, 1, ps, ps).to(device)
vm = VonMisesKernel(patch_size=ps, coeffs=[0.38214156, 0.48090413]).to(device)
out = vm(inp)
assert out.shape == (1, 3, ps, ps)
@pytest.mark.parametrize("bs", [1, 5, 13])
def test_batch_shape(self, bs, device):
inp = torch.ones(bs, 1, 15, 15).to(device)
vm = VonMisesKernel(patch_size=15, coeffs=[0.38214156, 0.48090413]).to(device)
out = vm(inp)
assert out.shape == (bs, 3, 15, 15)
@pytest.mark.parametrize("coeffs", COEFFS.values())
def test_coeffs(self, coeffs, device):
inp = torch.ones(1, 1, 15, 15).to(device)
vm = VonMisesKernel(patch_size=15, coeffs=coeffs).to(device)
out = vm(inp)
assert out.shape == (1, 2 * len(coeffs) - 1, 15, 15)
def test_print(self, device):
vm = VonMisesKernel(patch_size=32, coeffs=[0.38214156, 0.48090413]).to(device)
vm.__repr__()
def test_toy(self, device):
patch = torch.ones(1, 1, 6, 6).float().to(device)
patch[0, 0, :, 3:] = 0
vm = VonMisesKernel(patch_size=6, coeffs=[0.38214156, 0.48090413]).to(device)
out = vm(patch)
expected = torch.ones_like(out[0, 0, :, :]).to(device)
assert_close(out[0, 0, :, :], expected * 0.6182, atol=1e-3, rtol=1e-3)
expected = torch.Tensor([0.3747, 0.3747, 0.3747, 0.6935, 0.6935, 0.6935]).to(device)
expected = expected.unsqueeze(0).repeat(6, 1)
assert_close(out[0, 1, :, :], expected, atol=1e-3, rtol=1e-3)
expected = torch.Tensor([0.5835, 0.5835, 0.5835, 0.0000, 0.0000, 0.0000]).to(device)
expected = expected.unsqueeze(0).repeat(6, 1)
assert_close(out[0, 2, :, :], expected, atol=1e-3, rtol=1e-3)
def test_gradcheck(self, device):
batch_size, channels, ps = 1, 1, 13
patches = torch.rand(batch_size, channels, ps, ps).to(device)
patches = utils.tensor_to_gradcheck_var(patches) # to var
def vm_describe(patches, ps=13):
vmkernel = VonMisesKernel(patch_size=ps, coeffs=[0.38214156, 0.48090413]).double()
vmkernel.to(device)
return vmkernel(patches.double())
assert gradcheck(vm_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.jit
def test_jit(self, device, dtype):
B, C, H, W = 2, 1, 13, 13
patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
model = VonMisesKernel(patch_size=13, coeffs=[0.38214156, 0.48090413]).to(patches.device, patches.dtype).eval()
model_jit = torch.jit.script(
VonMisesKernel(patch_size=13, coeffs=[0.38214156, 0.48090413]).to(patches.device, patches.dtype).eval()
)
assert_close(model(patches), model_jit(patches))
class TestEmbedGradients:
@pytest.mark.parametrize("ps,relative", [(5, True), (13, True), (25, True), (5, False), (13, False), (25, False)])
def test_shape(self, ps, relative, device):
inp = torch.ones(1, 2, ps, ps).to(device)
emb_grads = EmbedGradients(patch_size=ps, relative=relative).to(device)
out = emb_grads(inp)
assert out.shape == (1, 7, ps, ps)
@pytest.mark.parametrize("bs", [1, 5, 13])
def test_batch_shape(self, bs, device):
inp = torch.ones(bs, 2, 15, 15).to(device)
emb_grads = EmbedGradients(patch_size=15, relative=True).to(device)
out = emb_grads(inp)
assert out.shape == (bs, 7, 15, 15)
def test_print(self, device):
emb_grads = EmbedGradients(patch_size=15, relative=True).to(device)
emb_grads.__repr__()
def test_toy(self, device):
grads = torch.ones(1, 2, 6, 6).float().to(device)
grads[0, 0, :, 3:] = 0
emb_grads = EmbedGradients(patch_size=6, relative=True).to(device)
out = emb_grads(grads)
expected = torch.ones_like(out[0, 0, :, :3]).to(device)
assert_close(out[0, 0, :, :3], expected * 0.3787, atol=1e-3, rtol=1e-3)
assert_close(out[0, 0, :, 3:], expected * 0, atol=1e-3, rtol=1e-3)
# TODO: review this test implementation
@pytest.mark.xfail(reason="RuntimeError: Jacobian mismatch for output 0 with respect to input 0,")
def test_gradcheck(self, device):
batch_size, channels, ps = 1, 2, 13
patches = torch.rand(batch_size, channels, ps, ps).to(device)
patches = utils.tensor_to_gradcheck_var(patches) # to var
def emb_grads_describe(patches, ps=13):
emb_grads = EmbedGradients(patch_size=ps, relative=True).double()
emb_grads.to(device)
return emb_grads(patches.double())
assert gradcheck(emb_grads_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.jit
def test_jit(self, device, dtype):
B, C, H, W = 2, 2, 13, 13
patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
model = EmbedGradients(patch_size=W, relative=True).to(patches.device, patches.dtype).eval()
model_jit = torch.jit.script(
EmbedGradients(patch_size=W, relative=True).to(patches.device, patches.dtype).eval()
)
assert_close(model(patches), model_jit(patches))
@pytest.mark.parametrize("kernel_type,d,ps", [('cart', 9, 9), ('polar', 25, 9), ('cart', 9, 16), ('polar', 25, 16)])
def test_spatial_kernel_embedding(kernel_type, ps, d):
grids = get_grid_dict(ps)
spatial_kernel = spatial_kernel_embedding(kernel_type, grids)
assert spatial_kernel.shape == (d, ps, ps)
class TestExplicitSpacialEncoding:
@pytest.mark.parametrize(
"kernel_type,ps,in_dims", [('cart', 9, 3), ('polar', 9, 3), ('cart', 13, 7), ('polar', 13, 7)]
)
def test_shape(self, kernel_type, ps, in_dims, device):
inp = torch.ones(1, in_dims, ps, ps).to(device)
ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=ps, in_dims=in_dims).to(device)
out = ese(inp)
d_ = 9 if kernel_type == 'cart' else 25
assert out.shape == (1, d_ * in_dims)
@pytest.mark.parametrize(
"kernel_type,bs", [('cart', 1), ('cart', 5), ('cart', 13), ('polar', 1), ('polar', 5), ('polar', 13)]
)
def test_batch_shape(self, kernel_type, bs, device):
inp = torch.ones(bs, 7, 15, 15).to(device)
ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=15, in_dims=7).to(device)
out = ese(inp)
d_ = 9 if kernel_type == 'cart' else 25
assert out.shape == (bs, d_ * 7)
@pytest.mark.parametrize("kernel_type", ['cart', 'polar'])
def test_print(self, kernel_type, device):
ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=15, in_dims=7).to(device)
ese.__repr__()
def test_toy(self, device):
inp = torch.ones(1, 2, 6, 6).to(device).float()
inp[0, 0, :, :] = 0
cart_ese = ExplicitSpacialEncoding(kernel_type='cart', fmap_size=6, in_dims=2).to(device)
out = cart_ese(inp)
out_part = out[:, :9]
expected = torch.zeros_like(out_part).to(device)
assert_close(out_part, expected, atol=1e-3, rtol=1e-3)
polar_ese = ExplicitSpacialEncoding(kernel_type='polar', fmap_size=6, in_dims=2).to(device)
out = polar_ese(inp)
out_part = out[:, :25]
expected = torch.zeros_like(out_part).to(device)
assert_close(out_part, expected, atol=1e-3, rtol=1e-3)
@pytest.mark.parametrize("kernel_type", ['cart', 'polar'])
def test_gradcheck(self, kernel_type, device):
batch_size, channels, ps = 1, 2, 13
patches = torch.rand(batch_size, channels, ps, ps).to(device)
patches = utils.tensor_to_gradcheck_var(patches) # to var
def explicit_spatial_describe(patches, ps=13):
ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=ps, in_dims=2)
ese.to(device)
return ese(patches)
assert gradcheck(explicit_spatial_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.jit
def test_jit(self, device, dtype):
B, C, H, W = 2, 2, 13, 13
patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
model = (
ExplicitSpacialEncoding(kernel_type='cart', fmap_size=W, in_dims=2).to(patches.device, patches.dtype).eval()
)
model_jit = torch.jit.script(
ExplicitSpacialEncoding(kernel_type='cart', fmap_size=W, in_dims=2).to(patches.device, patches.dtype).eval()
)
assert_close(model(patches), model_jit(patches))
class TestWhitening:
@pytest.mark.parametrize(
"kernel_type,xform,output_dims",
[
('cart', None, 3),
('polar', None, 3),
('cart', 'lw', 7),
('polar', 'lw', 7),
('cart', 'pca', 9),
('polar', 'pca', 9),
],
)
def test_shape(self, kernel_type, xform, output_dims, device):
in_dims = 63 if kernel_type == 'cart' else 175
wh = Whitening(xform=xform, whitening_model=None, in_dims=in_dims, output_dims=output_dims).to(device)
inp = torch.ones(1, in_dims).to(device)
out = wh(inp)
assert out.shape == (1, output_dims)
@pytest.mark.parametrize("bs", [1, 3, 7])
def test_batch_shape(self, bs, device):
wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=128).to(device)
inp = torch.ones(bs, 175).to(device)
out = wh(inp)
assert out.shape == (bs, 128)
def test_print(self, device):
wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=128).to(device)
wh.__repr__()
def test_toy(self, device):
wh = Whitening(xform='lw', whitening_model=None, in_dims=175, output_dims=175).to(device)
inp = torch.ones(1, 175).to(device).float()
out = wh(inp)
expected = torch.ones_like(inp).to(device) * 0.0756
assert_close(out, expected, atol=1e-3, rtol=1e-3)
def test_gradcheck(self, device):
batch_size, in_dims = 1, 175
patches = torch.rand(batch_size, in_dims).to(device)
patches = utils.tensor_to_gradcheck_var(patches) # to var
def whitening_describe(patches, in_dims=175):
wh = Whitening(xform='lw', whitening_model=None, in_dims=in_dims).double()
wh.to(device)
return wh(patches.double())
assert gradcheck(whitening_describe, (patches, in_dims), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.jit
def test_jit(self, device, dtype):
batch_size, in_dims = 1, 175
patches = torch.rand(batch_size, in_dims).to(device)
model = Whitening(xform='lw', whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval()
model_jit = torch.jit.script(
Whitening(xform='lw', whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval()
)
assert_close(model(patches), model_jit(patches))
class TestMKDDescriptor:
dims = {'cart': 63, 'polar': 175, 'concat': 238}
@pytest.mark.parametrize(
"ps,kernel_type", [(9, 'concat'), (9, 'cart'), (9, 'polar'), (32, 'concat'), (32, 'cart'), (32, 'polar')]
)
def test_shape(self, ps, kernel_type, device):
mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=None).to(device)
inp = torch.ones(1, 1, ps, ps).to(device)
out = mkd(inp)
assert out.shape == (1, self.dims[kernel_type])
@pytest.mark.parametrize(
"ps,kernel_type,whitening",
[
(9, 'concat', 'lw'),
(9, 'cart', 'lw'),
(9, 'polar', 'lw'),
(9, 'concat', 'pcawt'),
(9, 'cart', 'pcawt'),
(9, 'polar', 'pcawt'),
],
)
def test_whitened_shape(self, ps, kernel_type, whitening, device):
mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=whitening).to(device)
inp = torch.ones(1, 1, ps, ps).to(device)
out = mkd(inp)
output_dims = min(self.dims[kernel_type], 128)
assert out.shape == (1, output_dims)
@pytest.mark.parametrize("bs", [1, 3, 7])
def test_batch_shape(self, bs, device):
mkd = MKDDescriptor(patch_size=19, kernel_type='concat', whitening=None).to(device)
inp = torch.ones(bs, 1, 19, 19).to(device)
out = mkd(inp)
assert out.shape == (bs, 238)
def test_print(self, device):
mkd = MKDDescriptor(patch_size=32, whitening='lw', training_set='liberty', output_dims=128).to(device)
mkd.__repr__()
def test_toy(self, device):
inp = torch.ones(1, 1, 6, 6).to(device).float()
inp[0, 0, :, :] = 0
mkd = MKDDescriptor(patch_size=6, kernel_type='concat', whitening=None).to(device)
out = mkd(inp)
out_part = out[0, -28:]
expected = torch.zeros_like(out_part).to(device)
assert_close(out_part, expected, atol=1e-3, rtol=1e-3)
@pytest.mark.skip("Just because")
@pytest.mark.parametrize("whitening", [None, 'lw', 'pca'])
def test_gradcheck(self, whitening, device):
batch_size, channels, ps = 1, 1, 19
patches = torch.rand(batch_size, channels, ps, ps).to(device)
patches = utils.tensor_to_gradcheck_var(patches) # to var
def mkd_describe(patches, patch_size=19):
mkd = MKDDescriptor(patch_size=patch_size, kernel_type='concat', whitening=whitening).double()
mkd.to(device)
return mkd(patches.double())
assert gradcheck(mkd_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.skip("neither dict, nor nn.ModuleDict works")
@pytest.mark.jit
def test_jit(self, device, dtype):
batch_size, channels, ps = 1, 1, 19
patches = torch.rand(batch_size, channels, ps, ps).to(device)
kt = 'concat'
wt = 'lw'
model = MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval()
model_jit = torch.jit.script(
MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval()
)
assert_close(model(patches), model_jit(patches))
class TestSimpleKD:
dims = {'cart': 63, 'polar': 175}
@pytest.mark.parametrize("ps,kernel_type", [(9, 'cart'), (9, 'polar'), (32, 'cart'), (32, 'polar')])
def test_shape(self, ps, kernel_type, device):
skd = SimpleKD(patch_size=ps, kernel_type=kernel_type).to(device)
inp = torch.ones(1, 1, ps, ps).to(device)
out = skd(inp)
assert out.shape == (1, min(128, self.dims[kernel_type]))
@pytest.mark.parametrize("bs", [1, 3, 7])
def test_batch_shape(self, bs, device):
skd = SimpleKD(patch_size=19, kernel_type='polar').to(device)
inp = torch.ones(bs, 1, 19, 19).to(device)
out = skd(inp)
assert out.shape == (bs, 128)
def test_print(self, device):
skd = SimpleKD(patch_size=19, kernel_type='polar').to(device)
skd.__repr__()
def test_gradcheck(self, device):
batch_size, channels, ps = 1, 1, 19
patches = torch.rand(batch_size, channels, ps, ps).to(device)
patches = utils.tensor_to_gradcheck_var(patches) # to var
def skd_describe(patches, patch_size=19):
skd = SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').double()
skd.to(device)
return skd(patches.double())
assert gradcheck(skd_describe, (patches, ps), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.jit
def test_jit(self, device, dtype):
batch_size, channels, ps = 1, 1, 19
patches = torch.rand(batch_size, channels, ps, ps).to(device)
model = SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').to(patches.device, patches.dtype).eval()
model_jit = torch.jit.script(
SimpleKD(patch_size=ps, kernel_type='polar', whitening='lw').to(patches.device, patches.dtype).eval()
)
assert_close(model(patches), model_jit(patches))