Spaces:
Runtime error
Runtime error
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import time | |
import torch | |
from spatial_correlation_sampler import SpatialCorrelationSampler | |
from tqdm import trange | |
TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000} | |
parser = argparse.ArgumentParser() | |
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda') | |
parser.add_argument('-b', '--batch-size', type=int, default=16) | |
parser.add_argument('-k', '--kernel-size', type=int, default=3) | |
parser.add_argument('--patch', type=int, default=3) | |
parser.add_argument('--patch_dilation', type=int, default=2) | |
parser.add_argument('-c', '--channel', type=int, default=64) | |
parser.add_argument('--height', type=int, default=100) | |
parser.add_argument('-w', '--width', type=int, default=100) | |
parser.add_argument('-s', '--stride', type=int, default=2) | |
parser.add_argument('-p', '--pad', type=int, default=1) | |
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us') | |
parser.add_argument('-r', '--runs', type=int, default=100) | |
parser.add_argument('--dilation', type=int, default=2) | |
parser.add_argument('-d', '--dtype', choices=['half', 'float', 'double']) | |
args = parser.parse_args() | |
device = torch.device(args.backend) | |
if args.dtype == 'half': | |
dtype = torch.float16 | |
elif args.dtype == 'float': | |
dtype = torch.float32 | |
else: | |
dtype = torch.float64 | |
input1 = torch.randn(args.batch_size, | |
args.channel, | |
args.height, | |
args.width, | |
dtype=dtype, | |
device=device, | |
requires_grad=True) | |
input2 = torch.randn_like(input1) | |
correlation_sampler = SpatialCorrelationSampler( | |
args.kernel_size, | |
args.patch, | |
args.stride, | |
args.pad, | |
args.dilation, | |
args.patch_dilation) | |
# Force CUDA initialization | |
output = correlation_sampler(input1, input2) | |
print(output.size()) | |
output.mean().backward() | |
forward_min = float('inf') | |
forward_time = 0 | |
backward_min = float('inf') | |
backward_time = 0 | |
for _ in trange(args.runs): | |
correlation_sampler.zero_grad() | |
start = time.time() | |
output = correlation_sampler(input1, input2) | |
elapsed = time.time() - start | |
forward_min = min(forward_min, elapsed) | |
forward_time += elapsed | |
output = output.mean() | |
start = time.time() | |
(output.mean()).backward() | |
elapsed = time.time() - start | |
backward_min = min(backward_min, elapsed) | |
backward_time += elapsed | |
scale = TIME_SCALES[args.scale] | |
forward_min *= scale | |
backward_min *= scale | |
forward_average = forward_time / args.runs * scale | |
backward_average = backward_time / args.runs * scale | |
print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format( | |
forward_min, forward_average, backward_min, backward_average, | |
args.scale)) | |