|
import pytest |
|
import torch |
|
|
|
from kornia.testing import assert_close |
|
from kornia.utils import _extract_device_dtype |
|
from kornia.utils.helpers import ( |
|
_torch_histc_cast, |
|
_torch_inverse_cast, |
|
_torch_solve_cast, |
|
_torch_svd_cast, |
|
safe_inverse_with_mask, |
|
safe_solve_with_mask, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"tensor_list,out_device,out_dtype,will_throw_error", |
|
[ |
|
([], torch.device('cpu'), torch.get_default_dtype(), False), |
|
([None, None], torch.device('cpu'), torch.get_default_dtype(), False), |
|
([torch.tensor(0, device='cpu', dtype=torch.float16), None], torch.device('cpu'), torch.float16, False), |
|
([torch.tensor(0, device='cpu', dtype=torch.float32), None], torch.device('cpu'), torch.float32, False), |
|
([torch.tensor(0, device='cpu', dtype=torch.float64), None], torch.device('cpu'), torch.float64, False), |
|
([torch.tensor(0, device='cpu', dtype=torch.float16)] * 2, torch.device('cpu'), torch.float16, False), |
|
([torch.tensor(0, device='cpu', dtype=torch.float32)] * 2, torch.device('cpu'), torch.float32, False), |
|
([torch.tensor(0, device='cpu', dtype=torch.float64)] * 2, torch.device('cpu'), torch.float64, False), |
|
( |
|
[torch.tensor(0, device='cpu', dtype=torch.float16), torch.tensor(0, device='cpu', dtype=torch.float64)], |
|
None, |
|
None, |
|
True, |
|
), |
|
( |
|
[torch.tensor(0, device='cpu', dtype=torch.float32), torch.tensor(0, device='cpu', dtype=torch.float64)], |
|
None, |
|
None, |
|
True, |
|
), |
|
( |
|
[torch.tensor(0, device='cpu', dtype=torch.float16), torch.tensor(0, device='cpu', dtype=torch.float32)], |
|
None, |
|
None, |
|
True, |
|
), |
|
], |
|
) |
|
def test_extract_device_dtype(tensor_list, out_device, out_dtype, will_throw_error): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if will_throw_error: |
|
with pytest.raises(ValueError): |
|
_extract_device_dtype(tensor_list) |
|
else: |
|
device, dtype = _extract_device_dtype(tensor_list) |
|
assert device == out_device |
|
assert dtype == out_dtype |
|
|
|
|
|
class TestInverseCast: |
|
@pytest.mark.parametrize("input_shape", [(1, 3, 4, 4), (2, 4, 5, 5)]) |
|
def test_smoke(self, device, dtype, input_shape): |
|
x = torch.rand(input_shape, device=device, dtype=dtype) |
|
y = _torch_inverse_cast(x) |
|
assert y.shape == x.shape |
|
|
|
def test_values(self, device, dtype): |
|
x = torch.tensor([[4.0, 7.0], [2.0, 6.0]], device=device, dtype=dtype) |
|
|
|
y_expected = torch.tensor([[0.6, -0.7], [-0.2, 0.4]], device=device, dtype=dtype) |
|
|
|
y = _torch_inverse_cast(x) |
|
|
|
assert_close(y, y_expected) |
|
|
|
def test_jit(self, device, dtype): |
|
x = torch.rand(1, 3, 4, 4, device=device, dtype=dtype) |
|
op = _torch_inverse_cast |
|
op_jit = torch.jit.script(op) |
|
assert_close(op(x), op_jit(x)) |
|
|
|
|
|
class TestHistcCast: |
|
def test_smoke(self, device, dtype): |
|
x = torch.tensor([1.0, 2.0, 1.0], device=device, dtype=dtype) |
|
y_expected = torch.tensor([0.0, 2.0, 1.0, 0.0], device=device, dtype=dtype) |
|
|
|
y = _torch_histc_cast(x, bins=4, min=0, max=3) |
|
|
|
assert_close(y, y_expected) |
|
|
|
|
|
class TestSvdCast: |
|
def test_smoke(self, device, dtype): |
|
a = torch.randn(5, 3, 3, device=device, dtype=dtype) |
|
u, s, v = _torch_svd_cast(a) |
|
|
|
tol_val: float = 1e-1 if dtype == torch.float16 else 1e-3 |
|
assert_close(a, u @ torch.diag_embed(s) @ v.transpose(-2, -1), atol=tol_val, rtol=tol_val) |
|
|
|
|
|
class TestSolveCast: |
|
def test_smoke(self, device, dtype): |
|
A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype) |
|
B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype) |
|
|
|
X, _ = _torch_solve_cast(B, A) |
|
error = torch.dist(B, A.matmul(X)) |
|
|
|
tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4 |
|
assert_close(error, torch.zeros_like(error), atol=tol_val, rtol=tol_val) |
|
|
|
|
|
class TestSolveWithMask: |
|
def test_smoke(self, device, dtype): |
|
A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype) |
|
B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype) |
|
|
|
X, _, mask = safe_solve_with_mask(B, A) |
|
X2, _ = _torch_solve_cast(B, A) |
|
tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4 |
|
if mask.sum() > 0: |
|
assert_close(X[mask], X2[mask], atol=tol_val, rtol=tol_val) |
|
|
|
@pytest.mark.skipif( |
|
(int(torch.__version__.split('.')[0]) == 1) and (int(torch.__version__.split('.')[1]) < 10), |
|
reason='<1.10.0 not supporting', |
|
) |
|
def test_all_bad(self, device, dtype): |
|
A = torch.ones(10, 3, 3, device=device, dtype=dtype) |
|
B = torch.ones(3, 10, device=device, dtype=dtype) |
|
|
|
X, _, mask = safe_solve_with_mask(B, A) |
|
assert torch.equal(mask, torch.zeros_like(mask)) |
|
|
|
|
|
class TestInverseWithMask: |
|
def test_smoke(self, device, dtype): |
|
x = torch.tensor([[4.0, 7.0], [2.0, 6.0]], device=device, dtype=dtype) |
|
|
|
y_expected = torch.tensor([[0.6, -0.7], [-0.2, 0.4]], device=device, dtype=dtype) |
|
|
|
y, mask = safe_inverse_with_mask(x) |
|
|
|
assert_close(y, y_expected) |
|
assert torch.equal(mask, torch.ones_like(mask)) |
|
|
|
@pytest.mark.skipif( |
|
(int(torch.__version__.split('.')[0]) == 1) and (int(torch.__version__.split('.')[1]) < 9), |
|
reason='<1.9.0 not supporting', |
|
) |
|
def test_all_bad(self, device, dtype): |
|
A = torch.ones(10, 3, 3, device=device, dtype=dtype) |
|
X, mask = safe_inverse_with_mask(A) |
|
assert torch.equal(mask, torch.zeros_like(mask)) |
|
|