|
import torch |
|
from torch.autograd import gradcheck |
|
|
|
import kornia |
|
import kornia.testing as utils |
|
|
|
|
|
class TestBatchedForward: |
|
def test_runbatch(self, device): |
|
patches = torch.rand(34, 1, 32, 32) |
|
sift = kornia.feature.SIFTDescriptor(32) |
|
desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32) |
|
desc = sift(patches) |
|
assert torch.allclose(desc, desc_batched) |
|
|
|
def test_runone(self, device): |
|
patches = torch.rand(16, 1, 32, 32) |
|
sift = kornia.feature.SIFTDescriptor(32) |
|
desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32) |
|
desc = sift(patches) |
|
assert torch.allclose(desc, desc_batched) |
|
|
|
def test_gradcheck(self, device): |
|
batch_size, channels, height, width = 3, 2, 5, 4 |
|
img = torch.rand(batch_size, channels, height, width, device=device) |
|
img = utils.tensor_to_gradcheck_var(img) |
|
assert gradcheck( |
|
kornia.utils.memory.batched_forward, |
|
(kornia.feature.BlobHessian(), img, device, 2), |
|
raise_exception=True, |
|
nondet_tol=1e-4, |
|
) |
|
|