rinong commited on
Commit
210c702
1 Parent(s): 7a331ca

Overhauled editing UI, output to gallery

Browse files
Files changed (2) hide show
  1. app.py +143 -108
  2. generate_videos.py +23 -153
app.py CHANGED
@@ -1,33 +1,24 @@
1
  import os
2
- from posixpath import basename
3
 
4
  import torch
5
  import gradio as gr
6
 
7
- import os
8
- import sys
9
- import numpy as np
10
-
11
  from e4e.models.psp import pSp
12
  from util import *
13
  from huggingface_hub import hf_hub_download
14
 
15
- import os
16
- import sys
17
  import tempfile
18
- import shutil
19
  from argparse import Namespace
20
- from pathlib import Path
21
  import shutil
22
 
23
  import dlib
24
  import numpy as np
25
  import torchvision.transforms as transforms
26
  from torchvision import utils
27
- from PIL import Image
28
 
29
  from model.sg2_model import Generator
30
- from generate_videos import generate_frames, video_from_interpolations, vid_to_gif
31
 
32
  model_dir = "models"
33
  os.makedirs(model_dir, exist_ok=True)
@@ -120,7 +111,6 @@ class ImageEditor(object):
120
  print("setup complete")
121
 
122
  def get_style_list(self):
123
- # style_list = ['all', 'list - enter below']
124
  style_list = []
125
 
126
  for key in self.generators:
@@ -146,26 +136,70 @@ class ImageEditor(object):
146
 
147
  def get_generators_for_styles(self, output_styles, loop_styles=False):
148
 
149
- # if style_string:
150
- # styles = style_string.split(",")
151
- # for style in styles:
152
- # if style not in self.model_list:
153
- # raise ValueError(f"Encountered style '{style}' in the input style list which is not an available option.")
154
- # else:
155
- # styles = style_checkbox_list
156
-
157
- if "base" in output_styles: # always start with base if chosen
158
  output_styles.insert(0, output_styles.pop(output_styles.index("base")))
159
  if loop_styles:
160
  output_styles.append(output_styles[0])
161
 
162
  return [self.generators[style] for style in output_styles]
163
 
164
- def edit_image(self, input, output_styles):
165
- return self.predict(input, output_styles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- def edit_video(self, input, output_styles, with_editing, video_format, loop_styles):
168
- return self.predict(input, output_styles, True, with_editing, video_format, loop_styles)
 
169
 
170
  def predict(
171
  self,
@@ -173,55 +207,57 @@ class ImageEditor(object):
173
  output_styles, # Style checkbox options.
174
  generate_video = False, # Generate a video instead of an output image
175
  with_editing = False, # Apply latent space editing to the generated video
176
- video_format = "mp4", # Choose gif to display in browser, mp4 for higher-quality downloadable video
177
  loop_styles = False, # Loop back to the initial style
 
178
  ):
179
 
 
 
 
180
  # @title Align image
181
- out_dir = Path(tempfile.mkdtemp())
182
- out_path = out_dir / "out.jpg"
183
 
184
  inverted_latent = self.invert_image(input)
185
  generators = self.get_generators_for_styles(output_styles, loop_styles)
186
 
 
 
187
  if not generate_video:
 
 
188
  with torch.no_grad():
189
- img_list = []
190
  for g_ema in generators:
191
- img, _ = g_ema(inverted_latent, input_is_latent=True, truncation=1, randomize_noise=False)
192
- img_list.append(img)
193
-
194
- out_img = torch.cat(img_list, axis=0)
195
- utils.save_image(out_img, out_path, nrow=int(np.sqrt(out_img.size(0))), normalize=True, scale_each=True, range=(-1, 1))
196
 
197
- return str(out_path)
 
 
 
 
 
 
 
198
 
199
- return self.generate_vid(generators, inverted_latent, out_dir, video_format, with_editing)
200
-
201
- def generate_vid(self, generators, latent, out_dir, video_format, with_editing):
202
- np_latent = latent.squeeze(0).cpu().detach().numpy()
203
- args = {
204
- 'fps': 24,
205
- 'target_latents': None,
206
- 'edit_directions': None,
207
- 'unedited_frames': 0 if with_editing else 40 * (len(generators) - 1)
208
- }
209
-
210
- args = Namespace(**args)
211
  with tempfile.TemporaryDirectory() as dirpath:
212
 
213
- generate_frames(args, np_latent, generators, dirpath)
214
- video_from_interpolations(args.fps, dirpath)
215
 
216
- gen_path = Path(dirpath) / "out.mp4"
217
- out_path = out_dir / f"out.{video_format}"
218
 
219
- if video_format == 'gif':
220
- vid_to_gif(gen_path, out_dir, scale=256, fps=args.fps)
221
- else:
222
- shutil.copy2(gen_path, out_path)
223
 
224
- return str(out_path)
225
 
226
  def run_alignment(self, image_path):
227
  aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor)
