compvis / test /utils /test_one_hot.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
raw
history blame contribute delete
690 Bytes
import pytest
import torch
import kornia
class TestOneHot:
def test_smoke(self):
num_classes = 4
labels = torch.zeros(2, 2, 1, dtype=torch.int64)
labels[0, 0, 0] = 0
labels[0, 1, 0] = 1
labels[1, 0, 0] = 2
labels[1, 1, 0] = 3
# convert labels to one hot tensor
one_hot = kornia.utils.one_hot(labels, num_classes)
assert pytest.approx(one_hot[0, labels[0, 0, 0], 0, 0].item(), 1.0)
assert pytest.approx(one_hot[0, labels[0, 1, 0], 1, 0].item(), 1.0)
assert pytest.approx(one_hot[1, labels[1, 0, 0], 0, 0].item(), 1.0)
assert pytest.approx(one_hot[1, labels[1, 1, 0], 1, 0].item(), 1.0)