compvis / test /color /test_raw.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import warnings
import pytest
import torch
from torch.autograd import gradcheck
import kornia
from kornia.testing import BaseTester # test utils
from kornia.testing import assert_close
from packaging import version
class TestRawToRgb(BaseTester):
def test_smoke(self, device, dtype):
C, H, W = 1, 4, 6
img = torch.rand(C, H, W, device=device, dtype=dtype)
assert isinstance(kornia.color.raw_to_rgb(img, kornia.color.CFA.BG), torch.Tensor)
@pytest.mark.parametrize("batch_size, height, width", [(1, 6, 4), (2, 2, 4), (3, 4, 2)])
def test_cardinality(self, device, dtype, batch_size, height, width):
img = torch.ones(batch_size, 1, height, width, device=device, dtype=dtype)
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.BG).shape == (batch_size, 3, height, width)
def test_exception(self, device, dtype):
with pytest.raises(TypeError):
assert kornia.color.raw_to_rgb([0.0], kornia.color.CFA.BG)
with pytest.raises(ValueError):
img = torch.ones(1, 1, device=device, dtype=dtype)
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GB)
with pytest.raises(ValueError):
img = torch.ones(2, 1, 1, device=device, dtype=dtype)
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.RG)
with pytest.raises(ValueError):
img = torch.ones(1, 3, 1, 1, device=device, dtype=dtype)
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR)
# dimensionality test
with pytest.raises(ValueError):
img = torch.ones(3, 2, 1, device=device, dtype=dtype)
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR)
# dimensionality test
with pytest.raises(ValueError):
img = torch.ones(3, 1, 2, device=device, dtype=dtype)
assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR)
# With he current implementations we should get back an identical raw representation when doing raw -> rgb -> raw
# Note that with more advanced implementations this may not necessarily be true or desirable
def test_forth_and_back(self, device, dtype): # skipcq: PYL-R0201
data = torch.rand(1, 80, 80, device=device, dtype=dtype)
raw = kornia.color.rgb_to_raw
rgb = kornia.color.raw_to_rgb
for x in kornia.color.CFA:
data_out = raw(rgb(data, cfa=x), cfa=x)
assert_close(data_out, data)
# make sure different cfas are actually different
def test_cfas_not_the_same(self, device, dtype): # skipcq: PYL-R0201
data = torch.rand(1, 16, 16, device=device, dtype=dtype)
assert (
torch.max(
kornia.color.raw_to_rgb(data, kornia.color.CFA.BG) - kornia.color.raw_to_rgb(data, kornia.color.CFA.RG)
)
> 0.0
)
# The outcome will be very different for different implementations
# Here we compare against a current baseline, it is safe to update this if the underlying algorithm changes
def test_functional(self, device, dtype): # skipcq: PYL-R0201
data = torch.tensor(
[[[1, 0.5, 0.2, 0.4], [0.75, 0.25, 0.8, 0.3], [0.65, 0.15, 0.7, 0.2], [0.55, 0.5, 0.6, 0.1]]],
device=device,
dtype=dtype,
)
# checked by hand as correct interpolation. Note the ugly replication that happens for Red on the last column
# and row. We shall accept to live with that
expected = torch.tensor(
[
[
[1.0000, 0.6000, 0.2000, 0.2000],
[0.8250, 0.6375, 0.4500, 0.4500],
[0.6500, 0.6750, 0.7000, 0.7000],
[0.6500, 0.6750, 0.7000, 0.7000],
],
[
[0.6250, 0.5000, 0.6250, 0.4000],
[0.7500, 0.5500, 0.8000, 0.5500],
[0.4000, 0.1500, 0.4375, 0.2000],
[0.5500, 0.3625, 0.6000, 0.4000],
],
[
[0.2500, 0.2500, 0.2750, 0.3000],
[0.2500, 0.2500, 0.2750, 0.3000],
[0.3750, 0.3750, 0.2875, 0.2000],
[0.5000, 0.5000, 0.3000, 0.1000],
],
],
device=device,
dtype=dtype,
)
img_rgb = kornia.color.raw_to_rgb(data, kornia.color.raw.CFA.BG)
assert_close(img_rgb, expected)
# If we roll the data and the different CFAs they give the same result (expect on edges!)
def test_cfa_on_rolled(self, device, dtype): # skipcq: PYL-R0201
data = torch.rand(1, 1, 8, 8, device=device, dtype=dtype)
bgres = kornia.color.raw_to_rgb(data, kornia.color.raw.CFA.BG)
gbres = kornia.color.raw_to_rgb(data.roll((0, 1), (-2, -1)), kornia.color.raw.CFA.GB)
grres = kornia.color.raw_to_rgb(data.roll((1, 0), (-2, -1)), kornia.color.raw.CFA.GR)
rgres = kornia.color.raw_to_rgb(data.roll((1, 1), (-2, -1)), kornia.color.raw.CFA.RG)
assert_close(bgres[:, :, 1:5, 1:5], gbres[:, :, 1:5, 2:6])
assert_close(bgres[:, :, 1:5, 1:5], grres[:, :, 2:6, 1:5])
assert_close(bgres[:, :, 1:5, 1:5], rgres[:, :, 2:6, 2:6])
@pytest.mark.grad
def test_gradcheck(self, device, dtype):
B, C, H, W = 2, 1, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True)
assert gradcheck(kornia.color.raw_to_rgb, (img, kornia.color.raw.CFA.BG), raise_exception=True)
@pytest.mark.jit
def test_jit(self, device, dtype):
if version.parse(torch.__version__) < version.parse('1.7.0'):
warnings.warn(
"This test is not compatible with pytorch < 1.7.0. This message will be removed as soon as we do not "
"support pytorch 1.6.0. `rgb_to_hls()` method for pytorch < 1.7.0 version cannot be compiled with JIT.",
DeprecationWarning,
stacklevel=2,
)
return
B, C, H, W = 2, 1, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=dtype)
op = kornia.color.raw_to_rgb
op_jit = torch.jit.script(op)
assert_close(op(img, kornia.color.raw.CFA.BG), op_jit(img, kornia.color.raw.CFA.BG))
@pytest.mark.nn
def test_module(self, device, dtype):
B, C, H, W = 2, 1, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=dtype)
raw_ops = kornia.color.RawToRgb(kornia.color.raw.CFA.BG).to(device, dtype)
raw_fcn = kornia.color.raw_to_rgb
assert_close(raw_ops(img), raw_fcn(img, kornia.color.raw.CFA.BG))
class TestRgbToRaw(BaseTester):
def test_smoke(self, device, dtype):
C, H, W = 3, 4, 6
img = torch.rand(C, H, W, device=device, dtype=dtype)
assert isinstance(kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.BG), torch.Tensor)
@pytest.mark.parametrize("batch_size, height, width", [(1, 3, 4), (2, 2, 4), (3, 4, 1)])
def test_cardinality(self, device, dtype, batch_size, height, width):
img = torch.ones(batch_size, 3, height, width, device=device, dtype=dtype)
assert kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.GR).shape == (batch_size, 1, height, width)
def test_exception(self, device, dtype):
with pytest.raises(TypeError):
assert kornia.color.rgb_to_raw([0.0], kornia.color.raw.CFA.RG)
with pytest.raises(ValueError):
img = torch.ones(1, 1, device=device, dtype=dtype)
assert kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.BG)
# Reverse test in rawtorgb is sufficient functional test
@pytest.mark.grad
def test_gradcheck(self, device, dtype):
B, C, H, W = 2, 3, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True)
assert gradcheck(kornia.color.rgb_to_raw, (img, kornia.color.raw.CFA.BG), raise_exception=True)
@pytest.mark.jit
def test_jit(self, device, dtype):
if version.parse(torch.__version__) < version.parse('1.7.0'):
warnings.warn(
"This test is not compatible with pytorch < 1.7.0. This message will be removed as soon as we do not "
"support pytorch 1.6.0. `rgb_to_hls()` method for pytorch < 1.7.0 version cannot be compiled with JIT.",
DeprecationWarning,
stacklevel=2,
)
return
B, C, H, W = 2, 3, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=dtype)
op = kornia.color.rgb_to_raw
op_jit = torch.jit.script(op)
assert_close(op(img, kornia.color.raw.CFA.BG), op_jit(img, kornia.color.raw.CFA.BG))
@pytest.mark.nn
def test_module(self, device, dtype):
B, C, H, W = 2, 3, 4, 4
img = torch.ones(B, C, H, W, device=device, dtype=dtype)
raw_ops = kornia.color.RgbToRaw(kornia.color.raw.CFA.BG).to(device, dtype)
raw_fcn = kornia.color.rgb_to_raw
assert_close(raw_ops(img), raw_fcn(img, kornia.color.raw.CFA.BG))