@@ -236,12 +272,12 @@ class ImageEditor(object):
236
 
237
  editor = ImageEditor()
238
 
239
- def change_component_visibility(component_types, invert_choices):
240
 
241
- def visibility_impl(visible):
242
- return [component_types[idx].update(visible=visible ^ invert_choices[idx]) for idx in range(len(component_types))]
243
 
244
- return visibility_impl
245
 
246
  # def group_visibility(visible):
247
  # print("visible: ", visible)
@@ -258,60 +294,59 @@ with blocks:
258
  gr.Markdown(
259
  "For more information about the paper and code for training your own models (with examples OR text), see below."
260
  )
261
-
262
  with gr.Row():
 
263
 
264
- with gr.Column():
265
- input_img = gr.inputs.Image(type="filepath", label="Input image")
266
  style_choice = gr.inputs.CheckboxGroup(choices=editor.get_style_list(), type="value", label="Choose your styles!")
267
 
268
- video_choice = gr.inputs.Checkbox(default=False, label="Generate Video?", optional=False)
269
-
270
- loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?", visible=False)
271
- edit_choice = gr.inputs.Checkbox(default=False, label="With Editing?", visible=False)
272
- vid_format_choice = gr.inputs.Radio(choices=["gif", "mp4"], type="value", default='mp4', label="Video Format", visible=False)
273
-
274
- # img_button = gr.Button("Edit Image")
275
- # vid_button = gr.Button("Generate Video")
276
- img_button = gr.Button("Edit Image")
277
- vid_button = gr.Button("Generate Video", visible=False)
278
-
279
- with gr.Column():
280
- img_output = gr.outputs.Image(type="file")
281
- vid_output = gr.outputs.Video(visible=False)
282
-
283
- visibility_fn = change_component_visibility(component_types=[gr.Checkbox, gr.Radio, gr.Video, gr.Button, gr.Image, gr.Button, gr.Checkbox],
284
- invert_choices=[False, False, False, False, True, True, False])
285
-
286
- video_choice.change(fn=visibility_fn, inputs=video_choice, outputs=[edit_choice, vid_format_choice, vid_output, vid_button, img_output, img_button])
287
- # video_choice.change(fn=group_visibility, inputs=video_choice, outputs=video_options_group)
288
- img_button.click(fn=editor.edit_image, inputs=[input_img, style_choice], outputs=img_output)
289
- vid_button.click(fn=editor.edit_video, inputs=[input_img, style_choice, edit_choice, vid_format_choice, loop_styles], outputs=vid_output)
290
-
291
- # with gr.Row():
292
- # input_img = gr.inputs.Image(type="filepath", label="Input image")
293
- # style_choice = gr.inputs.CheckboxGroup(choices=editor.get_style_list(), type="value", label="Choose your styles!")
294
-
295
- # with gr.Tabs():
296
- # with gr.TabItem("Edit Images"):
297
- # with gr.Column():
298
- # img_button = gr.Button("Edit Image")
299
- # with gr.Column():
300
- # img_output = gr.outputs.Image(type="file", label="Output Image")
301
 
302
- # with gr.TabItem("Create Video"):
303
- # with gr.Column():
304
- # with gr.Row():
305
- # vid_button = gr.Button("Generate Video")
306
- # loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
307
- # edit_choice = gr.inputs.Checkbox(default=False, label="With latent space editing?")
308
- # vid_format_choice = gr.inputs.Radio(choices=["gif", "mp4"], type="value", default='mp4', label="Video Format")
309
-
310
- # with gr.Column():
311
- # vid_output = gr.outputs.Video(label="Output Video")
 
 
 
 
 
 
312
 
