rinong commited on
Commit
3a19a1a
1 Parent(s): b879b4e

Updated app, added video generation and model files

Browse files
Files changed (3) hide show
  1. app.py +205 -1
  2. generate_videos.py +259 -0
  3. model/sg2_model.py +780 -0
app.py CHANGED
@@ -1,3 +1,207 @@
 
 
 
1
  import gradio as gr
2
 
3
- gr.Interface.load("spaces/eugenesiow/remove-bg").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
  import gradio as gr
5
 
6
+ import os
7
+ import sys
8
+ import numpy as np
9
+
10
+ from e4e.models.psp import pSp
11
+ from util import *
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ import os
15
+ import sys
16
+ import tempfile
17
+ import shutil
18
+ from argparse import Namespace
19
+ from pathlib import Path
20
+
21
+ import dlib
22
+ import numpy as np
23
+ import torchvision.transforms as transforms
24
+ from torchvision import utils
25
+ from PIL import Image
26
+
27
+ from model.sg2_model import Generator
28
+ from generate_videos import generate_frames, video_from_interpolations, vid_to_gif
29
+
30
+ model_dir = "models"
31
+ os.makedirs(model_dir, exist_ok=True)
32
+
33
+ models_and_paths = {"akhaliq/JoJoGAN_e4e_ffhq_encode": "e4e_ffhq_encode.pt",
34
+ "akhaliq/jojogan_dlib": "shape_predictor_68_face_landmarks.dat",
35
+ "akhaliq/jojogan-stylegan2-ffhq-config-f": f"{model_dir}/base.pt"}
36
+
37
+ def get_models():
38
+ for repo_id, file_path in models_and_paths:
39
+ hf_hub_download(repo_id=repo_id, filename=file_path)
40
+
41
+ model_list = ['base'] + [Path(model_ckpt).stem for model_ckpt in os.listdir(model_dir) if not 'base' in model_ckpt]
42
+
43
+ return model_list
44
+
45
+ model_list = get_models()
46
+
47
+ class ImageEditor(object):
48
+ def __init__(self):
49
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ latent_size = 512
52
+ n_mlp = 8
53
+ channel_mult = 2
54
+ model_size = 1024
55
+
56
+ self.generators = {}
57
+
58
+ for model in model_list:
59
+ g_ema = Generator(
60
+ model_size, latent_size, n_mlp, channel_multiplier=channel_mult
61
+ ).to(self.device)
62
+
63
+ checkpoint = torch.load(f"models/{model}.pt")
64
+
65
+ g_ema.load_state_dict(checkpoint['g_ema'])
66
+
67
+ self.generators[model] = g_ema
68
+
69
+ self.experiment_args = {"model_path": "e4e_ffhq_encode.pt"}
70
+ self.experiment_args["transform"] = transforms.Compose(
71
+ [
72
+ transforms.Resize((256, 256)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
75
+ ]
76
+ )
77
+ self.resize_dims = (256, 256)
78
+
79
+ model_path = self.experiment_args["model_path"]
80
+
81
+ ckpt = torch.load(model_path, map_location="cpu")
82
+ opts = ckpt["opts"]
83
+
84
+ opts["checkpoint_path"] = model_path
85
+ opts = Namespace(**opts)
86
+
87
+ self.e4e_net = pSp(opts)
88
+ self.e4e_net.eval()
89
+ self.e4e_net.cuda()
90
+
91
+ self.shape_predictor = dlib.shape_predictor(
92
+ models_and_paths["akhaliq/jojogan_dlib"]
93
+ )
94
+
95
+ print("setup complete")
96
+
97
+ def get_style_list(self):
98
+ style_list = ['all', 'list - enter below']
99
+
100
+ for key in self.generators:
101
+ style_list.append(key)
102
+
103
+ return style_list
104
+
105
+ def predict(
106
+ self,
107
+ input, # Input image path
108
+ output_style, # Which output style do you want to use?
109
+ style_list, # Comma seperated list of models to use. Only accepts models from the output_style list
110
+ generate_video, # Generate a video instead of an output image
111
+ with_editing, # Apply latent space editing to the generated video
112
+ video_format # Choose gif to display in browser, mp4 for higher-quality downloadable video
113
+ ):
114
+
115
+ if output_style == 'all':
116
+ styles = model_list
117
+ elif output_style == 'list - enter below':
118
+ styles = style_list.split(",")
119
+ for style in styles:
120
+ if style not in model_list:
121
+ raise ValueError(f"Encountered style '{style}' in the style_list which is not an available option.")
122
+ else:
123
+ styles = [output_style]
124
+
125
+ # @title Align image
126
+ input_image = self.run_alignment(str(input))
127
+
128
+ input_image = input_image.resize(self.resize_dims)
129
+
130
+ img_transforms = self.experiment_args["transform"]
131
+ transformed_image = img_transforms(input_image)
132
+
133
+ with torch.no_grad():
134
+ images, latents = self.run_on_batch(transformed_image.unsqueeze(0))
135
+ result_image, latent = images[0], latents[0]
136
+
137
+ inverted_latent = latent.unsqueeze(0).unsqueeze(1)
138
+ out_dir = Path(tempfile.mkdtemp())
139
+ out_path = out_dir / "out.jpg"
140
+
141
+ generators = [self.generators[style] for style in styles]
142
+
143
+ if not generate_video:
144
+ with torch.no_grad():
145
+ img_list = []
146
+ for g_ema in generators:
147
+ img, _ = g_ema(inverted_latent, input_is_latent=True, truncation=1, randomize_noise=False)
148
+ img_list.append(img)
149
+
150
+ out_img = torch.cat(img_list, axis=0)
151
+ utils.save_image(out_img, out_path, nrow=int(np.sqrt(out_img.size(0))), normalize=True, scale_each=True, range=(-1, 1))
152
+
153
+ return Path(out_path)
154
+
155
+ return self.generate_vid(generators, inverted_latent, out_dir, video_format, with_editing)
156
+
157
+ def generate_vid(self, generators, latent, out_dir, video_format, with_editing):
158
+ np_latent = latent.squeeze(0).cpu().detach().numpy()
159
+ args = {
160
+ 'fps': 24,
161
+ 'target_latents': None,
162
+ 'edit_directions': None,
163
+ 'unedited_frames': 0 if with_editing else 40 * (len(generators) - 1)
164
+ }
165
+
166
+ args = Namespace(**args)
167
+ with tempfile.TemporaryDirectory() as dirpath:
168
+
169
+ generate_frames(args, np_latent, generators, dirpath)
170
+ video_from_interpolations(args.fps, dirpath)
171
+
172
+ gen_path = Path(dirpath) / "out.mp4"
173
+ out_path = out_dir / f"out.{video_format}"
174
+
175
+ if video_format == 'gif':
176
+ vid_to_gif(gen_path, out_dir, scale=256, fps=args.fps)
177
+ else:
178
+ shutil.copy2(gen_path, out_path)
179
+
180
+ return out_path
181
+
182
+ def run_alignment(self, image_path):
183
+ aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor)
184
+ print("Aligned image has shape: {}".format(aligned_image.size))
185
+ return aligned_image
186
+
187
+ def run_on_batch(self, inputs):
188
+ images, latents = self.e4e_net(
189
+ inputs.to("cuda").float(), randomize_noise=False, return_latents=True
190
+ )
191
+ return images, latents
192
+
193
+ editor = ImageEditor()
194
+
195
+ title = "StyleGAN-NADA"
196
+ description = "Gradio Demo for StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators (SIGGRAPH 2022). To use it, upload your image and select a target style. More information about the paper and training new models can be found below."
197
+
198
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.00946' target='_blank'>StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators</a> | <a href='https://stylegan-nada.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/rinongal/StyleGAN-nada' target='_blank'>Code</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=rinong_sgnada' alt='visitor badge'></center>"
199
+
200
+ gr.Interface(editor.predict, [gr.inputs.Image(type="pil"),
201
+ gr.inputs.Dropdown(choices=editor.get_style_list(), type="value", default='base', label="Model"),
202
+ gr.inputs.Textbox(lines=1, placeholder=None, default="joker,anime,modigliani", label="Style List", optional=True),
203
+ gr.inputs.Checkbox(default=False, label="Generate Video?", optional=False),
204
+ gr.inputs.Checkbox(default=False, label="With Editing?", optional=False),
205
+ gr.inputs.Radio(choices=["gif", "mp4"], type="value", default='mp4', label="Video Format")],
206
+ gr.outputs.Image(type="file"), title=title, description=description, article=article, allow_flagging=False, allow_screenshot=False).launch()
207
+
generate_videos.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Tool for generating editing videos across different domains.
3
+
4
+ Given a set of latent codes and pre-trained models, it will interpolate between the different codes in each of the target domains
5
+ and combine the resulting images into a video.
6
+
7
+ Example run command:
8
+
9
+ python generate_videos.py --ckpt /model_dir/pixar.pt \
10
+ /model_dir/ukiyoe.pt \
11
+ /model_dir/edvard_munch.pt \
12
+ /model_dir/botero.pt \
13
+ --out_dir /output/video/ \
14
+ --source_latent /latents/latent000.npy \
15
+ --target_latents /latents/
16
+
17
+ '''
18
+
19
+ import os
20
+ import argparse
21
+
22
+ import torch
23
+ from torchvision import utils
24
+
25
+ from model.sg2_model import Generator
26
+ from tqdm import tqdm
27
+ from pathlib import Path
28
+
29
+ import numpy as np
30
+
31
+ import subprocess
32
+ import shutil
33
+ import copy
34
+
35
+ VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]
36
+
37
+ SUGGESTED_DISTANCES = {
38
+ "pose": (3.0, -3.0),
39
+ "smile": (2.0, -2.0),
40
+ "age": (4.0, -4.0),
41
+ "gender": (3.0, -3.0),
42
+ "hair_length": (None, -4.0),
43
+ "beard": (2.0, None)
44
+ }
45
+
46
+ def project_code(latent_code, boundary, distance=3.0):
47
+
48
+ if len(boundary) == 2:
49
+ boundary = boundary.reshape(1, 1, -1)
50
+
51
+ return latent_code + distance * boundary
52
+
53
+ def generate_frames(args, source_latent, g_ema_list, output_dir):
54
+
55
+ alphas = np.linspace(0, 1, num=20)
56
+
57
+ interpolate_func = interpolate_with_boundaries # default
58
+ if args.target_latents: # if provided with targets
59
+ interpolate_func = interpolate_with_target_latents
60
+ if args.unedited_frames: # if only interpolating through generators
61
+ interpolate_func = duplicate_latent
62
+
63
+ latents = interpolate_func(args, source_latent, alphas)
64
+
65
+ segments = len(g_ema_list) - 1
66
+ if segments:
67
+ segment_length = len(latents) / segments
68
+
69
+ g_ema = copy.deepcopy(g_ema_list[0])
70
+
71
+ src_pars = dict(g_ema.named_parameters())
72
+ mix_pars = [dict(model.named_parameters()) for model in g_ema_list]
73
+ else:
74
+ g_ema = g_ema_list[0]
75
+
76
+ print("Generating frames for video...")
77
+ for idx, latent in tqdm(enumerate(latents), total=len(latents)):
78
+
79
+ if segments:
80
+ mix_alpha = (idx % segment_length) * 1.0 / segment_length
81
+ segment_id = int(idx // segment_length)
82
+
83
+ for k in src_pars.keys():
84
+ src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)
85
+
86
+ if idx == 0 or segments or latent is not latents[idx - 1]:
87
+ w = torch.from_numpy(latent).float().cuda()
88
+
89
+ with torch.no_grad():
90
+ img, _ = g_ema([w], input_is_latent=True, truncation=1, randomize_noise=False)
91
+
92
+ utils.save_image(img, f"{output_dir}/{str(idx).zfill(3)}.jpg", nrow=1, normalize=True, scale_each=True, range=(-1, 1))
93
+
94
+ def interpolate_forward_backward(source_latent, target_latent, alphas):
95
+ latents_forward = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
96
+ latents_backward = latents_forward[::-1] # interpolate from target to source
97
+ return latents_forward + [target_latent] * 20 + latents_backward # forward + short delay at target + return
98
+
99
+ def duplicate_latent(args, source_latent, alphas):
100
+ return [source_latent for _ in range(args.unedited_frames)]
101
+
102
+ def interpolate_with_boundaries(args, source_latent, alphas):
103
+ edit_directions = args.edit_directions or ['pose', 'smile', 'gender', 'age', 'hair_length']
104
+
105
+ # interpolate latent codes with all targets
106
+
107
+ print("Interpolating latent codes...")
108
+
109
+ boundary_dir = Path(os.path.abspath(__file__)).parents[1].joinpath("editing", "interfacegan_boundaries")
110
+
111
+ boundaries_and_distances = []
112
+ for direction_type in edit_directions:
113
+ distances = SUGGESTED_DISTANCES[direction_type]
114
+ boundary = torch.load(os.path.join(boundary_dir, f'{direction_type}.pt')).cpu().detach().numpy()
115
+
116
+ for distance in distances:
117
+ if distance:
118
+ boundaries_and_distances.append((boundary, distance))
119
+
120
+ latents = []
121
+ for boundary, distance in boundaries_and_distances:
122
+
123
+ target_latent = project_code(source_latent, boundary, distance)
124
+ latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))
125
+
126
+ return latents
127
+
128
+ def interpolate_with_target_latents(args, source_latent, alphas):
129
+ # interpolate latent codes with all targets
130
+
131
+ print("Interpolating latent codes...")
132
+
133
+ latents = []
134
+ for target_latent_path in args.target_latents:
135
+
136
+ if target_latent_path == args.source_latent:
137
+ continue
138
+
139
+ target_latent = np.load(target_latent_path, allow_pickle=True)
140
+
141
+ latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))
142
+
143
+ return latents
144
+
145
+ def video_from_interpolations(fps, output_dir):
146
+
147
+ # combine frames to a video
148
+ command = ["ffmpeg",
149
+ "-r", f"{fps}",
150
+ "-i", f"{output_dir}/%03d.jpg",
151
+ "-c:v", "libx264",
152
+ "-vf", f"fps={fps}",
153
+ "-pix_fmt", "yuv420p",
154
+ f"{output_dir}/out.mp4"]
155
+
156
+ subprocess.call(command)
157
+
158
+ def merge_videos(output_dir, num_subdirs):
159
+
160
+ output_file = os.path.join(output_dir, "combined.mp4")
161
+
162
+ if num_subdirs == 1: # if we only have one video, just copy it over
163
+ shutil.copy2(os.path.join(output_dir, str(0), "out.mp4"), output_file)
164
+ else: # otherwise merge using ffmpeg
165
+ command = ["ffmpeg"]
166
+ for dir in range(num_subdirs):
167
+ command.extend(['-i', os.path.join(output_dir, str(dir), "out.mp4")])
168
+
169
+ sqrt_subdirs = int(num_subdirs ** .5)
170
+
171
+ if (sqrt_subdirs ** 2) != num_subdirs:
172
+ raise ValueError("Number of checkpoints cannot be arranged in a square grid")
173
+
174
+ command.append("-filter_complex")
175
+
176
+ filter_string = ""
177
+ vstack_string = ""
178
+ for row in range(sqrt_subdirs):
179
+ row_str = ""
180
+ for col in range(sqrt_subdirs):
181
+ row_str += f"[{row * sqrt_subdirs + col}:v]"
182
+
183
+ letter = chr(ord('A')+row)
184
+ row_str += f"hstack=inputs={sqrt_subdirs}[{letter}];"
185
+ vstack_string += f"[{letter}]"
186
+
187
+ filter_string += row_str
188
+
189
+ vstack_string += f"vstack=inputs={sqrt_subdirs}[out]"
190
+ filter_string += vstack_string
191
+
192
+ command.extend([filter_string, "-map", "[out]", output_file])
193
+
194
+ subprocess.call(command)
195
+
196
+ def vid_to_gif(vid_path, output_dir, scale=256, fps=35):
197
+
198
+ command = ["ffmpeg",
199
+ "-i", f"{vid_path}",
200
+ "-vf", f"fps={fps},scale={scale}:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1]fifo[s2];[s2][p]paletteuse",
201
+ "-loop", "0",
202
+ f"{output_dir}/out.gif"]
203
+
204
+ subprocess.call(command)
205
+
206
+
207
+ if __name__ == '__main__':
208
+ device = 'cuda'
209
+
210
+ parser = argparse.ArgumentParser()
211
+
212
+ parser.add_argument('--size', type=int, default=1024)
213
+ parser.add_argument('--ckpt', type=str, nargs="+", required=True, help="Path to one or more pre-trained generator checkpoints.")
214
+ parser.add_argument('--channel_multiplier', type=int, default=2)
215
+ parser.add_argument('--out_dir', type=str, required=True, help="Directory where output files will be placed")
216
+ parser.add_argument('--source_latent', type=str, required=True, help="Path to an .npy file containing an initial latent code")
217
+ parser.add_argument('--target_latents', nargs="+", type=str, help="A list of paths to .npy files containing target latent codes to interpolate towards, or a directory containing such .npy files.")
218
+ parser.add_argument('--force', '-f', action='store_true', help="Force run with non-empty directory. Image files not overwritten by the proccess may still be included in the final video")
219
+ parser.add_argument('--fps', default=35, type=int, help='Frames per second in the generated videos.')
220
+ parser.add_argument('--edit_directions', nargs="+", type=str, help=f"A list of edit directions to use in video generation (if not using a target latent directory). Available directions are: {VALID_EDITS}")
221
+ parser.add_argument('--unedited_frames', type=int, default=0, help="Used to generate videos with no latent editing. If set to a positive number and target_latents is not provided, will simply duplicate the initial frame <unedited_frames> times.")
222
+
223
+ args = parser.parse_args()
224
+
225
+ os.makedirs(args.out_dir, exist_ok=True)
226
+
227
+ if not args.force and os.listdir(args.out_dir):
228
+ print("Output directory is not empty. Either delete the directory content or re-run with -f.")
229
+ exit(0)
230
+
231
+ if args.target_latents and len(args.target_latents) == 1 and os.path.isdir(args.target_latents[0]):
232
+ args.target_latents = [os.path.join(args.target_latents[0], file_name) for file_name in os.listdir(args.target_latents[0]) if file_name.endswith(".npy")]
233
+ args.target_latents = sorted(args.target_latents)
234
+
235
+ args.latent = 512
236
+ args.n_mlp = 8
237
+
238
+ g_ema = Generator(
239
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
240
+ ).to(device)
241
+
242
+ source_latent = np.load(args.source_latent, allow_pickle=True)
243
+
244
+ for idx, ckpt_path in enumerate(args.ckpt):
245
+ print(f"Generating video using checkpoint: {ckpt_path}")
246
+ checkpoint = torch.load(ckpt_path)
247
+
248
+ g_ema.load_state_dict(checkpoint['g_ema'])
249
+
250
+ output_dir = os.path.join(args.out_dir, str(idx))
251
+ os.makedirs(output_dir)
252
+
253
+ generate_frames(args, source_latent, [g_ema], output_dir)
254
+ video_from_interpolations(args.fps, output_dir)
255
+
256
+ merge_videos(args.out_dir, len(args.ckpt))
257
+
258
+
259
+
model/sg2_model.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, input):
19
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20
+
21
+
22
+ def make_kernel(k):
23
+ k = torch.tensor(k, dtype=torch.float32)
24
+
25
+ if k.ndim == 1:
26
+ k = k[None, :] * k[:, None]
27
+
28
+ k /= k.sum()
29
+
30
+ return k
31
+
32
+
33
+ class Upsample(nn.Module):
34
+ def __init__(self, kernel, factor=2):
35
+ super().__init__()
36
+
37
+ self.factor = factor
38
+ kernel = make_kernel(kernel) * (factor ** 2)
39
+ self.register_buffer("kernel", kernel)
40
+
41
+ p = kernel.shape[0] - factor
42
+
43
+ pad0 = (p + 1) // 2 + factor - 1
44
+ pad1 = p // 2
45
+
46
+ self.pad = (pad0, pad1)
47
+
48
+ def forward(self, input):
49
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50
+
51
+ return out
52
+
53
+
54
+ class Downsample(nn.Module):
55
+ def __init__(self, kernel, factor=2):
56
+ super().__init__()
57
+
58
+ self.factor = factor
59
+ kernel = make_kernel(kernel)
60
+ self.register_buffer("kernel", kernel)
61
+
62
+ p = kernel.shape[0] - factor
63
+
64
+ pad0 = (p + 1) // 2
65
+ pad1 = p // 2
66
+
67
+ self.pad = (pad0, pad1)
68
+
69
+ def forward(self, input):
70
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71
+
72
+ return out
73
+
74
+
75
+ class Blur(nn.Module):
76
+ def __init__(self, kernel, pad, upsample_factor=1):
77
+ super().__init__()
78
+
79
+ kernel = make_kernel(kernel)
80
+
81
+ if upsample_factor > 1:
82
+ kernel = kernel * (upsample_factor ** 2)
83
+
84
+ self.register_buffer("kernel", kernel)
85
+
86
+ self.pad = pad
87
+
88
+ def forward(self, input):
89
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
90
+
91
+ return out
92
+
93
+
94
+ class EqualConv2d(nn.Module):
95
+ def __init__(
96
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97
+ ):
98
+ super().__init__()
99
+
100
+ self.weight = nn.Parameter(
101
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102
+ )
103
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104
+
105
+ self.stride = stride
106
+ self.padding = padding
107
+
108
+ if bias:
109
+ self.bias = nn.Parameter(torch.zeros(out_channel))
110
+
111
+ else:
112
+ self.bias = None
113
+
114
+ def forward(self, input):
115
+ out = conv2d_gradfix.conv2d(
116
+ input,
117
+ self.weight * self.scale,
118
+ bias=self.bias,
119
+ stride=self.stride,
120
+ padding=self.padding,
121
+ )
122
+
123
+ return out
124
+
125
+ def __repr__(self):
126
+ return (
127
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129
+ )
130
+
131
+
132
+ class EqualLinear(nn.Module):
133
+ def __init__(
134
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135
+ ):
136
+ super().__init__()
137
+
138
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139
+
140
+ if bias:
141
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142
+
143
+ else:
144
+ self.bias = None
145
+
146
+ self.activation = activation
147
+
148
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149
+ self.lr_mul = lr_mul
150
+
151
+ def forward(self, input):
152
+ if self.activation:
153
+ out = F.linear(input, self.weight * self.scale)
154
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
155
+
156
+ else:
157
+ out = F.linear(
158
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
159
+ )
160
+
161
+ return out
162
+
163
+ def __repr__(self):
164
+ return (
165
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166
+ )
167
+
168
+
169
+ class ModulatedConv2d(nn.Module):
170
+ def __init__(
171
+ self,
172
+ in_channel,
173
+ out_channel,
174
+ kernel_size,
175
+ style_dim,
176
+ demodulate=True,
177
+ upsample=False,
178
+ downsample=False,
179
+ blur_kernel=[1, 3, 3, 1],
180
+ fused=True,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.eps = 1e-8
185
+ self.kernel_size = kernel_size
186
+ self.in_channel = in_channel
187
+ self.out_channel = out_channel
188
+ self.upsample = upsample
189
+ self.downsample = downsample
190
+
191
+ if upsample:
192
+ factor = 2
193
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
194
+ pad0 = (p + 1) // 2 + factor - 1
195
+ pad1 = p // 2 + 1
196
+
197
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
198
+
199
+ if downsample:
200
+ factor = 2
201
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
202
+ pad0 = (p + 1) // 2
203
+ pad1 = p // 2
204
+
205
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
206
+
207
+ fan_in = in_channel * kernel_size ** 2
208
+ self.scale = 1 / math.sqrt(fan_in)
209
+ self.padding = kernel_size // 2
210
+
211
+ self.weight = nn.Parameter(
212
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
213
+ )
214
+
215
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216
+
217
+ self.demodulate = demodulate
218
+ self.fused = fused
219
+
220
+ def __repr__(self):
221
+ return (
222
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
223
+ f"upsample={self.upsample}, downsample={self.downsample})"
224
+ )
225
+
226
+ def forward(self, input, style, is_s_code=False):
227
+ batch, in_channel, height, width = input.shape
228
+
229
+ if not self.fused:
230
+ weight = self.scale * self.weight.squeeze(0)
231
+
232
+ if is_s_code:
233
+ style = style[self.modulation]
234
+ else:
235
+ style = self.modulation(style)
236
+
237
+ if self.demodulate:
238
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
239
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
240
+
241
+ input = input * style.reshape(batch, in_channel, 1, 1)
242
+
243
+ if self.upsample:
244
+ weight = weight.transpose(0, 1)
245
+ out = conv2d_gradfix.conv_transpose2d(
246
+ input, weight, padding=0, stride=2
247
+ )
248
+ out = self.blur(out)
249
+
250
+ elif self.downsample:
251
+ input = self.blur(input)
252
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
253
+
254
+ else:
255
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
256
+
257
+ if self.demodulate:
258
+ out = out * dcoefs.view(batch, -1, 1, 1)
259
+
260
+ return out
261
+
262
+ if is_s_code:
263
+ style = style[self.modulation]
264
+ else:
265
+ style = self.modulation(style)
266
+
267
+ style = style.view(batch, 1, in_channel, 1, 1)
268
+ weight = self.scale * self.weight * style
269
+
270
+ if self.demodulate:
271
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
272
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
273
+
274
+ weight = weight.view(
275
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
276
+ )
277
+
278
+ if self.upsample:
279
+ input = input.view(1, batch * in_channel, height, width)
280
+ weight = weight.view(
281
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
282
+ )
283
+ weight = weight.transpose(1, 2).reshape(
284
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
285
+ )
286
+ out = conv2d_gradfix.conv_transpose2d(
287
+ input, weight, padding=0, stride=2, groups=batch
288
+ )
289
+ _, _, height, width = out.shape
290
+ out = out.view(batch, self.out_channel, height, width)
291
+ out = self.blur(out)
292
+
293
+ elif self.downsample:
294
+ input = self.blur(input)
295
+ _, _, height, width = input.shape
296
+ input = input.view(1, batch * in_channel, height, width)
297
+ out = conv2d_gradfix.conv2d(
298
+ input, weight, padding=0, stride=2, groups=batch
299
+ )
300
+ _, _, height, width = out.shape
301
+ out = out.view(batch, self.out_channel, height, width)
302
+
303
+ else:
304
+ input = input.view(1, batch * in_channel, height, width)
305
+ out = conv2d_gradfix.conv2d(
306
+ input, weight, padding=self.padding, groups=batch
307
+ )
308
+ _, _, height, width = out.shape
309
+ out = out.view(batch, self.out_channel, height, width)
310
+
311
+ return out
312
+
313
+
314
+ class NoiseInjection(nn.Module):
315
+ def __init__(self):
316
+ super().__init__()
317
+
318
+ self.weight = nn.Parameter(torch.zeros(1))
319
+
320
+ def forward(self, image, noise=None):
321
+ if noise is None:
322
+ batch, _, height, width = image.shape
323
+ noise = image.new_empty(batch, 1, height, width).normal_()
324
+
325
+ return image + self.weight * noise
326
+
327
+
328
+ class ConstantInput(nn.Module):
329
+ def __init__(self, channel, size=4):
330
+ super().__init__()
331
+
332
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
333
+
334
+ def forward(self, input, is_s_code=False):
335
+ if not is_s_code:
336
+ batch = input.shape[0]
337
+ else:
338
+ batch = next(iter(input.values())).shape[0]
339
+
340
+ out = self.input.repeat(batch, 1, 1, 1)
341
+
342
+ return out
343
+
344
+
345
+ class StyledConv(nn.Module):
346
+ def __init__(
347
+ self,
348
+ in_channel,
349
+ out_channel,
350
+ kernel_size,
351
+ style_dim,
352
+ upsample=False,
353
+ blur_kernel=[1, 3, 3, 1],
354
+ demodulate=True,
355
+ ):
356
+ super().__init__()
357
+
358
+ self.conv = ModulatedConv2d(
359
+ in_channel,
360
+ out_channel,
361
+ kernel_size,
362
+ style_dim,
363
+ upsample=upsample,
364
+ blur_kernel=blur_kernel,
365
+ demodulate=demodulate,
366
+ )
367
+
368
+ self.noise = NoiseInjection()
369
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
370
+ # self.activate = ScaledLeakyReLU(0.2)
371
+ self.activate = FusedLeakyReLU(out_channel)
372
+
373
+ def forward(self, input, style, noise=None, is_s_code=False):
374
+ out = self.conv(input, style, is_s_code=is_s_code)
375
+ out = self.noise(out, noise=noise)
376
+ # out = out + self.bias
377
+ out = self.activate(out)
378
+
379
+ return out
380
+
381
+
382
+ class ToRGB(nn.Module):
383
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
384
+ super().__init__()
385
+
386
+ if upsample:
387
+ self.upsample = Upsample(blur_kernel)
388
+
389
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
390
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
391
+
392
+ def forward(self, input, style, skip=None, is_s_code=False):
393
+ out = self.conv(input, style, is_s_code=is_s_code)
394
+ out = out + self.bias
395
+
396
+ if skip is not None:
397
+ skip = self.upsample(skip)
398
+
399
+ out = out + skip
400
+
401
+ return out
402
+
403
+
404
+ class Generator(nn.Module):
405
+ def __init__(
406
+ self,
407
+ size,
408
+ style_dim,
409
+ n_mlp,
410
+ channel_multiplier=2,
411
+ blur_kernel=[1, 3, 3, 1],
412
+ lr_mlp=0.01,
413
+ ):
414
+ super().__init__()
415
+
416
+ self.size = size
417
+
418
+ self.style_dim = style_dim
419
+
420
+ layers = [PixelNorm()]
421
+
422
+ for i in range(n_mlp):
423
+ layers.append(
424
+ EqualLinear(
425
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
426
+ )
427
+ )
428
+
429
+ self.style = nn.Sequential(*layers)
430
+
431
+ self.channels = {
432
+ 4: 512,
433
+ 8: 512,
434
+ 16: 512,
435
+ 32: 512,
436
+ 64: 256 * channel_multiplier,
437
+ 128: 128 * channel_multiplier,
438
+ 256: 64 * channel_multiplier,
439
+ 512: 32 * channel_multiplier,
440
+ 1024: 16 * channel_multiplier,
441
+ }
442
+
443
+ self.input = ConstantInput(self.channels[4])
444
+ self.conv1 = StyledConv(
445
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
446
+ )
447
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
448
+
449
+ self.log_size = int(math.log(size, 2))
450
+ self.num_layers = (self.log_size - 2) * 2 + 1
451
+
452
+ self.convs = nn.ModuleList()
453
+ self.upsamples = nn.ModuleList()
454
+ self.to_rgbs = nn.ModuleList()
455
+ self.noises = nn.Module()
456
+
457
+ in_channel = self.channels[4]
458
+
459
+ for layer_idx in range(self.num_layers):
460
+ res = (layer_idx + 5) // 2
461
+ shape = [1, 1, 2 ** res, 2 ** res]
462
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
463
+
464
+ for i in range(3, self.log_size + 1):
465
+ out_channel = self.channels[2 ** i]
466
+
467
+ self.convs.append(
468
+ StyledConv(
469
+ in_channel,
470
+ out_channel,
471
+ 3,
472
+ style_dim,
473
+ upsample=True,
474
+ blur_kernel=blur_kernel,
475
+ )
476
+ )
477
+
478
+ self.convs.append(
479
+ StyledConv(
480
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
481
+ )
482
+ )
483
+
484
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
485
+
486
+ in_channel = out_channel
487
+
488
+ self.n_latent = self.log_size * 2 - 2
489
+
490
+
491
+ self.modulation_layers = [self.conv1.conv.modulation, self.to_rgb1.conv.modulation] + \
492
+ [layer.conv.modulation for layer in self.convs] + \
493
+ [layer.conv.modulation for layer in self.to_rgbs]
494
+
495
+ def make_noise(self):
496
+ device = self.input.input.device
497
+
498
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
499
+
500
+ for i in range(3, self.log_size + 1):
501
+ for _ in range(2):
502
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
503
+
504
+ return noises
505
+
506
+ def mean_latent(self, n_latent):
507
+ latent_in = torch.randn(
508
+ n_latent, self.style_dim, device=self.input.input.device
509
+ )
510
+ latent = self.style(latent_in).mean(0, keepdim=True)
511
+
512
+ return latent
513
+
514
+ def get_latent(self, input):
515
+ return self.style(input)
516
+
517
+ def get_s_code(self, styles, input_is_latent):
518
+
519
+ if not input_is_latent:
520
+ styles = [self.style(s) for s in styles]
521
+
522
+ s_codes = [{layer: layer(s) for layer in self.modulation_layers} for s in styles] * len(styles)
523
+
524
+ return s_codes
525
+
526
+ def forward(
527
+ self,
528
+ styles,
529
+ return_latents=False,
530
+ inject_index=None,
531
+ truncation=1,
532
+ truncation_latent=None,
533
+ input_is_latent=False,
534
+ input_is_s_code=False,
535
+ noise=None,
536
+ randomize_noise=True,
537
+ ):
538
+ if not input_is_s_code:
539
+ return self.forward_with_w(styles, return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise, randomize_noise)
540
+
541
+ return self.forward_with_s(styles, return_latents, noise, randomize_noise)
542
+
543
+ def forward_with_w(
544
+ self,
545
+ styles,
546
+ return_latents=False,
547
+ inject_index=None,
548
+ truncation=1,
549
+ truncation_latent=None,
550
+ input_is_latent=False,
551
+ noise=None,
552
+ randomize_noise=True,
553
+ ):
554
+ if not input_is_latent:
555
+ styles = [self.style(s) for s in styles]
556
+
557
+ if noise is None:
558
+ if randomize_noise:
559
+ noise = [None] * self.num_layers
560
+ else:
561
+ noise = [
562
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
563
+ ]
564
+
565
+ if truncation < 1:
566
+ style_t = []
567
+
568
+ for style in styles:
569
+ style_t.append(
570
+ truncation_latent + truncation * (style - truncation_latent)
571
+ )
572
+
573
+ styles = style_t
574
+
575
+ if len(styles) < 2:
576
+ inject_index = self.n_latent
577
+
578
+ if styles[0].ndim < 3:
579
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
580
+
581
+ else:
582
+ latent = styles[0]
583
+
584
+ else:
585
+ if inject_index is None:
586
+ inject_index = random.randint(1, self.n_latent - 1)
587
+
588
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
589
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
590
+
591
+ latent = torch.cat([latent, latent2], 1)
592
+
593
+ out = self.input(latent)
594
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
595
+
596
+ skip = self.to_rgb1(out, latent[:, 1])
597
+
598
+ i = 1
599
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
600
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
601
+ ):
602
+ out = conv1(out, latent[:, i], noise=noise1)
603
+ out = conv2(out, latent[:, i + 1], noise=noise2)
604
+ skip = to_rgb(out, latent[:, i + 2], skip)
605
+
606
+ i += 2
607
+
608
+ image = skip
609
+
610
+ if return_latents:
611
+ return image, latent
612
+
613
+ else:
614
+ return image, None
615
+
616
+ def forward_with_s(
617
+ self,
618
+ styles,
619
+ return_latents=False,
620
+ noise=None,
621
+ randomize_noise=True,
622
+ ):
623
+
624
+ if noise is None:
625
+ if randomize_noise:
626
+ noise = [None] * self.num_layers
627
+ else:
628
+ noise = [
629
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
630
+ ]
631
+
632
+ out = self.input(styles, is_s_code=True)
633
+ out = self.conv1(out, styles, is_s_code=True, noise=noise[0])
634
+
635
+ skip = self.to_rgb1(out, styles, is_s_code=True)
636
+
637
+ i = 1
638
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
639
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
640
+ ):
641
+ out = conv1(out, styles, is_s_code=True, noise=noise1)
642
+ out = conv2(out, styles, is_s_code=True, noise=noise2)
643
+ skip = to_rgb(out, styles, skip, is_s_code=True)
644
+
645
+ i += 2
646
+
647
+ image = skip
648
+
649
+ if return_latents:
650
+ return image, styles
651
+
652
+ else:
653
+ return image, None
654
+
655
+ class ConvLayer(nn.Sequential):
656
+ def __init__(
657
+ self,
658
+ in_channel,
659
+ out_channel,
660
+ kernel_size,
661
+ downsample=False,
662
+ blur_kernel=[1, 3, 3, 1],
663
+ bias=True,
664
+ activate=True,
665
+ ):
666
+ layers = []
667
+
668
+ if downsample:
669
+ factor = 2
670
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
671
+ pad0 = (p + 1) // 2
672
+ pad1 = p // 2
673
+
674
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
675
+
676
+ stride = 2
677
+ self.padding = 0
678
+
679
+ else:
680
+ stride = 1
681
+ self.padding = kernel_size // 2
682
+
683
+ layers.append(
684
+ EqualConv2d(
685
+ in_channel,
686
+ out_channel,
687
+ kernel_size,
688
+ padding=self.padding,
689
+ stride=stride,
690
+ bias=bias and not activate,
691
+ )
692
+ )
693
+
694
+ if activate:
695
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
696
+
697
+ super().__init__(*layers)
698
+
699
+
700
+ class ResBlock(nn.Module):
701
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
702
+ super().__init__()
703
+
704
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
705
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
706
+
707
+ self.skip = ConvLayer(
708
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
709
+ )
710
+
711
+ def forward(self, input):
712
+ out = self.conv1(input)
713
+ out = self.conv2(out)
714
+
715
+ skip = self.skip(input)
716
+ out = (out + skip) / math.sqrt(2)
717
+
718
+ return out
719
+
720
+
721
+ class Discriminator(nn.Module):
722
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
723
+ super().__init__()
724
+
725
+ channels = {
726
+ 4: 512,
727
+ 8: 512,
728
+ 16: 512,
729
+ 32: 512,
730
+ 64: 256 * channel_multiplier,
731
+ 128: 128 * channel_multiplier,
732
+ 256: 64 * channel_multiplier,
733
+ 512: 32 * channel_multiplier,
734
+ 1024: 16 * channel_multiplier,
735
+ }
736
+
737
+ convs = [ConvLayer(3, channels[size], 1)]
738
+
739
+ log_size = int(math.log(size, 2))
740
+
741
+ in_channel = channels[size]
742
+
743
+ for i in range(log_size, 2, -1):
744
+ out_channel = channels[2 ** (i - 1)]
745
+
746
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
747
+
748
+ in_channel = out_channel
749
+
750
+ self.convs = nn.Sequential(*convs)
751
+
752
+ self.stddev_group = 4
753
+ self.stddev_feat = 1
754
+
755
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
756
+ self.final_linear = nn.Sequential(
757
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
758
+ EqualLinear(channels[4], 1),
759
+ )
760
+
761
+ def forward(self, input):
762
+ out = self.convs(input)
763
+
764
+ batch, channel, height, width = out.shape
765
+ group = min(batch, self.stddev_group)
766
+ stddev = out.view(
767
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
768
+ )
769
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
770
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
771
+ stddev = stddev.repeat(group, 1, height, width)
772
+ out = torch.cat([out, stddev], 1)
773
+
774
+ out = self.final_conv(out)
775
+
776
+ out = out.view(batch, -1)
777
+ out = self.final_linear(out)
778
+
779
+ return out
780
+