i72sijia commited on
Commit
e01477f
1 Parent(s): 1fa5edb

Upload grid_sample_gradfix.py

Browse files
torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import warnings
15
+ import torch
16
+
17
+ # pylint: disable=redefined-builtin
18
+ # pylint: disable=arguments-differ
19
+ # pylint: disable=protected-access
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ enabled = False # Enable the custom op by setting this to true.
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ def grid_sample(input, grid):
28
+ if _should_use_custom_op():
29
+ return _GridSample2dForward.apply(input, grid)
30
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def _should_use_custom_op():
35
+ if not enabled:
36
+ return False
37
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38
+ return True
39
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40
+ return False
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ class _GridSample2dForward(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, input, grid):
47
+ assert input.ndim == 4
48
+ assert grid.ndim == 4
49
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50
+ ctx.save_for_backward(input, grid)
51
+ return output
52
+
53
+ @staticmethod
54
+ def backward(ctx, grad_output):
55
+ input, grid = ctx.saved_tensors
56
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57
+ return grad_input, grad_grid
58
+
59
+ #----------------------------------------------------------------------------
60
+
61
+ class _GridSample2dBackward(torch.autograd.Function):
62
+ @staticmethod
63
+ def forward(ctx, grad_output, input, grid):
64
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66
+ ctx.save_for_backward(grid)
67
+ return grad_input, grad_grid
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
71
+ _ = grad2_grad_grid # unused
72
+ grid, = ctx.saved_tensors
73
+ grad2_grad_output = None
74
+ grad2_input = None
75
+ grad2_grid = None
76
+
77
+ if ctx.needs_input_grad[0]:
78
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79
+
80
+ assert not ctx.needs_input_grad[2]
81
+ return grad2_grad_output, grad2_input, grad2_grid
82
+
83
+ #----------------------------------------------------------------------------