313
- # img_button.click(fn=editor.edit_image, inputs=[input_img, style_choice], outputs=img_output)
314
- # vid_button.click(fn=editor.edit_video, inputs=[input_img, style_choice, edit_choice, vid_format_choice, loop_styles], outputs=vid_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.00946' target='_blank'>StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators</a> | <a href='https://stylegan-nada.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/rinongal/StyleGAN-nada' target='_blank'>Code</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=rinong_sgnada' alt='visitor badge'></center>"
317
  gr.Markdown(article)
 
1
  import os
2
+ import random
3
 
4
  import torch
5
  import gradio as gr
6
 
 
 
 
 
7
  from e4e.models.psp import pSp
8
  from util import *
9
  from huggingface_hub import hf_hub_download
10
 
 
 
11
  import tempfile
 
12
  from argparse import Namespace
 
13
  import shutil
14
 
15
  import dlib
16
  import numpy as np
17
  import torchvision.transforms as transforms
18
  from torchvision import utils
 
19
 
20
  from model.sg2_model import Generator
21
+ from generate_videos import generate_frames, video_from_interpolations, project_code_by_edit_name
22
 
23
  model_dir = "models"
24
  os.makedirs(model_dir, exist_ok=True)
 
111
  print("setup complete")
112
 
113
  def get_style_list(self):
 
114
  style_list = []
115
 
116
  for key in self.generators:
 
136
 
137
  def get_generators_for_styles(self, output_styles, loop_styles=False):
138
 
139
+ if "base" in output_styles: # always start with base if chosen
 
 
 
 
 
 
 
 
140
  output_styles.insert(0, output_styles.pop(output_styles.index("base")))
141
  if loop_styles:
142
  output_styles.append(output_styles[0])
143
 
144
  return [self.generators[style] for style in output_styles]
145
 
146
+ def _pack_edits(func):
147
+ def inner(self,
148
+ edit_type_choice,
149
+ pose_slider,
150
+ smile_slider,
151
+ gender_slider,
152
+ age_slider,
153
+ hair_slider,
154
+ src_text_styleclip,
155
+ tar_text_styleclip,
156
+ alpha_styleclip,
157
+ beta_styleclip,
158
+ *args):
159
+
160
+ edit_choices = {"edit_type": edit_type_choice,
161
+ "pose": pose_slider,
162
+ "smile": smile_slider,
163
+ "gender": gender_slider,
164
+ "age": age_slider,
165
+ "hair": hair_slider,
166
+ "src_text": src_text_styleclip,
167
+ "tar_text": tar_text_styleclip,
168
+ "alpha": alpha_styleclip,
169
+ "beta": beta_styleclip}
170
+
171
+ return func(self, *args, edit_choices)
172
+
173
+ return inner
174
+
175
+ def get_target_latents(self, source_latent, edit_choices, generators):
176
+
177
+ np_source_latent = source_latent.squeeze(0).cpu().detach().numpy()
178
+
179
+ target_latents = []
180
+
181
+ if edit_choices["edit_type"] == "InterFaceGAN":
182
+ for attribute_name in ["pose", "smile", "gender", "age", "hair"]:
183
+ strength = edit_choices[attribute_name]
184
+ if strength != 0.0:
185
+ target_latents.append(project_code_by_edit_name(np_source_latent, attribute_name, strength))
186
+
187
+ elif edit_choices["edit_type"] == "StyleCLIP":
188
+ pass
189
+
190
+ # if edit type is none or if all slides were set to 0
191
+ if not target_latents:
192
+ target_latents = [source_latent, ] * (len(generators) - 1)
193
+
194
+ return target_latents
195
+
196
+ @_pack_edits
197
+ def edit_image(self, input, output_styles, edit_choices):
198
+ return self.predict(input, output_styles, edit_choices)
199
 
200
+ @_pack_edits
201
+ def edit_video(self, input, output_styles, loop_styles, edit_choices):
202
+ return self.predict(input, output_styles, True, loop_styles, edit_choices)
203
 
204
  def predict(
205
  self,
 
207
  output_styles, # Style checkbox options.
208
  generate_video = False, # Generate a video instead of an output image
209
  with_editing = False, # Apply latent space editing to the generated video
 
210
  loop_styles = False, # Loop back to the initial style
211
+ edit_choices = None, # Optional dictionary with edit choice arguments
212
  ):
213
 
214
+ if edit_choices is None:
215
+ edit_choices = {"edit_type": "None"}
216
+
217
  # @title Align image
218
+ out_dir = tempfile.mkdtemp()
 
219
 
220
  inverted_latent = self.invert_image(input)
221
  generators = self.get_generators_for_styles(output_styles, loop_styles)
222
 
223
+ target_latents = self.get_target_latents(inverted_latent, edit_choices, generators)
224
+
225
  if not generate_video:
226
+ output_paths = []
227
+
228
  with torch.no_grad():
 
229
  for g_ema in generators:
230
+ latent_for_gen = random.choice(target_latents)
231
+ latent_for_gen = [torch.from_numpy(latent_for_gen).float().to(self.device)]
 
 
 
232
 
233
+ img, _ = g_ema(latent_for_gen, input_is_latent=True, truncation=1, randomize_noise=False)
234
+
235
+ output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
236
+ utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
237
+
238
+ output_paths.append(output_path)
239
+
240
+ return output_paths
241
 
242
+ return self.generate_vid(generators, inverted_latent, out_dir, with_editing)
243
+
244
+ def generate_vid(self, generators, source_latent, target_latents, out_dir):
245
+
246
+ fps = 24
247
+
248
+ np_latent = source_latent.squeeze(0).cpu().detach().numpy()
249
+
 
 
 
 
250
  with tempfile.TemporaryDirectory() as dirpath:
251
 
252
+ generate_frames(np_latent, target_latents, generators, dirpath)
253
+ video_from_interpolations(fps, dirpath)
254
 
255
+ gen_path = os.path.join(dirpath, "out.mp4")
256
+ out_path = os.path.join(out_dir, "out.mp4")
257
 
258
+ shutil.copy2(gen_path, out_path)
 
 
 
259
 
260
+ return out_path
261
 
262
  def run_alignment(self, image_path):
263
  aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor)
 
272
 
273
  editor = ImageEditor()
274
 
275
+ # def change_component_visibility(component_types, invert_choices):
276
 
277
+ # def visibility_impl(visible):
278
+ # return [component_types[idx].update(visible=visible ^ invert_choices[idx]) for idx in range(len(component_types))]
279
 
280
+ # return visibility_impl
281
 
282
  # def group_visibility(visible):
283
  # print("visible: ", visible)
 
294
  gr.Markdown(
295
  "For more information about the paper and code for training your own models (with examples OR text), see below."
296
  )
297
+
298
  with gr.Row():
299
+ input_img = gr.inputs.Image(type="filepath", label="Input image")
300
 
301
+ with gr.Column():
 
302
  style_choice = gr.inputs.CheckboxGroup(choices=editor.get_style_list(), type="value", label="Choose your styles!")
303
 
