compvis / test /utils /test_memory.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import torch
from torch.autograd import gradcheck
import kornia
import kornia.testing as utils
class TestBatchedForward:
def test_runbatch(self, device):
patches = torch.rand(34, 1, 32, 32)
sift = kornia.feature.SIFTDescriptor(32)
desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32)
desc = sift(patches)
assert torch.allclose(desc, desc_batched)
def test_runone(self, device):
patches = torch.rand(16, 1, 32, 32)
sift = kornia.feature.SIFTDescriptor(32)
desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32)
desc = sift(patches)
assert torch.allclose(desc, desc_batched)
def test_gradcheck(self, device):
batch_size, channels, height, width = 3, 2, 5, 4
img = torch.rand(batch_size, channels, height, width, device=device)
img = utils.tensor_to_gradcheck_var(img) # to var
assert gradcheck(
kornia.utils.memory.batched_forward,
(kornia.feature.BlobHessian(), img, device, 2),
raise_exception=True,
nondet_tol=1e-4,
)