Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
# torch.set_printoptions(precision=1, threshold=10000) | |
from torch.autograd import gradcheck | |
from spatial_correlation_sampler import SpatialCorrelationSampler | |
parser = argparse.ArgumentParser() | |
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda') | |
parser.add_argument('-b', '--batch-size', type=int, default=2) | |
parser.add_argument('-k', '--kernel-size', type=int, default=3) | |
parser.add_argument('--patch', type=int, default=3) | |
parser.add_argument('--patch_dilation', type=int, default=2) | |
parser.add_argument('-c', '--channel', type=int, default=2) | |
parser.add_argument('--height', type=int, default=10) | |
parser.add_argument('-w', '--width', type=int, default=10) | |
parser.add_argument('-s', '--stride', type=int, default=2) | |
parser.add_argument('-p', '--pad', type=int, default=1) | |
parser.add_argument('-d', '--dilation', type=int, default=2) | |
args = parser.parse_args() | |
input1 = torch.randn(args.batch_size, | |
args.channel, | |
args.height, | |
args.width, | |
dtype=torch.float64, | |
device=torch.device(args.backend)) | |
input2 = torch.randn(args.batch_size, | |
args.channel, | |
args.height, | |
args.width, | |
dtype=torch.float64, | |
device=torch.device(args.backend)) | |
input1.requires_grad = True | |
input2.requires_grad = True | |
correlation_sampler = SpatialCorrelationSampler(args.kernel_size, | |
args.patch, | |
args.stride, | |
args.pad, | |
args.dilation, | |
args.patch_dilation) | |
if gradcheck(correlation_sampler, [input1, input2]): | |
print('Ok') | |