Schrodingers's picture
Upload folder using huggingface_hub
ffbe0b4
raw
history blame
1.9 kB
import argparse
import torch
# torch.set_printoptions(precision=1, threshold=10000)
from torch.autograd import gradcheck
from spatial_correlation_sampler import SpatialCorrelationSampler
parser = argparse.ArgumentParser()
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('-b', '--batch-size', type=int, default=2)
parser.add_argument('-k', '--kernel-size', type=int, default=3)
parser.add_argument('--patch', type=int, default=3)
parser.add_argument('--patch_dilation', type=int, default=2)
parser.add_argument('-c', '--channel', type=int, default=2)
parser.add_argument('--height', type=int, default=10)
parser.add_argument('-w', '--width', type=int, default=10)
parser.add_argument('-s', '--stride', type=int, default=2)
parser.add_argument('-p', '--pad', type=int, default=1)
parser.add_argument('-d', '--dilation', type=int, default=2)
args = parser.parse_args()
input1 = torch.randn(args.batch_size,
args.channel,
args.height,
args.width,
dtype=torch.float64,
device=torch.device(args.backend))
input2 = torch.randn(args.batch_size,
args.channel,
args.height,
args.width,
dtype=torch.float64,
device=torch.device(args.backend))
input1.requires_grad = True
input2.requires_grad = True
correlation_sampler = SpatialCorrelationSampler(args.kernel_size,
args.patch,
args.stride,
args.pad,
args.dilation,
args.patch_dilation)
if gradcheck(correlation_sampler, [input1, input2]):
print('Ok')