304
+ editing_type_choice = gr.Radio(choices=["None", "InterFaceGAN", "StyleCLIP"], label="Choose latent space editing option. For InterFaceGAN and StyleCLIP, set the options below:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ with gr.Tabs():
307
+ with gr.TabItem("InterFaceGAN Editing Options"):
308
+ gr.Markdown("Move the sliders to make the chosen attribute stronger (e.g. the person older) or leave at 0 to disable editing.")
309
+ gr.Markdown("If multiple options are provided, they will be used randomly between images (or sequentially for a video), <u>not</u> together")
310
+
311
+ pose_slider = gr.Slider(label="Pose", minimum=-1, maximum=1, value=0, step=0.02)
312
+ smile_slider = gr.Slider(label="Smile", minimum=-1, maximum=1, value=0, step=0.02)
313
+ gender_slider = gr.Slider(label="Perceived Gender", minimum=-1, maximum=1, value=0, step=0.02)
314
+ age_slider = gr.Slider(label="Age", minimum=-1, maximum=1, value=0, step=0.02)
315
+ hair_slider = gr.Slider(label="Hair Length", minimum=-1, maximum=1, value=0, step=0.02)
316
+
317
+ ig_edit_choices = [pose_slider, smile_slider, gender_slider, age_slider, hair_slider]
318
+
319
+ with gr.TabItem("StyleCLIP Editing Options"):
320
+ gr.Markdown("Move the sliders to make the chosen attribute stronger (e.g. the person older) or leave at 0 to disable editing.")
321
+ gr.Markdown("If multiple options are provided, they will be used randomly between images (or sequentially for a video), <u>not</u> together")
322
 
323
+ src_text_styleclip = gr.Textbox(label="Source text")
324
+ tar_text_styleclip = gr.Textbox(label="Target text")
325
+
326
+ alpha_styleclip = gr.Slider(label="Edit strength", minimum=-10, maximum=10, value=0, step=0.1)
327
+ beta_styleclip = gr.Slider(label="Disentanglement Threshold", minimum=0.08, maximum=0.3, value=0.14, step=0.01)
328
+
329
+ sc_edit_choices = [src_text_styleclip, tar_text_styleclip, alpha_styleclip, beta_styleclip]
330
+
331
+ with gr.Tabs():
332
+ with gr.TabItem("Edit Images"):
333
+ with gr.Column():
334
+ img_button = gr.Button("Edit Image")
335
+ with gr.Column():
336
+ img_output = gr.Gallery(label="Output Images")
337
+
338
+ with gr.TabItem("Create Video"):
339
+ with gr.Row():
340
+ with gr.Column():
341
+ vid_button = gr.Button("Generate Video")
342
+ loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
343
+
344
+ with gr.Column():
345
+ vid_output = gr.outputs.Video(label="Output Video")
346
+
347
+ edit_inputs = [editing_type_choice] + ig_edit_choices + sc_edit_choices
348
+ img_button.click(fn=editor.edit_image, inputs=edit_inputs + [input_img, style_choice], outputs=img_output)
349
+ vid_button.click(fn=editor.edit_video, inputs=edit_inputs + [input_img, style_choice, loop_styles], outputs=vid_output)
350
 
351
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.00946' target='_blank'>StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators</a> | <a href='https://stylegan-nada.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/rinongal/StyleGAN-nada' target='_blank'>Code</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=rinong_sgnada' alt='visitor badge'></center>"
352
  gr.Markdown(article)
generate_videos.py CHANGED
@@ -35,12 +35,12 @@ import copy
35
  VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]
36
 
37
  SUGGESTED_DISTANCES = {
38
- "pose": (3.0, -3.0),
39
- "smile": (2.0, -2.0),
40
- "age": (4.0, -4.0),
41
- "gender": (3.0, -3.0),
42
- "hair_length": (None, -4.0),
43
- "beard": (2.0, None)
44
  }
45
 
46
  def project_code(latent_code, boundary, distance=3.0):
@@ -50,21 +50,26 @@ def project_code(latent_code, boundary, distance=3.0):
50
 
51
  return latent_code + distance * boundary
52
 
53
- def generate_frames(args, source_latent, g_ema_list, output_dir):
 
 
 
 
 
 
 
 
54
 
55
  device = "cuda" if torch.cuda.is_available() else "cpu"
56
 
57
- alphas = np.linspace(0, 1, num=20)
58
-
59
- interpolate_func = interpolate_with_boundaries # default
60
- if args.target_latents: # if provided with targets
61
- interpolate_func = interpolate_with_target_latents
62
- if args.unedited_frames: # if only interpolating through generators
63
- interpolate_func = duplicate_latent
64
 
65
- latents = interpolate_func(args, source_latent, alphas)
 
 
66
 
67
  segments = len(g_ema_list) - 1
 
68
  if segments:
69
  segment_length = len(latents) / segments
70
 
@@ -96,50 +101,15 @@ def generate_frames(args, source_latent, g_ema_list, output_dir):
96
  def interpolate_forward_backward(source_latent, target_latent, alphas):
97
  latents_forward = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
98
  latents_backward = latents_forward[::-1] # interpolate from target to source
99
- return latents_forward + [target_latent] * 20 + latents_backward # forward + short delay at target + return
100
 
