i72sijia commited on
Commit
2b83341
1 Parent(s): c0068b4

Upload conv2d_gradfix.py

Browse files
Files changed (1) hide show
  1. torch_utils/ops/conv2d_gradfix.py +170 -0
torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import warnings
13
+ import contextlib
14
+ import torch
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24
+
25
+ @contextlib.contextmanager
26
+ def no_weight_gradients():
27
+ global weight_gradients_disabled
28
+ old = weight_gradients_disabled
29
+ weight_gradients_disabled = True
30
+ yield
31
+ weight_gradients_disabled = old
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
+
40
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
+ if _should_use_custom_op(input):
42
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def _should_use_custom_op(input):
48
+ assert isinstance(input, torch.Tensor)
49
+ if (not enabled) or (not torch.backends.cudnn.enabled):
50
+ return False
51
+ if input.device.type != 'cuda':
52
+ return False
53
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54
+ return True
55
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56
+ return False
57
+
58
+ def _tuple_of_ints(xs, ndim):
59
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60
+ assert len(xs) == ndim
61
+ assert all(isinstance(x, int) for x in xs)
62
+ return xs
63
+
64
+ #----------------------------------------------------------------------------
65
+
66
+ _conv2d_gradfix_cache = dict()
67
+
68
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69
+ # Parse arguments.
70
+ ndim = 2
71
+ weight_shape = tuple(weight_shape)
72
+ stride = _tuple_of_ints(stride, ndim)
73
+ padding = _tuple_of_ints(padding, ndim)
74
+ output_padding = _tuple_of_ints(output_padding, ndim)
75
+ dilation = _tuple_of_ints(dilation, ndim)
76
+
77
+ # Lookup from cache.
78
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79
+ if key in _conv2d_gradfix_cache:
80
+ return _conv2d_gradfix_cache[key]
81
+
82
+ # Validate arguments.
83
+ assert groups >= 1
84
+ assert len(weight_shape) == ndim + 2
85
+ assert all(stride[i] >= 1 for i in range(ndim))
86
+ assert all(padding[i] >= 0 for i in range(ndim))
87
+ assert all(dilation[i] >= 0 for i in range(ndim))
88
+ if not transpose:
89
+ assert all(output_padding[i] == 0 for i in range(ndim))
90
+ else: # transpose
91
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92
+
93
+ # Helpers.
94
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95
+ def calc_output_padding(input_shape, output_shape):
96
+ if transpose:
97
+ return [0, 0]
98
+ return [
99
+ input_shape[i + 2]
100
+ - (output_shape[i + 2] - 1) * stride[i]
101
+ - (1 - 2 * padding[i])
102
+ - dilation[i] * (weight_shape[i + 2] - 1)
103
+ for i in range(ndim)
104
+ ]
105
+
106
+ # Forward & backward.
107
+ class Conv2d(torch.autograd.Function):
108
+ @staticmethod
109
+ def forward(ctx, input, weight, bias):
110
+ assert weight.shape == weight_shape
111
+ if not transpose:
112
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113
+ else: # transpose
114
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115
+ ctx.save_for_backward(input, weight)
116
+ return output
117
+
118
+ @staticmethod
119
+ def backward(ctx, grad_output):
120
+ input, weight = ctx.saved_tensors
121
+ grad_input = None
122
+ grad_weight = None
123
+ grad_bias = None
124
+
125
+ if ctx.needs_input_grad[0]:
126
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128
+ assert grad_input.shape == input.shape
129
+
130
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
132
+ assert grad_weight.shape == weight_shape
133
+
134
+ if ctx.needs_input_grad[2]:
135
+ grad_bias = grad_output.sum([0, 2, 3])
136
+
137
+ return grad_input, grad_weight, grad_bias
138
+
139
+ # Gradient with respect to the weights.
140
+ class Conv2dGradWeight(torch.autograd.Function):
141
+ @staticmethod
142
+ def forward(ctx, grad_output, input):
143
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146
+ assert grad_weight.shape == weight_shape
147
+ ctx.save_for_backward(grad_output, input)
148
+ return grad_weight
149
+
150
+ @staticmethod
151
+ def backward(ctx, grad2_grad_weight):
152
+ grad_output, input = ctx.saved_tensors
153
+ grad2_grad_output = None
154
+ grad2_input = None
155
+
156
+ if ctx.needs_input_grad[0]:
157
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158
+ assert grad2_grad_output.shape == grad_output.shape
159
+
160
+ if ctx.needs_input_grad[1]:
161
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163
+ assert grad2_input.shape == input.shape
164
+
165
+ return grad2_grad_output, grad2_input
166
+
167
+ _conv2d_gradfix_cache[key] = Conv2d
168
+ return Conv2d
169
+
170
+ #----------------------------------------------------------------------------