rinong commited on
Commit
fcf0449
1 Parent(s): 6104a4e

Added StyleCLIP support

Browse files
Files changed (3) hide show
  1. app.py +32 -6
  2. generate_videos.py +2 -2
  3. styleclip/styleclip_global.py +158 -0
app.py CHANGED
@@ -19,12 +19,16 @@ from torchvision import utils
19
 
20
  from model.sg2_model import Generator
21
  from generate_videos import generate_frames, video_from_interpolations, project_code_by_edit_name
 
 
 
22
 
23
  model_dir = "models"
24
  os.makedirs(model_dir, exist_ok=True)
25
 
26
  model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
27
  "dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
 
28
  "base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
29
  "anime": ("rinong/stylegan-nada-models", "anime.pt"),
30
  "joker": ("rinong/stylegan-nada-models", "joker.pt"),
@@ -70,7 +74,7 @@ class ImageEditor(object):
70
 
71
  self.generators = {}
72
 
73
- self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
74
 
75
  for model in self.model_list:
76
  g_ema = Generator(
@@ -108,6 +112,10 @@ class ImageEditor(object):
108
  model_paths["dlib"]
109
  )
110
 
 
 
 
 
111
  print("setup complete")
112
 
113
  def get_style_list(self):
@@ -186,7 +194,15 @@ class ImageEditor(object):
186
  target_latents.append(project_code_by_edit_name(np_source_latent, attribute_name, strength))
187
 
188
  elif edit_choices["edit_type"] == "StyleCLIP":
189
- pass
 
 
 
 
 
 
 
 
190
 
191
  # if edit type is none or if all slides were set to 0
192
  if not target_latents:
@@ -228,9 +244,13 @@ class ImageEditor(object):
228
  with torch.no_grad():
229
  for g_ema in generators:
230
  latent_for_gen = random.choice(target_latents)
231
- latent_for_gen = [torch.from_numpy(latent_for_gen).float().to(self.device)]
232
 
233
- img, _ = g_ema(latent_for_gen, input_is_latent=True, truncation=1, randomize_noise=False)
 
 
 
 
 
234
 
235
  output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
236
  utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
@@ -294,6 +314,9 @@ with blocks:
294
  gr.Markdown(
295
  "For more information about the paper and code for training your own models (with examples OR text), see below."
296
  )
 
 
 
297
 
298
  with gr.Row():
299
  input_img = gr.inputs.Image(type="filepath", label="Input image")
@@ -306,7 +329,8 @@ with blocks:
306
  with gr.Tabs():
307
  with gr.TabItem("InterFaceGAN Editing Options"):
308
  gr.Markdown("Move the sliders to make the chosen attribute stronger (e.g. the person older) or leave at 0 to disable editing.")
309
- gr.Markdown("If multiple options are provided, they will be used randomly between images (or sequentially for a video), <u>not</u> together")
 
310
 
311
  pose_slider = gr.Slider(label="Pose", minimum=-1, maximum=1, value=0, step=0.05)
312
  smile_slider = gr.Slider(label="Smile", minimum=-1, maximum=1, value=0, step=0.05)
@@ -343,7 +367,9 @@ with blocks:
343
  with gr.Row():
344
  vid_button = gr.Button("Generate Video")
345
  loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
346
-
 
 
347
  with gr.Column():
348
  vid_output = gr.outputs.Video(label="Output Video")
349
 
 
19
 
20
  from model.sg2_model import Generator
21
  from generate_videos import generate_frames, video_from_interpolations, project_code_by_edit_name
22
+ from styleclip.styleclip_global import project_code_with_styleclip, style_tensor_to_style_dict
23
+
24
+ import clip
25
 
26
  model_dir = "models"
27
  os.makedirs(model_dir, exist_ok=True)
28
 
29
  model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
30
  "dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
31
+ "sc_fs3": ("rinong/stylegan-nada-models", "fs3.npy"),
32
  "base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
33
  "anime": ("rinong/stylegan-nada-models", "anime.pt"),
34
  "joker": ("rinong/stylegan-nada-models", "joker.pt"),
 
74
 
75
  self.generators = {}
76
 
77
+ self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib", "sc_fs3"]]
78
 
79
  for model in self.model_list:
80
  g_ema = Generator(
 
112
  model_paths["dlib"]
113
  )
114
 
115
+ self.styleclip_fs3 = torch.from_numpy(np.load(model_paths["sc_fs3"])).to(self.device)
116
+
117
+ self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
118
+
119
  print("setup complete")
120
 
121
  def get_style_list(self):
 
194
  target_latents.append(project_code_by_edit_name(np_source_latent, attribute_name, strength))
195
 
196
  elif edit_choices["edit_type"] == "StyleCLIP":
197
+ source_s_dict = generators[0].get_s_code(source_latent, input_is_latent=True)
198
+ target_latents.append(project_code_with_styleclip(source_s_dict,
199
+ edit_choices["src_text"],
200
+ edit_choices["tar_text"],
201
+ edit_choices["alpha"],
202
+ edit_choices["beta"],
203
+ generators[0],
204
+ self.styleclip_fs3,
205
+ self.clip_model))
206
 
