PKUWilliamYang commited on
Commit
8059447
1 Parent(s): 5e97cdf

Update vtoonify/model/stylegan/op/conv2d_gradfix.py

Browse files
vtoonify/model/stylegan/op/conv2d_gradfix.py CHANGED
@@ -1,227 +1,227 @@
1
- import contextlib
2
- import warnings
3
-
4
- import torch
5
- from torch import autograd
6
- from torch.nn import functional as F
7
-
8
- enabled = True
9
- weight_gradients_disabled = False
10
-
11
-
12
- @contextlib.contextmanager
13
- def no_weight_gradients():
14
- global weight_gradients_disabled
15
-
16
- old = weight_gradients_disabled
17
- weight_gradients_disabled = True
18
- yield
19
- weight_gradients_disabled = old
20
-
21
-
22
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
- if could_use_op(input):
24
- return conv2d_gradfix(
25
- transpose=False,
26
- weight_shape=weight.shape,
27
- stride=stride,
28
- padding=padding,
29
- output_padding=0,
30
- dilation=dilation,
31
- groups=groups,
32
- ).apply(input, weight, bias)
33
-
34
- return F.conv2d(
35
- input=input,
36
- weight=weight,
37
- bias=bias,
38
- stride=stride,
39
- padding=padding,
40
- dilation=dilation,
41
- groups=groups,
42
- )
43
-
44
-
45
- def conv_transpose2d(
46
- input,
47
- weight,
48
- bias=None,
49
- stride=1,
50
- padding=0,
51
- output_padding=0,
52
- groups=1,
53
- dilation=1,
54
- ):
55
- if could_use_op(input):
56
- return conv2d_gradfix(
57
- transpose=True,
58
- weight_shape=weight.shape,
59
- stride=stride,
60
- padding=padding,
61
- output_padding=output_padding,
62
- groups=groups,
63
- dilation=dilation,
64
- ).apply(input, weight, bias)
65
-
66
- return F.conv_transpose2d(
67
- input=input,
68
- weight=weight,
69
- bias=bias,
70
- stride=stride,
71
- padding=padding,
72
- output_padding=output_padding,
73
- dilation=dilation,
74
- groups=groups,
75
- )
76
-
77
-
78
- def could_use_op(input):
79
- if (not enabled) or (not torch.backends.cudnn.enabled):
80
- return False
81
-
82
- if input.device.type != "cuda":
83
- return False
84
-
85
- if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
- return True
87
-
88
- warnings.warn(
89
- f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
- )
91
-
92
- return False
93
-
94
-
95
- def ensure_tuple(xs, ndim):
96
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
-
98
- return xs
99
-
100
-
101
- conv2d_gradfix_cache = dict()
102
-
103
-
104
- def conv2d_gradfix(
105
- transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
- ):
107
- ndim = 2
108
- weight_shape = tuple(weight_shape)
109
- stride = ensure_tuple(stride, ndim)
110
- padding = ensure_tuple(padding, ndim)
111
- output_padding = ensure_tuple(output_padding, ndim)
112
- dilation = ensure_tuple(dilation, ndim)
113
-
114
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
- if key in conv2d_gradfix_cache:
116
- return conv2d_gradfix_cache[key]
117
-
118
- common_kwargs = dict(
119
- stride=stride, padding=padding, dilation=dilation, groups=groups
120
- )
121
-
122
- def calc_output_padding(input_shape, output_shape):
123
- if transpose:
124
- return [0, 0]
125
-
126
- return [
127
- input_shape[i + 2]
128
- - (output_shape[i + 2] - 1) * stride[i]
129
- - (1 - 2 * padding[i])
130
- - dilation[i] * (weight_shape[i + 2] - 1)
131
- for i in range(ndim)
132
- ]
133
-
134
- class Conv2d(autograd.Function):
135
- @staticmethod
136
- def forward(ctx, input, weight, bias):
137
- if not transpose:
138
- out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
-
140
- else:
141
- out = F.conv_transpose2d(
142
- input=input,
143
- weight=weight,
144
- bias=bias,
145
- output_padding=output_padding,
146
- **common_kwargs,
147
- )
148
-
149
- ctx.save_for_backward(input, weight)
150
-
151
- return out
152
-
153
- @staticmethod
154
- def backward(ctx, grad_output):
155
- input, weight = ctx.saved_tensors
156
- grad_input, grad_weight, grad_bias = None, None, None
157
-
158
- if ctx.needs_input_grad[0]:
159
- p = calc_output_padding(
160
- input_shape=input.shape, output_shape=grad_output.shape
161
- )
162
- grad_input = conv2d_gradfix(
163
- transpose=(not transpose),
164
- weight_shape=weight_shape,
165
- output_padding=p,
166
- **common_kwargs,
167
- ).apply(grad_output, weight, None)
168
-
169
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
-
172
- if ctx.needs_input_grad[2]:
173
- grad_bias = grad_output.sum((0, 2, 3))
174
-
175
- return grad_input, grad_weight, grad_bias
176
-
177
- class Conv2dGradWeight(autograd.Function):
178
- @staticmethod
179
- def forward(ctx, grad_output, input):
180
- op = torch._C._jit_get_operation(
181
- "aten::cudnn_convolution_backward_weight"
182
- if not transpose
183
- else "aten::cudnn_convolution_transpose_backward_weight"
184
- )
185
- flags = [
186
- torch.backends.cudnn.benchmark,
187
- torch.backends.cudnn.deterministic,
188
- torch.backends.cudnn.allow_tf32,
189
- ]
190
- grad_weight = op(
191
- weight_shape,
192
- grad_output,
193
- input,
194
- padding,
195
- stride,
196
- dilation,
197
- groups,
198
- *flags,
199
- )
200
- ctx.save_for_backward(grad_output, input)
201
-
202
- return grad_weight
203
-
204
- @staticmethod
205
- def backward(ctx, grad_grad_weight):
206
- grad_output, input = ctx.saved_tensors
207
- grad_grad_output, grad_grad_input = None, None
208
-
209
- if ctx.needs_input_grad[0]:
210
- grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
-
212
- if ctx.needs_input_grad[1]:
213
- p = calc_output_padding(
214
- input_shape=input.shape, output_shape=grad_output.shape
215
- )
216
- grad_grad_input = conv2d_gradfix(
217
- transpose=(not transpose),
218
- weight_shape=weight_shape,
219
- output_padding=p,
220
- **common_kwargs,
221
- ).apply(grad_output, grad_grad_weight, None)
222
-
223
- return grad_grad_output, grad_grad_input
224
-
225
- conv2d_gradfix_cache[key] = Conv2d
226
-
227
- return Conv2d
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ #warnings.warn(
89
+ # f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ #)
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d