File size: 9,137 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import warnings

import pytest
import torch
from torch.autograd import gradcheck

import kornia
from kornia.testing import BaseTester  # test utils
from kornia.testing import assert_close
from packaging import version


class TestRawToRgb(BaseTester):
    def test_smoke(self, device, dtype):
        C, H, W = 1, 4, 6
        img = torch.rand(C, H, W, device=device, dtype=dtype)
        assert isinstance(kornia.color.raw_to_rgb(img, kornia.color.CFA.BG), torch.Tensor)

    @pytest.mark.parametrize("batch_size, height, width", [(1, 6, 4), (2, 2, 4), (3, 4, 2)])
    def test_cardinality(self, device, dtype, batch_size, height, width):
        img = torch.ones(batch_size, 1, height, width, device=device, dtype=dtype)
        assert kornia.color.raw_to_rgb(img, kornia.color.CFA.BG).shape == (batch_size, 3, height, width)

    def test_exception(self, device, dtype):
        with pytest.raises(TypeError):
            assert kornia.color.raw_to_rgb([0.0], kornia.color.CFA.BG)

        with pytest.raises(ValueError):
            img = torch.ones(1, 1, device=device, dtype=dtype)
            assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GB)

        with pytest.raises(ValueError):
            img = torch.ones(2, 1, 1, device=device, dtype=dtype)
            assert kornia.color.raw_to_rgb(img, kornia.color.CFA.RG)

        with pytest.raises(ValueError):
            img = torch.ones(1, 3, 1, 1, device=device, dtype=dtype)
            assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR)

        # dimensionality test
        with pytest.raises(ValueError):
            img = torch.ones(3, 2, 1, device=device, dtype=dtype)
            assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR)

        # dimensionality test
        with pytest.raises(ValueError):
            img = torch.ones(3, 1, 2, device=device, dtype=dtype)
            assert kornia.color.raw_to_rgb(img, kornia.color.CFA.GR)

    # With he current implementations we should get back an identical raw representation when doing raw -> rgb -> raw
    # Note that with more advanced implementations this may not necessarily be true or desirable
    def test_forth_and_back(self, device, dtype):  # skipcq: PYL-R0201
        data = torch.rand(1, 80, 80, device=device, dtype=dtype)
        raw = kornia.color.rgb_to_raw
        rgb = kornia.color.raw_to_rgb

        for x in kornia.color.CFA:
            data_out = raw(rgb(data, cfa=x), cfa=x)
            assert_close(data_out, data)

    # make sure different cfas are actually different
    def test_cfas_not_the_same(self, device, dtype):  # skipcq: PYL-R0201
        data = torch.rand(1, 16, 16, device=device, dtype=dtype)
        assert (
            torch.max(
                kornia.color.raw_to_rgb(data, kornia.color.CFA.BG) - kornia.color.raw_to_rgb(data, kornia.color.CFA.RG)
            )
            > 0.0
        )

    # The outcome will be very different for different implementations
    # Here we compare against a current baseline, it is safe to update this if the underlying algorithm changes
    def test_functional(self, device, dtype):  # skipcq: PYL-R0201
        data = torch.tensor(
            [[[1, 0.5, 0.2, 0.4], [0.75, 0.25, 0.8, 0.3], [0.65, 0.15, 0.7, 0.2], [0.55, 0.5, 0.6, 0.1]]],
            device=device,
            dtype=dtype,
        )
        # checked by hand as correct interpolation. Note the ugly replication that happens for Red on the last column
        # and row. We shall accept to live with that
        expected = torch.tensor(
            [
                [
                    [1.0000, 0.6000, 0.2000, 0.2000],
                    [0.8250, 0.6375, 0.4500, 0.4500],
                    [0.6500, 0.6750, 0.7000, 0.7000],
                    [0.6500, 0.6750, 0.7000, 0.7000],
                ],
                [
                    [0.6250, 0.5000, 0.6250, 0.4000],
                    [0.7500, 0.5500, 0.8000, 0.5500],
                    [0.4000, 0.1500, 0.4375, 0.2000],
                    [0.5500, 0.3625, 0.6000, 0.4000],
                ],
                [
                    [0.2500, 0.2500, 0.2750, 0.3000],
                    [0.2500, 0.2500, 0.2750, 0.3000],
                    [0.3750, 0.3750, 0.2875, 0.2000],
                    [0.5000, 0.5000, 0.3000, 0.1000],
                ],
            ],
            device=device,
            dtype=dtype,
        )

        img_rgb = kornia.color.raw_to_rgb(data, kornia.color.raw.CFA.BG)
        assert_close(img_rgb, expected)

    # If we roll the data and the different CFAs they give the same result (expect on edges!)
    def test_cfa_on_rolled(self, device, dtype):  # skipcq: PYL-R0201
        data = torch.rand(1, 1, 8, 8, device=device, dtype=dtype)
        bgres = kornia.color.raw_to_rgb(data, kornia.color.raw.CFA.BG)
        gbres = kornia.color.raw_to_rgb(data.roll((0, 1), (-2, -1)), kornia.color.raw.CFA.GB)
        grres = kornia.color.raw_to_rgb(data.roll((1, 0), (-2, -1)), kornia.color.raw.CFA.GR)
        rgres = kornia.color.raw_to_rgb(data.roll((1, 1), (-2, -1)), kornia.color.raw.CFA.RG)

        assert_close(bgres[:, :, 1:5, 1:5], gbres[:, :, 1:5, 2:6])
        assert_close(bgres[:, :, 1:5, 1:5], grres[:, :, 2:6, 1:5])
        assert_close(bgres[:, :, 1:5, 1:5], rgres[:, :, 2:6, 2:6])

    @pytest.mark.grad
    def test_gradcheck(self, device, dtype):
        B, C, H, W = 2, 1, 4, 4
        img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True)
        assert gradcheck(kornia.color.raw_to_rgb, (img, kornia.color.raw.CFA.BG), raise_exception=True)

    @pytest.mark.jit
    def test_jit(self, device, dtype):
        if version.parse(torch.__version__) < version.parse('1.7.0'):
            warnings.warn(
                "This test is not compatible with pytorch < 1.7.0. This message will be removed as soon as we do not "
                "support pytorch 1.6.0. `rgb_to_hls()` method for pytorch < 1.7.0 version cannot be compiled with JIT.",
                DeprecationWarning,
                stacklevel=2,
            )
            return
        B, C, H, W = 2, 1, 4, 4
        img = torch.ones(B, C, H, W, device=device, dtype=dtype)
        op = kornia.color.raw_to_rgb
        op_jit = torch.jit.script(op)
        assert_close(op(img, kornia.color.raw.CFA.BG), op_jit(img, kornia.color.raw.CFA.BG))

    @pytest.mark.nn
    def test_module(self, device, dtype):
        B, C, H, W = 2, 1, 4, 4
        img = torch.ones(B, C, H, W, device=device, dtype=dtype)
        raw_ops = kornia.color.RawToRgb(kornia.color.raw.CFA.BG).to(device, dtype)
        raw_fcn = kornia.color.raw_to_rgb
        assert_close(raw_ops(img), raw_fcn(img, kornia.color.raw.CFA.BG))


class TestRgbToRaw(BaseTester):
    def test_smoke(self, device, dtype):
        C, H, W = 3, 4, 6
        img = torch.rand(C, H, W, device=device, dtype=dtype)
        assert isinstance(kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.BG), torch.Tensor)

    @pytest.mark.parametrize("batch_size, height, width", [(1, 3, 4), (2, 2, 4), (3, 4, 1)])
    def test_cardinality(self, device, dtype, batch_size, height, width):
        img = torch.ones(batch_size, 3, height, width, device=device, dtype=dtype)
        assert kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.GR).shape == (batch_size, 1, height, width)

    def test_exception(self, device, dtype):
        with pytest.raises(TypeError):
            assert kornia.color.rgb_to_raw([0.0], kornia.color.raw.CFA.RG)

        with pytest.raises(ValueError):
            img = torch.ones(1, 1, device=device, dtype=dtype)
            assert kornia.color.rgb_to_raw(img, kornia.color.raw.CFA.BG)

        # Reverse test in rawtorgb is sufficient functional test

    @pytest.mark.grad
    def test_gradcheck(self, device, dtype):
        B, C, H, W = 2, 3, 4, 4
        img = torch.ones(B, C, H, W, device=device, dtype=torch.float64, requires_grad=True)
        assert gradcheck(kornia.color.rgb_to_raw, (img, kornia.color.raw.CFA.BG), raise_exception=True)

    @pytest.mark.jit
    def test_jit(self, device, dtype):
        if version.parse(torch.__version__) < version.parse('1.7.0'):
            warnings.warn(
                "This test is not compatible with pytorch < 1.7.0. This message will be removed as soon as we do not "
                "support pytorch 1.6.0. `rgb_to_hls()` method for pytorch < 1.7.0 version cannot be compiled with JIT.",
                DeprecationWarning,
                stacklevel=2,
            )
            return
        B, C, H, W = 2, 3, 4, 4
        img = torch.ones(B, C, H, W, device=device, dtype=dtype)
        op = kornia.color.rgb_to_raw
        op_jit = torch.jit.script(op)
        assert_close(op(img, kornia.color.raw.CFA.BG), op_jit(img, kornia.color.raw.CFA.BG))

    @pytest.mark.nn
    def test_module(self, device, dtype):
        B, C, H, W = 2, 3, 4, 4
        img = torch.ones(B, C, H, W, device=device, dtype=dtype)
        raw_ops = kornia.color.RgbToRaw(kornia.color.raw.CFA.BG).to(device, dtype)
        raw_fcn = kornia.color.rgb_to_raw
        assert_close(raw_ops(img), raw_fcn(img, kornia.color.raw.CFA.BG))