Мясников Филипп Сергеевич commited on
Commit
cc78303
1 Parent(s): 45e462d
app.py CHANGED
@@ -15,9 +15,6 @@ from tqdm import tqdm
15
  import lpips
16
  import time
17
 
18
-
19
- #from e4e_projection import projection as e4e_projection
20
-
21
  from copy import deepcopy
22
  import imageio
23
 
 
15
  import lpips
16
  import time
17
 
 
 
 
18
  from copy import deepcopy
19
  import imageio
20
 
e4e_projection.py DELETED
@@ -1,38 +0,0 @@
1
- import os
2
- import sys
3
- import numpy as np
4
- from PIL import Image
5
- import torch
6
- import torchvision.transforms as transforms
7
- from argparse import Namespace
8
- from e4e.models.psp import pSp
9
- from util import *
10
-
11
-
12
-
13
- @ torch.no_grad()
14
- def projection(img, name, device='cuda'):
15
-
16
-
17
- model_path = 'e4e_ffhq_encode.pt'
18
- ckpt = torch.load(model_path, map_location='cpu')
19
- opts = ckpt['opts']
20
- opts['checkpoint_path'] = model_path
21
- opts= Namespace(**opts)
22
- net = pSp(opts, device).eval().to(device)
23
-
24
- transform = transforms.Compose(
25
- [
26
- transforms.Resize(256),
27
- transforms.CenterCrop(256),
28
- transforms.ToTensor(),
29
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
30
- ]
31
- )
32
-
33
- img = transform(img).unsqueeze(0).to(device)
34
- images, w_plus = net(img, randomize_noise=False, return_latents=True)
35
- result_file = {}
36
- result_file['latent'] = w_plus[0]
37
- torch.save(result_file, name)
38
- return w_plus[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/__init__.py DELETED
File without changes
op/__pycache__/__init__.cpython-37.pyc DELETED
Binary file (227 Bytes)
 
op/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (231 Bytes)
 
op/__pycache__/conv2d_gradfix.cpython-37.pyc DELETED
Binary file (5.23 kB)
 
op/__pycache__/conv2d_gradfix.cpython-38.pyc DELETED
Binary file (5.3 kB)
 
op/__pycache__/fused_act.cpython-37.pyc DELETED
Binary file (2.78 kB)
 
op/__pycache__/fused_act.cpython-38.pyc DELETED
Binary file (2.84 kB)
 
op/__pycache__/upfirdn2d.cpython-37.pyc DELETED
Binary file (3.82 kB)
 
op/__pycache__/upfirdn2d.cpython-38.pyc DELETED
Binary file (3.9 kB)
 
op/conv2d_gradfix.py DELETED
@@ -1,227 +0,0 @@
1
- import contextlib
2
- import warnings
3
-
4
- import torch
5
- from torch import autograd
6
- from torch.nn import functional as F
7
-
8
- enabled = True
9
- weight_gradients_disabled = False
10
-
11
-
12
- @contextlib.contextmanager
13
- def no_weight_gradients():
14
- global weight_gradients_disabled
15
-
16
- old = weight_gradients_disabled
17
- weight_gradients_disabled = True
18
- yield
19
- weight_gradients_disabled = old
20
-
21
-
22
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
- if could_use_op(input):
24
- return conv2d_gradfix(
25
- transpose=False,
26
- weight_shape=weight.shape,
27
- stride=stride,
28
- padding=padding,
29
- output_padding=0,
30
- dilation=dilation,
31
- groups=groups,
32
- ).apply(input, weight, bias)
33
-
34
- return F.conv2d(
35
- input=input,
36
- weight=weight,
37
- bias=bias,
38
- stride=stride,
39
- padding=padding,
40
- dilation=dilation,
41
- groups=groups,
42
- )
43
-
44
-
45
- def conv_transpose2d(
46
- input,
47
- weight,
48
- bias=None,
49
- stride=1,
50
- padding=0,
51
- output_padding=0,
52
- groups=1,
53
- dilation=1,
54
- ):
55
- if could_use_op(input):
56
- return conv2d_gradfix(
57
- transpose=True,
58
- weight_shape=weight.shape,
59
- stride=stride,
60
- padding=padding,
61
- output_padding=output_padding,
62
- groups=groups,
63
- dilation=dilation,
64
- ).apply(input, weight, bias)
65
-
66
- return F.conv_transpose2d(
67
- input=input,
68
- weight=weight,
69
- bias=bias,
70
- stride=stride,
71
- padding=padding,
72
- output_padding=output_padding,
73
- dilation=dilation,
74
- groups=groups,
75
- )
76
-
77
-
78
- def could_use_op(input):
79
- if (not enabled) or (not torch.backends.cudnn.enabled):
80
- return False
81
-
82
- if input.device.type != "cuda":
83
- return False
84
-
85
- if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
- return True
87
-
88
- warnings.warn(
89
- f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
- )
91
-
92
- return False
93
-
94
-
95
- def ensure_tuple(xs, ndim):
96
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
-
98
- return xs
99
-
100
-
101
- conv2d_gradfix_cache = dict()
102
-
103
-
104
- def conv2d_gradfix(
105
- transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
- ):
107
- ndim = 2
108
- weight_shape = tuple(weight_shape)
109
- stride = ensure_tuple(stride, ndim)
110
- padding = ensure_tuple(padding, ndim)
111
- output_padding = ensure_tuple(output_padding, ndim)
112
- dilation = ensure_tuple(dilation, ndim)
113
-
114
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
- if key in conv2d_gradfix_cache:
116
- return conv2d_gradfix_cache[key]
117
-
118
- common_kwargs = dict(
119
- stride=stride, padding=padding, dilation=dilation, groups=groups
120
- )
121
-
122
- def calc_output_padding(input_shape, output_shape):
123
- if transpose:
124
- return [0, 0]
125
-
126
- return [
127
- input_shape[i + 2]
128
- - (output_shape[i + 2] - 1) * stride[i]
129
- - (1 - 2 * padding[i])
130
- - dilation[i] * (weight_shape[i + 2] - 1)
131
- for i in range(ndim)
132
- ]
133
-
134
- class Conv2d(autograd.Function):
135
- @staticmethod
136
- def forward(ctx, input, weight, bias):
137
- if not transpose:
138
- out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
-
140
- else:
141
- out = F.conv_transpose2d(
142
- input=input,
143
- weight=weight,
144
- bias=bias,
145
- output_padding=output_padding,
146
- **common_kwargs,
147
- )
148
-
149
- ctx.save_for_backward(input, weight)
150
-
151
- return out
152
-
153
- @staticmethod
154
- def backward(ctx, grad_output):
155
- input, weight = ctx.saved_tensors
156
- grad_input, grad_weight, grad_bias = None, None, None
157
-
158
- if ctx.needs_input_grad[0]:
159
- p = calc_output_padding(
160
- input_shape=input.shape, output_shape=grad_output.shape
161
- )
162
- grad_input = conv2d_gradfix(
163
- transpose=(not transpose),
164
- weight_shape=weight_shape,
165
- output_padding=p,
166
- **common_kwargs,
167
- ).apply(grad_output, weight, None)
168
-
169
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
-
172
- if ctx.needs_input_grad[2]:
173
- grad_bias = grad_output.sum((0, 2, 3))
174
-
175
- return grad_input, grad_weight, grad_bias
176
-
177
- class Conv2dGradWeight(autograd.Function):
178
- @staticmethod
179
- def forward(ctx, grad_output, input):
180
- op = torch._C._jit_get_operation(
181
- "aten::cudnn_convolution_backward_weight"
182
- if not transpose
183
- else "aten::cudnn_convolution_transpose_backward_weight"
184
- )
185
- flags = [
186
- torch.backends.cudnn.benchmark,
187
- torch.backends.cudnn.deterministic,
188
- torch.backends.cudnn.allow_tf32,
189
- ]
190
- grad_weight = op(
191
- weight_shape,
192
- grad_output,
193
- input,
194
- padding,
195
- stride,
196
- dilation,
197
- groups,
198
- *flags,
199
- )
200
- ctx.save_for_backward(grad_output, input)
201
-
202
- return grad_weight
203
-
204
- @staticmethod
205
- def backward(ctx, grad_grad_weight):
206
- grad_output, input = ctx.saved_tensors
207
- grad_grad_output, grad_grad_input = None, None
208
-
209
- if ctx.needs_input_grad[0]:
210
- grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
-
212
- if ctx.needs_input_grad[1]:
213
- p = calc_output_padding(
214
- input_shape=input.shape, output_shape=grad_output.shape
215
- )
216
- grad_grad_input = conv2d_gradfix(
217
- transpose=(not transpose),
218
- weight_shape=weight_shape,
219
- output_padding=p,
220
- **common_kwargs,
221
- ).apply(grad_output, grad_grad_weight, None)
222
-
223
- return grad_grad_output, grad_grad_input
224
-
225
- conv2d_gradfix_cache[key] = Conv2d
226
-
227
- return Conv2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_act.py DELETED
@@ -1,86 +0,0 @@
1
- import os
2
-
3
- import torch
4
- from torch import nn
5
- from torch.autograd import Function
6
- from torch.utils.cpp_extension import load
7
-
8
-
9
- module_path = os.path.dirname(__file__)
10
- fused = load(
11
- 'fused',
12
- sources=[
13
- os.path.join(module_path, 'fused_bias_act.cpp'),
14
- os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15
- ],
16
- )
17
-
18
-
19
- class FusedLeakyReLUFunctionBackward(Function):
20
- @staticmethod
21
- def forward(ctx, grad_output, out, negative_slope, scale):
22
- ctx.save_for_backward(out)
23
- ctx.negative_slope = negative_slope
24
- ctx.scale = scale
25
-
26
- empty = grad_output.new_empty(0)
27
-
28
- grad_input = fused.fused_bias_act(
29
- grad_output, empty, out, 3, 1, negative_slope, scale
30
- )
31
-
32
- dim = [0]
33
-
34
- if grad_input.ndim > 2:
35
- dim += list(range(2, grad_input.ndim))
36
-
37
- grad_bias = grad_input.sum(dim).detach()
38
-
39
- return grad_input, grad_bias
40
-
41
- @staticmethod
42
- def backward(ctx, gradgrad_input, gradgrad_bias):
43
- out, = ctx.saved_tensors
44
- gradgrad_out = fused.fused_bias_act(
45
- gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
46
- )
47
-
48
- return gradgrad_out, None, None, None
49
-
50
-
51
- class FusedLeakyReLUFunction(Function):
52
- @staticmethod
53
- def forward(ctx, input, bias, negative_slope, scale):
54
- empty = input.new_empty(0)
55
- out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
56
- ctx.save_for_backward(out)
57
- ctx.negative_slope = negative_slope
58
- ctx.scale = scale
59
-
60
- return out
61
-
62
- @staticmethod
63
- def backward(ctx, grad_output):
64
- out, = ctx.saved_tensors
65
-
66
- grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
67
- grad_output, out, ctx.negative_slope, ctx.scale
68
- )
69
-
70
- return grad_input, grad_bias, None, None
71
-
72
-
73
- class FusedLeakyReLU(nn.Module):
74
- def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
75
- super().__init__()
76
-
77
- self.bias = nn.Parameter(torch.zeros(channel))
78
- self.negative_slope = negative_slope
79
- self.scale = scale
80
-
81
- def forward(self, input):
82
- return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
83
-
84
-
85
- def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_act_cpu.py DELETED
@@ -1,41 +0,0 @@
1
- import os
2
-
3
- import torch
4
- from torch import nn
5
- from torch.autograd import Function
6
- from torch.nn import functional as F
7
-
8
-
9
- module_path = os.path.dirname(__file__)
10
-
11
-
12
- class FusedLeakyReLU(nn.Module):
13
- def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
14
- super().__init__()
15
-
16
- self.bias = nn.Parameter(torch.zeros(channel))
17
- self.negative_slope = negative_slope
18
- self.scale = scale
19
-
20
- def forward(self, input):
21
- return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
22
-
23
- def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
24
- if input.device.type == "cpu":
25
- if bias is not None:
26
- rest_dim = [1] * (input.ndim - bias.ndim - 1)
27
- return (
28
- F.leaky_relu(
29
- input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
30
- )
31
- * scale
32
- )
33
-
34
- else:
35
- return F.leaky_relu(input, negative_slope=0.2) * scale
36
-
37
- else:
38
- return FusedLeakyReLUFunction.apply(
39
- input.contiguous(), bias, negative_slope, scale
40
- )
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_bias_act.cpp DELETED
@@ -1,21 +0,0 @@
1
- #include <torch/extension.h>
2
-
3
-
4
- torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
- int act, int grad, float alpha, float scale);
6
-
7
- #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
- #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
-
11
- torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
- int act, int grad, float alpha, float scale) {
13
- CHECK_CUDA(input);
14
- CHECK_CUDA(bias);
15
-
16
- return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
- }
18
-
19
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
- m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/fused_bias_act_kernel.cu DELETED
@@ -1,99 +0,0 @@
1
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
- //
3
- // This work is made available under the Nvidia Source Code License-NC.
4
- // To view a copy of this license, visit
5
- // https://nvlabs.github.io/stylegan2/license.html
6
-
7
- #include <torch/types.h>
8
-
9
- #include <ATen/ATen.h>
10
- #include <ATen/AccumulateType.h>
11
- #include <ATen/cuda/CUDAContext.h>
12
- #include <ATen/cuda/CUDAApplyUtils.cuh>
13
-
14
- #include <cuda.h>
15
- #include <cuda_runtime.h>
16
-
17
-
18
- template <typename scalar_t>
19
- static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
- int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
- int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
-
23
- scalar_t zero = 0.0;
24
-
25
- for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
- scalar_t x = p_x[xi];
27
-
28
- if (use_bias) {
29
- x += p_b[(xi / step_b) % size_b];
30
- }
31
-
32
- scalar_t ref = use_ref ? p_ref[xi] : zero;
33
-
34
- scalar_t y;
35
-
36
- switch (act * 10 + grad) {
37
- default:
38
- case 10: y = x; break;
39
- case 11: y = x; break;
40
- case 12: y = 0.0; break;
41
-
42
- case 30: y = (x > 0.0) ? x : x * alpha; break;
43
- case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
- case 32: y = 0.0; break;
45
- }
46
-
47
- out[xi] = y * scale;
48
- }
49
- }
50
-
51
-
52
- torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
- int act, int grad, float alpha, float scale) {
54
- int curDevice = -1;
55
- cudaGetDevice(&curDevice);
56
- cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
-
58
- auto x = input.contiguous();
59
- auto b = bias.contiguous();
60
- auto ref = refer.contiguous();
61
-
62
- int use_bias = b.numel() ? 1 : 0;
63
- int use_ref = ref.numel() ? 1 : 0;
64
-
65
- int size_x = x.numel();
66
- int size_b = b.numel();
67
- int step_b = 1;
68
-
69
- for (int i = 1 + 1; i < x.dim(); i++) {
70
- step_b *= x.size(i);
71
- }
72
-
73
- int loop_x = 4;
74
- int block_size = 4 * 32;
75
- int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
-
77
- auto y = torch::empty_like(x);
78
-
79
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
- fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
- y.data_ptr<scalar_t>(),
82
- x.data_ptr<scalar_t>(),
83
- b.data_ptr<scalar_t>(),
84
- ref.data_ptr<scalar_t>(),
85
- act,
86
- grad,
87
- alpha,
88
- scale,
89
- loop_x,
90
- size_x,
91
- step_b,
92
- size_b,
93
- use_bias,
94
- use_ref
95
- );
96
- });
97
-
98
- return y;
99
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d.cpp DELETED
@@ -1,23 +0,0 @@
1
- #include <torch/extension.h>
2
-
3
-
4
- torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
- int up_x, int up_y, int down_x, int down_y,
6
- int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
-
8
- #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
- #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
-
12
- torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
- int up_x, int up_y, int down_x, int down_y,
14
- int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
- CHECK_CUDA(input);
16
- CHECK_CUDA(kernel);
17
-
18
- return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
- }
20
-
21
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
- m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d.py DELETED
@@ -1,187 +0,0 @@
1
- import os
2
-
3
- import torch
4
- from torch.autograd import Function
5
- from torch.utils.cpp_extension import load
6
-
7
-
8
- module_path = os.path.dirname(__file__)
9
- upfirdn2d_op = load(
10
- 'upfirdn2d',
11
- sources=[
12
- os.path.join(module_path, 'upfirdn2d.cpp'),
13
- os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14
- ],
15
- )
16
-
17
-
18
- class UpFirDn2dBackward(Function):
19
- @staticmethod
20
- def forward(
21
- ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
22
- ):
23
-
24
- up_x, up_y = up
25
- down_x, down_y = down
26
- g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27
-
28
- grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29
-
30
- grad_input = upfirdn2d_op.upfirdn2d(
31
- grad_output,
32
- grad_kernel,
33
- down_x,
34
- down_y,
35
- up_x,
36
- up_y,
37
- g_pad_x0,
38
- g_pad_x1,
39
- g_pad_y0,
40
- g_pad_y1,
41
- )
42
- grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43
-
44
- ctx.save_for_backward(kernel)
45
-
46
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
47
-
48
- ctx.up_x = up_x
49
- ctx.up_y = up_y
50
- ctx.down_x = down_x
51
- ctx.down_y = down_y
52
- ctx.pad_x0 = pad_x0
53
- ctx.pad_x1 = pad_x1
54
- ctx.pad_y0 = pad_y0
55
- ctx.pad_y1 = pad_y1
56
- ctx.in_size = in_size
57
- ctx.out_size = out_size
58
-
59
- return grad_input
60
-
61
- @staticmethod
62
- def backward(ctx, gradgrad_input):
63
- kernel, = ctx.saved_tensors
64
-
65
- gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66
-
67
- gradgrad_out = upfirdn2d_op.upfirdn2d(
68
- gradgrad_input,
69
- kernel,
70
- ctx.up_x,
71
- ctx.up_y,
72
- ctx.down_x,
73
- ctx.down_y,
74
- ctx.pad_x0,
75
- ctx.pad_x1,
76
- ctx.pad_y0,
77
- ctx.pad_y1,
78
- )
79
- # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80
- gradgrad_out = gradgrad_out.view(
81
- ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82
- )
83
-
84
- return gradgrad_out, None, None, None, None, None, None, None, None
85
-
86
-
87
- class UpFirDn2d(Function):
88
- @staticmethod
89
- def forward(ctx, input, kernel, up, down, pad):
90
- up_x, up_y = up
91
- down_x, down_y = down
92
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
93
-
94
- kernel_h, kernel_w = kernel.shape
95
- batch, channel, in_h, in_w = input.shape
96
- ctx.in_size = input.shape
97
-
98
- input = input.reshape(-1, in_h, in_w, 1)
99
-
100
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101
-
102
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104
- ctx.out_size = (out_h, out_w)
105
-
106
- ctx.up = (up_x, up_y)
107
- ctx.down = (down_x, down_y)
108
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109
-
110
- g_pad_x0 = kernel_w - pad_x0 - 1
111
- g_pad_y0 = kernel_h - pad_y0 - 1
112
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114
-
115
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116
-
117
- out = upfirdn2d_op.upfirdn2d(
118
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119
- )
120
- # out = out.view(major, out_h, out_w, minor)
121
- out = out.view(-1, channel, out_h, out_w)
122
-
123
- return out
124
-
125
- @staticmethod
126
- def backward(ctx, grad_output):
127
- kernel, grad_kernel = ctx.saved_tensors
128
-
129
- grad_input = UpFirDn2dBackward.apply(
130
- grad_output,
131
- kernel,
132
- grad_kernel,
133
- ctx.up,
134
- ctx.down,
135
- ctx.pad,
136
- ctx.g_pad,
137
- ctx.in_size,
138
- ctx.out_size,
139
- )
140
-
141
- return grad_input, None, None, None, None
142
-
143
-
144
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
- out = UpFirDn2d.apply(
146
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147
- )
148
-
149
- return out
150
-
151
-
152
- def upfirdn2d_native(
153
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
154
- ):
155
- _, in_h, in_w, minor = input.shape
156
- kernel_h, kernel_w = kernel.shape
157
-
158
- out = input.view(-1, in_h, 1, in_w, 1, minor)
159
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
160
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
161
-
162
- out = F.pad(
163
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
164
- )
165
- out = out[
166
- :,
167
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
168
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
169
- :,
170
- ]
171
-
172
- out = out.permute(0, 3, 1, 2)
173
- out = out.reshape(
174
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
175
- )
176
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
177
- out = F.conv2d(out, w)
178
- out = out.reshape(
179
- -1,
180
- minor,
181
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
182
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
183
- )
184
- out = out.permute(0, 2, 3, 1)
185
-
186
- return out[:, ::down_y, ::down_x, :]
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d_cpu.py DELETED
@@ -1,60 +0,0 @@
1
- import os
2
-
3
- import torch
4
- from torch.autograd import Function
5
- from torch.nn import functional as F
6
-
7
-
8
-
9
- module_path = os.path.dirname(__file__)
10
-
11
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
- out = upfirdn2d_native(
13
- input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
- )
15
-
16
- return out
17
-
18
-
19
- def upfirdn2d_native(
20
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
- ):
22
- _, channel, in_h, in_w = input.shape
23
- input = input.reshape(-1, in_h, in_w, 1)
24
-
25
- _, in_h, in_w, minor = input.shape
26
- kernel_h, kernel_w = kernel.shape
27
-
28
- out = input.view(-1, in_h, 1, in_w, 1, minor)
29
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
-
32
- out = F.pad(
33
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
- )
35
- out = out[
36
- :,
37
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
- :,
40
- ]
41
-
42
- out = out.permute(0, 3, 1, 2)
43
- out = out.reshape(
44
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
- )
46
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
- out = F.conv2d(out, w)
48
- out = out.reshape(
49
- -1,
50
- minor,
51
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
- )
54
- out = out.permute(0, 2, 3, 1)
55
- out = out[:, ::down_y, ::down_x, :]
56
-
57
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
58
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
59
-
60
- return out.view(-1, channel, out_h, out_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
op/upfirdn2d_kernel.cu DELETED
@@ -1,272 +0,0 @@
1
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
- //
3
- // This work is made available under the Nvidia Source Code License-NC.
4
- // To view a copy of this license, visit
5
- // https://nvlabs.github.io/stylegan2/license.html
6
-
7
- #include <torch/types.h>
8
-
9
- #include <ATen/ATen.h>
10
- #include <ATen/AccumulateType.h>
11
- #include <ATen/cuda/CUDAContext.h>
12
- #include <ATen/cuda/CUDAApplyUtils.cuh>
13
-
14
- #include <cuda.h>
15
- #include <cuda_runtime.h>
16
-
17
-
18
- static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
- int c = a / b;
20
-
21
- if (c * b > a) {
22
- c--;
23
- }
24
-
25
- return c;
26
- }
27
-
28
-
29
- struct UpFirDn2DKernelParams {
30
- int up_x;
31
- int up_y;
32
- int down_x;
33
- int down_y;
34
- int pad_x0;
35
- int pad_x1;
36
- int pad_y0;
37
- int pad_y1;
38
-
39
- int major_dim;
40
- int in_h;
41
- int in_w;
42
- int minor_dim;
43
- int kernel_h;
44
- int kernel_w;
45
- int out_h;
46
- int out_w;
47
- int loop_major;
48
- int loop_x;
49
- };
50
-
51
-
52
- template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
- __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
- const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
- const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
-
57
- __shared__ volatile float sk[kernel_h][kernel_w];
58
- __shared__ volatile float sx[tile_in_h][tile_in_w];
59
-
60
- int minor_idx = blockIdx.x;
61
- int tile_out_y = minor_idx / p.minor_dim;
62
- minor_idx -= tile_out_y * p.minor_dim;
63
- tile_out_y *= tile_out_h;
64
- int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
- int major_idx_base = blockIdx.z * p.loop_major;
66
-
67
- if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
- return;
69
- }
70
-
71
- for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
- int ky = tap_idx / kernel_w;
73
- int kx = tap_idx - ky * kernel_w;
74
- scalar_t v = 0.0;
75
-
76
- if (kx < p.kernel_w & ky < p.kernel_h) {
77
- v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
- }
79
-
80
- sk[ky][kx] = v;
81
- }
82
-
83
- for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
- for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
- int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
- int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
- int tile_in_x = floor_div(tile_mid_x, up_x);
88
- int tile_in_y = floor_div(tile_mid_y, up_y);
89
-
90
- __syncthreads();
91
-
92
- for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
- int rel_in_y = in_idx / tile_in_w;
94
- int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
- int in_x = rel_in_x + tile_in_x;
96
- int in_y = rel_in_y + tile_in_y;
97
-
98
- scalar_t v = 0.0;
99
-
100
- if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
- v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
- }
103
-
104
- sx[rel_in_y][rel_in_x] = v;
105
- }
106
-
107
- __syncthreads();
108
- for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
- int rel_out_y = out_idx / tile_out_w;
110
- int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
- int out_x = rel_out_x + tile_out_x;
112
- int out_y = rel_out_y + tile_out_y;
113
-
114
- int mid_x = tile_mid_x + rel_out_x * down_x;
115
- int mid_y = tile_mid_y + rel_out_y * down_y;
116
- int in_x = floor_div(mid_x, up_x);
117
- int in_y = floor_div(mid_y, up_y);
118
- int rel_in_x = in_x - tile_in_x;
119
- int rel_in_y = in_y - tile_in_y;
120
- int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
- int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
-
123
- scalar_t v = 0.0;
124
-
125
- #pragma unroll
126
- for (int y = 0; y < kernel_h / up_y; y++)
127
- #pragma unroll
128
- for (int x = 0; x < kernel_w / up_x; x++)
129
- v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
-
131
- if (out_x < p.out_w & out_y < p.out_h) {
132
- out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
- }
134
- }
135
- }
136
- }
137
- }
138
-
139
-
140
- torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
- int up_x, int up_y, int down_x, int down_y,
142
- int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
- int curDevice = -1;
144
- cudaGetDevice(&curDevice);
145
- cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
-
147
- UpFirDn2DKernelParams p;
148
-
149
- auto x = input.contiguous();
150
- auto k = kernel.contiguous();
151
-
152
- p.major_dim = x.size(0);
153
- p.in_h = x.size(1);
154
- p.in_w = x.size(2);
155
- p.minor_dim = x.size(3);
156
- p.kernel_h = k.size(0);
157
- p.kernel_w = k.size(1);
158
- p.up_x = up_x;
159
- p.up_y = up_y;
160
- p.down_x = down_x;
161
- p.down_y = down_y;
162
- p.pad_x0 = pad_x0;
163
- p.pad_x1 = pad_x1;
164
- p.pad_y0 = pad_y0;
165
- p.pad_y1 = pad_y1;
166
-
167
- p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
- p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
-
170
- auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
-
172
- int mode = -1;
173
-
174
- int tile_out_h;
175
- int tile_out_w;
176
-
177
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
- mode = 1;
179
- tile_out_h = 16;
180
- tile_out_w = 64;
181
- }
182
-
183
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
- mode = 2;
185
- tile_out_h = 16;
186
- tile_out_w = 64;
187
- }
188
-
189
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
- mode = 3;
191
- tile_out_h = 16;
192
- tile_out_w = 64;
193
- }
194
-
195
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
- mode = 4;
197
- tile_out_h = 16;
198
- tile_out_w = 64;
199
- }
200
-
201
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
- mode = 5;
203
- tile_out_h = 8;
204
- tile_out_w = 32;
205
- }
206
-
207
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
- mode = 6;
209
- tile_out_h = 8;
210
- tile_out_w = 32;
211
- }
212
-
213
- dim3 block_size;
214
- dim3 grid_size;
215
-
216
- if (tile_out_h > 0 && tile_out_w) {
217
- p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
- p.loop_x = 1;
219
- block_size = dim3(32 * 8, 1, 1);
220
- grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
- (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
- (p.major_dim - 1) / p.loop_major + 1);
223
- }
224
-
225
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
- switch (mode) {
227
- case 1:
228
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
- );
231
-
232
- break;
233
-
234
- case 2:
235
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
- );
238
-
239
- break;
240
-
241
- case 3:
242
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
- );
245
-
246
- break;
247
-
248
- case 4:
249
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
- );
252
-
253
- break;
254
-
255
- case 5:
256
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
- );
259
-
260
- break;
261
-
262
- case 6:
263
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
- );
266
-
267
- break;
268
- }
269
- });
270
-
271
- return out;
272
- }