101
- def duplicate_latent(args, source_latent, alphas):
102
- return [source_latent for _ in range(args.unedited_frames)]
103
-
104
- def interpolate_with_boundaries(args, source_latent, alphas):
105
- edit_directions = args.edit_directions or ['pose', 'smile', 'gender', 'age', 'hair_length']
106
-
107
- # interpolate latent codes with all targets
108
-
109
- print("Interpolating latent codes...")
110
-
111
- boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries")
112
-
113
- boundaries_and_distances = []
114
- for direction_type in edit_directions:
115
- distances = SUGGESTED_DISTANCES[direction_type]
116
- boundary = torch.load(os.path.join(boundary_dir, f'{direction_type}.pt'), map_location="cpu").numpy()
117
-
118
- for distance in distances:
119
- if distance:
120
- boundaries_and_distances.append((boundary, distance))
121
-
122
- latents = []
123
- for boundary, distance in boundaries_and_distances:
124
-
125
- target_latent = project_code(source_latent, boundary, distance)
126
- latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))
127
-
128
- return latents
129
-
130
- def interpolate_with_target_latents(args, source_latent, alphas):
131
  # interpolate latent codes with all targets
132
 
133
  print("Interpolating latent codes...")
134
 
135
  latents = []
136
- for target_latent_path in args.target_latents:
137
-
138
- if target_latent_path == args.source_latent:
139
- continue
140
-
141
- target_latent = np.load(target_latent_path, allow_pickle=True)
142
-
143
  latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))
144
 
145
  return latents
@@ -157,105 +127,5 @@ def video_from_interpolations(fps, output_dir):
157
 
158
  subprocess.call(command)
159
 
160
- def merge_videos(output_dir, num_subdirs):
161
-
162
- output_file = os.path.join(output_dir, "combined.mp4")
163
-
164
- if num_subdirs == 1: # if we only have one video, just copy it over
165
- shutil.copy2(os.path.join(output_dir, str(0), "out.mp4"), output_file)
166
- else: # otherwise merge using ffmpeg
167
- command = ["ffmpeg"]
168
- for dir in range(num_subdirs):
169
- command.extend(['-i', os.path.join(output_dir, str(dir), "out.mp4")])
170
-
171
- sqrt_subdirs = int(num_subdirs ** .5)
172
-
173
- if (sqrt_subdirs ** 2) != num_subdirs:
174
- raise ValueError("Number of checkpoints cannot be arranged in a square grid")
175
-
176
- command.append("-filter_complex")
177
-
178
- filter_string = ""
179
- vstack_string = ""
180
- for row in range(sqrt_subdirs):
181
- row_str = ""
182
- for col in range(sqrt_subdirs):
183
- row_str += f"[{row * sqrt_subdirs + col}:v]"
184
-
185
- letter = chr(ord('A')+row)
186
- row_str += f"hstack=inputs={sqrt_subdirs}[{letter}];"
187
- vstack_string += f"[{letter}]"
188
-
189
- filter_string += row_str
190
-
191
- vstack_string += f"vstack=inputs={sqrt_subdirs}[out]"
192
- filter_string += vstack_string
193
-
194
- command.extend([filter_string, "-map", "[out]", output_file])
195
-
196
- subprocess.call(command)
197
-
198
- def vid_to_gif(vid_path, output_dir, scale=256, fps=35):
199
-
200
- command = ["ffmpeg",
201
- "-i", f"{vid_path}",
202
- "-vf", f"fps={fps},scale={scale}:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1]fifo[s2];[s2][p]paletteuse",
203
- "-loop", "0",
204
- f"{output_dir}/out.gif"]
205
-
206
- subprocess.call(command)
207
-
208
-
209
- if __name__ == '__main__':
210
- device = "cuda" if torch.cuda.is_available() else "cpu"
211
-
212
- parser = argparse.ArgumentParser()
213
-
214
- parser.add_argument('--size', type=int, default=1024)
215
- parser.add_argument('--ckpt', type=str, nargs="+", required=True, help="Path to one or more pre-trained generator checkpoints.")
216
- parser.add_argument('--channel_multiplier', type=int, default=2)
217
- parser.add_argument('--out_dir', type=str, required=True, help="Directory where output files will be placed")
218
- parser.add_argument('--source_latent', type=str, required=True, help="Path to an .npy file containing an initial latent code")
219
- parser.add_argument('--target_latents', nargs="+", type=str, help="A list of paths to .npy files containing target latent codes to interpolate towards, or a directory containing such .npy files.")
220
- parser.add_argument('--force', '-f', action='store_true', help="Force run with non-empty directory. Image files not overwritten by the proccess may still be included in the final video")
221
- parser.add_argument('--fps', default=35, type=int, help='Frames per second in the generated videos.')
222
- parser.add_argument('--edit_directions', nargs="+", type=str, help=f"A list of edit directions to use in video generation (if not using a target latent directory). Available directions are: {VALID_EDITS}")
223
- parser.add_argument('--unedited_frames', type=int, default=0, help="Used to generate videos with no latent editing. If set to a positive number and target_latents is not provided, will simply duplicate the initial frame <unedited_frames> times.")
224
-
225
- args = parser.parse_args()
226
-
227
- os.makedirs(args.out_dir, exist_ok=True)
228
-
229
- if not args.force and os.listdir(args.out_dir):
230
- print("Output directory is not empty. Either delete the directory content or re-run with -f.")
231
- exit(0)
232
-
233
- if args.target_latents and len(args.target_latents) == 1 and os.path.isdir(args.target_latents[0]):
234
- args.target_latents = [os.path.join(args.target_latents[0], file_name) for file_name in os.listdir(args.target_latents[0]) if file_name.endswith(".npy")]
235
- args.target_latents = sorted(args.target_latents)
236
-
237
- args.latent = 512
238
- args.n_mlp = 8
239
-
240
- g_ema = Generator(
241
- args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
242
- ).to(device)
243
-
244
- source_latent = np.load(args.source_latent, allow_pickle=True)
245
-
246
- for idx, ckpt_path in enumerate(args.ckpt):
247
- print(f"Generating video using checkpoint: {ckpt_path}")
248
- checkpoint = torch.load(ckpt_path)
249
-
250
- g_ema.load_state_dict(checkpoint['g_ema'])
251
-
252
- output_dir = os.path.join(args.out_dir, str(idx))
253
- os.makedirs(output_dir)
254
-
255
- generate_frames(args, source_latent, [g_ema], output_dir)
256
- video_from_interpolations(args.fps, output_dir)
257
-
258
- merge_videos(args.out_dir, len(args.ckpt))
259
-
260
 
