ethanNeuralImage commited on
Commit
9d408b4
1 Parent(s): 34e88c0

changing around sg2/op

Browse files
app.py CHANGED
@@ -27,7 +27,7 @@ from PIL import Image
27
  opts_args = ['--no_fine_mapper']
28
  opts = GradioTestOptions().parse(opts_args)
29
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
- opts.device=device
31
 
32
  mapper_dict = {
33
  'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt',
 
27
  opts_args = ['--no_fine_mapper']
28
  opts = GradioTestOptions().parse(opts_args)
29
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ opts.device= device if opts.device is None else opts.device
31
 
32
  mapper_dict = {
33
  'afro':'./pretrained_models/styleCLIP_mappers/afro_hairstyle.pt',
models/stylegan2/model.py CHANGED
@@ -5,12 +5,8 @@ import torch
5
  from torch import nn
6
  from torch.nn import functional as F
7
 
8
- if torch.cuda.is_available():
9
- from models.stylegan2.op.fused_act import FusedLeakyReLU, fused_leaky_relu
10
- from op.upfirdn2d import upfirdn2d
11
- else:
12
- from models.stylegan2.op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
13
- from models.stylegan2.op.upfirdn2d_cpu import upfirdn2d
14
 
15
  class PixelNorm(nn.Module):
16
  def __init__(self):
 
5
  from torch import nn
6
  from torch.nn import functional as F
7
 
8
+ from models.stylegan2.op.fused_act import FusedLeakyReLU, fused_leaky_relu
9
+ from models.stylegan2.op.upfirdn2d import upfirdn2d
 
 
 
 
10
 
11
  class PixelNorm(nn.Module):
12
  def __init__(self):
models/stylegan2/op/conv2d_gradfix.py DELETED
@@ -1,227 +0,0 @@
1
- import contextlib
2
- import warnings
3
-
4
- import torch
5
- from torch import autograd
6
- from torch.nn import functional as F
7
-
8
- enabled = True
9
- weight_gradients_disabled = False
10
-
11
-
12
- @contextlib.contextmanager
13
- def no_weight_gradients():
14
- global weight_gradients_disabled
15
-
16
- old = weight_gradients_disabled
17
- weight_gradients_disabled = True
18
- yield
19
- weight_gradients_disabled = old
20
-
21
-
22
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
- if could_use_op(input):
24
- return conv2d_gradfix(
25
- transpose=False,
26
- weight_shape=weight.shape,
27
- stride=stride,
28
- padding=padding,
29
- output_padding=0,
30
- dilation=dilation,
31
- groups=groups,
32
- ).apply(input, weight, bias)
33
-
34
- return F.conv2d(
35
- input=input,
36
- weight=weight,
37
- bias=bias,
38
- stride=stride,
39
- padding=padding,
40
- dilation=dilation,
41
- groups=groups,
42
- )
43
-
44
-
45
- def conv_transpose2d(
46
- input,
47
- weight,
48
- bias=None,
49
- stride=1,
50
- padding=0,
51
- output_padding=0,
52
- groups=1,
53
- dilation=1,
54
- ):
55
- if could_use_op(input):
56
- return conv2d_gradfix(
57
- transpose=True,
58
- weight_shape=weight.shape,
59
- stride=stride,
60
- padding=padding,
61
- output_padding=output_padding,
62
- groups=groups,
63
- dilation=dilation,
64
- ).apply(input, weight, bias)
65
-
66
- return F.conv_transpose2d(
67
- input=input,
68
- weight=weight,
69
- bias=bias,
70
- stride=stride,
71
- padding=padding,
72
- output_padding=output_padding,
73
- dilation=dilation,
74
- groups=groups,
75
- )
76
-
77
-
78
- def could_use_op(input):
79
- if (not enabled) or (not torch.backends.cudnn.enabled):
80
- return False
81
-
82
- if input.device.type != "cuda":
83
- return False
84
-
85
- if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
- return True
87
-
88
- warnings.warn(
89
- f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
- )
91
-
92
- return False
93
-
94
-
95
- def ensure_tuple(xs, ndim):
96
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
-
98
- return xs
99
-
100
-
101
- conv2d_gradfix_cache = dict()
102
-
103
-
104
- def conv2d_gradfix(
105
- transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
- ):
107
- ndim = 2
108
- weight_shape = tuple(weight_shape)
109
- stride = ensure_tuple(stride, ndim)
110
- padding = ensure_tuple(padding, ndim)
111
- output_padding = ensure_tuple(output_padding, ndim)
112
- dilation = ensure_tuple(dilation, ndim)
113
-
114
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
- if key in conv2d_gradfix_cache:
116
- return conv2d_gradfix_cache[key]
117
-
118
- common_kwargs = dict(
119
- stride=stride, padding=padding, dilation=dilation, groups=groups
120
- )
121
-
122
- def calc_output_padding(input_shape, output_shape):
123
- if transpose:
124
- return [0, 0]
125
-
126
- return [
127
- input_shape[i + 2]
128
- - (output_shape[i + 2] - 1) * stride[i]
129
- - (1 - 2 * padding[i])
130
- - dilation[i] * (weight_shape[i + 2] - 1)
131
- for i in range(ndim)
132
- ]
133
-
134
- class Conv2d(autograd.Function):
135
- @staticmethod
136
- def forward(ctx, input, weight, bias):
137
- if not transpose:
138
- out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
-
140
- else:
141
- out = F.conv_transpose2d(
142
- input=input,
143
- weight=weight,
144
- bias=bias,
145
- output_padding=output_padding,
146
- **common_kwargs,
147
- )
148
-
149
- ctx.save_for_backward(input, weight)
150
-
151
- return out
152
-
153
- @staticmethod
154
- def backward(ctx, grad_output):
155
- input, weight = ctx.saved_tensors
156
- grad_input, grad_weight, grad_bias = None, None, None
157
-
158
- if ctx.needs_input_grad[0]:
159
- p = calc_output_padding(
160
- input_shape=input.shape, output_shape=grad_output.shape
161
- )
162
- grad_input = conv2d_gradfix(
163
- transpose=(not transpose),
164
- weight_shape=weight_shape,
165
- output_padding=p,
166
- **common_kwargs,
167
- ).apply(grad_output, weight, None)
168
-
169
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
-
172
- if ctx.needs_input_grad[2]:
173
- grad_bias = grad_output.sum((0, 2, 3))
174
-
175
- return grad_input, grad_weight, grad_bias
176
-
177
- class Conv2dGradWeight(autograd.Function):
178
- @staticmethod
179
- def forward(ctx, grad_output, input):
180
- op = torch._C._jit_get_operation(
181
- "aten::cudnn_convolution_backward_weight"
182
- if not transpose
183
- else "aten::cudnn_convolution_transpose_backward_weight"
184
- )
185
- flags = [
186
- torch.backends.cudnn.benchmark,
187
- torch.backends.cudnn.deterministic,
188
- torch.backends.cudnn.allow_tf32,
189
- ]
190
- grad_weight = op(
191
- weight_shape,
192
- grad_output,
193
- input,
194
- padding,
195
- stride,
196
- dilation,
197
- groups,
198
- *flags,
199
- )
200
- ctx.save_for_backward(grad_output, input)
201
-
202
- return grad_weight
203
-
204
- @staticmethod
205
- def backward(ctx, grad_grad_weight):
206
- grad_output, input = ctx.saved_tensors
207
- grad_grad_output, grad_grad_input = None, None
208
-
209
- if ctx.needs_input_grad[0]:
210
- grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
-
212
- if ctx.needs_input_grad[1]:
213
- p = calc_output_padding(
214
- input_shape=input.shape, output_shape=grad_output.shape
215
- )
216
- grad_grad_input = conv2d_gradfix(
217
- transpose=(not transpose),
218
- weight_shape=weight_shape,
219
- output_padding=p,
220
- **common_kwargs,
221
- ).apply(grad_output, grad_grad_weight, None)
222
-
223
- return grad_grad_output, grad_grad_input
224
-
225
- conv2d_gradfix_cache[key] = Conv2d
226
-
227
- return Conv2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/stylegan2/op/fused_act.py CHANGED
@@ -2,72 +2,10 @@ import os
2
 
3
  import torch
4
  from torch import nn
5
- from torch.autograd import Function
6
- from torch.utils.cpp_extension import load
7
-
8
 
9
  module_path = os.path.dirname(__file__)
10
- fused = load(
11
- 'fused',
12
- sources=[
13
- os.path.join(module_path, 'fused_bias_act.cpp'),
14
- os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15
- ],
16
- )
17
-
18
-
19
- class FusedLeakyReLUFunctionBackward(Function):
20
- @staticmethod
21
- def forward(ctx, grad_output, out, negative_slope, scale):
22
- ctx.save_for_backward(out)
23
- ctx.negative_slope = negative_slope
24
- ctx.scale = scale
25
-
26
- empty = grad_output.new_empty(0)
27
-
28
- grad_input = fused.fused_bias_act(
29
- grad_output, empty, out, 3, 1, negative_slope, scale
30
- )
31
-
32
- dim = [0]
33
-
34
- if grad_input.ndim > 2:
35
- dim += list(range(2, grad_input.ndim))
36
-
37
- grad_bias = grad_input.sum(dim).detach()
38
-
39
- return grad_input, grad_bias
40
-
41
- @staticmethod
42
- def backward(ctx, gradgrad_input, gradgrad_bias):
43
- out, = ctx.saved_tensors
44
- gradgrad_out = fused.fused_bias_act(
45
- gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
46
- )
47
-
48
- return gradgrad_out, None, None, None
49
-
50
 
51
- class FusedLeakyReLUFunction(Function):
52
- @staticmethod
53
- def forward(ctx, input, bias, negative_slope, scale):
54
- empty = input.new_empty(0)
55
- out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
56
- ctx.save_for_backward(out)
57
- ctx.negative_slope = negative_slope
58
- ctx.scale = scale
59
-
60
- return out
61
-
62
- @staticmethod
63
- def backward(ctx, grad_output):
64
- out, = ctx.saved_tensors
65
-
66
- grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
67
- grad_output, out, ctx.negative_slope, ctx.scale
68
- )
69
-
70
- return grad_input, grad_bias, None, None
71
 
72
 
73
  class FusedLeakyReLU(nn.Module):
@@ -83,4 +21,20 @@ class FusedLeakyReLU(nn.Module):
83
 
84
 
85
  def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from torch import nn
5
+ from torch.nn import functional as F
 
 
6
 
7
  module_path = os.path.dirname(__file__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class FusedLeakyReLU(nn.Module):
 
21
 
22
 
23
  def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
24
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
25
+ input = input.cuda()
26
+ if input.ndim == 3:
27
+ return (
28
+ F.leaky_relu(
29
+ input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
30
+ )
31
+ * scale
32
+ )
33
+ else:
34
+ return (
35
+ F.leaky_relu(
36
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
37
+ )
38
+ * scale
39
+ )
40
+
models/stylegan2/op/upfirdn2d.py CHANGED
@@ -1,149 +1,16 @@
1
  import os
2
 
3
  import torch
4
- from torch.autograd import Function
5
- from torch.utils.cpp_extension import load
6
 
7
 
8
  module_path = os.path.dirname(__file__)
9
- upfirdn2d_op = load(
10
- 'upfirdn2d',
11
- sources=[
12
- os.path.join(module_path, 'upfirdn2d.cpp'),
13
- os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14
- ],
15
- )
16
 
17
 
18
- class UpFirDn2dBackward(Function):
19
- @staticmethod
20
- def forward(
21
- ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
22
- ):
23
-
24
- up_x, up_y = up
25
- down_x, down_y = down
26
- g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27
-
28
- grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29
-
30
- grad_input = upfirdn2d_op.upfirdn2d(
31
- grad_output,
32
- grad_kernel,
33
- down_x,
34
- down_y,
35
- up_x,
36
- up_y,
37
- g_pad_x0,
38
- g_pad_x1,
39
- g_pad_y0,
40
- g_pad_y1,
41
- )
42
- grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43
-
44
- ctx.save_for_backward(kernel)
45
-
46
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
47
-
48
- ctx.up_x = up_x
49
- ctx.up_y = up_y
50
- ctx.down_x = down_x
51
- ctx.down_y = down_y
52
- ctx.pad_x0 = pad_x0
53
- ctx.pad_x1 = pad_x1
54
- ctx.pad_y0 = pad_y0
55
- ctx.pad_y1 = pad_y1
56
- ctx.in_size = in_size
57
- ctx.out_size = out_size
58
-
59
- return grad_input
60
-
61
- @staticmethod
62
- def backward(ctx, gradgrad_input):
63
- kernel, = ctx.saved_tensors
64
-
65
- gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66
-
67
- gradgrad_out = upfirdn2d_op.upfirdn2d(
68
- gradgrad_input,
69
- kernel,
70
- ctx.up_x,
71
- ctx.up_y,
72
- ctx.down_x,
73
- ctx.down_y,
74
- ctx.pad_x0,
75
- ctx.pad_x1,
76
- ctx.pad_y0,
77
- ctx.pad_y1,
78
- )
79
- # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80
- gradgrad_out = gradgrad_out.view(
81
- ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82
- )
83
-
84
- return gradgrad_out, None, None, None, None, None, None, None, None
85
-
86
-
87
- class UpFirDn2d(Function):
88
- @staticmethod
89
- def forward(ctx, input, kernel, up, down, pad):
90
- up_x, up_y = up
91
- down_x, down_y = down
92
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
93
-
94
- kernel_h, kernel_w = kernel.shape
95
- batch, channel, in_h, in_w = input.shape
96
- ctx.in_size = input.shape
97
-
98
- input = input.reshape(-1, in_h, in_w, 1)
99
-
100
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101
-
102
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104
- ctx.out_size = (out_h, out_w)
105
-
106
- ctx.up = (up_x, up_y)
107
- ctx.down = (down_x, down_y)
108
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109
-
110
- g_pad_x0 = kernel_w - pad_x0 - 1
111
- g_pad_y0 = kernel_h - pad_y0 - 1
112
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114
-
115
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116
-
117
- out = upfirdn2d_op.upfirdn2d(
118
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119
- )
120
- # out = out.view(major, out_h, out_w, minor)
121
- out = out.view(-1, channel, out_h, out_w)
122
-
123
- return out
124
-
125
- @staticmethod
126
- def backward(ctx, grad_output):
127
- kernel, grad_kernel = ctx.saved_tensors
128
-
129
- grad_input = UpFirDn2dBackward.apply(
130
- grad_output,
131
- kernel,
132
- grad_kernel,
133
- ctx.up,
134
- ctx.down,
135
- ctx.pad,
136
- ctx.g_pad,
137
- ctx.in_size,
138
- ctx.out_size,
139
- )
140
-
141
- return grad_input, None, None, None, None
142
-
143
 
144
  def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
- out = UpFirDn2d.apply(
146
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147
  )
148
 
149
  return out
@@ -152,6 +19,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
152
  def upfirdn2d_native(
153
  input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
154
  ):
 
 
 
155
  _, in_h, in_w, minor = input.shape
156
  kernel_h, kernel_w = kernel.shape
157
 
@@ -182,6 +52,9 @@ def upfirdn2d_native(
182
  in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
183
  )
184
  out = out.permute(0, 2, 3, 1)
 
185
 
186
- return out[:, ::down_y, ::down_x, :]
 
187
 
 
 
1
  import os
2
 
3
  import torch
4
+ from torch.nn import functional as F
 
5
 
6
 
7
  module_path = os.path.dirname(__file__)
 
 
 
 
 
 
 
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
  )
15
 
16
  return out
 
19
  def upfirdn2d_native(
20
  input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
  ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
  _, in_h, in_w, minor = input.shape
26
  kernel_h, kernel_w = kernel.shape
27
 
 
52
  in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
  )
54
  out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
 
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
59
 
60
+ return out.view(-1, channel, out_h, out_w)