adirik commited on
Commit
ff50bb1
1 Parent(s): b262a3f

update find_direction

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. find_direction.py +13 -121
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
find_direction.py CHANGED
@@ -8,34 +8,17 @@
8
 
9
  """Generate images using pretrained network pickle."""
10
 
11
- import os
12
- import re
13
- import random
14
  import math
15
- import time
16
- import click
17
  import legacy
18
- from typing import List, Optional
19
-
20
- import cv2
21
  import clip
22
  import dnnlib
23
  import numpy as np
24
  import torch
25
- from torch import linalg as LA
26
  import torch.nn.functional as F
27
- import torchvision
28
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
29
- import PIL.Image
30
  from PIL import Image
31
- import matplotlib.pyplot as plt
32
-
33
  from torch_utils import misc
34
- from torch_utils import persistence
35
- from torch_utils.ops import conv2d_resample
36
  from torch_utils.ops import upfirdn2d
37
- from torch_utils.ops import bias_act
38
- from torch_utils.ops import fma
39
  import id_loss
40
 
41
 
@@ -81,8 +64,6 @@ def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None
81
  assert img is None or img.dtype == torch.float32
82
  return x, img
83
 
84
-
85
-
86
  def unravel_index(index, shape):
87
  out = []
88
  for dim in reversed(shape):
@@ -90,108 +71,27 @@ def unravel_index(index, shape):
90
  index = index // dim
91
  return tuple(reversed(out))
92
 
93
-
94
- def num_range(s: str) -> List[int]:
95
- """
96
- Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.
97
- """
98
-
99
- range_re = re.compile(r'^(\d+)-(\d+)$')
100
- m = range_re.match(s)
101
- if m:
102
- return list(range(int(m.group(1)), int(m.group(2)) + 1))
103
- vals = s.split(',')
104
- return [int(x) for x in vals]
105
-
106
-
107
- @click.command()
108
- @click.pass_context
109
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
110
- @click.option('--seeds', type=num_range, help='List of random seeds')
111
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
112
- @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
113
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
114
- @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
115
- @click.option('--projected_s', help='Projection result file', type=str, metavar='FILE')
116
- @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
117
- @click.option('--text_prompt', help='Text', type=str, required=True)
118
- @click.option('--resolution', help='Resolution of output images', type=int, required=True)
119
- @click.option('--batch_size', help='Batch Size', type=int, required=True)
120
- @click.option('--identity_power', help='How much change occurs on the face', type=str, required=True)
121
- def generate_images(
122
- ctx: click.Context,
123
  network_pkl: str,
124
- seeds: Optional[List[int]],
125
- truncation_psi: float,
126
- noise_mode: str,
127
- outdir: str,
128
- class_idx: Optional[int],
129
- projected_w: Optional[str],
130
- projected_s: Optional[str],
131
  text_prompt: str,
132
- resolution: int,
133
- batch_size: int,
134
- identity_power: str,
 
135
  ):
136
- """
137
- Generate images using pretrained network pickle.
138
-
139
- Examples:
140
- # Generate curated MetFaces images without truncation (Fig.10 left)
141
- python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
142
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
143
-
144
- # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
145
- python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
146
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
147
-
148
- # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
149
- python generate.py --outdir=out --seeds=0-35 --class=1 \\
150
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
151
-
152
- # Render an image from projected W
153
- python generate.py --outdir=out --projected_w=projected_w.npz \\
154
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
155
- """
156
 
 
157
  print('Loading networks from "%s"...' % network_pkl)
158
- # Use GPU if available
159
- if torch.cuda.is_available():
160
- device = torch.device("cuda")
161
- else:
162
- device = torch.device("cpu")
163
-
164
  with dnnlib.util.open_url(network_pkl) as f:
165
  G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
166
 
167
- os.makedirs(outdir, exist_ok=True)
168
-
169
- # Synthesize the result of a W projection
170
- if projected_w is not None:
171
- if seeds is not None:
172
- print('warn: --seeds is ignored when using --projected-w')
173
- print(f'Generating images from projected W "{projected_w}"')
174
- ws = np.load(projected_w)['w']
175
- ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
176
- assert ws.shape[1:] == (G.num_ws, G.w_dim)
177
- for idx, w in enumerate(ws):
178
- img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
179
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
180
- img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
181
- return
182
-
183
- if seeds is None:
184
- ctx.fail('--seeds option is required when not using --projected-w')
185
-
186
  # Labels
 
187
  label = torch.zeros([1, G.c_dim], device=device).requires_grad_()
188
  if G.c_dim != 0:
189
- if class_idx is None:
190
- ctx.fail('Must specify class label with --class when using a conditional network')
191
  label[:, class_idx] = 1
192
- else:
193
- if class_idx is not None:
194
- print('warn: --class=lbl ignored when running on an unconditional network')
195
 
196
  model, preprocess = clip.load("ViT-B/32", device=device)
197
  text = clip.tokenize([text_prompt]).to(device)
