Spaces:
Running
on
Zero
Running
on
Zero
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
"""Custom replacement for `torch.nn.functional.grid_sample` that | |
supports arbitrarily high order gradients between the input and output. | |
Only works on 2D images and assumes | |
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" | |
import torch | |
# pylint: disable=redefined-builtin | |
# pylint: disable=arguments-differ | |
# pylint: disable=protected-access | |
#---------------------------------------------------------------------------- | |
enabled = False # Enable the custom op by setting this to true. | |
#---------------------------------------------------------------------------- | |
def grid_sample(input, grid): | |
if _should_use_custom_op(): | |
return _GridSample2dForward.apply(input, grid) | |
return torch.nn.functional.grid_sample(input=input, | |
grid=grid, | |
mode='bilinear', | |
padding_mode='zeros', | |
align_corners=False) | |
#---------------------------------------------------------------------------- | |
def _should_use_custom_op(): | |
return enabled | |
#---------------------------------------------------------------------------- | |
class _GridSample2dForward(torch.autograd.Function): | |
def forward(ctx, input, grid): | |
assert input.ndim == 4 | |
assert grid.ndim == 4 | |
output = torch.nn.functional.grid_sample(input=input, | |
grid=grid, | |
mode='bilinear', | |
padding_mode='zeros', | |
align_corners=False) | |
ctx.save_for_backward(input, grid) | |
return output | |
def backward(ctx, grad_output): | |
input, grid = ctx.saved_tensors | |
grad_input, grad_grid = _GridSample2dBackward.apply( | |
grad_output, input, grid) | |
return grad_input, grad_grid | |
#---------------------------------------------------------------------------- | |
class _GridSample2dBackward(torch.autograd.Function): | |
def forward(ctx, grad_output, input, grid): | |
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') | |
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) | |
ctx.save_for_backward(grid) | |
return grad_input, grad_grid | |
def backward(ctx, grad2_grad_input, grad2_grad_grid): | |
_ = grad2_grad_grid # unused | |
grid, = ctx.saved_tensors | |
grad2_grad_output = None | |
grad2_input = None | |
grad2_grid = None | |
if ctx.needs_input_grad[0]: | |
grad2_grad_output = _GridSample2dForward.apply( | |
grad2_grad_input, grid) | |
assert not ctx.needs_input_grad[2] | |
return grad2_grad_output, grad2_input, grad2_grid | |
#---------------------------------------------------------------------------- | |