File size: 1,153 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
from torch.autograd import gradcheck
import kornia
import kornia.testing as utils
class TestBatchedForward:
def test_runbatch(self, device):
patches = torch.rand(34, 1, 32, 32)
sift = kornia.feature.SIFTDescriptor(32)
desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32)
desc = sift(patches)
assert torch.allclose(desc, desc_batched)
def test_runone(self, device):
patches = torch.rand(16, 1, 32, 32)
sift = kornia.feature.SIFTDescriptor(32)
desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32)
desc = sift(patches)
assert torch.allclose(desc, desc_batched)
def test_gradcheck(self, device):
batch_size, channels, height, width = 3, 2, 5, 4
img = torch.rand(batch_size, channels, height, width, device=device)
img = utils.tensor_to_gradcheck_var(img) # to var
assert gradcheck(
kornia.utils.memory.batched_forward,
(kornia.feature.BlobHessian(), img, device, 2),
raise_exception=True,
nondet_tol=1e-4,
)
|