@@ -211,8 +111,6 @@ def generate_images(
211
  transf = Compose([Resize(224, interpolation=Image.BICUBIC), CenterCrop(224)])
212
 
213
  styles_array = []
214
- print("seeds:", seeds)
215
- t1 = time.time()
216
  for seed_idx, seed in enumerate(seeds):
217
  if seed == seeds[-1]:
218
  print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
@@ -260,8 +158,7 @@ def generate_images(
260
  styles_array.append(styles)
261
 
262
  resolution_dict = {256: 6, 512: 7, 1024: 8}
263
- id_coeff_dict = {"high": 2, "medium": 0.5, "low": 0.1, "none": 0}
264
- id_coeff = id_coeff_dict[identity_power]
265
  styles_direction = torch.zeros(1, 26, 512, device=device)
266
  styles_direction_grad_el2 = torch.zeros(1, 26, 512, device=device)
267
  styles_direction.requires_grad_()
@@ -272,7 +169,6 @@ def generate_images(
272
  temp_photos = []
273
  grads = []
274
  for i in range(math.ceil(len(seeds) / batch_size)):
275
- # print(i*batch_size, "processed", time.time()-t1)
276
 
277
  styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
278
  seed = seeds[i]
@@ -325,6 +221,7 @@ def generate_images(
325
  styles_direction *= 0
326
 
327
  for i in range(math.ceil(len(seeds) / batch_size)):
 
328
  seed = seeds[i]
329
  styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
330
  img2 = torch.tensor(temp_photos[i]).to(device)
@@ -364,9 +261,4 @@ def generate_images(
364
  styles_direction = styles_direction.detach()
365
  styles_direction[styles_direction_grad_el2 > (len(seeds) / batch_size) / 4] = 0
366
 
367
- output_filepath = f'{outdir}/direction_' + text_prompt.replace(" ", "_") + '.npz'
368
- np.savez(output_filepath, s=styles_direction.cpu().numpy())
369
-
370
-
371
- if __name__ == "__main__":
372
- generate_images()
 
8
 
9
  """Generate images using pretrained network pickle."""
10
 
 
 
 
11
  import math
 
 
12
  import legacy
 
 
 
13
  import clip
14
  import dnnlib
15
  import numpy as np
16
  import torch
 
17
  import torch.nn.functional as F
18
+ from torchvision.transforms import Compose, Resize, CenterCrop
 
 
19
  from PIL import Image
 
 
20
  from torch_utils import misc
 
 
21
  from torch_utils.ops import upfirdn2d
 
 
22
  import id_loss
23
 
24
 
 
64
  assert img is None or img.dtype == torch.float32
65
  return x, img
66
 
 
 
67
  def unravel_index(index, shape):
68
  out = []
69
  for dim in reversed(shape):
 
71
  index = index // dim
72
  return tuple(reversed(out))
73
 
74
+ def find_direction(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  network_pkl: str,
 
 
 
 
 
 
 
76
  text_prompt: str,
77
+ truncation_psi: float = 0.7,
78
+ noise_mode: str = "const",
79
+ resolution: int = 256,
80
+ identity_power: float = 0.5,
81
  ):
82
+ seeds=np.random.randint(0, 1000, 128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ batch_size=1
85
  print('Loading networks from "%s"...' % network_pkl)
86
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
87
  with dnnlib.util.open_url(network_pkl) as f:
88
  G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # Labels
91
+ class_idx=None
92
  label = torch.zeros([1, G.c_dim], device=device).requires_grad_()
93
  if G.c_dim != 0:
 
 
94
  label[:, class_idx] = 1
 
 
 
95
 
96
  model, preprocess = clip.load("ViT-B/32", device=device)
97
  text = clip.tokenize([text_prompt]).to(device)
 
111
  transf = Compose([Resize(224, interpolation=Image.BICUBIC), CenterCrop(224)])
112
 
113
  styles_array = []
 
 
114
  for seed_idx, seed in enumerate(seeds):
115
  if seed == seeds[-1]:
116
  print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
 
158
  styles_array.append(styles)
159
 
160
  resolution_dict = {256: 6, 512: 7, 1024: 8}
161
+ id_coeff = identity_power
 
162
  styles_direction = torch.zeros(1, 26, 512, device=device)
163
  styles_direction_grad_el2 = torch.zeros(1, 26, 512, device=device)
164
  styles_direction.requires_grad_()
 
169
  temp_photos = []
170
  grads = []
171
  for i in range(math.ceil(len(seeds) / batch_size)):
 
172
 
173
  styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
174
  seed = seeds[i]
 
221
  styles_direction *= 0
222
 
223
  for i in range(math.ceil(len(seeds) / batch_size)):
224
+
225
  seed = seeds[i]
226
  styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
227
  img2 = torch.tensor(temp_photos[i]).to(device)
 
261
  styles_direction = styles_direction.detach()
262
  styles_direction[styles_direction_grad_el2 > (len(seeds) / batch_size) / 4] = 0
263
 
264
+ return styles_direction.cpu().numpy()