BrainFM / utils /interpol /tests /test_gradcheck_pushpull.py
peirong26's picture
Upload 187 files
2571f24 verified
import torch
from torch.autograd import gradcheck
from interpol import grid_pull, grid_push, grid_count, grid_grad, add_identity_grid_
import pytest
import inspect
# global parameters
dtype = torch.double # data type (double advised to check gradients)
shape1 = 3 # size along each dimension
extrapolate = True
if hasattr(torch, 'use_deterministic_algorithms'):
torch.use_deterministic_algorithms(True)
kwargs = dict(rtol=1., raise_exception=True)
if 'check_undefined_grad' in inspect.signature(gradcheck).parameters:
kwargs['check_undefined_grad'] = False
if 'nondet_tol' in inspect.signature(gradcheck).parameters:
kwargs['nondet_tol'] = 1e-3
# parameters
devices = [('cpu', 1)]
if torch.backends.openmp.is_available() or torch.backends.mkl.is_available():
print('parallel backend available')
devices.append(('cpu', 10))
if torch.cuda.is_available():
print('cuda backend available')
devices.append('cuda')
dims = [1, 2, 3]
bounds = list(range(7))
order_bounds = []
for o in range(3):
for b in bounds:
order_bounds += [(o, b)]
for o in range(3, 8):
order_bounds += [(o, 3)] # only test dc2 for order > 2
def make_data(shape, device, dtype):
grid = torch.randn([2, *shape, len(shape)], device=device, dtype=dtype)
grid = add_identity_grid_(grid)
vol = torch.randn((2, 1,) + shape, device=device, dtype=dtype)
return vol, grid
def init_device(device):
if isinstance(device, (list, tuple)):
device, param = device
else:
param = 1 if device == 'cpu' else 0
if device == 'cuda':
torch.cuda.set_device(param)
torch.cuda.init()
try:
torch.cuda.empty_cache()
except RuntimeError:
pass
device = '{}:{}'.format(device, param)
else:
assert device == 'cpu'
torch.set_num_threads(param)
return torch.device(device)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_grad(device, dim, bound, interpolation):
print(f'grad_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
vol, grid = make_data(shape, device, dtype)
vol.requires_grad = True
grid.requires_grad = True
assert gradcheck(grid_grad, (vol, grid, interpolation, bound, extrapolate),
**kwargs)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_pull(device, dim, bound, interpolation):
print(f'pull_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
vol, grid = make_data(shape, device, dtype)
vol.requires_grad = True
grid.requires_grad = True
assert gradcheck(grid_pull, (vol, grid, interpolation, bound, extrapolate),
**kwargs)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_push(device, dim, bound, interpolation):
print(f'push_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
vol, grid = make_data(shape, device, dtype)
vol.requires_grad = True
grid.requires_grad = True
assert gradcheck(grid_push, (vol, grid, shape, interpolation, bound, extrapolate),
**kwargs)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_count(device, dim, bound, interpolation):
print(f'count_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
_, grid = make_data(shape, device, dtype)
grid.requires_grad = True
assert gradcheck(grid_count, (grid, shape, interpolation, bound, extrapolate),
**kwargs)