rinong commited on
Commit
628a5a4
·
1 Parent(s): 2c2855f

Added non-gpu ops

Browse files
model/sg2_model.py CHANGED
@@ -8,7 +8,12 @@ from torch import nn
8
  from torch.nn import functional as F
9
  from torch.autograd import Function
10
 
11
- from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
 
 
 
 
 
12
 
13
 
14
  class PixelNorm(nn.Module):
 
8
  from torch.nn import functional as F
9
  from torch.autograd import Function
10
 
11
+ if torch.cuda.is_available():
12
+ from op.fused_act import FusedLeakyReLU, fused_leaky_relu
13
+ from op.upfirdn2d import upfirdn2d
14
+ else:
15
+ from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
16
+ from op.upfirdn2d_cpu import upfirdn2d
17
 
18
 
19
  class PixelNorm(nn.Module):
op/__init__.py CHANGED
@@ -1,2 +0,0 @@
1
- from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
- from .upfirdn2d import upfirdn2d
 
 
 
op/conv2d_gradfix.py CHANGED
@@ -1,227 +1,227 @@
1
- import contextlib
2
- import warnings
3
-
4
- import torch
5
- from torch import autograd
6
- from torch.nn import functional as F
7
-
8
- enabled = True
9
- weight_gradients_disabled = False
10
-
11
-
12
- @contextlib.contextmanager
13
- def no_weight_gradients():
14
- global weight_gradients_disabled
15
-
16
- old = weight_gradients_disabled
17
- weight_gradients_disabled = True
18
- yield
19
- weight_gradients_disabled = old
20
-
21
-
22
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
- if could_use_op(input):
24
- return conv2d_gradfix(
25
- transpose=False,
26
- weight_shape=weight.shape,
27
- stride=stride,
28
- padding=padding,
29
- output_padding=0,
30
- dilation=dilation,
31
- groups=groups,
32
- ).apply(input, weight, bias)
33
-
34
- return F.conv2d(
35
- input=input,
36
- weight=weight,
37
- bias=bias,
38
- stride=stride,
39
- padding=padding,
40
- dilation=dilation,
41
- groups=groups,
42
- )
43
-
44
-
45
- def conv_transpose2d(
46
- input,
47
- weight,
48
- bias=None,
49
- stride=1,
50
- padding=0,
51
- output_padding=0,
52
- groups=1,
53
- dilation=1,
54
- ):
55
- if could_use_op(input):
56
- return conv2d_gradfix(
57
- transpose=True,
58
- weight_shape=weight.shape,
59
- stride=stride,
60
- padding=padding,
61
- output_padding=output_padding,
62
- groups=groups,
63
- dilation=dilation,
64
- ).apply(input, weight, bias)
65
-
66
- return F.conv_transpose2d(
67
- input=input,
68
- weight=weight,
69
- bias=bias,
70
- stride=stride,
71
- padding=padding,
72
- output_padding=output_padding,
73
- dilation=dilation,
74
- groups=groups,
75
- )
76
-
77
-
78
- def could_use_op(input):
79
- if (not enabled) or (not torch.backends.cudnn.enabled):
80
- return False
81
-
82
- if input.device.type != "cuda":
83
- return False
84
-
85
- if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
- return True
87
-
88
- warnings.warn(
89
- f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
- )
91
-
92
- return False
93
-
94
-
95
- def ensure_tuple(xs, ndim):
96
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
-
98
- return xs
99
-
100
-
101
- conv2d_gradfix_cache = dict()
102
-
103
-
104
- def conv2d_gradfix(
105
- transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
- ):
107
- ndim = 2
108
- weight_shape = tuple(weight_shape)
109
- stride = ensure_tuple(stride, ndim)
110
- padding = ensure_tuple(padding, ndim)
111
- output_padding = ensure_tuple(output_padding, ndim)
112
- dilation = ensure_tuple(dilation, ndim)
113
-
114
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
- if key in conv2d_gradfix_cache:
116
- return conv2d_gradfix_cache[key]
117
-
118
- common_kwargs = dict(
119
- stride=stride, padding=padding, dilation=dilation, groups=groups
120
- )
121
-
122
- def calc_output_padding(input_shape, output_shape):
123
- if transpose:
124
- return [0, 0]
125
-
126
- return [
127
- input_shape[i + 2]
128
- - (output_shape[i + 2] - 1) * stride[i]
129
- - (1 - 2 * padding[i])
130
- - dilation[i] * (weight_shape[i + 2] - 1)
131
- for i in range(ndim)
132
- ]
133
-
134
- class Conv2d(autograd.Function):
135
- @staticmethod
136
- def forward(ctx, input, weight, bias):
137
- if not transpose:
138
- out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
-
140
- else:
141
- out = F.conv_transpose2d(
142
- input=input,
143
- weight=weight,
144
- bias=bias,
145
- output_padding=output_padding,
146
- **common_kwargs,
147
- )
148
-
149
- ctx.save_for_backward(input, weight)
150
-
151
- return out
152
-
153
- @staticmethod
154
- def backward(ctx, grad_output):
155
- input, weight = ctx.saved_tensors
156
- grad_input, grad_weight, grad_bias = None, None, None
157
-
158
- if ctx.needs_input_grad[0]:
159
- p = calc_output_padding(
160
- input_shape=input.shape, output_shape=grad_output.shape
161
- )
162
- grad_input = conv2d_gradfix(
163
- transpose=(not transpose),
164
- weight_shape=weight_shape,
165
- output_padding=p,
166
- **common_kwargs,
167
- ).apply(grad_output, weight, None)
168
-
169
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
-
172
- if ctx.needs_input_grad[2]:
173
- grad_bias = grad_output.sum((0, 2, 3))
174
-
175
- return grad_input, grad_weight, grad_bias
176
-
177
- class Conv2dGradWeight(autograd.Function):
178
- @staticmethod
179
- def forward(ctx, grad_output, input):
180
- op = torch._C._jit_get_operation(
181
- "aten::cudnn_convolution_backward_weight"
182
- if not transpose
183
- else "aten::cudnn_convolution_transpose_backward_weight"
184
- )
185
- flags = [
186
- torch.backends.cudnn.benchmark,
187
- torch.backends.cudnn.deterministic,
188
- torch.backends.cudnn.allow_tf32,
189
- ]
190
- grad_weight = op(
191
- weight_shape,
192
- grad_output,
193
- input,
194
- padding,
195
- stride,
196
- dilation,
197
- groups,
198
- *flags,
199
- )
200
- ctx.save_for_backward(grad_output, input)
201
-
202
- return grad_weight
203
-
204
- @staticmethod
205
- def backward(ctx, grad_grad_weight):
206
- grad_output, input = ctx.saved_tensors
207
- grad_grad_output, grad_grad_input = None, None
208
-
209
- if ctx.needs_input_grad[0]:
210
- grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
-
212
- if ctx.needs_input_grad[1]:
213
- p = calc_output_padding(
214
- input_shape=input.shape, output_shape=grad_output.shape
215
- )
216
- grad_grad_input = conv2d_gradfix(
217
- transpose=(not transpose),
218
- weight_shape=weight_shape,
219
- output_padding=p,
220
- **common_kwargs,
221
- ).apply(grad_output, grad_grad_weight, None)
222
-
223
- return grad_grad_output, grad_grad_input
224
-
225
- conv2d_gradfix_cache[key] = Conv2d
226
-
227
- return Conv2d
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
op/fused_act.py CHANGED
@@ -1,119 +1,86 @@
1
- import os
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from torch.autograd import Function
7
- from torch.utils.cpp_extension import load
8
-
9
-
10
- module_path = os.path.dirname(__file__)
11
- fused = load(
12
- "fused",
13
- sources=[
14
- os.path.join(module_path, "fused_bias_act.cpp"),
15
- os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
- ],
17
- )
18
-
19
-
20
- class FusedLeakyReLUFunctionBackward(Function):
21
- @staticmethod
22
- def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
- ctx.save_for_backward(out)
24
- ctx.negative_slope = negative_slope
25
- ctx.scale = scale
26
-
27
- empty = grad_output.new_empty(0)
28
-
29
- grad_input = fused.fused_bias_act(
30
- grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
- )
32
-
33
- dim = [0]
34
-
35
- if grad_input.ndim > 2:
36
- dim += list(range(2, grad_input.ndim))
37
-
38
- if bias:
39
- grad_bias = grad_input.sum(dim).detach()
40
-
41
- else:
42
- grad_bias = empty
43
-
44
- return grad_input, grad_bias
45
-
46
- @staticmethod
47
- def backward(ctx, gradgrad_input, gradgrad_bias):
48
- out, = ctx.saved_tensors
49
- gradgrad_out = fused.fused_bias_act(
50
- gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51
- )
52
-
53
- return gradgrad_out, None, None, None, None
54
-
55
-
56
- class FusedLeakyReLUFunction(Function):
57
- @staticmethod
58
- def forward(ctx, input, bias, negative_slope, scale):
59
- empty = input.new_empty(0)
60
-
61
- ctx.bias = bias is not None
62
-
63
- if bias is None:
64
- bias = empty
65
-
66
- out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
67
- ctx.save_for_backward(out)
68
- ctx.negative_slope = negative_slope
69
- ctx.scale = scale
70
-
71
- return out
72
-
73
- @staticmethod
74
- def backward(ctx, grad_output):
75
- out, = ctx.saved_tensors
76
-
77
- grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78
- grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79
- )
80
-
81
- if not ctx.bias:
82
- grad_bias = None
83
-
84
- return grad_input, grad_bias, None, None
85
-
86
-
87
- class FusedLeakyReLU(nn.Module):
88
- def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89
- super().__init__()
90
-
91
- if bias:
92
- self.bias = nn.Parameter(torch.zeros(channel))
93
-
94
- else:
95
- self.bias = None
96
-
97
- self.negative_slope = negative_slope
98
- self.scale = scale
99
-
100
- def forward(self, input):
101
- return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102
-
103
-
104
- def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105
- if input.device.type == "cpu":
106
- if bias is not None:
107
- rest_dim = [1] * (input.ndim - bias.ndim - 1)
108
- return (
109
- F.leaky_relu(
110
- input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111
- )
112
- * scale
113
- )
114
-
115
- else:
116
- return F.leaky_relu(input, negative_slope=0.2) * scale
117
-
118
- else:
119
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+ fused = load(
11
+ 'fused',
12
+ sources=[
13
+ os.path.join(module_path, 'fused_bias_act.cpp'),
14
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15
+ ],
16
+ )
17
+
18
+
19
+ class FusedLeakyReLUFunctionBackward(Function):
20
+ @staticmethod
21
+ def forward(ctx, grad_output, out, negative_slope, scale):
22
+ ctx.save_for_backward(out)
23
+ ctx.negative_slope = negative_slope
24
+ ctx.scale = scale
25
+
26
+ empty = grad_output.new_empty(0)
27
+
28
+ grad_input = fused.fused_bias_act(
29
+ grad_output, empty, out, 3, 1, negative_slope, scale
30
+ )
31
+
32
+ dim = [0]
33
+
34
+ if grad_input.ndim > 2:
35
+ dim += list(range(2, grad_input.ndim))
36
+
37
+ grad_bias = grad_input.sum(dim).detach()
38
+
39
+ return grad_input, grad_bias
40
+
41
+ @staticmethod
42
+ def backward(ctx, gradgrad_input, gradgrad_bias):
43
+ out, = ctx.saved_tensors
44
+ gradgrad_out = fused.fused_bias_act(
45
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
46
+ )
47
+
48
+ return gradgrad_out, None, None, None
49
+
50
+
51
+ class FusedLeakyReLUFunction(Function):
52
+ @staticmethod
53
+ def forward(ctx, input, bias, negative_slope, scale):
54
+ empty = input.new_empty(0)
55
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
56
+ ctx.save_for_backward(out)
57
+ ctx.negative_slope = negative_slope
58
+ ctx.scale = scale
59
+
60
+ return out
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output):
64
+ out, = ctx.saved_tensors
65
+
66
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
67
+ grad_output, out, ctx.negative_slope, ctx.scale
68
+ )
69
+
70
+ return grad_input, grad_bias, None, None
71
+
72
+
73
+ class FusedLeakyReLU(nn.Module):
74
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
75
+ super().__init__()
76
+
77
+ self.bias = nn.Parameter(torch.zeros(channel))
78
+ self.negative_slope = negative_slope
79
+ self.scale = scale
80
+
81
+ def forward(self, input):
82
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
83
+
84
+
85
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_act_cpu.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.nn import functional as F
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+
12
+ class FusedLeakyReLU(nn.Module):
13
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
14
+ super().__init__()
15
+
16
+ self.bias = nn.Parameter(torch.zeros(channel))
17
+ self.negative_slope = negative_slope
18
+ self.scale = scale
19
+
20
+ def forward(self, input):
21
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
22
+
23
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
24
+ if input.device.type == "cpu":
25
+ if bias is not None:
26
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
27
+ return (
28
+ F.leaky_relu(
29
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
30
+ )
31
+ * scale
32
+ )
33
+
34
+ else:
35
+ return F.leaky_relu(input, negative_slope=0.2) * scale
36
+
37
+ else:
38
+ return FusedLeakyReLUFunction.apply(
39
+ input.contiguous(), bias, negative_slope, scale
40
+ )
41
+
op/fused_bias_act.cpp CHANGED
@@ -1,32 +1,21 @@
1
-
2
- #include <ATen/ATen.h>
3
- #include <torch/extension.h>
4
-
5
- torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
- const torch::Tensor &bias,
7
- const torch::Tensor &refer, int act, int grad,
8
- float alpha, float scale);
9
-
10
- #define CHECK_CUDA(x) \
11
- TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
- #define CHECK_CONTIGUOUS(x) \
13
- TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
- #define CHECK_INPUT(x) \
15
- CHECK_CUDA(x); \
16
- CHECK_CONTIGUOUS(x)
17
-
18
- torch::Tensor fused_bias_act(const torch::Tensor &input,
19
- const torch::Tensor &bias,
20
- const torch::Tensor &refer, int act, int grad,
21
- float alpha, float scale) {
22
- CHECK_INPUT(input);
23
- CHECK_INPUT(bias);
24
-
25
- at::DeviceGuard guard(input.device());
26
-
27
- return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
- }
29
-
30
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
- m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
  }
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
 
 
 
 
 
 
 
 
 
 
 
