|
import warnings |
|
|
|
import pytest |
|
import torch |
|
from torch.autograd import gradcheck |
|
|
|
import kornia |
|
from kornia.testing import BaseTester |
|
from kornia.testing import assert_close |
|
from packaging import version |
|
|
|
|
|
class TestRawToRgb(BaseTester): |
|
def test_smoke(self, device, dtype): |
|
C, H, W = 1, 4, 6 |
|
img = torch.rand(C, H, W, device=device, dtype=dtype) |
|
assert isinstance(kornia.color.raw_to_rgb(img, kornia.color.CFA.BG), torch.Tensor) |
|
|
|
@pytest.mark.parametrize("batch_size, height, width", [(1, 6, 4), (2, 2, 4), (3, 4, 2)]) |
|
def test_cardinality(self, device, dtype, batch_size, height, width): |
|
img = torch.ones(batch_size, 1, height, width, device=device, dtype=dtype) |
|
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.BG).shape == (batch_size, 3, height, width) |
|
|
|
def test_exception(self, device, dtype): |
|
with pytest.raises(TypeError): |
|
assert kornia.color.raw_to_rgb([0.0], kornia.color.CFA.BG) |
|
|
|
with pytest.raises(ValueError): |
|
img = torch.ones(1, 1, device=device, dtype=dtype) |
|
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GB) |
|
|
|
with pytest.raises(ValueError): |
|
img = torch.ones(2, 1, 1, device=device, dtype=dtype) |
|
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.RG) |
|
|
|
with pytest.raises(ValueError): |
|
img = torch.ones(1, 3, 1, 1, device=device, dtype=dtype) |
|
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR) |
|
|
|
|
|
with pytest.raises(ValueError): |
|
img = torch.ones(3, 2, 1, device=device, dtype=dtype) |
|
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR) |
|
|
|
|
|
with pytest.raises(ValueError): |
|
img = torch.ones(3, 1, 2, device=device, dtype=dtype) |
|
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR) |
|
|
|
|
|
|
|
def test_forth_and_back(self, device, dtype): |
|
data = torch.rand(1, 80, 80, device=device, dtype=dtype) |
|
raw = kornia.color.rgb_to_raw |
|
rgb = kornia.color.raw_to_rgb |
|
|
|
for x in kornia.color.CFA: |
|
data_out = raw(rgb(data, cfa=x), cfa=x) |
|
assert_close(data_out, data) |
|
|
|
|
|
def test_cfas_not_the_same(self, device, dtype): |
|
data = torch.rand(1, 16, 16, device=device, dtype=dtype) |
|
assert ( |
|
torch.max( |
|
kornia.color.raw_to_rgb(data, kornia.color.CFA.BG) - kornia.color.raw_to_rgb(data, kornia.color.CFA.RG) |
|
) |
|
> 0.0 |
|
) |
|
|
|
|
|
|
|
def test_functional(self, device, dtype): |
|
data = torch.tensor( |
|
[[[1, 0.5, 0.2, 0.4], [0.75, 0.25, 0.8, 0.3], [0.65, 0.15, 0.7, 0.2], [0.55, 0.5, 0.6, 0.1]]], |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
[ |
|
[ |
|
[1.0000, 0.6000, 0.2000, 0.2000], |
|
[0.8250, 0.6375, 0.4500, 0.4500], |
|
[0.6500, 0.6750, 0.7000, 0.7000], |
|
[0.6500, 0.6750, 0.7000, 0.7000], |
|
], |
|
[ |
|
[0.6250, 0.5000, 0.6250, 0.4000], |
|
[0.7500, 0.5500, 0.8000, 0.5500], |
|
[0.4000, 0.1500, 0.4375, 0.2000], |
|
[0.5500, 0.3625, 0.6000, 0.4000], |
|
], |
|
[ |
|
[0.2500, 0.2500, 0.2750, 0.3000], |
|
[0.2500, 0.2500, 0.2750, 0.3000], |
|
[0.3750, 0.3750, 0.2875, 0.2000], |
|
[0.5000, 0.5000, 0.3000, 0.1000], |
|
], |
|
], |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
img_rgb = kornia.color.raw_to_rgb(data, kornia.color.raw.CFA.BG) |
|
assert_close(img_rgb, expected) |
|
|
|
|
|
def test_cfa_on_rolled(self, device, dtype): |
|
data = torch.rand(1, 1, 8, 8, device=device, dtype=dtype) |
|
bgres = kornia.color.raw_to_rgb(data, kornia.color.raw.CFA.BG) |
|
gbres = kornia.color.raw_to_rgb(data.roll((0, 1), (-2, -1)), kornia.color.raw.CFA.GB) |
|
grres = kornia.color.raw_to_rgb(data.roll((1, 0), (-2, -1)), kornia.color.raw.CFA.GR) |
|
rgres = kornia.color.raw_to_rgb(data.roll((1, 1), (-2, -1)), kornia.color.raw.CFA.RG) |
|
|
|
assert_close(bgres[:, :, 1:5, 1:5], gbres[:, :, 1:5, 2:6]) |
|
assert_close(bgres[:, :, 1:5, 1:5], grres[:, :, 2:6, 1:5]) |
|
assert_close(bgres[:, :, 1:5, 1:5], rgres[:, :, 2:6, 2:6]) |
|
|
|
@pytest.mark.grad |
|
def test_gradcheck(self, device, dtype): |
|
B, C, H, W = 2, 1, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True) |
|
assert gradcheck(kornia.color.raw_to_rgb, (img, kornia.color.raw.CFA.BG), raise_exception=True) |
|
|
|
@pytest.mark.jit |
|
def test_jit(self, device, dtype): |
|
if version.parse(torch.__version__) < version.parse('1.7.0'): |
|
warnings.warn( |
|
"This test is not compatible with pytorch < 1.7.0. This message will be removed as soon as we do not " |
|
"support pytorch 1.6.0. `rgb_to_hls()` method for pytorch < 1.7.0 version cannot be compiled with JIT.", |
|
DeprecationWarning, |
|
stacklevel=2, |
|
) |
|
return |
|
B, C, H, W = 2, 1, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=dtype) |
|
op = kornia.color.raw_to_rgb |
|
op_jit = torch.jit.script(op) |
|
assert_close(op(img, kornia.color.raw.CFA.BG), op_jit(img, kornia.color.raw.CFA.BG)) |
|
|
|
@pytest.mark.nn |
|
def test_module(self, device, dtype): |
|
B, C, H, W = 2, 1, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=dtype) |
|
raw_ops = kornia.color.RawToRgb(kornia.color.raw.CFA.BG).to(device, dtype) |
|
raw_fcn = kornia.color.raw_to_rgb |
|
assert_close(raw_ops(img), raw_fcn(img, kornia.color.raw.CFA.BG)) |
|
|
|
|
|
class TestRgbToRaw(BaseTester): |
|
def test_smoke(self, device, dtype): |
|
C, H, W = 3, 4, 6 |
|
img = torch.rand(C, H, W, device=device, dtype=dtype) |
|
assert isinstance(kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.BG), torch.Tensor) |
|
|
|
@pytest.mark.parametrize("batch_size, height, width", [(1, 3, 4), (2, 2, 4), (3, 4, 1)]) |
|
def test_cardinality(self, device, dtype, batch_size, height, width): |
|
img = torch.ones(batch_size, 3, height, width, device=device, dtype=dtype) |
|
assert kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.GR).shape == (batch_size, 1, height, width) |
|
|
|
def test_exception(self, device, dtype): |
|
with pytest.raises(TypeError): |
|
assert kornia.color.rgb_to_raw([0.0], kornia.color.raw.CFA.RG) |
|
|
|
with pytest.raises(ValueError): |
|
img = torch.ones(1, 1, device=device, dtype=dtype) |
|
assert kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.BG) |
|
|
|
|
|
|
|
@pytest.mark.grad |
|
def test_gradcheck(self, device, dtype): |
|
B, C, H, W = 2, 3, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True) |
|
assert gradcheck(kornia.color.rgb_to_raw, (img, kornia.color.raw.CFA.BG), raise_exception=True) |
|
|
|
@pytest.mark.jit |
|
def test_jit(self, device, dtype): |
|
if version.parse(torch.__version__) < version.parse('1.7.0'): |
|
warnings.warn( |
|
"This test is not compatible with pytorch < 1.7.0. This message will be removed as soon as we do not " |
|
"support pytorch 1.6.0. `rgb_to_hls()` method for pytorch < 1.7.0 version cannot be compiled with JIT.", |
|
DeprecationWarning, |
|
stacklevel=2, |
|
) |
|
return |
|
B, C, H, W = 2, 3, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=dtype) |
|
op = kornia.color.rgb_to_raw |
|
op_jit = torch.jit.script(op) |
|
assert_close(op(img, kornia.color.raw.CFA.BG), op_jit(img, kornia.color.raw.CFA.BG)) |
|
|
|
@pytest.mark.nn |
|
def test_module(self, device, dtype): |
|
B, C, H, W = 2, 3, 4, 4 |
|
img = torch.ones(B, C, H, W, device=device, dtype=dtype) |
|
raw_ops = kornia.color.RgbToRaw(kornia.color.raw.CFA.BG).to(device, dtype) |
|
raw_fcn = kornia.color.rgb_to_raw |
|
assert_close(raw_ops(img), raw_fcn(img, kornia.color.raw.CFA.BG)) |
|
|