compvis / test /enhance /test_zca.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import pytest
import torch
from torch.autograd import gradcheck
import kornia
import kornia.testing as utils # test utils
from kornia.testing import assert_close
class TestZCA:
@pytest.mark.parametrize("unbiased", [True, False])
def test_zca_unbiased(self, unbiased, device, dtype):
data = torch.tensor([[0, 1], [1, 0], [-1, 0], [0, -1]], device=device, dtype=dtype)
if unbiased:
unbiased_val = 1.5
else:
unbiased_val = 2.0
expected = torch.sqrt(unbiased_val * torch.abs(data)) * torch.sign(data)
zca = kornia.enhance.ZCAWhitening(unbiased=unbiased).fit(data)
actual = zca(data)
tol_val: float = utils._get_precision(device, dtype)
assert_close(actual, expected, rtol=tol_val, atol=tol_val)
@pytest.mark.parametrize("dim", [0, 1])
def test_dim_args(self, dim, device, dtype):
if 'xla' in device.type:
pytest.skip("buggy with XLA devices.")
data = torch.tensor([[0, 1], [1, 0], [-1, 0], [0, -1]], device=device, dtype=dtype)
if dim == 1:
expected = torch.tensor(
[
[-0.35360718, 0.35360718],
[0.35351562, -0.35351562],
[-0.35353088, 0.35353088],
[0.35353088, -0.35353088],
],
device=device,
dtype=dtype,
)
elif dim == 0:
expected = torch.tensor(
[[0.0, 1.2247448], [1.2247448, 0.0], [-1.2247448, 0.0], [0.0, -1.2247448]], device=device, dtype=dtype
)
zca = kornia.enhance.ZCAWhitening(dim=dim)
actual = zca(data, True)
tol_val: float = utils._get_precision(device, dtype)
assert_close(actual, expected, rtol=tol_val, atol=tol_val)
@pytest.mark.parametrize("input_shape,eps", [((15, 2, 2, 2), 1e-6), ((10, 4), 0.1), ((20, 3, 2, 2), 1e-3)])
def test_identity(self, input_shape, eps, device, dtype):
"""Assert that data can be recovered by the inverse transform."""
data = torch.randn(*input_shape, device=device, dtype=dtype)
zca = kornia.enhance.ZCAWhitening(compute_inv=True, eps=eps).fit(data)
data_w = zca(data)
data_hat = zca.inverse_transform(data_w)
tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1, 1e-4)
assert_close(data, data_hat, rtol=tol_val, atol=tol_val)
def test_grad_zca_individual_transforms(self, device, dtype):
"""Check if the gradients of the transforms are correct w.r.t to the input data."""
data = torch.tensor([[2, 0], [0, 1], [-2, 0], [0, -1]], device=device, dtype=dtype)
data = utils.tensor_to_gradcheck_var(data)
def zca_T(x):
return kornia.enhance.zca_mean(x)[0]
def zca_mu(x):
return kornia.enhance.zca_mean(x)[1]
def zca_T_inv(x):
return kornia.enhance.zca_mean(x, return_inverse=True)[2]
assert gradcheck(zca_T, (data,), raise_exception=True)
assert gradcheck(zca_mu, (data,), raise_exception=True)
assert gradcheck(zca_T_inv, (data,), raise_exception=True)
def test_grad_zca_with_fit(self, device, dtype):
data = torch.tensor([[2, 0], [0, 1], [-2, 0], [0, -1]], device=device, dtype=dtype)
data = utils.tensor_to_gradcheck_var(data)
def zca_fit(x):
zca = kornia.enhance.ZCAWhitening(detach_transforms=False)
return zca(x, include_fit=True)
assert gradcheck(zca_fit, (data,), raise_exception=True)
def test_grad_detach_zca(self, device, dtype):
data = torch.tensor([[1, 0], [0, 1], [-2, 0], [0, -1]], device=device, dtype=dtype)
data = utils.tensor_to_gradcheck_var(data)
zca = kornia.enhance.ZCAWhitening()
zca.fit(data)
assert gradcheck(zca, (data,), raise_exception=True)
def test_not_fitted(self, device, dtype):
with pytest.raises(RuntimeError):
data = torch.rand(10, 2, device=device, dtype=dtype)
zca = kornia.enhance.ZCAWhitening()
zca(data)
def test_not_fitted_inv(self, device, dtype):
with pytest.raises(RuntimeError):
data = torch.rand(10, 2, device=device, dtype=dtype)
zca = kornia.enhance.ZCAWhitening()
zca.inverse_transform(data)
def test_jit(self, device, dtype):
data = torch.rand(10, 3, 1, 2, device=device, dtype=dtype)
zca = kornia.enhance.ZCAWhitening().fit(data)
zca_jit = kornia.enhance.ZCAWhitening().fit(data)
zca_jit = torch.jit.script(zca_jit)
assert_close(zca_jit(data), zca(data))
@pytest.mark.parametrize("unbiased", [True, False])
def test_zca_whiten_func_unbiased(self, unbiased, device, dtype):
data = torch.tensor([[0, 1], [1, 0], [-1, 0], [0, -1]], device=device, dtype=dtype)
if unbiased:
unbiased_val = 1.5
else:
unbiased_val = 2.0
expected = torch.sqrt(unbiased_val * torch.abs(data)) * torch.sign(data)
actual = kornia.enhance.zca_whiten(data, unbiased=unbiased)
tol_val: float = utils._get_precision(device, dtype)
assert_close(actual, expected, atol=tol_val, rtol=tol_val)