ucalyptus commited on
Commit
880233f
1 Parent(s): a6f782a

Add application file

Browse files
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: DragGAN Unofficial
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from drag_gan import stylegan2, drag_gan
4
+ from PIL import Image
5
+
6
+ device = 'cuda'
7
+ g_ema = stylegan2().to(device)
8
+
9
+
10
+ def to_image(tensor):
11
+ tensor = tensor.squeeze(0).permute(1, 2, 0)
12
+ arr = tensor.detach().cpu().numpy()
13
+ arr = (arr - arr.min()) / (arr.max() - arr.min())
14
+ arr = arr * 255
15
+ return arr.astype('uint8')
16
+
17
+
18
+ def on_click(image, target_point, points, evt: gr.SelectData):
19
+ x = evt.index[1]
20
+ y = evt.index[0]
21
+ if target_point:
22
+ image[x:x + 5, y:y + 5, :] = 255
23
+ points['target'].append([evt.index[1], evt.index[0]])
24
+ return image, str(evt.index)
25
+ points['handle'].append([evt.index[1], evt.index[0]])
26
+ image[x:x + 5, y:y + 5, :] = 0
27
+ return image, str(evt.index)
28
+
29
+
30
+ def on_drag(points, max_iters, state):
31
+ max_iters = int(max_iters)
32
+ latent = state['latent']
33
+ noise = state['noise']
34
+ F = state['F']
35
+
36
+ handle_points = [torch.tensor(p).float() for p in points['handle']]
37
+ target_points = [torch.tensor(p).float() for p in points['target']]
38
+ mask = torch.zeros((1, 1, 1024, 1024)).to(device)
39
+ mask[..., 720:820, 390:600] = 1
40
+ for sample2, latent, F in drag_gan(g_ema, latent, noise, F,
41
+ handle_points, target_points, mask,
42
+ max_iters=max_iters):
43
+ points = {'target': [], 'handle': []}
44
+ image = to_image(sample2)
45
+
46
+ state['F'] = F
47
+ state['latent'] = latent
48
+ yield points, image, state
49
+
50
+
51
+ def main():
52
+ torch.cuda.manual_seed(25)
53
+ sample_z = torch.randn([1, 512], device=device)
54
+ latent, noise = g_ema.prepare([sample_z])
55
+ sample, F = g_ema.generate(latent, noise)
56
+
57
+ with gr.Blocks() as demo:
58
+ state = gr.State({
59
+ 'latent': latent,
60
+ 'noise': noise,
61
+ 'F': F,
62
+ })
63
+ max_iters = gr.Slider(1, 100, 5, label='Max Iterations')
64
+ image = gr.Image(to_image(sample)).style(height=512, width=512)
65
+ text = gr.Textbox()
66
+ btn = gr.Button('Drag it')
67
+ points = gr.State({'target': [], 'handle': []})
68
+ target_point = gr.Checkbox(label='Target Point')
69
+ image.select(on_click, [image, target_point, points], [image, text])
70
+ btn.click(on_drag, inputs=[points, max_iters, state], outputs=[points, image, state])
71
+
72
+ demo.queue(concurrency_count=5, max_size=20).launch()
73
+
74
+
75
+ if __name__ == '__main__':
76
+ main()
assets/demo.png ADDED
drag_gan.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ import urllib.request
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as FF
9
+ import torch.optim
10
+ from torchvision import utils
11
+ from tqdm import tqdm
12
+
13
+ from stylegan2.model import Generator
14
+
15
+
16
+ class DownloadProgressBar(tqdm):
17
+ def update_to(self, b=1, bsize=1, tsize=None):
18
+ if tsize is not None:
19
+ self.total = tsize
20
+ self.update(b * bsize - self.n)
21
+
22
+
23
+ def get_path(base_path):
24
+ BASE_DIR = os.path.join('checkpoints')
25
+
26
+ save_path = os.path.join(BASE_DIR, base_path)
27
+ if not os.path.exists(save_path):
28
+ url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
29
+ print(f'{base_path} not found')
30
+ print('Try to download from huggingface: ', url)
31
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
32
+ download_url(url, save_path)
33
+ print('Downloaded to ', save_path)
34
+ return save_path
35
+
36
+
37
+ def download_url(url, output_path):
38
+ with DownloadProgressBar(unit='B', unit_scale=True,
39
+ miniters=1, desc=url.split('/')[-1]) as t:
40
+ urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
41
+
42
+
43
+ class CustomGenerator(Generator):
44
+ def prepare(
45
+ self,
46
+ styles,
47
+ inject_index=None,
48
+ truncation=1,
49
+ truncation_latent=None,
50
+ input_is_latent=False,
51
+ noise=None,
52
+ randomize_noise=True,
53
+ ):
54
+ if not input_is_latent:
55
+ styles = [self.style(s) for s in styles]
56
+
57
+ if noise is None:
58
+ if randomize_noise:
59
+ noise = [None] * self.num_layers
60
+ else:
61
+ noise = [
62
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
63
+ ]
64
+
65
+ if truncation < 1:
66
+ style_t = []
67
+
68
+ for style in styles:
69
+ style_t.append(
70
+ truncation_latent + truncation * (style - truncation_latent)
71
+ )
72
+
73
+ styles = style_t
74
+
75
+ if len(styles) < 2:
76
+ inject_index = self.n_latent
77
+
78
+ if styles[0].ndim < 3:
79
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
80
+
81
+ else:
82
+ latent = styles[0]
83
+
84
+ else:
85
+ if inject_index is None:
86
+ inject_index = random.randint(1, self.n_latent - 1)
87
+
88
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
89
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
90
+
91
+ latent = torch.cat([latent, latent2], 1)
92
+
93
+ return latent, noise
94
+
95
+ def generate(
96
+ self,
97
+ latent,
98
+ noise,
99
+ ):
100
+ out = self.input(latent)
101
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
102
+
103
+ skip = self.to_rgb1(out, latent[:, 1])
104
+ i = 1
105
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
106
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
107
+ ):
108
+ out = conv1(out, latent[:, i], noise=noise1)
109
+ out = conv2(out, latent[:, i + 1], noise=noise2)
110
+ skip = to_rgb(out, latent[:, i + 2], skip)
111
+ if out.shape[-1] == 256: F = out
112
+ i += 2
113
+
114
+ image = skip
115
+ F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
116
+ return image, F
117
+
118
+
119
+ def stylegan2(
120
+ size=1024,
121
+ channel_multiplier=2,
122
+ latent=512,
123
+ n_mlp=8,
124
+ ckpt='stylegan2-ffhq-config-f.pt'
125
+ ):
126
+ g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier)
127
+ checkpoint = torch.load(get_path(ckpt))
128
+ g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
129
+ g_ema.requires_grad_(False)
130
+ g_ema.eval()
131
+ return g_ema
132
+
133
+
134
+ def bilinear_interpolate_torch(im, y, x):
135
+ """
136
+ im : B,C,H,W
137
+ y : 1,numPoints -- pixel location y float
138
+ x : 1,numPOints -- pixel location y float
139
+ """
140
+
141
+ x0 = torch.floor(x).long()
142
+ x1 = x0 + 1
143
+
144
+ y0 = torch.floor(y).long()
145
+ y1 = y0 + 1
146
+
147
+ wa = (x1.float() - x) * (y1.float() - y)
148
+ wb = (x1.float() - x) * (y - y0.float())
149
+ wc = (x - x0.float()) * (y1.float() - y)
150
+ wd = (x - x0.float()) * (y - y0.float())
151
+ # Instead of clamp
152
+ x1 = x1 - torch.floor(x1 / im.shape[3]).int()
153
+ y1 = y1 - torch.floor(y1 / im.shape[2]).int()
154
+ Ia = im[:, :, y0, x0]
155
+ Ib = im[:, :, y1, x0]
156
+ Ic = im[:, :, y0, x1]
157
+ Id = im[:, :, y1, x1]
158
+
159
+ return Ia * wa + Ib * wb + Ic * wc + Id * wd
160
+
161
+
162
+ def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000):
163
+ handle_points0 = copy.deepcopy(handle_points)
164
+ n = len(handle_points)
165
+ r1, r2, lam, d = 3, 12, 20, 1
166
+
167
+ def neighbor(x, y, d):
168
+ points = []
169
+ for i in range(x - d, x + d):
170
+ for j in range(y - d, y + d):
171
+ points.append(torch.tensor([i, j]).float().cuda())
172
+ return points
173
+
174
+ F0 = F.detach().clone()
175
+ # latent = latent.detach().clone().requires_grad_(True)
176
+ latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
177
+ latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
178
+ optimizer = torch.optim.Adam([latent_trainable], lr=2e-3)
179
+ for iter in range(max_iters):
180
+ for s in range(1):
181
+ optimizer.zero_grad()
182
+ latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
183
+ sample2, F2 = g_ema.generate(latent, noise)
184
+
185
+ # motion supervision
186
+ loss = 0
187
+ for i in range(n):
188
+ pi, ti = handle_points[i], target_points[i]
189
+ di = (ti - pi) / torch.sum((ti - pi)**2)
190
+
191
+ for qi in neighbor(int(pi[0]), int(pi[1]), r1):
192
+ # f1 = F[..., int(qi[0]), int(qi[1])]
193
+ # f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])]
194
+ f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach()
195
+ f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
196
+ loss += FF.l1_loss(f2, f1)
197
+
198
+ # loss += ((F-F0) * (1-mask)).abs().mean() * lam
199
+
200
+ loss.backward()
201
+ optimizer.step()
202
+
203
+ print(latent_trainable[0, 0, :10])
204
+ # if s % 10 ==0:
205
+ # utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))
206
+
207
+ # point tracking
208
+ with torch.no_grad():
209
+ sample2, F2 = g_ema.generate(latent, noise)
210
+ for i in range(n):
211
+ pi = handle_points0[i]
212
+ # f = F0[..., int(pi[0]), int(pi[1])]
213
+ f0 = bilinear_interpolate_torch(F0, pi[0], pi[1])
214
+ minv = 1e9
215
+ minx = 1e9
216
+ miny = 1e9
217
+ for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
218
+ # f2 = F2[..., int(qi[0]), int(qi[1])]
219
+ try:
220
+ f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
221
+ except:
222
+ import ipdb
223
+ ipdb.set_trace()
224
+ v = torch.norm(f2 - f0, p=1)
225
+ if v < minv:
226
+ minv = v
227
+ minx = int(qi[0])
228
+ miny = int(qi[1])
229
+ handle_points[i][0] = minx
230
+ handle_points[i][1] = miny
231
+
232
+ F = F2.detach().clone()
233
+ if iter % 1 == 0:
234
+ print(iter, loss.item(), handle_points, target_points)
235
+ # p = handle_points[0].int()
236
+ # sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0
237
+ # t = target_points[0].int()
238
+ # sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255
239
+
240
+ # sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0
241
+ utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))
242
+
243
+ yield sample2, latent, F2
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ tqdm
stylegan2/_init__.py ADDED
File without changes
stylegan2/model.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, input):
19
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20
+
21
+
22
+ def make_kernel(k):
23
+ k = torch.tensor(k, dtype=torch.float32)
24
+
25
+ if k.ndim == 1:
26
+ k = k[None, :] * k[:, None]
27
+
28
+ k /= k.sum()
29
+
30
+ return k
31
+
32
+
33
+ class Upsample(nn.Module):
34
+ def __init__(self, kernel, factor=2):
35
+ super().__init__()
36
+
37
+ self.factor = factor
38
+ kernel = make_kernel(kernel) * (factor ** 2)
39
+ self.register_buffer("kernel", kernel)
40
+
41
+ p = kernel.shape[0] - factor
42
+
43
+ pad0 = (p + 1) // 2 + factor - 1
44
+ pad1 = p // 2
45
+
46
+ self.pad = (pad0, pad1)
47
+
48
+ def forward(self, input):
49
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50
+
51
+ return out
52
+
53
+
54
+ class Downsample(nn.Module):
55
+ def __init__(self, kernel, factor=2):
56
+ super().__init__()
57
+
58
+ self.factor = factor
59
+ kernel = make_kernel(kernel)
60
+ self.register_buffer("kernel", kernel)
61
+
62
+ p = kernel.shape[0] - factor
63
+
64
+ pad0 = (p + 1) // 2
65
+ pad1 = p // 2
66
+
67
+ self.pad = (pad0, pad1)
68
+
69
+ def forward(self, input):
70
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71
+
72
+ return out
73
+
74
+
75
+ class Blur(nn.Module):
76
+ def __init__(self, kernel, pad, upsample_factor=1):
77
+ super().__init__()
78
+
79
+ kernel = make_kernel(kernel)
80
+
81
+ if upsample_factor > 1:
82
+ kernel = kernel * (upsample_factor ** 2)
83
+
84
+ self.register_buffer("kernel", kernel)
85
+
86
+ self.pad = pad
87
+
88
+ def forward(self, input):
89
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
90
+
91
+ return out
92
+
93
+
94
+ class EqualConv2d(nn.Module):
95
+ def __init__(
96
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97
+ ):
98
+ super().__init__()
99
+
100
+ self.weight = nn.Parameter(
101
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102
+ )
103
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104
+
105
+ self.stride = stride
106
+ self.padding = padding
107
+
108
+ if bias:
109
+ self.bias = nn.Parameter(torch.zeros(out_channel))
110
+
111
+ else:
112
+ self.bias = None
113
+
114
+ def forward(self, input):
115
+ out = conv2d_gradfix.conv2d(
116
+ input,
117
+ self.weight * self.scale,
118
+ bias=self.bias,
119
+ stride=self.stride,
120
+ padding=self.padding,
121
+ )
122
+
123
+ return out
124
+
125
+ def __repr__(self):
126
+ return (
127
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129
+ )
130
+
131
+
132
+ class EqualLinear(nn.Module):
133
+ def __init__(
134
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135
+ ):
136
+ super().__init__()
137
+
138
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139
+
140
+ if bias:
141
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142
+
143
+ else:
144
+ self.bias = None
145
+
146
+ self.activation = activation
147
+
148
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149
+ self.lr_mul = lr_mul
150
+
151
+ def forward(self, input):
152
+ if self.activation:
153
+ out = F.linear(input, self.weight * self.scale)
154
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
155
+
156
+ else:
157
+ out = F.linear(
158
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
159
+ )
160
+
161
+ return out
162
+
163
+ def __repr__(self):
164
+ return (
165
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166
+ )
167
+
168
+
169
+ class ModulatedConv2d(nn.Module):
170
+ def __init__(
171
+ self,
172
+ in_channel,
173
+ out_channel,
174
+ kernel_size,
175
+ style_dim,
176
+ demodulate=True,
177
+ upsample=False,
178
+ downsample=False,
179
+ blur_kernel=[1, 3, 3, 1],
180
+ fused=True,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.eps = 1e-8
185
+ self.kernel_size = kernel_size
186
+ self.in_channel = in_channel
187
+ self.out_channel = out_channel
188
+ self.upsample = upsample
189
+ self.downsample = downsample
190
+
191
+ if upsample:
192
+ factor = 2
193
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
194
+ pad0 = (p + 1) // 2 + factor - 1
195
+ pad1 = p // 2 + 1
196
+
197
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
198
+
199
+ if downsample:
200
+ factor = 2
201
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
202
+ pad0 = (p + 1) // 2
203
+ pad1 = p // 2
204
+
205
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
206
+
207
+ fan_in = in_channel * kernel_size ** 2
208
+ self.scale = 1 / math.sqrt(fan_in)
209
+ self.padding = kernel_size // 2
210
+
211
+ self.weight = nn.Parameter(
212
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
213
+ )
214
+
215
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216
+
217
+ self.demodulate = demodulate
218
+ self.fused = fused
219
+
220
+ def __repr__(self):
221
+ return (
222
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
223
+ f"upsample={self.upsample}, downsample={self.downsample})"
224
+ )
225
+
226
+ def forward(self, input, style):
227
+ batch, in_channel, height, width = input.shape
228
+
229
+ if not self.fused:
230
+ weight = self.scale * self.weight.squeeze(0)
231
+ style = self.modulation(style)
232
+
233
+ if self.demodulate:
234
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
235
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
236
+
237
+ input = input * style.reshape(batch, in_channel, 1, 1)
238
+
239
+ if self.upsample:
240
+ weight = weight.transpose(0, 1)
241
+ out = conv2d_gradfix.conv_transpose2d(
242
+ input, weight, padding=0, stride=2
243
+ )
244
+ out = self.blur(out)
245
+
246
+ elif self.downsample:
247
+ input = self.blur(input)
248
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
249
+
250
+ else:
251
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
252
+
253
+ if self.demodulate:
254
+ out = out * dcoefs.view(batch, -1, 1, 1)
255
+
256
+ return out
257
+
258
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
259
+ weight = self.scale * self.weight * style
260
+
261
+ if self.demodulate:
262
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
263
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
264
+
265
+ weight = weight.view(
266
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
267
+ )
268
+
269
+ if self.upsample:
270
+ input = input.view(1, batch * in_channel, height, width)
271
+ weight = weight.view(
272
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
273
+ )
274
+ weight = weight.transpose(1, 2).reshape(
275
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
276
+ )
277
+ out = conv2d_gradfix.conv_transpose2d(
278
+ input, weight, padding=0, stride=2, groups=batch
279
+ )
280
+ _, _, height, width = out.shape
281
+ out = out.view(batch, self.out_channel, height, width)
282
+ out = self.blur(out)
283
+
284
+ elif self.downsample:
285
+ input = self.blur(input)
286
+ _, _, height, width = input.shape
287
+ input = input.view(1, batch * in_channel, height, width)
288
+ out = conv2d_gradfix.conv2d(
289
+ input, weight, padding=0, stride=2, groups=batch
290
+ )
291
+ _, _, height, width = out.shape
292
+ out = out.view(batch, self.out_channel, height, width)
293
+
294
+ else:
295
+ input = input.view(1, batch * in_channel, height, width)
296
+ out = conv2d_gradfix.conv2d(
297
+ input, weight, padding=self.padding, groups=batch
298
+ )
299
+ _, _, height, width = out.shape
300
+ out = out.view(batch, self.out_channel, height, width)
301
+
302
+ return out
303
+
304
+
305
+ class NoiseInjection(nn.Module):
306
+ def __init__(self):
307
+ super().__init__()
308
+
309
+ self.weight = nn.Parameter(torch.zeros(1))
310
+
311
+ def forward(self, image, noise=None):
312
+ if noise is None:
313
+ batch, _, height, width = image.shape
314
+ noise = image.new_empty(batch, 1, height, width).normal_()
315
+
316
+ return image + self.weight * noise
317
+
318
+
319
+ class ConstantInput(nn.Module):
320
+ def __init__(self, channel, size=4):
321
+ super().__init__()
322
+
323
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
324
+
325
+ def forward(self, input):
326
+ batch = input.shape[0]
327
+ out = self.input.repeat(batch, 1, 1, 1)
328
+
329
+ return out
330
+
331
+
332
+ class StyledConv(nn.Module):
333
+ def __init__(
334
+ self,
335
+ in_channel,
336
+ out_channel,
337
+ kernel_size,
338
+ style_dim,
339
+ upsample=False,
340
+ blur_kernel=[1, 3, 3, 1],
341
+ demodulate=True,
342
+ ):
343
+ super().__init__()
344
+
345
+ self.conv = ModulatedConv2d(
346
+ in_channel,
347
+ out_channel,
348
+ kernel_size,
349
+ style_dim,
350
+ upsample=upsample,
351
+ blur_kernel=blur_kernel,
352
+ demodulate=demodulate,
353
+ )
354
+
355
+ self.noise = NoiseInjection()
356
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
357
+ # self.activate = ScaledLeakyReLU(0.2)
358
+ self.activate = FusedLeakyReLU(out_channel)
359
+
360
+ def forward(self, input, style, noise=None):
361
+ out = self.conv(input, style)
362
+ out = self.noise(out, noise=noise)
363
+ # out = out + self.bias
364
+ out = self.activate(out)
365
+
366
+ return out
367
+
368
+
369
+ class ToRGB(nn.Module):
370
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
371
+ super().__init__()
372
+
373
+ if upsample:
374
+ self.upsample = Upsample(blur_kernel)
375
+
376
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
377
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
378
+
379
+ def forward(self, input, style, skip=None):
380
+ out = self.conv(input, style)
381
+ out = out + self.bias
382
+
383
+ if skip is not None:
384
+ skip = self.upsample(skip)
385
+
386
+ out = out + skip
387
+
388
+ return out
389
+
390
+
391
+ class Generator(nn.Module):
392
+ def __init__(
393
+ self,
394
+ size,
395
+ style_dim,
396
+ n_mlp,
397
+ channel_multiplier=2,
398
+ blur_kernel=[1, 3, 3, 1],
399
+ lr_mlp=0.01,
400
+ ):
401
+ super().__init__()
402
+
403
+ self.size = size
404
+
405
+ self.style_dim = style_dim
406
+
407
+ layers = [PixelNorm()]
408
+
409
+ for i in range(n_mlp):
410
+ layers.append(
411
+ EqualLinear(
412
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
413
+ )
414
+ )
415
+
416
+ self.style = nn.Sequential(*layers)
417
+
418
+ self.channels = {
419
+ 4: 512,
420
+ 8: 512,
421
+ 16: 512,
422
+ 32: 512,
423
+ 64: 256 * channel_multiplier,
424
+ 128: 128 * channel_multiplier,
425
+ 256: 64 * channel_multiplier,
426
+ 512: 32 * channel_multiplier,
427
+ 1024: 16 * channel_multiplier,
428
+ }
429
+
430
+ self.input = ConstantInput(self.channels[4])
431
+ self.conv1 = StyledConv(
432
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
433
+ )
434
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
435
+
436
+ self.log_size = int(math.log(size, 2))
437
+ self.num_layers = (self.log_size - 2) * 2 + 1
438
+
439
+ self.convs = nn.ModuleList()
440
+ self.upsamples = nn.ModuleList()
441
+ self.to_rgbs = nn.ModuleList()
442
+ self.noises = nn.Module()
443
+
444
+ in_channel = self.channels[4]
445
+
446
+ for layer_idx in range(self.num_layers):
447
+ res = (layer_idx + 5) // 2
448
+ shape = [1, 1, 2 ** res, 2 ** res]
449
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
450
+
451
+ for i in range(3, self.log_size + 1):
452
+ out_channel = self.channels[2 ** i]
453
+
454
+ self.convs.append(
455
+ StyledConv(
456
+ in_channel,
457
+ out_channel,
458
+ 3,
459
+ style_dim,
460
+ upsample=True,
461
+ blur_kernel=blur_kernel,
462
+ )
463
+ )
464
+
465
+ self.convs.append(
466
+ StyledConv(
467
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
468
+ )
469
+ )
470
+
471
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
472
+
473
+ in_channel = out_channel
474
+
475
+ self.n_latent = self.log_size * 2 - 2
476
+
477
+ def make_noise(self):
478
+ device = self.input.input.device
479
+
480
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
481
+
482
+ for i in range(3, self.log_size + 1):
483
+ for _ in range(2):
484
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
485
+
486
+ return noises
487
+
488
+ def mean_latent(self, n_latent):
489
+ latent_in = torch.randn(
490
+ n_latent, self.style_dim, device=self.input.input.device
491
+ )
492
+ latent = self.style(latent_in).mean(0, keepdim=True)
493
+
494
+ return latent
495
+
496
+ def get_latent(self, input):
497
+ return self.style(input)
498
+
499
+ def forward(
500
+ self,
501
+ styles,
502
+ return_latents=False,
503
+ inject_index=None,
504
+ truncation=1,
505
+ truncation_latent=None,
506
+ input_is_latent=False,
507
+ noise=None,
508
+ randomize_noise=True,
509
+ ):
510
+ if not input_is_latent:
511
+ styles = [self.style(s) for s in styles]
512
+
513
+ if noise is None:
514
+ if randomize_noise:
515
+ noise = [None] * self.num_layers
516
+ else:
517
+ noise = [
518
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
519
+ ]
520
+
521
+ if truncation < 1:
522
+ style_t = []
523
+
524
+ for style in styles:
525
+ style_t.append(
526
+ truncation_latent + truncation * (style - truncation_latent)
527
+ )
528
+
529
+ styles = style_t
530
+
531
+ if len(styles) < 2:
532
+ inject_index = self.n_latent
533
+
534
+ if styles[0].ndim < 3:
535
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
536
+
537
+ else:
538
+ latent = styles[0]
539
+
540
+ else:
541
+ if inject_index is None:
542
+ inject_index = random.randint(1, self.n_latent - 1)
543
+
544
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
545
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
546
+
547
+ latent = torch.cat([latent, latent2], 1)
548
+
549
+ out = self.input(latent)
550
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
551
+
552
+ skip = self.to_rgb1(out, latent[:, 1])
553
+
554
+ i = 1
555
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
556
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
557
+ ):
558
+ out = conv1(out, latent[:, i], noise=noise1)
559
+ out = conv2(out, latent[:, i + 1], noise=noise2)
560
+ skip = to_rgb(out, latent[:, i + 2], skip)
561
+
562
+ i += 2
563
+
564
+
565
+ image = skip
566
+
567
+ if return_latents:
568
+ return image, latent
569
+
570
+ else:
571
+ return image, None
572
+
573
+
574
+ class ConvLayer(nn.Sequential):
575
+ def __init__(
576
+ self,
577
+ in_channel,
578
+ out_channel,
579
+ kernel_size,
580
+ downsample=False,
581
+ blur_kernel=[1, 3, 3, 1],
582
+ bias=True,
583
+ activate=True,
584
+ ):
585
+ layers = []
586
+
587
+ if downsample:
588
+ factor = 2
589
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
590
+ pad0 = (p + 1) // 2
591
+ pad1 = p // 2
592
+
593
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
594
+
595
+ stride = 2
596
+ self.padding = 0
597
+
598
+ else:
599
+ stride = 1
600
+ self.padding = kernel_size // 2
601
+
602
+ layers.append(
603
+ EqualConv2d(
604
+ in_channel,
605
+ out_channel,
606
+ kernel_size,
607
+ padding=self.padding,
608
+ stride=stride,
609
+ bias=bias and not activate,
610
+ )
611
+ )
612
+
613
+ if activate:
614
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
615
+
616
+ super().__init__(*layers)
617
+
618
+
619
+ class ResBlock(nn.Module):
620
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
621
+ super().__init__()
622
+
623
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
624
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
625
+
626
+ self.skip = ConvLayer(
627
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
628
+ )
629
+
630
+ def forward(self, input):
631
+ out = self.conv1(input)
632
+ out = self.conv2(out)
633
+
634
+ skip = self.skip(input)
635
+ out = (out + skip) / math.sqrt(2)
636
+
637
+ return out
638
+
639
+
640
+ class Discriminator(nn.Module):
641
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
642
+ super().__init__()
643
+
644
+ channels = {
645
+ 4: 512,
646
+ 8: 512,
647
+ 16: 512,
648
+ 32: 512,
649
+ 64: 256 * channel_multiplier,
650
+ 128: 128 * channel_multiplier,
651
+ 256: 64 * channel_multiplier,
652
+ 512: 32 * channel_multiplier,
653
+ 1024: 16 * channel_multiplier,
654
+ }
655
+
656
+ convs = [ConvLayer(3, channels[size], 1)]
657
+
658
+ log_size = int(math.log(size, 2))
659
+
660
+ in_channel = channels[size]
661
+
662
+ for i in range(log_size, 2, -1):
663
+ out_channel = channels[2 ** (i - 1)]
664
+
665
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
666
+
667
+ in_channel = out_channel
668
+
669
+ self.convs = nn.Sequential(*convs)
670
+
671
+ self.stddev_group = 4
672
+ self.stddev_feat = 1
673
+
674
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
675
+ self.final_linear = nn.Sequential(
676
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
677
+ EqualLinear(channels[4], 1),
678
+ )
679
+
680
+ def forward(self, input):
681
+ out = self.convs(input)
682
+
683
+ batch, channel, height, width = out.shape
684
+ group = min(batch, self.stddev_group)
685
+ stddev = out.view(
686
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
687
+ )
688
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
689
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
690
+ stddev = stddev.repeat(group, 1, height, width)
691
+ out = torch.cat([out, stddev], 1)
692
+
693
+ out = self.final_conv(out)
694
+
695
+ out = out.view(batch, -1)
696
+ out = self.final_linear(out)
697
+
698
+ return out
699
+
stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
stylegan2/op/conv2d_gradfix.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ if bias:
39
+ grad_bias = grad_input.sum(dim).detach()
40
+
41
+ else:
42
+ grad_bias = empty
43
+
44
+ return grad_input, grad_bias
45
+
46
+ @staticmethod
47
+ def backward(ctx, gradgrad_input, gradgrad_bias):
48
+ out, = ctx.saved_tensors
49
+ gradgrad_out = fused.fused_bias_act(
50
+ gradgrad_input.contiguous(),
51
+ gradgrad_bias,
52
+ out,
53
+ 3,
54
+ 1,
55
+ ctx.negative_slope,
56
+ ctx.scale,
57
+ )
58
+
59
+ return gradgrad_out, None, None, None, None
60
+
61
+
62
+ class FusedLeakyReLUFunction(Function):
63
+ @staticmethod
64
+ def forward(ctx, input, bias, negative_slope, scale):
65
+ empty = input.new_empty(0)
66
+
67
+ ctx.bias = bias is not None
68
+
69
+ if bias is None:
70
+ bias = empty
71
+
72
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73
+ ctx.save_for_backward(out)
74
+ ctx.negative_slope = negative_slope
75
+ ctx.scale = scale
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output):
81
+ out, = ctx.saved_tensors
82
+
83
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85
+ )
86
+
87
+ if not ctx.bias:
88
+ grad_bias = None
89
+
90
+ return grad_input, grad_bias, None, None
91
+
92
+
93
+ class FusedLeakyReLU(nn.Module):
94
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95
+ super().__init__()
96
+
97
+ if bias:
98
+ self.bias = nn.Parameter(torch.zeros(channel))
99
+
100
+ else:
101
+ self.bias = None
102
+
103
+ self.negative_slope = negative_slope
104
+ self.scale = scale
105
+
106
+ def forward(self, input):
107
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108
+
109
+
110
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111
+ if input.device.type == "cpu":
112
+ if bias is not None:
113
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
114
+ return (
115
+ F.leaky_relu(
116
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117
+ )
118
+ * scale
119
+ )
120
+
121
+ else:
122
+ return F.leaky_relu(input, negative_slope=0.2) * scale
123
+
124
+ else:
125
+ return FusedLeakyReLUFunction.apply(
126
+ input.contiguous(), bias, negative_slope, scale
127
+ )
stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <ATen/ATen.h>
3
+ #include <torch/extension.h>
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
+ const torch::Tensor &bias,
7
+ const torch::Tensor &refer, int act, int grad,
8
+ float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) \
11
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) \
13
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) \
15
+ CHECK_CUDA(x); \
16
+ CHECK_CONTIGUOUS(x)
17
+
18
+ torch::Tensor fused_bias_act(const torch::Tensor &input,
19
+ const torch::Tensor &bias,
20
+ const torch::Tensor &refer, int act, int grad,
21
+ float alpha, float scale) {
22
+ CHECK_INPUT(input);
23
+ CHECK_INPUT(bias);
24
+
25
+ at::DeviceGuard guard(input.device());
26
+
27
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
+ }
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
+ }
stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void
20
+ fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
+ const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
+ scalar_t scale, int loop_x, int size_x, int step_b,
23
+ int size_b, int use_bias, int use_ref) {
24
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
+
26
+ scalar_t zero = 0.0;
27
+
28
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
+ loop_idx++, xi += blockDim.x) {
30
+ scalar_t x = p_x[xi];
31
+
32
+ if (use_bias) {
33
+ x += p_b[(xi / step_b) % size_b];
34
+ }
35
+
36
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
37
+
38
+ scalar_t y;
39
+
40
+ switch (act * 10 + grad) {
41
+ default:
42
+ case 10:
43
+ y = x;
44
+ break;
45
+ case 11:
46
+ y = x;
47
+ break;
48
+ case 12:
49
+ y = 0.0;
50
+ break;
51
+
52
+ case 30:
53
+ y = (x > 0.0) ? x : x * alpha;
54
+ break;
55
+ case 31:
56
+ y = (ref > 0.0) ? x : x * alpha;
57
+ break;
58
+ case 32:
59
+ y = 0.0;
60
+ break;
61
+ }
62
+
63
+ out[xi] = y * scale;
64
+ }
65
+ }
66
+
67
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
+ const torch::Tensor &bias,
69
+ const torch::Tensor &refer, int act, int grad,
70
+ float alpha, float scale) {
71
+ int curDevice = -1;
72
+ cudaGetDevice(&curDevice);
73
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
+
75
+ auto x = input.contiguous();
76
+ auto b = bias.contiguous();
77
+ auto ref = refer.contiguous();
78
+
79
+ int use_bias = b.numel() ? 1 : 0;
80
+ int use_ref = ref.numel() ? 1 : 0;
81
+
82
+ int size_x = x.numel();
83
+ int size_b = b.numel();
84
+ int step_b = 1;
85
+
86
+ for (int i = 1 + 1; i < x.dim(); i++) {
87
+ step_b *= x.size(i);
88
+ }
89
+
90
+ int loop_x = 4;
91
+ int block_size = 4 * 32;
92
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
+
94
+ auto y = torch::empty_like(x);
95
+
96
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
+ x.scalar_type(), "fused_bias_act_kernel", [&] {
98
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
+ y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
+ b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
+ scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
+ });
103
+
104
+ return y;
105
+ }
stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
+ const torch::Tensor &kernel, int up_x, int up_y,
6
+ int down_x, int down_y, int pad_x0, int pad_x1,
7
+ int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) \
10
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) \
12
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x)
16
+
17
+ torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
+ int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
+ int pad_x1, int pad_y0, int pad_y1) {
20
+ CHECK_INPUT(input);
21
+ CHECK_INPUT(kernel);
22
+
23
+ at::DeviceGuard guard(input.device());
24
+
25
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
+ pad_y0, pad_y1);
27
+ }
28
+
29
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
+ }
stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ upfirdn2d_op = load(
12
+ "upfirdn2d",
13
+ sources=[
14
+ os.path.join(module_path, "upfirdn2d.cpp"),
15
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class UpFirDn2dBackward(Function):
21
+ @staticmethod
22
+ def forward(
23
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
+ ):
25
+
26
+ up_x, up_y = up
27
+ down_x, down_y = down
28
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
+
30
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
+
32
+ grad_input = upfirdn2d_op.upfirdn2d(
33
+ grad_output,
34
+ grad_kernel,
35
+ down_x,
36
+ down_y,
37
+ up_x,
38
+ up_y,
39
+ g_pad_x0,
40
+ g_pad_x1,
41
+ g_pad_y0,
42
+ g_pad_y1,
43
+ )
44
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
+
46
+ ctx.save_for_backward(kernel)
47
+
48
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
+
50
+ ctx.up_x = up_x
51
+ ctx.up_y = up_y
52
+ ctx.down_x = down_x
53
+ ctx.down_y = down_y
54
+ ctx.pad_x0 = pad_x0
55
+ ctx.pad_x1 = pad_x1
56
+ ctx.pad_y0 = pad_y0
57
+ ctx.pad_y1 = pad_y1
58
+ ctx.in_size = in_size
59
+ ctx.out_size = out_size
60
+
61
+ return grad_input
62
+
63
+ @staticmethod
64
+ def backward(ctx, gradgrad_input):
65
+ kernel, = ctx.saved_tensors
66
+
67
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
+
69
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
70
+ gradgrad_input,
71
+ kernel,
72
+ ctx.up_x,
73
+ ctx.up_y,
74
+ ctx.down_x,
75
+ ctx.down_y,
76
+ ctx.pad_x0,
77
+ ctx.pad_x1,
78
+ ctx.pad_y0,
79
+ ctx.pad_y1,
80
+ )
81
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
+ gradgrad_out = gradgrad_out.view(
83
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
+ )
85
+
86
+ return gradgrad_out, None, None, None, None, None, None, None, None
87
+
88
+
89
+ class UpFirDn2d(Function):
90
+ @staticmethod
91
+ def forward(ctx, input, kernel, up, down, pad):
92
+ up_x, up_y = up
93
+ down_x, down_y = down
94
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
+
96
+ kernel_h, kernel_w = kernel.shape
97
+ batch, channel, in_h, in_w = input.shape
98
+ ctx.in_size = input.shape
99
+
100
+ input = input.reshape(-1, in_h, in_w, 1)
101
+
102
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
+
104
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
+ ctx.out_size = (out_h, out_w)
107
+
108
+ ctx.up = (up_x, up_y)
109
+ ctx.down = (down_x, down_y)
110
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
+
112
+ g_pad_x0 = kernel_w - pad_x0 - 1
113
+ g_pad_y0 = kernel_h - pad_y0 - 1
114
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
+
117
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
+
119
+ out = upfirdn2d_op.upfirdn2d(
120
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
+ )
122
+ # out = out.view(major, out_h, out_w, minor)
123
+ out = out.view(-1, channel, out_h, out_w)
124
+
125
+ return out
126
+
127
+ @staticmethod
128
+ def backward(ctx, grad_output):
129
+ kernel, grad_kernel = ctx.saved_tensors
130
+
131
+ grad_input = None
132
+
133
+ if ctx.needs_input_grad[0]:
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
+ if not isinstance(up, abc.Iterable):
151
+ up = (up, up)
152
+
153
+ if not isinstance(down, abc.Iterable):
154
+ down = (down, down)
155
+
156
+ if len(pad) == 2:
157
+ pad = (pad[0], pad[1], pad[0], pad[1])
158
+
159
+ if input.device.type == "cpu":
160
+ out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
+
162
+ else:
163
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
+
165
+ return out
166
+
167
+
168
+ def upfirdn2d_native(
169
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
+ ):
171
+ _, channel, in_h, in_w = input.shape
172
+ input = input.reshape(-1, in_h, in_w, 1)
173
+
174
+ _, in_h, in_w, minor = input.shape
175
+ kernel_h, kernel_w = kernel.shape
176
+
177
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
178
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
+
181
+ out = F.pad(
182
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
+ )
184
+ out = out[
185
+ :,
186
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
+ :,
189
+ ]
190
+
191
+ out = out.permute(0, 3, 1, 2)
192
+ out = out.reshape(
193
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
+ )
195
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
+ out = F.conv2d(out, w)
197
+ out = out.reshape(
198
+ -1,
199
+ minor,
200
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
+ )
203
+ out = out.permute(0, 2, 3, 1)
204
+ out = out[:, ::down_y, ::down_x, :]
205
+
206
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
+
209
+ return out.view(-1, channel, out_h, out_w)
stylegan2/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }