File size: 690 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import pytest
import torch
import kornia
class TestOneHot:
def test_smoke(self):
num_classes = 4
labels = torch.zeros(2, 2, 1, dtype=torch.int64)
labels[0, 0, 0] = 0
labels[0, 1, 0] = 1
labels[1, 0, 0] = 2
labels[1, 1, 0] = 3
# convert labels to one hot tensor
one_hot = kornia.utils.one_hot(labels, num_classes)
assert pytest.approx(one_hot[0, labels[0, 0, 0], 0, 0].item(), 1.0)
assert pytest.approx(one_hot[0, labels[0, 1, 0], 1, 0].item(), 1.0)
assert pytest.approx(one_hot[1, labels[1, 0, 0], 0, 0].item(), 1.0)
assert pytest.approx(one_hot[1, labels[1, 1, 0], 1, 0].item(), 1.0)
|