import pytest | |
import torch | |
import kornia | |
class TestOneHot: | |
def test_smoke(self): | |
num_classes = 4 | |
labels = torch.zeros(2, 2, 1, dtype=torch.int64) | |
labels[0, 0, 0] = 0 | |
labels[0, 1, 0] = 1 | |
labels[1, 0, 0] = 2 | |
labels[1, 1, 0] = 3 | |
# convert labels to one hot tensor | |
one_hot = kornia.utils.one_hot(labels, num_classes) | |
assert pytest.approx(one_hot[0, labels[0, 0, 0], 0, 0].item(), 1.0) | |
assert pytest.approx(one_hot[0, labels[0, 1, 0], 1, 0].item(), 1.0) | |
assert pytest.approx(one_hot[1, labels[1, 0, 0], 0, 0].item(), 1.0) | |
assert pytest.approx(one_hot[1, labels[1, 1, 0], 1, 0].item(), 1.0) | |