207
  # if edit type is none or if all slides were set to 0
208
  if not target_latents:
 
244
  with torch.no_grad():
245
  for g_ema in generators:
246
  latent_for_gen = random.choice(target_latents)
 
247
 
248
+ if edit_choices["edit_type"] == "StyleCLIP":
249
+ latent_for_gen = style_tensor_to_style_dict(latent_for_gen, g_ema)
250
+ img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
251
+ else:
252
+ latent_for_gen = [torch.from_numpy(latent_for_gen).float().to(self.device)]
253
+ img, _ = g_ema(latent_for_gen, input_is_latent=True, truncation=1, randomize_noise=False)
254
 
255
  output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
256
  utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
 
314
  gr.Markdown(
315
  "For more information about the paper and code for training your own models (with examples OR text), see below."
316
  )
317
+
318
+
319
+ gr.Markdown("<h4 style='font-size: 110%;margin-top:.5em'>On biases</h4><div>This model relies on StyleGAN and CLIP, both of which are prone to biases such as poor representation of minorities or reinforcement of societal biases, such as gender norms. </div>")
320
 
321
  with gr.Row():
322
  input_img = gr.inputs.Image(type="filepath", label="Input image")
 
329
  with gr.Tabs():
330
  with gr.TabItem("InterFaceGAN Editing Options"):
331
  gr.Markdown("Move the sliders to make the chosen attribute stronger (e.g. the person older) or leave at 0 to disable editing.")
332
+ gr.Markdown("If multiple options are provided, they will be used randomly between images (or sequentially for a video), <u>not</u> together.")
333
+ gr.Markdown("Please note that some directions may be entangled. For example, hair length adjustments are likely to also modify the perceived gender.")
334
 
335
  pose_slider = gr.Slider(label="Pose", minimum=-1, maximum=1, value=0, step=0.05)
336
  smile_slider = gr.Slider(label="Smile", minimum=-1, maximum=1, value=0, step=0.05)
 
367
  with gr.Row():
368
  vid_button = gr.Button("Generate Video")
369
  loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
370
+ with gr.Row():
371
+ gr.Markdown("Warning: Videos generation requires the synthesis of hundreds of frames and is expected to take several minutes.")
372
+ gr.Markdown("To reduce queue times, we significantly reduced the number of video frames. Using more than 3 styles will further reduce the frames per style, leading to quicker transitions. For better control, we reccomend cloning the gradio app, adjusting `num_alphas` in `generate_videos`, and running the code locally.")
373
  with gr.Column():
374
  vid_output = gr.outputs.Video(label="Output Video")
375
 
generate_videos.py CHANGED
@@ -62,14 +62,14 @@ def generate_frames(source_latent, target_latents, g_ema_list, output_dir):
62
 
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
 
65
- num_alphas = min(20, 60 // len(target_latents))
66
 
67
  alphas = np.linspace(0, 1, num=num_alphas)
68
 
69
  latents = interpolate_with_target_latents(source_latent, target_latents, alphas)
70
 
71
  segments = len(g_ema_list) - 1
72
-
73
  if segments:
74
  segment_length = len(latents) / segments
75
 
 
62
 
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
 
65
+ num_alphas = min(10, 30 // len(target_latents))
66
 
67
  alphas = np.linspace(0, 1, num=num_alphas)
68
 
69
  latents = interpolate_with_target_latents(source_latent, target_latents, alphas)
70
 
71
  segments = len(g_ema_list) - 1
72
+
73
  if segments:
74
  segment_length = len(latents) / segments
75
 
styleclip/styleclip_global.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ from pathlib import Path
5
+ import os
6
+
7
+ import clip
8
+
9
+ imagenet_templates = [
10
+ 'a bad photo of a {}.',
11
+ 'a photo of many {}.',
12
+ 'a sculpture of a {}.',
13
+ 'a photo of the hard to see {}.',
14
+ 'a low resolution photo of the {}.',
15
+ 'a rendering of a {}.',
16
+ 'graffiti of a {}.',
17
+ 'a bad photo of the {}.',
18
+ 'a cropped photo of the {}.',
19
+ 'a tattoo of a {}.',
20
+ 'the embroidered {}.',
21
+ 'a photo of a hard to see {}.',
22
+ 'a bright photo of a {}.',
23
+ 'a photo of a clean {}.',
24
+ 'a photo of a dirty {}.',
25
+ 'a dark photo of the {}.',
26
+ 'a drawing of a {}.',
27
+ 'a photo of my {}.',
28
+ 'the plastic {}.',
29
+ 'a photo of the cool {}.',
30
+ 'a close-up photo of a {}.',
31
+ 'a black and white photo of the {}.',
32
+ 'a painting of the {}.',
33
+ 'a painting of a {}.',
34
+ 'a pixelated photo of the {}.',
35
+ 'a sculpture of the {}.',
36
+ 'a bright photo of the {}.',
37
+ 'a cropped photo of a {}.',
38
+ 'a plastic {}.',
39
+ 'a photo of the dirty {}.',
40
+ 'a jpeg corrupted photo of a {}.',
41
+ 'a blurry photo of the {}.',
42
+ 'a photo of the {}.',
43
+ 'a good photo of the {}.',
44
+ 'a rendering of the {}.',
45
+ 'a {} in a video game.',
46
+ 'a photo of one {}.',
47
+ 'a doodle of a {}.',
48
+ 'a close-up photo of the {}.',
49
+ 'a photo of a {}.',
50
+ 'the origami {}.',
51
+ 'the {} in a video game.',
52
+ 'a sketch of a {}.',
53
+ 'a doodle of the {}.',
54
+ 'a origami {}.',
55
+ 'a low resolution photo of a {}.',
56
+ 'the toy {}.',
57
+ 'a rendition of the {}.',
58
+ 'a photo of the clean {}.',
59
+ 'a photo of a large {}.',
60
+ 'a rendition of a {}.',
61
+ 'a photo of a nice {}.',
62
+ 'a photo of a weird {}.',
63
+ 'a blurry photo of a {}.',
64
+ 'a cartoon {}.',
65
+ 'art of a {}.',
66
+ 'a sketch of the {}.',
67
+ 'a embroidered {}.',
68
+ 'a pixelated photo of a {}.',
69
+ 'itap of the {}.',
70
+ 'a jpeg corrupted photo of the {}.',
71
+ 'a good photo of a {}.',
72
+ 'a plushie {}.',
73
+ 'a photo of the nice {}.',
74
+ 'a photo of the small {}.',
75
+ 'a photo of the weird {}.',
76
+ 'the cartoon {}.',
77
+ 'art of the {}.',
78
+ 'a drawing of the {}.',
79
+ 'a photo of the large {}.',
80
+ 'a black and white photo of a {}.',
81
+ 'the plushie {}.',
82
+ 'a dark photo of a {}.',
83
+ 'itap of a {}.',
84
+ 'graffiti of the {}.',
85
+ 'a toy {}.',
86
+ 'itap of my {}.',
87
+ 'a photo of a cool {}.',
88
+ 'a photo of a small {}.',
89
+ 'a tattoo of the {}.',
90
+ ]
91
+
92
+ FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
93
+ [(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
94
+
95
+ def zeroshot_classifier(model, classnames, templates, device):
96
+
97
+ with torch.no_grad():
98
+ zeroshot_weights = []
99
+ for classname in tqdm(classnames):
100
+ texts = [template.format(classname) for template in templates] # format with class
101
+ texts = clip.tokenize(texts).to(device) # tokenize
102
+ class_embeddings = model.encode_text(texts) # embed with text encoder
103
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
104
+ class_embedding = class_embeddings.mean(dim=0)
105
+ class_embedding /= class_embedding.norm()
106
+ zeroshot_weights.append(class_embedding)
107
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
108
+ return zeroshot_weights
109
+
110
+
111
+ def get_direction(neutral_class, target_class, beta, di, clip_model=None):
112
+
113
+ device = "cuda" if torch.cuda.is_available() else "cpu"
114
+
115
+ if clip_model is None:
116
+ clip_model, _ = clip.load("ViT-B/32", device=device)
117
+
118
+ class_names = [neutral_class, target_class]
119
+ class_weights = zeroshot_classifier(clip_model, class_names, imagenet_templates, device)
120
+
121
+ dt = class_weights[:, 1] - class_weights[:, 0]
122
+ dt = dt / dt.norm()
123
+ relevance = di @ dt
124
+ mask = relevance.abs() > beta
125
+ direction = relevance * mask
126
+ direction_max = direction.abs().max()
127
+ if direction_max > 0:
128
+ direction = direction / direction_max
129
+ else:
130
+ raise ValueError(f'Beta value {beta} is too high for mapping from {neutral_class} to {target_class},'
131
+ f' try setting it to a lower value')
132
+ return direction
133
+
134
+ def style_tensor_to_style_dict(style_tensor, refernce_generator):
135
+ style_layers = refernce_generator.modulation_layers
136
+
137
+ style_dict = {}
138
+ for layer_idx, layer in enumerate(style_layers):
139
+ style_dict[layer] = style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]]
140
+
141
+ return style_dict
142
+
143
+ def style_dict_to_style_tensor(style_dict, reference_generator):
144
+ style_layers = reference_generator.modulation_layers
145
+
146
+ style_tensor = torch.zeros(shape=(1, 9088))
147
+ for layer in style_dict:
148
+ layer_idx = style_layers.index(layer)
149
+ style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]
150
+
151
+ return style_tensor
152
+
153
+ def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
154
+ edit_direction = get_direction(source_class, target_class, beta)
155
+
156
+ source_s = style_dict_to_style_tensor(source_latent, reference_generator)
157
+
158
+ return source_s + alpha * edit_direction