21
  }
op/fused_bias_act_kernel.cu CHANGED
@@ -1,105 +1,99 @@
1
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
- //
3
- // This work is made available under the Nvidia Source Code License-NC.
4
- // To view a copy of this license, visit
5
- // https://nvlabs.github.io/stylegan2/license.html
6
-
7
- #include <torch/types.h>
8
-
9
- #include <ATen/ATen.h>
10
- #include <ATen/AccumulateType.h>
11
- #include <ATen/cuda/CUDAApplyUtils.cuh>
12
- #include <ATen/cuda/CUDAContext.h>
13
-
14
-
15
- #include <cuda.h>
16
- #include <cuda_runtime.h>
17
-
18
- template <typename scalar_t>
19
- static __global__ void
20
- fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
- const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
- scalar_t scale, int loop_x, int size_x, int step_b,
23
- int size_b, int use_bias, int use_ref) {
24
- int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
-
26
- scalar_t zero = 0.0;
27
-
28
- for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
- loop_idx++, xi += blockDim.x) {
30
- scalar_t x = p_x[xi];
31
-
32
- if (use_bias) {
33
- x += p_b[(xi / step_b) % size_b];
34
- }
35
-
36
- scalar_t ref = use_ref ? p_ref[xi] : zero;
37
-
38
- scalar_t y;
39
-
40
- switch (act * 10 + grad) {
41
- default:
42
- case 10:
43
- y = x;
44
- break;
45
- case 11:
46
- y = x;
47
- break;
48
- case 12:
49
- y = 0.0;
50
- break;
51
-
52
- case 30:
53
- y = (x > 0.0) ? x : x * alpha;
54
- break;
55
- case 31:
56
- y = (ref > 0.0) ? x : x * alpha;
57
- break;
58
- case 32:
59
- y = 0.0;
60
- break;
61
- }
62
-
63
- out[xi] = y * scale;
64
- }
65
- }
66
-
67
- torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
- const torch::Tensor &bias,
69
- const torch::Tensor &refer, int act, int grad,
70
- float alpha, float scale) {
71
- int curDevice = -1;
72
- cudaGetDevice(&curDevice);
73
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
-
75
- auto x = input.contiguous();
76
- auto b = bias.contiguous();
77
- auto ref = refer.contiguous();
78
-
79
- int use_bias = b.numel() ? 1 : 0;
80
- int use_ref = ref.numel() ? 1 : 0;
81
-
82
- int size_x = x.numel();
83
- int size_b = b.numel();
84
- int step_b = 1;
85
-
86
- for (int i = 1 + 1; i < x.dim(); i++) {
87
- step_b *= x.size(i);
88
- }
89
-
90
- int loop_x = 4;
91
- int block_size = 4 * 32;
92
- int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
-
94
- auto y = torch::empty_like(x);
95
-
96
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
- x.scalar_type(), "fused_bias_act_kernel", [&] {
98
- fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
- y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
- b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
- scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
- });
103
-
104
- return y;
105
  }
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
 
 
 
 
 
 
99
  }
