fffiloni aaronb commited on
Commit
6b14aab
0 Parent(s):

Duplicate from aaronb/DragGAN

Browse files

Co-authored-by: black <aaronb@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DragGAN
3
+ emoji: ⚡
4
+ colorFrom: pink
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: gradio_app.py
9
+ pinned: false
10
+ duplicated_from: aaronb/DragGAN
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
drag_gan.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ import urllib.request
5
+
6
+ import torch
7
+ import torch.nn.functional as FF
8
+ import torch.optim
9
+ from torchvision import utils
10
+ from tqdm import tqdm
11
+
12
+ from stylegan2.model import Generator
13
+
14
+
15
+ class DownloadProgressBar(tqdm):
16
+ def update_to(self, b=1, bsize=1, tsize=None):
17
+ if tsize is not None:
18
+ self.total = tsize
19
+ self.update(b * bsize - self.n)
20
+
21
+
22
+ def get_path(base_path):
23
+ BASE_DIR = os.path.join('checkpoints')
24
+
25
+ save_path = os.path.join(BASE_DIR, base_path)
26
+ if not os.path.exists(save_path):
27
+ url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
28
+ print(f'{base_path} not found')
29
+ print('Try to download from huggingface: ', url)
30
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
31
+ download_url(url, save_path)
32
+ print('Downloaded to ', save_path)
33
+ return save_path
34
+
35
+
36
+ def download_url(url, output_path):
37
+ with DownloadProgressBar(unit='B', unit_scale=True,
38
+ miniters=1, desc=url.split('/')[-1]) as t:
39
+ urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
40
+
41
+
42
+ class CustomGenerator(Generator):
43
+ def prepare(
44
+ self,
45
+ styles,
46
+ inject_index=None,
47
+ truncation=1,
48
+ truncation_latent=None,
49
+ input_is_latent=False,
50
+ noise=None,
51
+ randomize_noise=True,
52
+ ):
53
+ if not input_is_latent:
54
+ styles = [self.style(s) for s in styles]
55
+
56
+ if noise is None:
57
+ if randomize_noise:
58
+ noise = [None] * self.num_layers
59
+ else:
60
+ noise = [
61
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
62
+ ]
63
+
64
+ if truncation < 1:
65
+ style_t = []
66
+
67
+ for style in styles:
68
+ style_t.append(
69
+ truncation_latent + truncation * (style - truncation_latent)
70
+ )
71
+
72
+ styles = style_t
73
+
74
+ if len(styles) < 2:
75
+ inject_index = self.n_latent
76
+
77
+ if styles[0].ndim < 3:
78
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
79
+
80
+ else:
81
+ latent = styles[0]
82
+
83
+ else:
84
+ if inject_index is None:
85
+ inject_index = random.randint(1, self.n_latent - 1)
86
+
87
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
88
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
89
+
90
+ latent = torch.cat([latent, latent2], 1)
91
+
92
+ return latent, noise
93
+
94
+ def generate(
95
+ self,
96
+ latent,
97
+ noise,
98
+ ):
99
+ out = self.input(latent)
100
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
101
+
102
+ skip = self.to_rgb1(out, latent[:, 1])
103
+ i = 1
104
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
105
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
106
+ ):
107
+ out = conv1(out, latent[:, i], noise=noise1)
108
+ out = conv2(out, latent[:, i + 1], noise=noise2)
109
+ skip = to_rgb(out, latent[:, i + 2], skip)
110
+ if out.shape[-1] == 256: F = out
111
+ i += 2
112
+
113
+ image = skip
114
+ F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
115
+ return image, F
116
+
117
+
118
+ def stylegan2(
119
+ size=1024,
120
+ channel_multiplier=2,
121
+ latent=512,
122
+ n_mlp=8,
123
+ ckpt='stylegan2-ffhq-config-f.pt'
124
+ ):
125
+ g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier)
126
+ checkpoint = torch.load(get_path(ckpt))
127
+ g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
128
+ g_ema.requires_grad_(False)
129
+ g_ema.eval()
130
+ return g_ema
131
+
132
+
133
+ def bilinear_interpolate_torch(im, y, x):
134
+ """
135
+ im : B,C,H,W
136
+ y : 1,numPoints -- pixel location y float
137
+ x : 1,numPOints -- pixel location y float
138
+ """
139
+
140
+ x0 = torch.floor(x).long()
141
+ x1 = x0 + 1
142
+
143
+ y0 = torch.floor(y).long()
144
+ y1 = y0 + 1
145
+
146
+ wa = (x1.float() - x) * (y1.float() - y)
147
+ wb = (x1.float() - x) * (y - y0.float())
148
+ wc = (x - x0.float()) * (y1.float() - y)
149
+ wd = (x - x0.float()) * (y - y0.float())
150
+ # Instead of clamp
151
+ x1 = x1 - torch.floor(x1 / im.shape[3]).int()
152
+ y1 = y1 - torch.floor(y1 / im.shape[2]).int()
153
+ Ia = im[:, :, y0, x0]
154
+ Ib = im[:, :, y1, x0]
155
+ Ic = im[:, :, y0, x1]
156
+ Id = im[:, :, y1, x1]
157
+
158
+ return Ia * wa + Ib * wb + Ic * wc + Id * wd
159
+
160
+
161
+ def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000):
162
+ handle_points0 = copy.deepcopy(handle_points)
163
+ n = len(handle_points)
164
+ r1, r2, lam, d = 3, 12, 20, 1
165
+
166
+ def neighbor(x, y, d):
167
+ points = []
168
+ for i in range(x - d, x + d):
169
+ for j in range(y - d, y + d):
170
+ points.append(torch.tensor([i, j]).float().cuda())
171
+ return points
172
+
173
+ F0 = F.detach().clone()
174
+
175
+ latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
176
+ latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
177
+ optimizer = torch.optim.Adam([latent_trainable], lr=2e-3)
178
+ for iter in range(max_iters):
179
+ for s in range(1):
180
+ optimizer.zero_grad()
181
+ latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
182
+ sample2, F2 = g_ema.generate(latent, noise)
183
+
184
+ # motion supervision
185
+ loss = 0
186
+ for i in range(n):
187
+ pi, ti = handle_points[i], target_points[i]
188
+ di = (ti - pi) / torch.sum((ti - pi)**2)
189
+
190
+ for qi in neighbor(int(pi[0]), int(pi[1]), r1):
191
+ # f1 = F[..., int(qi[0]), int(qi[1])]
192
+ # f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])]
193
+ f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach()
194
+ f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
195
+ loss += FF.l1_loss(f2, f1)
196
+
197
+ loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
198
+
199
+ loss.backward()
200
+ optimizer.step()
201
+
202
+ # point tracking
203
+ with torch.no_grad():
204
+ sample2, F2 = g_ema.generate(latent, noise)
205
+ for i in range(n):
206
+ pi = handle_points0[i]
207
+ # f = F0[..., int(pi[0]), int(pi[1])]
208
+ f0 = bilinear_interpolate_torch(F0, pi[0], pi[1])
209
+ minv = 1e9
210
+ minx = 1e9
211
+ miny = 1e9
212
+ for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
213
+ # f2 = F2[..., int(qi[0]), int(qi[1])]
214
+ try:
215
+ f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
216
+ except:
217
+ import ipdb
218
+ ipdb.set_trace()
219
+ v = torch.norm(f2 - f0, p=1)
220
+ if v < minv:
221
+ minv = v
222
+ minx = int(qi[0])
223
+ miny = int(qi[1])
224
+ handle_points[i][0] = minx
225
+ handle_points[i][1] = miny
226
+
227
+ F = F2.detach().clone()
228
+ if iter % 1 == 0:
229
+ print(iter, loss.item(), handle_points, target_points)
230
+ # p = handle_points[0].int()
231
+ # sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0
232
+ # t = target_points[0].int()
233
+ # sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255
234
+
235
+ # sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0
236
+ # utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))
237
+
238
+ yield sample2, latent, F2, handle_points
gradio_app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import imageio
6
+ from PIL import Image
7
+ import uuid
8
+
9
+ from drag_gan import drag_gan, stylegan2
10
+
11
+ device = 'cuda'
12
+
13
+
14
+ SIZE_TO_CLICK_SIZE = {
15
+ 1024: 5,
16
+ 256: 2
17
+ }
18
+
19
+ CKPT_SIZE = {
20
+ 'stylegan2-ffhq-config-f.pt': 1024,
21
+ 'stylegan2-cat-config-f.pt': 256,
22
+ 'stylegan2-church-config-f.pt': 256,
23
+ 'stylegan2-horse-config-f.pt': 256,
24
+ }
25
+
26
+
27
+ class ImageMask(gr.components.Image):
28
+ """
29
+ Sets: source="canvas", tool="sketch"
30
+ """
31
+
32
+ is_template = True
33
+
34
+ def __init__(self, **kwargs):
35
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
36
+
37
+ def preprocess(self, x):
38
+ if x is None:
39
+ return x
40
+ if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
41
+ decode_image = gr.processing_utils.decode_base64_to_image(x)
42
+ width, height = decode_image.size
43
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
44
+ mask[..., -1] = 255
45
+ mask = self.postprocess(mask)
46
+ x = {'image': x, 'mask': mask}
47
+ return super().preprocess(x)
48
+
49
+
50
+ class ModelWrapper:
51
+ def __init__(self, **kwargs):
52
+ self.g_ema = stylegan2(**kwargs).to(device)
53
+
54
+
55
+ def to_image(tensor):
56
+ tensor = tensor.squeeze(0).permute(1, 2, 0)
57
+ arr = tensor.detach().cpu().numpy()
58
+ arr = (arr - arr.min()) / (arr.max() - arr.min())
59
+ arr = arr * 255
60
+ return arr.astype('uint8')
61
+
62
+
63
+ def add_points_to_image(image, points, size=5):
64
+ h, w, = image.shape[:2]
65
+
66
+ for x, y in points['target']:
67
+ image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [255, 0, 0]
68
+ for x, y in points['handle']:
69
+ image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [0, 0, 255]
70
+
71
+ return image
72
+
73
+
74
+ def on_click(image, target_point, points, size, evt: gr.SelectData):
75
+ if target_point:
76
+ points['target'].append([evt.index[1], evt.index[0]])
77
+ image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
78
+ return image, str(evt.index), not target_point
79
+ points['handle'].append([evt.index[1], evt.index[0]])
80
+ image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
81
+ return image, str(evt.index), not target_point
82
+
83
+
84
+ def on_drag(model, points, max_iters, state, size, mask):
85
+ if len(points['handle']) == 0:
86
+ raise gr.Error('You must select at least one handle point and target point.')
87
+ if len(points['handle']) != len(points['target']):
88
+ raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
89
+ max_iters = int(max_iters)
90
+ latent = state['latent']
91
+ noise = state['noise']
92
+ F = state['F']
93
+
94
+ handle_points = [torch.tensor(p).float() for p in points['handle']]
95
+ target_points = [torch.tensor(p).float() for p in points['target']]
96
+
97
+ mask = Image.fromarray(mask['mask']).convert('L')
98
+ mask = np.array(mask) == 255
99
+
100
+ mask = torch.from_numpy(mask).float().to(device)
101
+ mask = mask.unsqueeze(0).unsqueeze(0)
102
+
103
+ step = 0
104
+ for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
105
+ handle_points, target_points, mask,
106
+ max_iters=max_iters):
107
+ image = to_image(sample2)
108
+
109
+ state['F'] = F
110
+ state['latent'] = latent
111
+ state['sample'] = sample2
112
+ points['handle'] = [p.cpu().numpy().astype('int') for p in handle_points]
113
+ add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
114
+
115
+ state['history'].append(image)
116
+ step += 1
117
+ yield image, state, step
118
+
119
+
120
+ def on_reset(points, image, state):
121
+ return {'target': [], 'handle': []}, to_image(state['sample'])
122
+
123
+
124
+ def on_undo(points, image, state, size):
125
+ image = to_image(state['sample'])
126
+
127
+ if len(points['target']) < len(points['handle']):
128
+ points['handle'] = points['handle'][:-1]
129
+ else:
130
+ points['handle'] = points['handle'][:-1]
131
+ points['target'] = points['target'][:-1]
132
+
133
+ add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
134
+ return points, image
135
+
136
+
137
+ def on_change_model(selected, model):
138
+ size = CKPT_SIZE[selected]
139
+ model = ModelWrapper(size=size, ckpt=selected)
140
+ g_ema = model.g_ema
141
+ sample_z = torch.randn([1, 512], device=device)
142
+ latent, noise = g_ema.prepare([sample_z])
143
+ sample, F = g_ema.generate(latent, noise)
144
+
145
+ state = {
146
+ 'latent': latent,
147
+ 'noise': noise,
148
+ 'F': F,
149
+ 'sample': sample,
150
+ 'history': []
151
+ }
152
+ return model, state, to_image(sample), size
153
+
154
+
155
+ def on_new_image(model):
156
+ g_ema = model.g_ema
157
+ sample_z = torch.randn([1, 512], device=device)
158
+ latent, noise = g_ema.prepare([sample_z])
159
+ sample, F = g_ema.generate(latent, noise)
160
+
161
+ state = {
162
+ 'latent': latent,
163
+ 'noise': noise,
164
+ 'F': F,
165
+ 'sample': sample,
166
+ 'history': []
167
+ }
168
+ points = {'target': [], 'handle': []}
169
+ target_point = False
170
+ return to_image(sample), to_image(sample), state, points, target_point
171
+
172
+
173
+ def on_max_iter_change(max_iters):
174
+ return gr.update(maximum=max_iters)
175
+
176
+
177
+ def on_save_files(image, state):
178
+ os.makedirs('tmp', exist_ok=True)
179
+ image_name = f'tmp/image_{uuid.uuid4()}.png'
180
+ video_name = f'tmp/video_{uuid.uuid4()}.mp4'
181
+ imageio.imsave(image_name, image)
182
+ imageio.mimsave(video_name, state['history'])
183
+ return [image_name, video_name]
184
+
185
+
186
+ def on_show_save():
187
+ return gr.update(visible=True)
188
+
189
+
190
+ def main():
191
+ torch.cuda.manual_seed(25)
192
+
193
+ with gr.Blocks() as demo:
194
+ wrapped_model = ModelWrapper()
195
+ model = gr.State(wrapped_model)
196
+ sample_z = torch.randn([1, 512], device=device)
197
+ latent, noise = wrapped_model.g_ema.prepare([sample_z])
198
+ sample, F = wrapped_model.g_ema.generate(latent, noise)
199
+
200
+ gr.Markdown(
201
+ """
202
+ # DragGAN (Unofficial)
203
+
204
+ Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
205
+
206
+ [Github](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
207
+
208
+ ## Tutorial
209
+
210
+ 1. (Optional) Draw a mask indicate the movable region.
211
+ 2. Setup a least one pair of handle point and target point.
212
+ 3. Click "Drag it".
213
+
214
+ """,
215
+ )
216
+ state = gr.State({
217
+ 'latent': latent,
218
+ 'noise': noise,
219
+ 'F': F,
220
+ 'sample': sample,
221
+ 'history': []
222
+ })
223
+ points = gr.State({'target': [], 'handle': []})
224
+ size = gr.State(1024)
225
+
226
+ with gr.Row():
227
+ with gr.Column(scale=0.3):
228
+ with gr.Accordion("Model"):
229
+ model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value='stylegan2-ffhq-config-f.pt',
230
+ label='StyleGAN2 model')
231
+ max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')
232
+ new_btn = gr.Button('New Image')
233
+ with gr.Accordion('Drag'):
234
+ with gr.Row():
235
+ with gr.Column(min_width=100):
236
+ text = gr.Textbox(label='Selected Point', interactive=False)
237
+ with gr.Column(min_width=100):
238
+ target_point = gr.Checkbox(label='Target Point', interactive=False)
239
+ with gr.Row():
240
+ with gr.Column(min_width=100):
241
+ reset_btn = gr.Button('Reset All')
242
+ with gr.Column(min_width=100):
243
+ undo_btn = gr.Button('Undo Last')
244
+ with gr.Row():
245
+ btn = gr.Button('Drag it', variant='primary')
246
+
247
+ with gr.Accordion('Save', visible=False) as save_panel:
248
+ files = gr.Files(value=[])
249
+
250
+ progress = gr.Slider(value=0, maximum=20, label='Progress', interactive=False)
251
+
252
+ with gr.Column():
253
+ with gr.Tabs():
254
+ with gr.Tab('Draw a Mask', id='mask'):
255
+ mask = gr.ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768)
256
+ with gr.Tab('Setup Handle Points', id='input'):
257
+ image = gr.Image(to_image(sample)).style(height=768, width=768)
258
+
259
+ image.select(on_click, [image, target_point, points, size], [image, text, target_point])
260
+ btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then(
261
+ on_show_save, outputs=save_panel).then(
262
+ on_save_files, inputs=[image, state], outputs=[files]
263
+ )
264
+ reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image])
265
+ undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image])
266
+ model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, size])
267
+ new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point])
268
+ max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
269
+ return demo
270
+
271
+
272
+ if __name__ == '__main__':
273
+ import fire
274
+ demo = main()
275
+ fire.Fire(demo.queue(concurrency_count=1, max_size=20).launch)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tqdm
3
+ torch
4
+ numpy
5
+ ninja
6
+ fire
7
+ imageio
8
+ torchvision
stylegan2/__pycache__/model.cpython-37.pyc ADDED
Binary file (16 kB). View file
stylegan2/_init__.py ADDED
File without changes
stylegan2/model.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
9
+
10
+
11
+ class PixelNorm(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, input):
16
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
17
+
18
+
19
+ def make_kernel(k):
20
+ k = torch.tensor(k, dtype=torch.float32)
21
+
22
+ if k.ndim == 1:
23
+ k = k[None, :] * k[:, None]
24
+
25
+ k /= k.sum()
26
+
27
+ return k
28
+
29
+
30
+ class Upsample(nn.Module):
31
+ def __init__(self, kernel, factor=2):
32
+ super().__init__()
33
+
34
+ self.factor = factor
35
+ kernel = make_kernel(kernel) * (factor ** 2)
36
+ self.register_buffer("kernel", kernel)
37
+
38
+ p = kernel.shape[0] - factor
39
+
40
+ pad0 = (p + 1) // 2 + factor - 1
41
+ pad1 = p // 2
42
+
43
+ self.pad = (pad0, pad1)
44
+
45
+ def forward(self, input):
46
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
47
+
48
+ return out
49
+
50
+
51
+ class Downsample(nn.Module):
52
+ def __init__(self, kernel, factor=2):
53
+ super().__init__()
54
+
55
+ self.factor = factor
56
+ kernel = make_kernel(kernel)
57
+ self.register_buffer("kernel", kernel)
58
+
59
+ p = kernel.shape[0] - factor
60
+
61
+ pad0 = (p + 1) // 2
62
+ pad1 = p // 2
63
+
64
+ self.pad = (pad0, pad1)
65
+
66
+ def forward(self, input):
67
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
68
+
69
+ return out
70
+
71
+
72
+ class Blur(nn.Module):
73
+ def __init__(self, kernel, pad, upsample_factor=1):
74
+ super().__init__()
75
+
76
+ kernel = make_kernel(kernel)
77
+
78
+ if upsample_factor > 1:
79
+ kernel = kernel * (upsample_factor ** 2)
80
+
81
+ self.register_buffer("kernel", kernel)
82
+
83
+ self.pad = pad
84
+
85
+ def forward(self, input):
86
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
87
+
88
+ return out
89
+
90
+
91
+ class EqualConv2d(nn.Module):
92
+ def __init__(
93
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
94
+ ):
95
+ super().__init__()
96
+
97
+ self.weight = nn.Parameter(
98
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
99
+ )
100
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
101
+
102
+ self.stride = stride
103
+ self.padding = padding
104
+
105
+ if bias:
106
+ self.bias = nn.Parameter(torch.zeros(out_channel))
107
+
108
+ else:
109
+ self.bias = None
110
+
111
+ def forward(self, input):
112
+ out = conv2d_gradfix.conv2d(
113
+ input,
114
+ self.weight * self.scale,
115
+ bias=self.bias,
116
+ stride=self.stride,
117
+ padding=self.padding,
118
+ )
119
+
120
+ return out
121
+
122
+ def __repr__(self):
123
+ return (
124
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
125
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
126
+ )
127
+
128
+
129
+ class EqualLinear(nn.Module):
130
+ def __init__(
131
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
132
+ ):
133
+ super().__init__()
134
+
135
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
136
+
137
+ if bias:
138
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
139
+
140
+ else:
141
+ self.bias = None
142
+
143
+ self.activation = activation
144
+
145
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
146
+ self.lr_mul = lr_mul
147
+
148
+ def forward(self, input):
149
+ if self.activation:
150
+ out = F.linear(input, self.weight * self.scale)
151
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
152
+
153
+ else:
154
+ out = F.linear(
155
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
156
+ )
157
+
158
+ return out
159
+
160
+ def __repr__(self):
161
+ return (
162
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
163
+ )
164
+
165
+
166
+ class ModulatedConv2d(nn.Module):
167
+ def __init__(
168
+ self,
169
+ in_channel,
170
+ out_channel,
171
+ kernel_size,
172
+ style_dim,
173
+ demodulate=True,
174
+ upsample=False,
175
+ downsample=False,
176
+ blur_kernel=[1, 3, 3, 1],
177
+ fused=True,
178
+ ):
179
+ super().__init__()
180
+
181
+ self.eps = 1e-8
182
+ self.kernel_size = kernel_size
183
+ self.in_channel = in_channel
184
+ self.out_channel = out_channel
185
+ self.upsample = upsample
186
+ self.downsample = downsample
187
+
188
+ if upsample:
189
+ factor = 2
190
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
191
+ pad0 = (p + 1) // 2 + factor - 1
192
+ pad1 = p // 2 + 1
193
+
194
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
195
+
196
+ if downsample:
197
+ factor = 2
198
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
199
+ pad0 = (p + 1) // 2
200
+ pad1 = p // 2
201
+
202
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
203
+
204
+ fan_in = in_channel * kernel_size ** 2
205
+ self.scale = 1 / math.sqrt(fan_in)
206
+ self.padding = kernel_size // 2
207
+
208
+ self.weight = nn.Parameter(
209
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
210
+ )
211
+
212
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
213
+
214
+ self.demodulate = demodulate
215
+ self.fused = fused
216
+
217
+ def __repr__(self):
218
+ return (
219
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
220
+ f"upsample={self.upsample}, downsample={self.downsample})"
221
+ )
222
+
223
+ def forward(self, input, style):
224
+ batch, in_channel, height, width = input.shape
225
+
226
+ if not self.fused:
227
+ weight = self.scale * self.weight.squeeze(0)
228
+ style = self.modulation(style)
229
+
230
+ if self.demodulate:
231
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
232
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
233
+
234
+ input = input * style.reshape(batch, in_channel, 1, 1)
235
+
236
+ if self.upsample:
237
+ weight = weight.transpose(0, 1)
238
+ out = conv2d_gradfix.conv_transpose2d(
239
+ input, weight, padding=0, stride=2
240
+ )
241
+ out = self.blur(out)
242
+
243
+ elif self.downsample:
244
+ input = self.blur(input)
245
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
246
+
247
+ else:
248
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
249
+
250
+ if self.demodulate:
251
+ out = out * dcoefs.view(batch, -1, 1, 1)
252
+
253
+ return out
254
+
255
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
256
+ weight = self.scale * self.weight * style
257
+
258
+ if self.demodulate:
259
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
260
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
261
+
262
+ weight = weight.view(
263
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
264
+ )
265
+
266
+ if self.upsample:
267
+ input = input.view(1, batch * in_channel, height, width)
268
+ weight = weight.view(
269
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
270
+ )
271
+ weight = weight.transpose(1, 2).reshape(
272
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
273
+ )
274
+ out = conv2d_gradfix.conv_transpose2d(
275
+ input, weight, padding=0, stride=2, groups=batch
276
+ )
277
+ _, _, height, width = out.shape
278
+ out = out.view(batch, self.out_channel, height, width)
279
+ out = self.blur(out)
280
+
281
+ elif self.downsample:
282
+ input = self.blur(input)
283
+ _, _, height, width = input.shape
284
+ input = input.view(1, batch * in_channel, height, width)
285
+ out = conv2d_gradfix.conv2d(
286
+ input, weight, padding=0, stride=2, groups=batch
287
+ )
288
+ _, _, height, width = out.shape
289
+ out = out.view(batch, self.out_channel, height, width)
290
+
291
+ else:
292
+ input = input.view(1, batch * in_channel, height, width)
293
+ out = conv2d_gradfix.conv2d(
294
+ input, weight, padding=self.padding, groups=batch
295
+ )
296
+ _, _, height, width = out.shape
297
+ out = out.view(batch, self.out_channel, height, width)
298
+
299
+ return out
300
+
301
+
302
+ class NoiseInjection(nn.Module):
303
+ def __init__(self):
304
+ super().__init__()
305
+
306
+ self.weight = nn.Parameter(torch.zeros(1))
307
+
308
+ def forward(self, image, noise=None):
309
+ if noise is None:
310
+ batch, _, height, width = image.shape
311
+ noise = image.new_empty(batch, 1, height, width).normal_()
312
+
313
+ return image + self.weight * noise
314
+
315
+
316
+ class ConstantInput(nn.Module):
317
+ def __init__(self, channel, size=4):
318
+ super().__init__()
319
+
320
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
321
+
322
+ def forward(self, input):
323
+ batch = input.shape[0]
324
+ out = self.input.repeat(batch, 1, 1, 1)
325
+
326
+ return out
327
+
328
+
329
+ class StyledConv(nn.Module):
330
+ def __init__(
331
+ self,
332
+ in_channel,
333
+ out_channel,
334
+ kernel_size,
335
+ style_dim,
336
+ upsample=False,
337
+ blur_kernel=[1, 3, 3, 1],
338
+ demodulate=True,
339
+ ):
340
+ super().__init__()
341
+
342
+ self.conv = ModulatedConv2d(
343
+ in_channel,
344
+ out_channel,
345
+ kernel_size,
346
+ style_dim,
347
+ upsample=upsample,
348
+ blur_kernel=blur_kernel,
349
+ demodulate=demodulate,
350
+ )
351
+
352
+ self.noise = NoiseInjection()
353
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
354
+ # self.activate = ScaledLeakyReLU(0.2)
355
+ self.activate = FusedLeakyReLU(out_channel)
356
+
357
+ def forward(self, input, style, noise=None):
358
+ out = self.conv(input, style)
359
+ out = self.noise(out, noise=noise)
360
+ # out = out + self.bias
361
+ out = self.activate(out)
362
+
363
+ return out
364
+
365
+
366
+ class ToRGB(nn.Module):
367
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
368
+ super().__init__()
369
+
370
+ if upsample:
371
+ self.upsample = Upsample(blur_kernel)
372
+
373
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
374
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
375
+
376
+ def forward(self, input, style, skip=None):
377
+ out = self.conv(input, style)
378
+ out = out + self.bias
379
+
380
+ if skip is not None:
381
+ skip = self.upsample(skip)
382
+
383
+ out = out + skip
384
+
385
+ return out
386
+
387
+
388
+ class Generator(nn.Module):
389
+ def __init__(
390
+ self,
391
+ size,
392
+ style_dim,
393
+ n_mlp,
394
+ channel_multiplier=2,
395
+ blur_kernel=[1, 3, 3, 1],
396
+ lr_mlp=0.01,
397
+ ):
398
+ super().__init__()
399
+
400
+ self.size = size
401
+
402
+ self.style_dim = style_dim
403
+
404
+ layers = [PixelNorm()]
405
+
406
+ for i in range(n_mlp):
407
+ layers.append(
408
+ EqualLinear(
409
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
410
+ )
411
+ )
412
+
413
+ self.style = nn.Sequential(*layers)
414
+
415
+ self.channels = {
416
+ 4: 512,
417
+ 8: 512,
418
+ 16: 512,
419
+ 32: 512,
420
+ 64: 256 * channel_multiplier,
421
+ 128: 128 * channel_multiplier,
422
+ 256: 64 * channel_multiplier,
423
+ 512: 32 * channel_multiplier,
424
+ 1024: 16 * channel_multiplier,
425
+ }
426
+
427
+ self.input = ConstantInput(self.channels[4])
428
+ self.conv1 = StyledConv(
429
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
430
+ )
431
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
432
+
433
+ self.log_size = int(math.log(size, 2))
434
+ self.num_layers = (self.log_size - 2) * 2 + 1
435
+
436
+ self.convs = nn.ModuleList()
437
+ self.upsamples = nn.ModuleList()
438
+ self.to_rgbs = nn.ModuleList()
439
+ self.noises = nn.Module()
440
+
441
+ in_channel = self.channels[4]
442
+
443
+ for layer_idx in range(self.num_layers):
444
+ res = (layer_idx + 5) // 2
445
+ shape = [1, 1, 2 ** res, 2 ** res]
446
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
447
+
448
+ for i in range(3, self.log_size + 1):
449
+ out_channel = self.channels[2 ** i]
450
+
451
+ self.convs.append(
452
+ StyledConv(
453
+ in_channel,
454
+ out_channel,
455
+ 3,
456
+ style_dim,
457
+ upsample=True,
458
+ blur_kernel=blur_kernel,
459
+ )
460
+ )
461
+
462
+ self.convs.append(
463
+ StyledConv(
464
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
465
+ )
466
+ )
467
+
468
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
469
+
470
+ in_channel = out_channel
471
+
472
+ self.n_latent = self.log_size * 2 - 2
473
+
474
+ def make_noise(self):
475
+ device = self.input.input.device
476
+
477
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
478
+
479
+ for i in range(3, self.log_size + 1):
480
+ for _ in range(2):
481
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
482
+
483
+ return noises
484
+
485
+ def mean_latent(self, n_latent):
486
+ latent_in = torch.randn(
487
+ n_latent, self.style_dim, device=self.input.input.device
488
+ )
489
+ latent = self.style(latent_in).mean(0, keepdim=True)
490
+
491
+ return latent
492
+
493
+ def get_latent(self, input):
494
+ return self.style(input)
495
+
496
+ def forward(
497
+ self,
498
+ styles,
499
+ return_latents=False,
500
+ inject_index=None,
501
+ truncation=1,
502
+ truncation_latent=None,
503
+ input_is_latent=False,
504
+ noise=None,
505
+ randomize_noise=True,
506
+ ):
507
+ if not input_is_latent:
508
+ styles = [self.style(s) for s in styles]
509
+
510
+ if noise is None:
511
+ if randomize_noise:
512
+ noise = [None] * self.num_layers
513
+ else:
514
+ noise = [
515
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
516
+ ]
517
+
518
+ if truncation < 1:
519
+ style_t = []
520
+
521
+ for style in styles:
522
+ style_t.append(
523
+ truncation_latent + truncation * (style - truncation_latent)
524
+ )
525
+
526
+ styles = style_t
527
+
528
+ if len(styles) < 2:
529
+ inject_index = self.n_latent
530
+
531
+ if styles[0].ndim < 3:
532
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
533
+
534
+ else:
535
+ latent = styles[0]
536
+
537
+ else:
538
+ if inject_index is None:
539
+ inject_index = random.randint(1, self.n_latent - 1)
540
+
541
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
542
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
543
+
544
+ latent = torch.cat([latent, latent2], 1)
545
+
546
+ out = self.input(latent)
547
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
548
+
549
+ skip = self.to_rgb1(out, latent[:, 1])
550
+
551
+ i = 1
552
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
553
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
554
+ ):
555
+ out = conv1(out, latent[:, i], noise=noise1)
556
+ out = conv2(out, latent[:, i + 1], noise=noise2)
557
+ skip = to_rgb(out, latent[:, i + 2], skip)
558
+
559
+ i += 2
560
+
561
+
562
+ image = skip
563
+
564
+ if return_latents:
565
+ return image, latent
566
+
567
+ else:
568
+ return image, None
569
+
570
+
571
+ class ConvLayer(nn.Sequential):
572
+ def __init__(
573
+ self,
574
+ in_channel,
575
+ out_channel,
576
+ kernel_size,
577
+ downsample=False,
578
+ blur_kernel=[1, 3, 3, 1],
579
+ bias=True,
580
+ activate=True,
581
+ ):
582
+ layers = []
583
+
584
+ if downsample:
585
+ factor = 2
586
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
587
+ pad0 = (p + 1) // 2
588
+ pad1 = p // 2
589
+
590
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
591
+
592
+ stride = 2
593
+ self.padding = 0
594
+
595
+ else:
596
+ stride = 1
597
+ self.padding = kernel_size // 2
598
+
599
+ layers.append(
600
+ EqualConv2d(
601
+ in_channel,
602
+ out_channel,
603
+ kernel_size,
604
+ padding=self.padding,
605
+ stride=stride,
606
+ bias=bias and not activate,
607
+ )
608
+ )
609
+
610
+ if activate:
611
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
612
+
613
+ super().__init__(*layers)
614
+
615
+
616
+ class ResBlock(nn.Module):
617
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
618
+ super().__init__()
619
+
620
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
621
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
622
+
623
+ self.skip = ConvLayer(
624
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
625
+ )
626
+
627
+ def forward(self, input):
628
+ out = self.conv1(input)
629
+ out = self.conv2(out)
630
+
631
+ skip = self.skip(input)
632
+ out = (out + skip) / math.sqrt(2)
633
+
634
+ return out
635
+
636
+
637
+ class Discriminator(nn.Module):
638
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
639
+ super().__init__()
640
+
641
+ channels = {
642
+ 4: 512,
643
+ 8: 512,
644
+ 16: 512,
645
+ 32: 512,
646
+ 64: 256 * channel_multiplier,
647
+ 128: 128 * channel_multiplier,
648
+ 256: 64 * channel_multiplier,
649
+ 512: 32 * channel_multiplier,
650
+ 1024: 16 * channel_multiplier,
651
+ }
652
+
653
+ convs = [ConvLayer(3, channels[size], 1)]
654
+
655
+ log_size = int(math.log(size, 2))
656
+
657
+ in_channel = channels[size]
658
+
659
+ for i in range(log_size, 2, -1):
660
+ out_channel = channels[2 ** (i - 1)]
661
+
662
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
663
+
664
+ in_channel = out_channel
665
+
666
+ self.convs = nn.Sequential(*convs)
667
+
668
+ self.stddev_group = 4
669
+ self.stddev_feat = 1
670
+
671
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
672
+ self.final_linear = nn.Sequential(
673
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
674
+ EqualLinear(channels[4], 1),
675
+ )
676
+
677
+ def forward(self, input):
678
+ out = self.convs(input)
679
+
680
+ batch, channel, height, width = out.shape
681
+ group = min(batch, self.stddev_group)
682
+ stddev = out.view(
683
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
684
+ )
685
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
686
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
687
+ stddev = stddev.repeat(group, 1, height, width)
688
+ out = torch.cat([out, stddev], 1)
689
+
690
+ out = self.final_conv(out)
691
+
692
+ out = out.view(batch, -1)
693
+ out = self.final_linear(out)
694
+
695
+ return out
696
+
stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
stylegan2/op/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (286 Bytes). View file
stylegan2/op/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (284 Bytes). View file
stylegan2/op/__pycache__/conv2d_gradfix.cpython-37.pyc ADDED
Binary file (5.29 kB). View file
stylegan2/op/__pycache__/conv2d_gradfix.cpython-38.pyc ADDED
Binary file (5.36 kB). View file
stylegan2/op/__pycache__/fused_act.cpython-37.pyc ADDED
Binary file (3.24 kB). View file
stylegan2/op/__pycache__/fused_act.cpython-38.pyc ADDED
Binary file (3.3 kB). View file
stylegan2/op/__pycache__/upfirdn2d.cpython-37.pyc ADDED
Binary file (4.29 kB). View file
stylegan2/op/__pycache__/upfirdn2d.cpython-38.pyc ADDED
Binary file (4.36 kB). View file
stylegan2/op/conv2d_gradfix.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ if bias:
39
+ grad_bias = grad_input.sum(dim).detach()
40
+
41
+ else:
42
+ grad_bias = empty
43
+
44
+ return grad_input, grad_bias
45
+
46
+ @staticmethod
47
+ def backward(ctx, gradgrad_input, gradgrad_bias):
48
+ out, = ctx.saved_tensors
49
+ gradgrad_out = fused.fused_bias_act(
50
+ gradgrad_input.contiguous(),
51
+ gradgrad_bias,
52
+ out,
53
+ 3,
54
+ 1,
55
+ ctx.negative_slope,
56
+ ctx.scale,
57
+ )
58
+
59
+ return gradgrad_out, None, None, None, None
60
+
61
+
62
+ class FusedLeakyReLUFunction(Function):
63
+ @staticmethod
64
+ def forward(ctx, input, bias, negative_slope, scale):
65
+ empty = input.new_empty(0)
66
+
67
+ ctx.bias = bias is not None
68
+
69
+ if bias is None:
70
+ bias = empty
71
+
72
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73
+ ctx.save_for_backward(out)
74
+ ctx.negative_slope = negative_slope
75
+ ctx.scale = scale
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output):
81
+ out, = ctx.saved_tensors
82
+
83
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85
+ )
86
+
87
+ if not ctx.bias:
88
+ grad_bias = None
89
+
90
+ return grad_input, grad_bias, None, None
91
+
92
+
93
+ class FusedLeakyReLU(nn.Module):
94
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95
+ super().__init__()
96
+
97
+ if bias:
98
+ self.bias = nn.Parameter(torch.zeros(channel))
99
+
100
+ else:
101
+ self.bias = None
102
+
103
+ self.negative_slope = negative_slope
104
+ self.scale = scale
105
+
106
+ def forward(self, input):
107
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108
+
109
+
110
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111
+ if input.device.type == "cpu":
112
+ if bias is not None:
113
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
114
+ return (
115
+ F.leaky_relu(
116
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117
+ )
118
+ * scale
119
+ )
120
+
121
+ else:
122
+ return F.leaky_relu(input, negative_slope=0.2) * scale
123
+
124
+ else:
125
+ return FusedLeakyReLUFunction.apply(
126
+ input.contiguous(), bias, negative_slope, scale
127
+ )
stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <ATen/ATen.h>
3
+ #include <torch/extension.h>
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
+ const torch::Tensor &bias,
7
+ const torch::Tensor &refer, int act, int grad,
8
+ float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) \
11
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) \
13
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) \
15
+ CHECK_CUDA(x); \
16
+ CHECK_CONTIGUOUS(x)
17
+
18
+ torch::Tensor fused_bias_act(const torch::Tensor &input,
19
+ const torch::Tensor &bias,
20
+ const torch::Tensor &refer, int act, int grad,
21
+ float alpha, float scale) {
22
+ CHECK_INPUT(input);
23
+ CHECK_INPUT(bias);
24
+
25
+ at::DeviceGuard guard(input.device());
26
+
27
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
+ }
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
+ }
stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void
20
+ fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
+ const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
+ scalar_t scale, int loop_x, int size_x, int step_b,
23
+ int size_b, int use_bias, int use_ref) {
24
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
+
26
+ scalar_t zero = 0.0;
27
+
28
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
+ loop_idx++, xi += blockDim.x) {
30
+ scalar_t x = p_x[xi];
31
+
32
+ if (use_bias) {
33
+ x += p_b[(xi / step_b) % size_b];
34
+ }
35
+
36
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
37
+
38
+ scalar_t y;
39
+
40
+ switch (act * 10 + grad) {
41
+ default:
42
+ case 10:
43
+ y = x;
44
+ break;
45
+ case 11:
46
+ y = x;
47
+ break;
48
+ case 12:
49
+ y = 0.0;
50
+ break;
51
+
52
+ case 30:
53
+ y = (x > 0.0) ? x : x * alpha;
54
+ break;
55
+ case 31:
56
+ y = (ref > 0.0) ? x : x * alpha;
57
+ break;
58
+ case 32:
59
+ y = 0.0;
60
+ break;
61
+ }
62
+
63
+ out[xi] = y * scale;
64
+ }
65
+ }
66
+
67
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
+ const torch::Tensor &bias,
69
+ const torch::Tensor &refer, int act, int grad,
70
+ float alpha, float scale) {
71
+ int curDevice = -1;
72
+ cudaGetDevice(&curDevice);
73
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
+
75
+ auto x = input.contiguous();
76
+ auto b = bias.contiguous();
77
+ auto ref = refer.contiguous();
78
+
79
+ int use_bias = b.numel() ? 1 : 0;
80
+ int use_ref = ref.numel() ? 1 : 0;
81
+
82
+ int size_x = x.numel();
83
+ int size_b = b.numel();
84
+ int step_b = 1;
85
+
86
+ for (int i = 1 + 1; i < x.dim(); i++) {
87
+ step_b *= x.size(i);
88
+ }
89
+
90
+ int loop_x = 4;
91
+ int block_size = 4 * 32;
92
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
+
94
+ auto y = torch::empty_like(x);
95
+
96
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
+ x.scalar_type(), "fused_bias_act_kernel", [&] {
98
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
+ y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
+ b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
+ scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
+ });
103
+
104
+ return y;
105
+ }
stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
+ const torch::Tensor &kernel, int up_x, int up_y,
6
+ int down_x, int down_y, int pad_x0, int pad_x1,
7
+ int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) \
10
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) \
12
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x)
16
+
17
+ torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
+ int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
+ int pad_x1, int pad_y0, int pad_y1) {
20
+ CHECK_INPUT(input);
21
+ CHECK_INPUT(kernel);
22
+
23
+ at::DeviceGuard guard(input.device());
24
+
25
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
+ pad_y0, pad_y1);
27
+ }
28
+
29
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
+ }
stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ upfirdn2d_op = load(
12
+ "upfirdn2d",
13
+ sources=[
14
+ os.path.join(module_path, "upfirdn2d.cpp"),
15
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class UpFirDn2dBackward(Function):
21
+ @staticmethod
22
+ def forward(
23
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
+ ):
25
+
26
+ up_x, up_y = up
27
+ down_x, down_y = down
28
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
+
30
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
+
32
+ grad_input = upfirdn2d_op.upfirdn2d(
33
+ grad_output,
34
+ grad_kernel,
35
+ down_x,
36
+ down_y,
37
+ up_x,
38
+ up_y,
39
+ g_pad_x0,
40
+ g_pad_x1,
41
+ g_pad_y0,
42
+ g_pad_y1,
43
+ )
44
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
+
46
+ ctx.save_for_backward(kernel)
47
+
48
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
+
50
+ ctx.up_x = up_x
51
+ ctx.up_y = up_y
52
+ ctx.down_x = down_x
53
+ ctx.down_y = down_y
54
+ ctx.pad_x0 = pad_x0
55
+ ctx.pad_x1 = pad_x1
56
+ ctx.pad_y0 = pad_y0
57
+ ctx.pad_y1 = pad_y1
58
+ ctx.in_size = in_size
59
+ ctx.out_size = out_size
60
+
61
+ return grad_input
62
+
63
+ @staticmethod
64
+ def backward(ctx, gradgrad_input):
65
+ kernel, = ctx.saved_tensors
66
+
67
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
+
69
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
70
+ gradgrad_input,
71
+ kernel,
72
+ ctx.up_x,
73
+ ctx.up_y,
74
+ ctx.down_x,
75
+ ctx.down_y,
76
+ ctx.pad_x0,
77
+ ctx.pad_x1,
78
+ ctx.pad_y0,
79
+ ctx.pad_y1,
80
+ )
81
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
+ gradgrad_out = gradgrad_out.view(
83
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
+ )
85
+
86
+ return gradgrad_out, None, None, None, None, None, None, None, None
87
+
88
+
89
+ class UpFirDn2d(Function):
90
+ @staticmethod
91
+ def forward(ctx, input, kernel, up, down, pad):
92
+ up_x, up_y = up
93
+ down_x, down_y = down
94
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
+
96
+ kernel_h, kernel_w = kernel.shape
97
+ batch, channel, in_h, in_w = input.shape
98
+ ctx.in_size = input.shape
99
+
100
+ input = input.reshape(-1, in_h, in_w, 1)
101
+
102
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
+
104
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
+ ctx.out_size = (out_h, out_w)
107
+
108
+ ctx.up = (up_x, up_y)
109
+ ctx.down = (down_x, down_y)
110
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
+
112
+ g_pad_x0 = kernel_w - pad_x0 - 1
113
+ g_pad_y0 = kernel_h - pad_y0 - 1
114
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
+
117
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
+
119
+ out = upfirdn2d_op.upfirdn2d(
120
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
+ )
122
+ # out = out.view(major, out_h, out_w, minor)
123
+ out = out.view(-1, channel, out_h, out_w)
124
+
125
+ return out
126
+
127
+ @staticmethod
128
+ def backward(ctx, grad_output):
129
+ kernel, grad_kernel = ctx.saved_tensors
130
+
131
+ grad_input = None
132
+
133
+ if ctx.needs_input_grad[0]:
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
+ if not isinstance(up, abc.Iterable):
151
+ up = (up, up)
152
+
153
+ if not isinstance(down, abc.Iterable):
154
+ down = (down, down)
155
+
156
+ if len(pad) == 2:
157
+ pad = (pad[0], pad[1], pad[0], pad[1])
158
+
159
+ if input.device.type == "cpu":
160
+ out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
+
162
+ else:
163
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
+
165
+ return out
166
+
167
+
168
+ def upfirdn2d_native(
169
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
+ ):
171
+ _, channel, in_h, in_w = input.shape
172
+ input = input.reshape(-1, in_h, in_w, 1)
173
+
174
+ _, in_h, in_w, minor = input.shape
175
+ kernel_h, kernel_w = kernel.shape
176
+
177
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
178
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
+
181
+ out = F.pad(
182
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
+ )
184
+ out = out[
185
+ :,
186
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
+ :,
189
+ ]
190
+
191
+ out = out.permute(0, 3, 1, 2)
192
+ out = out.reshape(
193
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
+ )
195
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
+ out = F.conv2d(out, w)
197
+ out = out.reshape(
198
+ -1,
199
+ minor,
200
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
+ )
203
+ out = out.permute(0, 2, 3, 1)
204
+ out = out[:, ::down_y, ::down_x, :]
205
+
206
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
+
209
+ return out.view(-1, channel, out_h, out_w)
stylegan2/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }