|
import math |
|
|
|
import pytest |
|
import torch |
|
from torch.autograd import gradcheck |
|
|
|
import kornia |
|
import kornia.testing as utils |
|
from kornia.testing import assert_close |
|
|
|
|
|
class TestBinaryFocalLossWithLogits: |
|
def test_smoke_none(self, device, dtype): |
|
logits = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
labels = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
|
|
assert kornia.losses.binary_focal_loss_with_logits( |
|
logits, labels, alpha=0.5, gamma=2.0, reduction="none" |
|
).shape == (2, 3, 2) |
|
|
|
def test_smoke_sum(self, device, dtype): |
|
logits = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
labels = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
|
|
assert ( |
|
kornia.losses.binary_focal_loss_with_logits(logits, labels, alpha=0.5, gamma=2.0, reduction="sum").shape |
|
== () |
|
) |
|
|
|
def test_smoke_mean(self, device, dtype): |
|
logits = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
labels = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
|
|
assert ( |
|
kornia.losses.binary_focal_loss_with_logits(logits, labels, alpha=0.5, gamma=2.0, reduction="mean").shape |
|
== () |
|
) |
|
|
|
def test_smoke_mean_flat(self, device, dtype): |
|
logits = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
labels = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
|
|
assert ( |
|
kornia.losses.binary_focal_loss_with_logits(logits, labels, alpha=0.5, gamma=2.0, reduction="mean").shape |
|
== () |
|
) |
|
|
|
def test_jit(self, device, dtype): |
|
logits = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
labels = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
|
|
op = kornia.losses.binary_focal_loss_with_logits |
|
op_script = torch.jit.script(op) |
|
actual = op_script(logits, labels, alpha=0.5, gamma=2.0, reduction="none") |
|
expected = op(logits, labels, alpha=0.5, gamma=2.0, reduction="none") |
|
assert_close(actual, expected) |
|
|
|
def test_gradcheck(self, device): |
|
alpha, gamma = 0.5, 2.0 |
|
logits = torch.rand(2, 3, 2).to(device) |
|
labels = torch.rand(2, 1, 3, 2) |
|
labels = labels.to(device).long() |
|
|
|
logits = utils.tensor_to_gradcheck_var(logits) |
|
assert gradcheck( |
|
kornia.losses.binary_focal_loss_with_logits, (logits, labels, alpha, gamma), raise_exception=True |
|
) |
|
|
|
def test_same_output(self, device, dtype): |
|
logits = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
labels = torch.rand(2, 3, 2, dtype=dtype, device=device) |
|
|
|
kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'} |
|
|
|
assert kornia.losses.binary_focal_loss_with_logits( |
|
logits, labels, **kwargs |
|
) == kornia.losses.BinaryFocalLossWithLogits(**kwargs)(logits, labels) |
|
|
|
|
|
class TestFocalLoss: |
|
def test_smoke_none(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
assert kornia.losses.focal_loss(logits, labels, alpha=0.5, gamma=2.0, reduction="none").shape == (2, 3, 2) |
|
|
|
def test_smoke_sum(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
assert kornia.losses.focal_loss(logits, labels, alpha=0.5, gamma=2.0, reduction="sum").shape == () |
|
|
|
def test_smoke_mean(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
assert kornia.losses.focal_loss(logits, labels, alpha=0.5, gamma=2.0, reduction="mean").shape == () |
|
|
|
def test_smoke_mean_flat(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, device=device, dtype=dtype) |
|
labels = torch.rand(2) * num_classes |
|
labels = labels.to(device).long() |
|
assert kornia.losses.focal_loss(logits, labels, alpha=0.5, gamma=2.0, reduction="mean").shape == () |
|
|
|
def test_gradcheck(self, device, dtype): |
|
num_classes = 3 |
|
alpha, gamma = 0.5, 2.0 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
logits = utils.tensor_to_gradcheck_var(logits) |
|
assert gradcheck(kornia.losses.focal_loss, (logits, labels, alpha, gamma), raise_exception=True) |
|
|
|
def test_jit(self, device, dtype): |
|
num_classes = 3 |
|
params = (0.5, 2.0) |
|
logits = torch.rand(2, num_classes, device=device, dtype=dtype) |
|
labels = torch.rand(2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
op = kornia.losses.focal_loss |
|
op_script = torch.jit.script(op) |
|
|
|
actual = op_script(logits, labels, *params) |
|
expected = op(logits, labels, *params) |
|
assert_close(actual, expected) |
|
|
|
def test_module(self, device, dtype): |
|
num_classes = 3 |
|
params = (0.5, 2.0) |
|
logits = torch.rand(2, num_classes, device=device, dtype=dtype) |
|
labels = torch.rand(2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
op = kornia.losses.focal_loss |
|
op_module = kornia.losses.FocalLoss(*params) |
|
|
|
actual = op_module(logits, labels) |
|
expected = op(logits, labels, *params) |
|
assert_close(actual, expected) |
|
|
|
|
|
class TestTverskyLoss: |
|
def test_smoke(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
criterion = kornia.losses.TverskyLoss(alpha=0.5, beta=0.5) |
|
assert criterion(logits, labels) is not None |
|
|
|
def test_all_zeros(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.zeros(2, num_classes, 1, 2, device=device, dtype=dtype) |
|
logits[:, 0] = 10.0 |
|
logits[:, 1] = 1.0 |
|
logits[:, 2] = 1.0 |
|
labels = torch.zeros(2, 1, 2, device=device, dtype=torch.int64) |
|
|
|
criterion = kornia.losses.TverskyLoss(alpha=0.5, beta=0.5) |
|
loss = criterion(logits, labels) |
|
assert_close(loss, torch.zeros_like(loss), atol=1e-3, rtol=1e-3) |
|
|
|
def test_gradcheck(self, device, dtype): |
|
num_classes = 3 |
|
alpha, beta = 0.5, 0.5 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
logits = utils.tensor_to_gradcheck_var(logits) |
|
assert gradcheck(kornia.losses.tversky_loss, (logits, labels, alpha, beta), raise_exception=True) |
|
|
|
def test_jit(self, device, dtype): |
|
num_classes = 3 |
|
params = (0.5, 0.05) |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
op = kornia.losses.tversky_loss |
|
op_script = torch.jit.script(op) |
|
|
|
actual = op_script(logits, labels, *params) |
|
expected = op(logits, labels, *params) |
|
assert_close(actual, expected) |
|
|
|
def test_module(self, device, dtype): |
|
num_classes = 3 |
|
params = (0.5, 0.5) |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
op = kornia.losses.tversky_loss |
|
op_module = kornia.losses.TverskyLoss(*params) |
|
|
|
actual = op_module(logits, labels) |
|
expected = op(logits, labels, *params) |
|
assert_close(actual, expected) |
|
|
|
|
|
class TestDiceLoss: |
|
def test_smoke(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
criterion = kornia.losses.DiceLoss() |
|
assert criterion(logits, labels) is not None |
|
|
|
def test_all_zeros(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.zeros(2, num_classes, 1, 2, device=device, dtype=dtype) |
|
logits[:, 0] = 10.0 |
|
logits[:, 1] = 1.0 |
|
logits[:, 2] = 1.0 |
|
labels = torch.zeros(2, 1, 2, device=device, dtype=torch.int64) |
|
|
|
criterion = kornia.losses.DiceLoss() |
|
loss = criterion(logits, labels) |
|
assert_close(loss, torch.zeros_like(loss), rtol=1e-3, atol=1e-3) |
|
|
|
def test_gradcheck(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 3, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
logits = utils.tensor_to_gradcheck_var(logits) |
|
assert gradcheck(kornia.losses.dice_loss, (logits, labels), raise_exception=True) |
|
|
|
def test_jit(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 1, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 1, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
op = kornia.losses.dice_loss |
|
op_script = torch.jit.script(op) |
|
|
|
assert_close(op(logits, labels), op_script(logits, labels)) |
|
|
|
def test_module(self, device, dtype): |
|
num_classes = 3 |
|
logits = torch.rand(2, num_classes, 1, 2, device=device, dtype=dtype) |
|
labels = torch.rand(2, 1, 2) * num_classes |
|
labels = labels.to(device).long() |
|
|
|
op = kornia.losses.dice_loss |
|
op_module = kornia.losses.DiceLoss() |
|
|
|
assert_close(op(logits, labels), op_module(logits, labels)) |
|
|
|
|
|
class TestDepthSmoothnessLoss: |
|
@pytest.mark.parametrize("data_shape", [(1, 1, 10, 16), (2, 4, 8, 15)]) |
|
def test_smoke(self, device, dtype, data_shape): |
|
image = torch.rand(data_shape, device=device, dtype=dtype) |
|
depth = torch.rand(data_shape, device=device, dtype=dtype) |
|
|
|
criterion = kornia.losses.InverseDepthSmoothnessLoss() |
|
assert criterion(depth, image) is not None |
|
|
|
def test_jit(self, device, dtype): |
|
image = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
depth = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
|
|
op = kornia.losses.inverse_depth_smoothness_loss |
|
op_script = torch.jit.script(op) |
|
|
|
assert_close(op(image, depth), op_script(image, depth)) |
|
|
|
def test_module(self, device, dtype): |
|
image = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
depth = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
|
|
op = kornia.losses.inverse_depth_smoothness_loss |
|
op_module = kornia.losses.InverseDepthSmoothnessLoss() |
|
|
|
assert_close(op(image, depth), op_module(image, depth)) |
|
|
|
def test_gradcheck(self, device, dtype): |
|
image = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
depth = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
depth = utils.tensor_to_gradcheck_var(depth) |
|
image = utils.tensor_to_gradcheck_var(image) |
|
assert gradcheck(kornia.losses.inverse_depth_smoothness_loss, (depth, image), raise_exception=True) |
|
|
|
|
|
class TestSSIMLoss: |
|
def test_ssim_equal_none(self, device, dtype): |
|
|
|
img1 = torch.rand(1, 1, 10, 16, device=device, dtype=dtype) |
|
img2 = torch.rand(1, 1, 10, 16, device=device, dtype=dtype) |
|
|
|
ssim1 = kornia.losses.ssim_loss(img1, img1, window_size=5, reduction="none") |
|
ssim2 = kornia.losses.ssim_loss(img2, img2, window_size=5, reduction="none") |
|
|
|
tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1, 1e-4) |
|
assert_close(ssim1, torch.zeros_like(img1), rtol=tol_val, atol=tol_val) |
|
assert_close(ssim2, torch.zeros_like(img2), rtol=tol_val, atol=tol_val) |
|
|
|
@pytest.mark.parametrize("window_size", [5, 11]) |
|
@pytest.mark.parametrize("reduction_type", ["mean", "sum"]) |
|
@pytest.mark.parametrize("batch_shape", [(1, 1, 10, 16), (2, 4, 8, 15)]) |
|
def test_ssim(self, device, dtype, batch_shape, window_size, reduction_type): |
|
if device.type == 'xla': |
|
pytest.skip("test highly unstable with tpu") |
|
|
|
|
|
img = torch.rand(batch_shape, device=device, dtype=dtype) |
|
|
|
loss = kornia.losses.ssim_loss(img, img, window_size, reduction=reduction_type) |
|
|
|
tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1, 1e-4) |
|
assert_close(loss.item(), 0.0, rtol=tol_val, atol=tol_val) |
|
|
|
def test_jit(self, device, dtype): |
|
img1 = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
img2 = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
|
|
args = (img1, img2, 5, 1.0, 1e-6, 'mean') |
|
|
|
op = kornia.losses.ssim_loss |
|
op_script = torch.jit.script(op) |
|
|
|
assert_close(op(*args), op_script(*args)) |
|
|
|
def test_module(self, device, dtype): |
|
img1 = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
img2 = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
|
|
args = (img1, img2, 5, 1.0, 1e-12, 'mean') |
|
|
|
op = kornia.losses.ssim_loss |
|
op_module = kornia.losses.SSIMLoss(*args[2:]) |
|
|
|
assert_close(op(*args), op_module(*args[:2])) |
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
|
window_size = 3 |
|
img1 = torch.rand(1, 1, 5, 4, device=device, dtype=dtype) |
|
img2 = torch.rand(1, 1, 5, 4, device=device, dtype=dtype) |
|
|
|
|
|
img1 = utils.tensor_to_gradcheck_var(img1) |
|
img2 = utils.tensor_to_gradcheck_var(img2) |
|
|
|
|
|
assert gradcheck(kornia.losses.ssim_loss, (img1, img2, window_size), raise_exception=True, nondet_tol=1e-8) |
|
|
|
|
|
class TestDivergenceLoss: |
|
@pytest.mark.parametrize( |
|
'input,target,expected', |
|
[ |
|
(torch.full((1, 1, 2, 4), 0.125), torch.full((1, 1, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.full((1, 7, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.zeros((1, 7, 2, 4)), 0.346574), |
|
(torch.zeros((1, 7, 2, 4)), torch.full((1, 7, 2, 4), 0.125), 0.346574), |
|
], |
|
) |
|
def test_js_div_loss_2d(self, device, dtype, input, target, expected): |
|
actual = kornia.losses.js_div_loss_2d(input.to(device, dtype), target.to(device, dtype)) |
|
assert_close(actual.item(), expected) |
|
|
|
@pytest.mark.parametrize( |
|
'input,target,expected', |
|
[ |
|
(torch.full((1, 1, 2, 4), 0.125), torch.full((1, 1, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.full((1, 7, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.zeros((1, 7, 2, 4)), 0.0), |
|
(torch.zeros((1, 7, 2, 4)), torch.full((1, 7, 2, 4), 0.125), math.inf), |
|
], |
|
) |
|
def test_kl_div_loss_2d(self, device, dtype, input, target, expected): |
|
actual = kornia.losses.kl_div_loss_2d(input.to(device, dtype), target.to(device, dtype)) |
|
assert_close(actual.item(), expected) |
|
|
|
@pytest.mark.parametrize( |
|
'input,target,expected', |
|
[ |
|
(torch.full((1, 1, 2, 4), 0.125), torch.full((1, 1, 2, 4), 0.125), torch.full((1, 1), 0.0)), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.full((1, 7, 2, 4), 0.125), torch.full((1, 7), 0.0)), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.zeros((1, 7, 2, 4)), torch.full((1, 7), 0.0)), |
|
(torch.zeros((1, 7, 2, 4)), torch.full((1, 7, 2, 4), 0.125), torch.full((1, 7), math.inf)), |
|
], |
|
) |
|
def test_kl_div_loss_2d_without_reduction(self, device, dtype, input, target, expected): |
|
actual = kornia.losses.kl_div_loss_2d(input.to(device, dtype), target.to(device, dtype), reduction='none') |
|
assert_close(actual, expected.to(device, dtype)) |
|
|
|
@pytest.mark.parametrize( |
|
'input,target,expected', |
|
[ |
|
(torch.full((1, 1, 2, 4), 0.125), torch.full((1, 1, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.full((1, 7, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.zeros((1, 7, 2, 4)), 0.0), |
|
(torch.zeros((1, 7, 2, 4)), torch.full((1, 7, 2, 4), 0.125), math.inf), |
|
], |
|
) |
|
def test_noncontiguous_kl(self, device, dtype, input, target, expected): |
|
input = input.to(device, dtype).view(input.shape[::-1]).T |
|
target = target.to(device, dtype).view(target.shape[::-1]).T |
|
actual = kornia.losses.kl_div_loss_2d(input, target).item() |
|
assert_close(actual, expected) |
|
|
|
@pytest.mark.parametrize( |
|
'input,target,expected', |
|
[ |
|
(torch.full((1, 1, 2, 4), 0.125), torch.full((1, 1, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.full((1, 7, 2, 4), 0.125), 0.0), |
|
(torch.full((1, 7, 2, 4), 0.125), torch.zeros((1, 7, 2, 4)), 0.346574), |
|
(torch.zeros((1, 7, 2, 4)), torch.full((1, 7, 2, 4), 0.125), 0.346574), |
|
], |
|
) |
|
def test_noncontiguous_js(self, device, dtype, input, target, expected): |
|
input = input.to(device, dtype).view(input.shape[::-1]).T |
|
target = target.to(device, dtype).view(target.shape[::-1]).T |
|
actual = kornia.losses.js_div_loss_2d(input, target).item() |
|
assert_close(actual, expected) |
|
|
|
def test_gradcheck_kl(self, device, dtype): |
|
input = torch.rand(1, 1, 10, 16, device=device, dtype=dtype) |
|
target = torch.rand(1, 1, 10, 16, device=device, dtype=dtype) |
|
|
|
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
target = utils.tensor_to_gradcheck_var(target) |
|
assert gradcheck(kornia.losses.kl_div_loss_2d, (input, target), raise_exception=True) |
|
|
|
def test_gradcheck_js(self, device, dtype): |
|
input = torch.rand(1, 1, 10, 16, device=device, dtype=dtype) |
|
target = torch.rand(1, 1, 10, 16, device=device, dtype=dtype) |
|
|
|
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
target = utils.tensor_to_gradcheck_var(target) |
|
assert gradcheck(kornia.losses.js_div_loss_2d, (input, target), raise_exception=True) |
|
|
|
def test_jit_kl(self, device, dtype): |
|
input = torch.randn((2, 4, 10, 16), dtype=dtype, device=device) |
|
target = torch.randn((2, 4, 10, 16), dtype=dtype, device=device) |
|
args = (input, target) |
|
op = kornia.losses.kl_div_loss_2d |
|
op_jit = torch.jit.script(op) |
|
assert_close(op(*args), op_jit(*args), rtol=0, atol=1e-5) |
|
|
|
def test_jit_js(self, device, dtype): |
|
input = torch.randn((2, 4, 10, 16), dtype=dtype, device=device) |
|
target = torch.randn((2, 4, 10, 16), dtype=dtype, device=device) |
|
args = (input, target) |
|
op = kornia.losses.js_div_loss_2d |
|
op_jit = torch.jit.script(op) |
|
assert_close(op(*args), op_jit(*args), rtol=0, atol=1e-5) |
|
|
|
|
|
class TestTotalVariation: |
|
|
|
@pytest.mark.parametrize( |
|
'input, expected', [(torch.ones(3, 4, 5), torch.zeros(())), (2 * torch.ones(2, 3, 4, 5), torch.zeros(2))] |
|
) |
|
def test_tv_on_constant(self, device, dtype, input, expected): |
|
actual = kornia.losses.total_variation(input.to(device, dtype)) |
|
assert_close(actual, expected.to(device, dtype)) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
'input, expected', |
|
[ |
|
( |
|
torch.tensor( |
|
[ |
|
[ |
|
[0.11747694, 0.5717714, 0.89223915, 0.2929412, 0.63556224], |
|
[0.5371079, 0.13416398, 0.7782737, 0.21392655, 0.1757018], |
|
[0.62360305, 0.8563448, 0.25304103, 0.68539226, 0.6956515], |
|
[0.9350611, 0.01694632, 0.78724295, 0.4760313, 0.73099905], |
|
], |
|
[ |
|
[0.4788819, 0.45253807, 0.932798, 0.5721999, 0.7612051], |
|
[0.5455887, 0.8836531, 0.79551977, 0.6677338, 0.74293613], |
|
[0.4830376, 0.16420758, 0.15784949, 0.21445751, 0.34168917], |
|
[0.8675162, 0.5468113, 0.6117004, 0.01305223, 0.17554593], |
|
], |
|
[ |
|
[0.6423703, 0.5561105, 0.54304767, 0.20339686, 0.8553698], |
|
[0.98024786, 0.31562763, 0.10122144, 0.17686582, 0.26260805], |
|
[0.20522952, 0.14523649, 0.8601968, 0.02593213, 0.7382898], |
|
[0.71935296, 0.9625162, 0.42287344, 0.07979459, 0.9149871], |
|
], |
|
] |
|
), |
|
torch.tensor(33.001236), |
|
), |
|
( |
|
torch.tensor([[[0.09094203, 0.32630223, 0.8066123], [0.10921168, 0.09534764, 0.48588026]]]), |
|
torch.tensor(1.6900232), |
|
), |
|
], |
|
) |
|
def test_tv_on_3d(self, device, dtype, input, expected): |
|
actual = kornia.losses.total_variation(input.to(device, dtype)) |
|
assert_close(actual, expected.to(device, dtype)) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
'input, expected', |
|
[ |
|
( |
|
torch.tensor( |
|
[ |
|
[ |
|
[[0.8756, 0.0920], [0.8034, 0.3107]], |
|
[[0.3069, 0.2981], [0.9399, 0.7944]], |
|
[[0.6269, 0.1494], [0.2493, 0.8490]], |
|
], |
|
[ |
|
[[0.3256, 0.9923], [0.2856, 0.9104]], |
|
[[0.4107, 0.4387], [0.2742, 0.0095]], |
|
[[0.7064, 0.3674], [0.6139, 0.2487]], |
|
], |
|
] |
|
), |
|
torch.tensor([5.0054283, 3.1870906]), |
|
), |
|
( |
|
torch.tensor( |
|
[ |
|
[[[0.1104, 0.2284, 0.4371], [0.4569, 0.1906, 0.8035]]], |
|
[[[0.0552, 0.6831, 0.8310], [0.3589, 0.5044, 0.0802]]], |
|
[[[0.5078, 0.5703, 0.9110], [0.4765, 0.8401, 0.2754]]], |
|
] |
|
), |
|
torch.tensor([1.9565653, 2.5786452, 2.2681699]), |
|
), |
|
], |
|
) |
|
def test_tv_on_4d(self, device, dtype, input, expected): |
|
actual = kornia.losses.total_variation(input.to(device, dtype)) |
|
assert_close(actual, expected.to(device, dtype), rtol=1e-4, atol=1e-4) |
|
|
|
|
|
@pytest.mark.parametrize('input', [torch.rand(2, 3, 4, 5, 3), torch.rand(3, 1)]) |
|
def test_tv_on_invalid_dims(self, device, dtype, input): |
|
with pytest.raises(ValueError): |
|
kornia.losses.total_variation(input.to(device, dtype)) |
|
|
|
|
|
@pytest.mark.parametrize('input', [1, [1, 2]]) |
|
def test_tv_on_invalid_types(self, device, dtype, input): |
|
with pytest.raises(TypeError): |
|
kornia.losses.total_variation(input) |
|
|
|
def test_jit(self, device, dtype): |
|
image = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
|
|
op = kornia.losses.total_variation |
|
op_script = torch.jit.script(op) |
|
|
|
assert_close(op(image), op_script(image)) |
|
|
|
def test_module(self, device, dtype): |
|
image = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
|
|
op = kornia.losses.total_variation |
|
op_module = kornia.losses.TotalVariation() |
|
|
|
assert_close(op(image), op_module(image)) |
|
|
|
def test_gradcheck(self, device, dtype): |
|
image = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
image = utils.tensor_to_gradcheck_var(image) |
|
assert gradcheck(kornia.losses.total_variation, (image,), raise_exception=True) |
|
|
|
|
|
class TestPSNRLoss: |
|
def test_smoke(self, device, dtype): |
|
input = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
target = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
|
|
criterion = kornia.losses.PSNRLoss(1.0) |
|
loss = criterion(input, target) |
|
|
|
assert loss is not None |
|
|
|
def test_type(self, device, dtype): |
|
|
|
|
|
criterion = kornia.losses.PSNRLoss(1.0).to(device, dtype) |
|
with pytest.raises(Exception): |
|
criterion(1, 2) |
|
|
|
def test_shape(self, device, dtype): |
|
|
|
|
|
criterion = kornia.losses.PSNRLoss(1.0).to(device, dtype) |
|
with pytest.raises(Exception): |
|
criterion(torch.rand(2, 3, 3, 2), torch.rand(2, 3, 3)) |
|
|
|
def test_loss(self, device, dtype): |
|
input = torch.ones(1, device=device, dtype=dtype) |
|
expected = torch.tensor(-20.0, device=device, dtype=dtype) |
|
actual = kornia.losses.psnr_loss(input, 1.2 * input, 2.0) |
|
assert_close(actual, expected) |
|
|
|
def test_jit(self, device, dtype): |
|
input = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
target = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
|
|
args = (input, target, 1.0) |
|
|
|
op = kornia.losses.psnr_loss |
|
op_script = torch.jit.script(op) |
|
|
|
assert_close(op(*args), op_script(*args)) |
|
|
|
def test_module(self, device, dtype): |
|
input = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
target = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
|
|
args = (input, target, 1.0) |
|
|
|
op = kornia.losses.psnr_loss |
|
op_module = kornia.losses.PSNRLoss(1.0) |
|
|
|
assert_close(op(*args), op_module(input, target)) |
|
|
|
def test_gradcheck(self, device, dtype): |
|
input = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
target = torch.rand(2, 3, 3, 2, device=device, dtype=dtype) |
|
input = utils.tensor_to_gradcheck_var(input) |
|
target = utils.tensor_to_gradcheck_var(target) |
|
assert gradcheck(kornia.losses.psnr_loss, (input, target, 1.0), raise_exception=True) |
|
|