op/upfirdn2d.cpp CHANGED
@@ -1,31 +1,23 @@
1
- #include <ATen/ATen.h>
2
- #include <torch/extension.h>
3
-
4
- torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
- const torch::Tensor &kernel, int up_x, int up_y,
6
- int down_x, int down_y, int pad_x0, int pad_x1,
7
- int pad_y0, int pad_y1);
8
-
9
- #define CHECK_CUDA(x) \
10
- TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
- #define CHECK_CONTIGUOUS(x) \
12
- TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
- #define CHECK_INPUT(x) \
14
- CHECK_CUDA(x); \
15
- CHECK_CONTIGUOUS(x)
16
-
17
- torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
- int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
- int pad_x1, int pad_y0, int pad_y1) {
20
- CHECK_INPUT(input);
21
- CHECK_INPUT(kernel);
22
-
23
- at::DeviceGuard guard(input.device());
24
-
25
- return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
- pad_y0, pad_y1);
27
- }
28
-
29
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
- m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
  }
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
 
 
 
 
 
 
 
 
23
  }
op/upfirdn2d.py CHANGED
@@ -1,209 +1,187 @@
1
- from collections import abc
2
- import os
3
-
4
- import torch
5
- from torch.nn import functional as F
6
- from torch.autograd import Function
7
- from torch.utils.cpp_extension import load
8
-
9
-
10
- module_path = os.path.dirname(__file__)
11
- upfirdn2d_op = load(
12
- "upfirdn2d",
13
- sources=[
14
- os.path.join(module_path, "upfirdn2d.cpp"),
15
- os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
- ],
17
- )
18
-
19
-
20
- class UpFirDn2dBackward(Function):
21
- @staticmethod
22
- def forward(
23
- ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
- ):
25
-
26
- up_x, up_y = up
27
- down_x, down_y = down
28
- g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
-
30
- grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
-
32
- grad_input = upfirdn2d_op.upfirdn2d(
33
- grad_output,
34
- grad_kernel,
35
- down_x,
36
- down_y,
37
- up_x,
38
- up_y,
39
- g_pad_x0,
40
- g_pad_x1,
41
- g_pad_y0,
42
- g_pad_y1,
43
- )
44
- grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
-
46
- ctx.save_for_backward(kernel)
47
-
48
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
-
50
- ctx.up_x = up_x
51
- ctx.up_y = up_y
52
- ctx.down_x = down_x
53
- ctx.down_y = down_y
54
- ctx.pad_x0 = pad_x0
55
- ctx.pad_x1 = pad_x1
56
- ctx.pad_y0 = pad_y0
57
- ctx.pad_y1 = pad_y1
58
- ctx.in_size = in_size
59
- ctx.out_size = out_size
60
-
61
- return grad_input
62
-
63
- @staticmethod
64
- def backward(ctx, gradgrad_input):
65
- kernel, = ctx.saved_tensors
66
-
67
- gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
-
69
- gradgrad_out = upfirdn2d_op.upfirdn2d(
70
- gradgrad_input,
71
- kernel,
72
- ctx.up_x,
73
- ctx.up_y,
74
- ctx.down_x,
75
- ctx.down_y,
76
- ctx.pad_x0,
77
- ctx.pad_x1,
78
- ctx.pad_y0,
79
- ctx.pad_y1,
80
- )
81
- # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
- gradgrad_out = gradgrad_out.view(
83
- ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
- )
85
-
86
- return gradgrad_out, None, None, None, None, None, None, None, None
87
-
88
-
89
- class UpFirDn2d(Function):
90
- @staticmethod
91
- def forward(ctx, input, kernel, up, down, pad):
92
- up_x, up_y = up
93
- down_x, down_y = down
94
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
-
96
- kernel_h, kernel_w = kernel.shape
97
- batch, channel, in_h, in_w = input.shape
98
- ctx.in_size = input.shape
99
-
100
- input = input.reshape(-1, in_h, in_w, 1)
101
-
102
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
-
104
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
- ctx.out_size = (out_h, out_w)
107
-
108
- ctx.up = (up_x, up_y)
109
- ctx.down = (down_x, down_y)
110
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
-
112
- g_pad_x0 = kernel_w - pad_x0 - 1
113
- g_pad_y0 = kernel_h - pad_y0 - 1
114
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
-
117
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
-
119
- out = upfirdn2d_op.upfirdn2d(
120
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
- )
122
- # out = out.view(major, out_h, out_w, minor)
123
- out = out.view(-1, channel, out_h, out_w)
124
-
125
- return out
126
-
127
- @staticmethod
128
- def backward(ctx, grad_output):
129
- kernel, grad_kernel = ctx.saved_tensors
130
-
131
- grad_input = None
132
-
133
- if ctx.needs_input_grad[0]:
134
- grad_input = UpFirDn2dBackward.apply(
135
- grad_output,
136
- kernel,
137
- grad_kernel,
138
- ctx.up,
139
- ctx.down,
140
- ctx.pad,
141
- ctx.g_pad,
142
- ctx.in_size,
143
- ctx.out_size,
144
- )
145
-
146
- return grad_input, None, None, None, None
147
-
148
-
149
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
- if not isinstance(up, abc.Iterable):
151
- up = (up, up)
152
-
153
- if not isinstance(down, abc.Iterable):
154
- down = (down, down)
155
-
156
- if len(pad) == 2:
157
- pad = (pad[0], pad[1], pad[0], pad[1])
158
-
159
- if input.device.type == "cpu":
160
- out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
-
162
- else:
163
- out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
-
165
- return out
166
-
167
-
168
- def upfirdn2d_native(
169
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
- ):
171
- _, channel, in_h, in_w = input.shape
172
- input = input.reshape(-1, in_h, in_w, 1)
173
-
174
- _, in_h, in_w, minor = input.shape
175
- kernel_h, kernel_w = kernel.shape
176
-
177
- out = input.view(-1, in_h, 1, in_w, 1, minor)
178
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
-
181
- out = F.pad(
182
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
- )
184
- out = out[
185
- :,
186
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
- :,
189
- ]
190
-
191
- out = out.permute(0, 3, 1, 2)
192
- out = out.reshape(
193
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
- )
195
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
- out = F.conv2d(out, w)
197
- out = out.reshape(
198
- -1,
199
- minor,
200
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
- )
203
- out = out.permute(0, 2, 3, 1)
204
- out = out[:, ::down_y, ::down_x, :]
205
-
206
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
-
209
- return out.view(-1, channel, out_h, out_w)
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.utils.cpp_extension import load
6
+
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ upfirdn2d_op = load(
10
+ 'upfirdn2d',
11
+ sources=[
12
+ os.path.join(module_path, 'upfirdn2d.cpp'),
13
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class UpFirDn2dBackward(Function):
19
+ @staticmethod
20
+ def forward(
21
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
22
+ ):
23
+
24
+ up_x, up_y = up
25
+ down_x, down_y = down
26
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27
+
28
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29
+
30
+ grad_input = upfirdn2d_op.upfirdn2d(
31
+ grad_output,
32
+ grad_kernel,
33
+ down_x,
34
+ down_y,
35
+ up_x,
36
+ up_y,
37
+ g_pad_x0,
38
+ g_pad_x1,
39
+ g_pad_y0,
40
+ g_pad_y1,
41
+ )
42
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43
+
44
+ ctx.save_for_backward(kernel)
45
+
46
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
47
+
48
+ ctx.up_x = up_x
49
+ ctx.up_y = up_y
50
+ ctx.down_x = down_x
51
+ ctx.down_y = down_y
52
+ ctx.pad_x0 = pad_x0
53
+ ctx.pad_x1 = pad_x1
54
+ ctx.pad_y0 = pad_y0
55
+ ctx.pad_y1 = pad_y1
56
+ ctx.in_size = in_size
57
+ ctx.out_size = out_size
58
+
59
+ return grad_input
60
+
61
+ @staticmethod
62
+ def backward(ctx, gradgrad_input):
63
+ kernel, = ctx.saved_tensors
64
+
65
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66
+
67
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
68
+ gradgrad_input,
69
+ kernel,
70
+ ctx.up_x,
71
+ ctx.up_y,
72
+ ctx.down_x,
73
+ ctx.down_y,
74
+ ctx.pad_x0,
75
+ ctx.pad_x1,
76
+ ctx.pad_y0,
77
+ ctx.pad_y1,
78
+ )
79
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80
+ gradgrad_out = gradgrad_out.view(
81
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82
+ )
83
+
84
+ return gradgrad_out, None, None, None, None, None, None, None, None
85
+
86
+
87
+ class UpFirDn2d(Function):
88
+ @staticmethod
89
+ def forward(ctx, input, kernel, up, down, pad):
90
+ up_x, up_y = up
91
+ down_x, down_y = down
92
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
93
+
94
+ kernel_h, kernel_w = kernel.shape
95
+ batch, channel, in_h, in_w = input.shape
96
+ ctx.in_size = input.shape
97
+
98
+ input = input.reshape(-1, in_h, in_w, 1)
99
+
100
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101
+
102
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104
+ ctx.out_size = (out_h, out_w)
105
+
106
+ ctx.up = (up_x, up_y)
107
+ ctx.down = (down_x, down_y)
108
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109
+
110
+ g_pad_x0 = kernel_w - pad_x0 - 1
111
+ g_pad_y0 = kernel_h - pad_y0 - 1
112
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114
+
115
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116
+
117
+ out = upfirdn2d_op.upfirdn2d(
118
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119
+ )
120
+ # out = out.view(major, out_h, out_w, minor)
121
+ out = out.view(-1, channel, out_h, out_w)
122
+
123
+ return out
124
+
125
+ @staticmethod
126
+ def backward(ctx, grad_output):
127
+ kernel, grad_kernel = ctx.saved_tensors
128
+
129
+ grad_input = UpFirDn2dBackward.apply(
130
+ grad_output,
131
+ kernel,
132
+ grad_kernel,
133
+ ctx.up,
134
+ ctx.down,
135
+ ctx.pad,
136
+ ctx.g_pad,
137
+ ctx.in_size,
138
+ ctx.out_size,
139
+ )
140
+
141
+ return grad_input, None, None, None, None
142
+
143
+
144
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
+ out = UpFirDn2d.apply(
146
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147
+ )
148
+
149
+ return out
150
+
151
+
152
+ def upfirdn2d_native(
153
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
154
+ ):
155
+ _, in_h, in_w, minor = input.shape
156
+ kernel_h, kernel_w = kernel.shape
157
+
158
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
159
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
160
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
161
+
162
+ out = F.pad(
163
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
164
+ )
165
+ out = out[
166
+ :,
167
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
168
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
169
+ :,
170
+ ]
171
+
172
+ out = out.permute(0, 3, 1, 2)
173
+ out = out.reshape(
174
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
175
+ )
176
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
177
+ out = F.conv2d(out, w)
178
+ out = out.reshape(
179
+ -1,
180
+ minor,
181
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
182
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
183
+ )
184
+ out = out.permute(0, 2, 3, 1)
185
+
186
+ return out[:, ::down_y, ::down_x, :]
187
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d_cpu.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.nn import functional as F
6
+
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
+ )
15
+
16
+ return out
17
+
18
+
19
+ def upfirdn2d_native(
20
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
+ ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
+ _, in_h, in_w, minor = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
29
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
+
32
+ out = F.pad(
33
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
+ )
35
+ out = out[
36
+ :,
37
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
+ :,
40
+ ]
41
+
42
+ out = out.permute(0, 3, 1, 2)
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
+
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
59
+
60
+ return out.view(-1, channel, out_h, out_w)
op/upfirdn2d_kernel.cu CHANGED
@@ -1,369 +1,272 @@
1
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
- //
3
- // This work is made available under the Nvidia Source Code License-NC.
4
- // To view a copy of this license, visit
5
- // https://nvlabs.github.io/stylegan2/license.html
6
-
7
- #include <torch/types.h>
8
-
9
- #include <ATen/ATen.h>
10
- #include <ATen/AccumulateType.h>
11
- #include <ATen/cuda/CUDAApplyUtils.cuh>
12
- #include <ATen/cuda/CUDAContext.h>
13
-
14
- #include <cuda.h>
15
- #include <cuda_runtime.h>
16
-
17
- static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
- int c = a / b;
19
-
20
- if (c * b > a) {
21
- c--;
22
- }
23
-
24
- return c;
25
- }
26
-
27
- struct UpFirDn2DKernelParams {
28
- int up_x;
29
- int up_y;
30
- int down_x;
31
- int down_y;
32
- int pad_x0;
33
- int pad_x1;
34
- int pad_y0;
35
- int pad_y1;
36
-
37
- int major_dim;
38
- int in_h;
39
- int in_w;
40
- int minor_dim;
41
- int kernel_h;
42
- int kernel_w;
43
- int out_h;
44
- int out_w;
45
- int loop_major;
46
- int loop_x;
47
- };
48
-
49
- template <typename scalar_t>
50
- __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
- const scalar_t *kernel,
52
- const UpFirDn2DKernelParams p) {
53
- int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
- int out_y = minor_idx / p.minor_dim;
55
- minor_idx -= out_y * p.minor_dim;
56
- int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
- int major_idx_base = blockIdx.z * p.loop_major;
58
-
59
- if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
- major_idx_base >= p.major_dim) {
61
- return;
62
- }
63
-
64
- int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
- int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
- int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
- int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
-
69
- for (int loop_major = 0, major_idx = major_idx_base;
70
- loop_major < p.loop_major && major_idx < p.major_dim;
71
- loop_major++, major_idx++) {
72
- for (int loop_x = 0, out_x = out_x_base;
73
- loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
- int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
- int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
- int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
- int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
-
79
- const scalar_t *x_p =
80
- &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
- minor_idx];
82
- const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
- int x_px = p.minor_dim;
84
- int k_px = -p.up_x;
85
- int x_py = p.in_w * p.minor_dim;
86
- int k_py = -p.up_y * p.kernel_w;
87
-
88
- scalar_t v = 0.0f;
89
-
90
- for (int y = 0; y < h; y++) {
91
- for (int x = 0; x < w; x++) {
92
- v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
- x_p += x_px;
94
- k_p += k_px;
95
- }
96
-
97
- x_p += x_py - w * x_px;
98
- k_p += k_py - w * k_px;
99
- }
100
-
101
- out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
- minor_idx] = v;
103
- }
104
- }
105
- }
106
-
107
- template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
- int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
- __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
- const scalar_t *kernel,
111
- const UpFirDn2DKernelParams p) {
112
- const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
- const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
-
115
- __shared__ volatile float sk[kernel_h][kernel_w];
116
- __shared__ volatile float sx[tile_in_h][tile_in_w];
117
-
118
- int minor_idx = blockIdx.x;
119
- int tile_out_y = minor_idx / p.minor_dim;
120
- minor_idx -= tile_out_y * p.minor_dim;
121
- tile_out_y *= tile_out_h;
122
- int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
- int major_idx_base = blockIdx.z * p.loop_major;
124
-
125
- if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
- major_idx_base >= p.major_dim) {
127
- return;
128
- }
129
-
130
- for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
- tap_idx += blockDim.x) {
132
- int ky = tap_idx / kernel_w;
133
- int kx = tap_idx - ky * kernel_w;
134
- scalar_t v = 0.0;
135
-
136
- if (kx < p.kernel_w & ky < p.kernel_h) {
137
- v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
- }
139
-
140
- sk[ky][kx] = v;
141
- }
142
-
143
- for (int loop_major = 0, major_idx = major_idx_base;
144
- loop_major < p.loop_major & major_idx < p.major_dim;
145
- loop_major++, major_idx++) {
146
- for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
- loop_x < p.loop_x & tile_out_x < p.out_w;
148
- loop_x++, tile_out_x += tile_out_w) {
149
- int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
- int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
- int tile_in_x = floor_div(tile_mid_x, up_x);
152
- int tile_in_y = floor_div(tile_mid_y, up_y);
153
-
154
- __syncthreads();
155
-
156
- for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
- in_idx += blockDim.x) {
158
- int rel_in_y = in_idx / tile_in_w;
159
- int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
- int in_x = rel_in_x + tile_in_x;
161
- int in_y = rel_in_y + tile_in_y;
162
-
163
- scalar_t v = 0.0;
164
-
165
- if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
- v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
- p.minor_dim +
168
- minor_idx];
169
- }
170
-
171
- sx[rel_in_y][rel_in_x] = v;
172
- }
173
-
174
- __syncthreads();
175
- for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
- out_idx += blockDim.x) {
177
- int rel_out_y = out_idx / tile_out_w;
178
- int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
- int out_x = rel_out_x + tile_out_x;
180
- int out_y = rel_out_y + tile_out_y;
181
-
182
- int mid_x = tile_mid_x + rel_out_x * down_x;
183
- int mid_y = tile_mid_y + rel_out_y * down_y;
184
- int in_x = floor_div(mid_x, up_x);
185
- int in_y = floor_div(mid_y, up_y);
186
- int rel_in_x = in_x - tile_in_x;
187
- int rel_in_y = in_y - tile_in_y;
188
- int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
- int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
-
191
- scalar_t v = 0.0;
192
-
193
- #pragma unroll
194
- for (int y = 0; y < kernel_h / up_y; y++)
195
- #pragma unroll
196
- for (int x = 0; x < kernel_w / up_x; x++)
197
- v += sx[rel_in_y + y][rel_in_x + x] *
198
- sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
-
200
- if (out_x < p.out_w & out_y < p.out_h) {
201
- out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
- minor_idx] = v;
203
- }
204
- }
205
- }
206
- }
207
- }
208
-
209
- torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
- const torch::Tensor &kernel, int up_x, int up_y,
211
- int down_x, int down_y, int pad_x0, int pad_x1,
212
- int pad_y0, int pad_y1) {
213
- int curDevice = -1;
214
- cudaGetDevice(&curDevice);
215
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
-
217
- UpFirDn2DKernelParams p;
218
-
219
- auto x = input.contiguous();
220
- auto k = kernel.contiguous();
221
-
222
- p.major_dim = x.size(0);
223
- p.in_h = x.size(1);
224
- p.in_w = x.size(2);
225
- p.minor_dim = x.size(3);
226
- p.kernel_h = k.size(0);
227
- p.kernel_w = k.size(1);
228
- p.up_x = up_x;
229
- p.up_y = up_y;
230
- p.down_x = down_x;
231
- p.down_y = down_y;
232
- p.pad_x0 = pad_x0;
233
- p.pad_x1 = pad_x1;
234
- p.pad_y0 = pad_y0;
235
- p.pad_y1 = pad_y1;
236
-
237
- p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
- p.down_y;
239
- p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
- p.down_x;
241
-
242
- auto out =
243
- at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
-
245
- int mode = -1;
246
-
247
- int tile_out_h = -1;
248
- int tile_out_w = -1;
249
-
250
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
- p.kernel_h <= 4 && p.kernel_w <= 4) {
252
- mode = 1;
253
- tile_out_h = 16;
254
- tile_out_w = 64;
255
- }
256
-
257
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
- p.kernel_h <= 3 && p.kernel_w <= 3) {
259
- mode = 2;
260
- tile_out_h = 16;
261
- tile_out_w = 64;
262
- }
263
-
264
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
- p.kernel_h <= 4 && p.kernel_w <= 4) {
266
- mode = 3;
267
- tile_out_h = 16;
268
- tile_out_w = 64;
269
- }
270
-
271
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
- p.kernel_h <= 2 && p.kernel_w <= 2) {
273
- mode = 4;
274
- tile_out_h = 16;
275
- tile_out_w = 64;
276
- }
277
-
278
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
- p.kernel_h <= 4 && p.kernel_w <= 4) {
280
- mode = 5;
281
- tile_out_h = 8;
282
- tile_out_w = 32;
283
- }
284
-
285
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
- p.kernel_h <= 2 && p.kernel_w <= 2) {
287
- mode = 6;
288
- tile_out_h = 8;
289
- tile_out_w = 32;
290
- }
291
-
292
- dim3 block_size;
293
- dim3 grid_size;
294
-
295
- if (tile_out_h > 0 && tile_out_w > 0) {
296
- p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
- p.loop_x = 1;
298
- block_size = dim3(32 * 8, 1, 1);
299
- grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
- (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
- (p.major_dim - 1) / p.loop_major + 1);
302
- } else {
303
- p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
- p.loop_x = 4;
305
- block_size = dim3(4, 32, 1);
306
- grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
- (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
- (p.major_dim - 1) / p.loop_major + 1);
309
- }
310
-
311
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
- switch (mode) {
313
- case 1:
314
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
- x.data_ptr<scalar_t>(),
317
- k.data_ptr<scalar_t>(), p);
318
-
319
- break;
320
-
321
- case 2:
322
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
- x.data_ptr<scalar_t>(),
325
- k.data_ptr<scalar_t>(), p);
326
-
327
- break;
328
-
329
- case 3:
330
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
- x.data_ptr<scalar_t>(),
333
- k.data_ptr<scalar_t>(), p);
334
-
335
- break;
336
-
337
- case 4:
338
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
- x.data_ptr<scalar_t>(),
341
- k.data_ptr<scalar_t>(), p);
342
-
343
- break;
344
-
345
- case 5:
346
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
- x.data_ptr<scalar_t>(),
349
- k.data_ptr<scalar_t>(), p);
350
-
351
- break;
352
-
353
- case 6:
354
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
- x.data_ptr<scalar_t>(),
357
- k.data_ptr<scalar_t>(), p);
358
-
359
- break;
360
-
361
- default:
362
- upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
- k.data_ptr<scalar_t>(), p);
365
- }
366
- });
367
-
368
- return out;
369
  }
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+
52
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
+ __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
+
57
+ __shared__ volatile float sk[kernel_h][kernel_w];
58
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
59
+
60
+ int minor_idx = blockIdx.x;
61
+ int tile_out_y = minor_idx / p.minor_dim;
62
+ minor_idx -= tile_out_y * p.minor_dim;
63
+ tile_out_y *= tile_out_h;
64
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
+ int major_idx_base = blockIdx.z * p.loop_major;
66
+
67
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
+ return;
69
+ }
70
+
71
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
+ int ky = tap_idx / kernel_w;
73
+ int kx = tap_idx - ky * kernel_w;
74
+ scalar_t v = 0.0;
75
+
76
+ if (kx < p.kernel_w & ky < p.kernel_h) {
77
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
+ }
79
+
80
+ sk[ky][kx] = v;
81
+ }
82
+
83
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
+ int tile_in_x = floor_div(tile_mid_x, up_x);
88
+ int tile_in_y = floor_div(tile_mid_y, up_y);
89
+
90
+ __syncthreads();
91
+
92
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
+ int rel_in_y = in_idx / tile_in_w;
94
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
+ int in_x = rel_in_x + tile_in_x;
96
+ int in_y = rel_in_y + tile_in_y;
97
+
98
+ scalar_t v = 0.0;
99
+
100
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
+ }
103
+
104
+ sx[rel_in_y][rel_in_x] = v;
105
+ }
106
+
107
+ __syncthreads();
108
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
+ int rel_out_y = out_idx / tile_out_w;
110
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
+ int out_x = rel_out_x + tile_out_x;
112
+ int out_y = rel_out_y + tile_out_y;
113
+
114
+ int mid_x = tile_mid_x + rel_out_x * down_x;
115
+ int mid_y = tile_mid_y + rel_out_y * down_y;
116
+ int in_x = floor_div(mid_x, up_x);
117
+ int in_y = floor_div(mid_y, up_y);
118
+ int rel_in_x = in_x - tile_in_x;
119
+ int rel_in_y = in_y - tile_in_y;
120
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
+
123
+ scalar_t v = 0.0;
124
+
125
+ #pragma unroll
126
+ for (int y = 0; y < kernel_h / up_y; y++)
127
+ #pragma unroll
128
+ for (int x = 0; x < kernel_w / up_x; x++)
129
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
+
131
+ if (out_x < p.out_w & out_y < p.out_h) {
132
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+
140
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
+ int up_x, int up_y, int down_x, int down_y,
142
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
+ int curDevice = -1;
144
+ cudaGetDevice(&curDevice);
145
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
+
147
+ UpFirDn2DKernelParams p;
148
+
149
+ auto x = input.contiguous();
150
+ auto k = kernel.contiguous();
151
+
152
+ p.major_dim = x.size(0);
153
+ p.in_h = x.size(1);
154
+ p.in_w = x.size(2);
155
+ p.minor_dim = x.size(3);
156
+ p.kernel_h = k.size(0);
157
+ p.kernel_w = k.size(1);
158
+ p.up_x = up_x;
159
+ p.up_y = up_y;
160
+ p.down_x = down_x;
161
+ p.down_y = down_y;
162
+ p.pad_x0 = pad_x0;
163
+ p.pad_x1 = pad_x1;
164
+ p.pad_y0 = pad_y0;
165
+ p.pad_y1 = pad_y1;
166
+
167
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
+
170
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
+
172
+ int mode = -1;
173
+
174
+ int tile_out_h;
175
+ int tile_out_w;
176
+
177
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
+ mode = 1;
179
+ tile_out_h = 16;
180
+ tile_out_w = 64;
181
+ }
182
+
183
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
+ mode = 2;
185
+ tile_out_h = 16;
186
+ tile_out_w = 64;
187
+ }
188
+
189
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
+ mode = 3;
191
+ tile_out_h = 16;
192
+ tile_out_w = 64;
193
+ }
194
+
195
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
+ mode = 4;
197
+ tile_out_h = 16;
198
+ tile_out_w = 64;
199
+ }
200
+
201
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
+ mode = 5;
203
+ tile_out_h = 8;
204
+ tile_out_w = 32;
205
+ }
206
+
207
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
+ mode = 6;
209
+ tile_out_h = 8;
210
+ tile_out_w = 32;
211
+ }
212
+
213
+ dim3 block_size;
214
+ dim3 grid_size;
215
+
216
+ if (tile_out_h > 0 && tile_out_w) {
217
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
+ p.loop_x = 1;
219
+ block_size = dim3(32 * 8, 1, 1);
220
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
+ (p.major_dim - 1) / p.loop_major + 1);
223
+ }
224
+
225
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
+ switch (mode) {
227
+ case 1:
228
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
+ );
231
+
232
+ break;
233
+
234
+ case 2:
235
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
+ );
238
+
239
+ break;
240
+
241
+ case 3:
242
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
+ );
245
+
246
+ break;
247
+
248
+ case 4:
249
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
+ );
252
+
253
+ break;
254
+
255
+ case 5:
256
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
+ );
259
+
260
+ break;
261
+
262
+ case 6:
263
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
+ );
266
+
267
+ break;
268
+ }
269
+ });
270
+
271
+ return out;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  }