261
 
 
35
  VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]
36
 
37
  SUGGESTED_DISTANCES = {
38
+ "pose": 3.0,
39
+ "smile": 2.0,
40
+ "age": 4.0,
41
+ "gender": 3.0,
42
+ "hair_length": -4.0,
43
+ "beard": 2.0
44
  }
45
 
46
  def project_code(latent_code, boundary, distance=3.0):
 
50
 
51
  return latent_code + distance * boundary
52
 
53
+ def project_code_by_edit_name(latent_code, name, strength):
54
+ boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries")
55
+
56
+ distance = SUGGESTED_DISTANCES[name] * strength
57
+ boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy()
58
+
59
+ return project_code(latent_code, boundary, distance)
60
+
61
+ def generate_frames(source_latent, target_latents, g_ema_list, output_dir):
62
 
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
 
65
+ num_alphas = min(20, 60 // len(target_latents))
 
 
 
 
 
 
66
 
67
+ alphas = np.linspace(0, 1, num=num_alphas)
68
+
69
+ latents = interpolate_with_target_latents(source_latent, target_latents, alphas)
70
 
71
  segments = len(g_ema_list) - 1
72
+
73
  if segments:
74
  segment_length = len(latents) / segments
75
 
 
101
  def interpolate_forward_backward(source_latent, target_latent, alphas):
102
  latents_forward = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
103
  latents_backward = latents_forward[::-1] # interpolate from target to source
104
+ return latents_forward + [target_latent] * len(alphas) + latents_backward # forward + short delay at target + return
105
 
106
+ def interpolate_with_target_latents(source_latent, target_latents, alphas):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # interpolate latent codes with all targets
108
 
109
  print("Interpolating latent codes...")
110
 
111
  latents = []
112
+ for target_latent in target_latents:
 
 
 
 
 
 
113
  latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))
114
 
115
  return latents
 
127
 
128
  subprocess.call(command)
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131