adirik commited on
Commit
20582f2
1 Parent(s): 27fc598

update demo

Browse files
Files changed (5) hide show
  1. .DS_Store +0 -0
  2. generate_fromS.py +0 -277
  3. generate_multi.py +0 -403
  4. generate_w.py +0 -148
  5. w_s_converter.py +124 -2
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
generate_fromS.py DELETED
@@ -1,277 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Generate images using pretrained network pickle."""
10
-
11
- import os
12
- import re
13
- import random
14
- import math
15
- import time
16
- import click
17
- import legacy
18
- from typing import List, Optional
19
-
20
- import cv2
21
- import clip
22
- import dnnlib
23
- import numpy as np
24
- import torchvision
25
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
26
- import PIL.Image
27
- import matplotlib.pyplot as plt
28
- import torch
29
- from torch import linalg as LA
30
- import torch.nn.functional as F
31
- from torch_utils import misc
32
- from torch_utils import persistence
33
- from torch_utils.ops import conv2d_resample
34
- from torch_utils.ops import upfirdn2d
35
- from torch_utils.ops import bias_act
36
- from torch_utils.ops import fma
37
-
38
-
39
- def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
40
- misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
41
- w_iter = iter(ws.unbind(dim=1))
42
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
43
- memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
44
- if fused_modconv is None:
45
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
46
- fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
47
-
48
- # Input.
49
- if self.in_channels == 0:
50
- x = self.const.to(dtype=dtype, memory_format=memory_format)
51
- x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
52
- else:
53
- misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
54
- x = x.to(dtype=dtype, memory_format=memory_format)
55
-
56
- # Main layers.
57
- if self.in_channels == 0:
58
- x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
59
- elif self.architecture == 'resnet':
60
- y = self.skip(x, gain=np.sqrt(0.5))
61
- x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
62
- x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
63
- x = y.add_(x)
64
- else:
65
- x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
66
- x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs)
67
-
68
- # ToRGB.
69
- if img is not None:
70
- misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
71
- img = upfirdn2d.upsample2d(img, self.resample_filter)
72
- if self.is_last or self.architecture == 'skip':
73
- y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv)
74
- y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
75
- img = img.add_(y) if img is not None else y
76
-
77
- assert x.dtype == dtype
78
- assert img is None or img.dtype == torch.float32
79
- return x, img
80
-
81
-
82
- def unravel_index(index, shape):
83
- out = []
84
- for dim in reversed(shape):
85
- out.append(index % dim)
86
- index = index // dim
87
- return tuple(reversed(out))
88
-
89
-
90
- def num_range(s: str) -> List[int]:
91
- """
92
- Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.
93
- """
94
-
95
- range_re = re.compile(r'^(\d+)-(\d+)$')
96
- m = range_re.match(s)
97
- if m:
98
- return list(range(int(m.group(1)), int(m.group(2))+1))
99
- vals = s.split(',')
100
- return [int(x) for x in vals]
101
-
102
-
103
- @click.command()
104
- @click.pass_context
105
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
106
- @click.option('--seeds', type=num_range, help='List of random seeds')
107
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.7, show_default=True)
108
- @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
109
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
110
- @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
111
- @click.option('--s_input', help='Projection result file', type=str, metavar='FILE')
112
- @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
113
- @click.option('--text_prompt', help='Text', type=str, required=True)
114
- @click.option('--change_power', help='Change power', type=int, required=True)
115
- @click.option('--from_video', 'from_video', is_flag=True, help="generate from video")
116
-
117
- def generate_images(
118
- ctx: click.Context,
119
- network_pkl: str,
120
- seeds: Optional[List[int]],
121
- truncation_psi: float,
122
- noise_mode: str,
123
- outdir: str,
124
- class_idx: Optional[int],
125
- projected_w: Optional[str],
126
- s_input: Optional[str],
127
- text_prompt: str,
128
- change_power: int,
129
- from_video: bool,
130
- ):
131
- """
132
- Generate images using pretrained network pickle.
133
-
134
- Examples:
135
- # Generate curated MetFaces images without truncation (Fig.10 left)
136
- python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
137
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
138
-
139
- # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
140
- python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
141
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
142
-
143
- # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
144
- python generate.py --outdir=out --seeds=0-35 --class=1 \\
145
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
146
-
147
- # Render an image from projected W
148
- python generate.py --outdir=out --projected_w=projected_w.npz \\
149
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
150
- """
151
-
152
- print('Loading networks from "%s"...' % network_pkl)
153
- # Use GPU if available
154
- if torch.cuda.is_available():
155
- device = torch.device("cuda")
156
- else:
157
- device = torch.device("cpu")
158
-
159
- with dnnlib.util.open_url(network_pkl) as f:
160
- G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
161
-
162
- os.makedirs(outdir, exist_ok=True)
163
-
164
- # Synthesize the result of a W projection.
165
- if projected_w is not None:
166
- if seeds is not None:
167
- print ('warn: --seeds is ignored when using --projected-w')
168
- print(f'Generating images from projected W "{projected_w}"')
169
- ws = np.load(projected_w)['w']
170
- ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
171
- assert ws.shape[1:] == (G.num_ws, G.w_dim)
172
- for idx, w in enumerate(ws):
173
- img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
174
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
175
- img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
176
- img.save(f'{outdir}/proj{idx:02d}.png')
177
- return
178
-
179
- # Labels
180
- label = torch.zeros([1, G.c_dim], device=device).requires_grad_()
181
- if G.c_dim != 0:
182
- if class_idx is None:
183
- ctx.fail('Must specify class label with --class when using a conditional network')
184
- label[:, class_idx] = 1
185
- else:
186
- if class_idx is not None:
187
- print ('warn: --class=lbl ignored when running on an unconditional network')
188
-
189
- # Generate images
190
- for i in G.parameters():
191
- i.requires_grad = False
192
-
193
-
194
- temp_shapes = []
195
- for res in G.synthesis.block_resolutions:
196
- block = getattr(G.synthesis, f'b{res}')
197
- if res == 4:
198
- temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
199
- block.conv1.affine = torch.nn.Identity()
200
- block.torgb.affine = torch.nn.Identity()
201
-
202
- else:
203
- temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
204
- block.conv0.affine = torch.nn.Identity()
205
- block.conv1.affine = torch.nn.Identity()
206
- block.torgb.affine = torch.nn.Identity()
207
-
208
- temp_shapes.append(temp_shape)
209
-
210
-
211
- if s_input is not None:
212
- styles = np.load(s_input)['s']
213
- styles_direction = np.load(f'{outdir}/direction_'+text_prompt.replace(" ", "_")+'.npz')['s']
214
-
215
- styles_direction = torch.tensor(styles_direction, device=device)
216
- styles = torch.tensor(styles, device=device)
217
-
218
- if from_video and not os.path.isdir(f'{outdir}_video'):
219
- os.makedirs(f'{outdir}_video')
220
-
221
- with torch.no_grad():
222
- if from_video:
223
- name_i = 1000
224
- for grad_change in np.arange(0, 1, 0.02)*change_power:
225
- imgs = []
226
- name_i += 1
227
-
228
- styles += styles_direction*grad_change
229
- styles_idx = 0
230
- x = img = None
231
- for k , res in enumerate(G.synthesis.block_resolutions):
232
- block = getattr(G.synthesis, f'b{res}')
233
-
234
- if res == 4:
235
- x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
236
- styles_idx += 2
237
- else:
238
- x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
239
- styles_idx += 3
240
-
241
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
242
- imgs.append(img[0].to(torch.uint8).cpu().numpy())
243
-
244
- styles -= styles_direction*grad_change
245
- img_filepath = '{}_video/{}_{}_{}.jpeg'.format(outdir, text_prompt.replace(" ", "_"), change_power, name_i)
246
- PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(img_filepath, quality=95)
247
- else:
248
- imgs = []
249
- grad_changes = [0, 0.25*change_power, 0.5*change_power, 0.75*change_power, change_power]
250
-
251
- for grad_change in grad_changes:
252
- styles += styles_direction*grad_change
253
-
254
- styles_idx = 0
255
- x = img = None
256
- for k , res in enumerate(G.synthesis.block_resolutions):
257
- block = getattr(G.synthesis, f'b{res}')
258
-
259
- if res == 4:
260
- x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
261
- styles_idx += 2
262
- else:
263
- x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
264
- styles_idx += 3
265
-
266
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
267
- imgs.append(img[0].to(torch.uint8).cpu().numpy())
268
-
269
- styles -= styles_direction*grad_change
270
-
271
- img_filepath = f'{outdir}/'+text_prompt.replace(" ", "_")+'_'+str(change_power)+'.jpeg'
272
- PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(img_filepath, quality=95)
273
-
274
-
275
-
276
- if __name__ == "__main__":
277
- generate_images()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate_multi.py DELETED
@@ -1,403 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Generate images using pretrained network pickle."""
10
-
11
- import os
12
- import re
13
- from typing import List, Optional
14
- import torchvision
15
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
16
- import click
17
- import dnnlib
18
- import numpy as np
19
- import PIL.Image
20
- import torch
21
- from torch import linalg as LA
22
- import clip
23
- from PIL import Image
24
- import legacy
25
- import torch.nn.functional as F
26
- import cv2
27
- import matplotlib.pyplot as plt
28
- from torch_utils import misc
29
- from torch_utils import persistence
30
- from torch_utils.ops import conv2d_resample
31
- from torch_utils.ops import upfirdn2d
32
- from torch_utils.ops import bias_act
33
- from torch_utils.ops import fma
34
- import random
35
- import math
36
- import time
37
- import id_loss
38
-
39
-
40
- def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
41
- misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
42
- w_iter = iter(ws.unbind(dim=1))
43
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
44
- memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
45
- if fused_modconv is None:
46
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
47
- fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
48
-
49
- # Input.
50
- if self.in_channels == 0:
51
- x = self.const.to(dtype=dtype, memory_format=memory_format)
52
- x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
53
- else:
54
- misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
55
- x = x.to(dtype=dtype, memory_format=memory_format)
56
-
57
- # Main layers.
58
- if self.in_channels == 0:
59
- x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
60
- elif self.architecture == 'resnet':
61
- y = self.skip(x, gain=np.sqrt(0.5))
62
- x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
63
- x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
64
- x = y.add_(x)
65
- else:
66
- x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
67
- x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs)
68
-
69
- # ToRGB.
70
- if img is not None:
71
- misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
72
- img = upfirdn2d.upsample2d(img, self.resample_filter)
73
- if self.is_last or self.architecture == 'skip':
74
- y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv)
75
- y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
76
- img = img.add_(y) if img is not None else y
77
-
78
- assert x.dtype == dtype
79
- assert img is None or img.dtype == torch.float32
80
- return x, img
81
-
82
- def unravel_index(index, shape):
83
- out = []
84
- for dim in reversed(shape):
85
- out.append(index % dim)
86
- index = index // dim
87
- return tuple(reversed(out))
88
-
89
-
90
- #----------------------------------------------------------------------------
91
-
92
- def num_range(s: str) -> List[int]:
93
- '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
94
-
95
- range_re = re.compile(r'^(\d+)-(\d+)$')
96
- m = range_re.match(s)
97
- if m:
98
- return list(range(int(m.group(1)), int(m.group(2))+1))
99
- vals = s.split(',')
100
- return [int(x) for x in vals]
101
-
102
- #----------------------------------------------------------------------------
103
-
104
- @click.command()
105
- @click.pass_context
106
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
107
- @click.option('--seeds', type=num_range, help='List of random seeds')
108
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
109
- @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
110
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
111
- @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
112
- @click.option('--projected_s', help='Projection result file', type=str, metavar='FILE')
113
- @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
114
- @click.option('--resolution', help='Resolution of output images', type=int, required=True)
115
- @click.option('--batch_size', help='Batch Size', type=int, required=True)
116
- @click.option('--identity_power', help='How much change occurs on the face', type=str, required=True)
117
- def generate_images(
118
- ctx: click.Context,
119
- network_pkl: str,
120
- seeds: Optional[List[int]],
121
- truncation_psi: float,
122
- noise_mode: str,
123
- outdir: str,
124
- class_idx: Optional[int],
125
- projected_w: Optional[str],
126
- projected_s: Optional[str],
127
- resolution: int,
128
- batch_size: int,
129
- identity_power: str
130
- ):
131
- """Generate images using pretrained network pickle.
132
-
133
- Examples:
134
-
135
- \b
136
- # Generate curated MetFaces images without truncation (Fig.10 left)
137
- python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
138
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
139
-
140
- \b
141
- # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
142
- python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
143
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
144
-
145
- \b
146
- # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
147
- python generate.py --outdir=out --seeds=0-35 --class=1 \\
148
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
149
-
150
- \b
151
- # Render an image from projected W
152
- python generate.py --outdir=out --projected_w=projected_w.npz \\
153
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
154
- """
155
-
156
- print('Loading networks from "%s"...' % network_pkl)
157
- device = torch.device('cuda')
158
- with dnnlib.util.open_url(network_pkl) as f:
159
- G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
160
-
161
- os.makedirs(outdir, exist_ok=True)
162
-
163
- # Synthesize the result of a W projection.
164
- if projected_w is not None:
165
- if seeds is not None:
166
- print ('warn: --seeds is ignored when using --projected-w')
167
- print(f'Generating images from projected W "{projected_w}"')
168
- ws = np.load(projected_w)['w']
169
- ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
170
- assert ws.shape[1:] == (G.num_ws, G.w_dim)
171
- for idx, w in enumerate(ws):
172
- img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
173
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
174
- img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
175
- return
176
-
177
- if seeds is None:
178
- ctx.fail('--seeds option is required when not using --projected-w')
179
-
180
- # Labels.
181
- label = torch.zeros([1, G.c_dim], device=device).requires_grad_()
182
- if G.c_dim != 0:
183
- if class_idx is None:
184
- ctx.fail('Must specify class label with --class when using a conditional network')
185
- label[:, class_idx] = 1
186
- else:
187
- if class_idx is not None:
188
- print ('warn: --class=lbl ignored when running on an unconditional network')
189
-
190
- model, preprocess = clip.load("ViT-B/32", device=device)
191
-
192
- text_prompts_file = open("text_prompts.txt")
193
- text_prompts = text_prompts_file.read().split("\n")
194
- text_prompts_file.close()
195
-
196
- text = clip.tokenize(text_prompts).to(device)
197
- text_features = model.encode_text(text)
198
-
199
- # Generate images.
200
- for i in G.parameters():
201
- i.requires_grad = True
202
-
203
- mean = torch.as_tensor((0.48145466, 0.4578275, 0.40821073), dtype=torch.float, device=device)
204
- std = torch.as_tensor((0.26862954, 0.26130258, 0.27577711), dtype=torch.float, device=device)
205
- if mean.ndim == 1:
206
- mean = mean.view(-1, 1, 1)
207
- if std.ndim == 1:
208
- std = std.view(-1, 1, 1)
209
-
210
- transf = Compose([
211
- Resize(224, interpolation=Image.BICUBIC),
212
- CenterCrop(224),
213
- ])
214
-
215
- styles_array = []
216
- print("seeds:", seeds)
217
- t1 = time.time()
218
- for seed_idx, seed in enumerate(seeds):
219
- if seed==seeds[-1]:
220
- print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
221
- z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
222
- ws = G.mapping(z, label, truncation_psi=truncation_psi)
223
-
224
- block_ws = []
225
- with torch.autograd.profiler.record_function('split_ws'):
226
- misc.assert_shape(ws, [None, G.synthesis.num_ws, G.synthesis.w_dim])
227
- ws = ws.to(torch.float32)
228
-
229
-
230
- w_idx = 0
231
- for res in G.synthesis.block_resolutions:
232
- block = getattr(G.synthesis, f'b{res}')
233
- block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
234
- w_idx += block.num_conv
235
-
236
-
237
- styles = torch.zeros(1,26,512, device=device)
238
- styles_idx = 0
239
- temp_shapes = []
240
- for res, cur_ws in zip(G.synthesis.block_resolutions, block_ws):
241
- block = getattr(G.synthesis, f'b{res}')
242
-
243
- if res == 4:
244
- temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
245
- styles[0,:1,:] = block.conv1.affine(cur_ws[0,:1,:])
246
- styles[0,1:2,:] = block.torgb.affine(cur_ws[0,1:2,:])
247
- if seed_idx==(len(seeds)-1):
248
- block.conv1.affine = torch.nn.Identity()
249
- block.torgb.affine = torch.nn.Identity()
250
- styles_idx += 2
251
- else:
252
- temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
253
- styles[0,styles_idx:styles_idx+1,:temp_shape[0]] = block.conv0.affine(cur_ws[0,:1,:])
254
- styles[0,styles_idx+1:styles_idx+2,:temp_shape[1]] = block.conv1.affine(cur_ws[0,1:2,:])
255
- styles[0,styles_idx+2:styles_idx+3,:temp_shape[2]] = block.torgb.affine(cur_ws[0,2:3,:])
256
- if seed_idx==(len(seeds)-1):
257
- block.conv0.affine = torch.nn.Identity()
258
- block.conv1.affine = torch.nn.Identity()
259
- block.torgb.affine = torch.nn.Identity()
260
- styles_idx += 3
261
- temp_shapes.append(temp_shape)
262
-
263
-
264
- styles = styles.detach()
265
- styles_array.append(styles)
266
-
267
- resolution_dict = {256: 6, 512: 7, 1024: 8}
268
- identity_coefficient_dict = {"high": 2,"medium": 0.5, "low": 0.1, "none": 0}
269
- identity_coefficient = identity_coefficient_dict[identity_power]
270
- styles_wanted_direction = torch.zeros(1,26,512, device=device)
271
- styles_wanted_direction_grad_el2 = torch.zeros(1,26,512, device=device)
272
- styles_wanted_direction.requires_grad_()
273
-
274
- global id_loss
275
- id_loss = id_loss.IDLoss("a").to(device).eval()
276
-
277
- temp_photos = []
278
- grads = []
279
- for i in range(math.ceil(len(seeds)/batch_size)):
280
- #print(i*batch_size, "processed", time.time()-t1)
281
-
282
-
283
- styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
284
-
285
-
286
- seed = seeds[i]
287
-
288
- styles_idx = 0
289
- x2 = img2 = None
290
-
291
- for k , (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)):
292
- block = getattr(G.synthesis, f'b{res}')
293
- if k>resolution_dict[resolution]:
294
- continue
295
-
296
- if res == 4:
297
- x2, img2 = block_forward(block, x2, img2, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode)
298
- styles_idx += 2
299
- else:
300
- x2, img2 = block_forward(block, x2, img2, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode)
301
- styles_idx += 3
302
-
303
- img2_cpu = img2.detach().cpu().numpy()
304
- temp_photos.append(img2_cpu)
305
- if i>3:
306
- continue
307
-
308
- styles2 = styles + styles_wanted_direction
309
-
310
- styles_idx = 0
311
- x = img = None
312
- for k , (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)):
313
- block = getattr(G.synthesis, f'b{res}')
314
- if k>resolution_dict[resolution]:
315
- continue
316
- if res == 4:
317
- x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode)
318
- styles_idx += 2
319
- else:
320
- x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode)
321
- styles_idx += 3
322
-
323
- identity_loss, _ = id_loss(img, img2)
324
- identity_loss *= identity_coefficient
325
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
326
- img = (transf(img.permute(0, 3, 1, 2))/255).sub_(mean).div_(std)
327
- image_features = model.encode_image(img)
328
- cos_sim = -1*F.cosine_similarity(image_features, (text_features[0]).unsqueeze(0))
329
- (identity_loss + cos_sim.sum()).backward(retain_graph=True)
330
-
331
-
332
-
333
-
334
- #t1 = time.time()
335
-
336
- for text_counter in range(len(text_prompts)):
337
- text_prompt = text_prompts[text_counter]
338
- print(text_prompt)
339
-
340
- styles_wanted_direction.grad.data.zero_()
341
- styles_wanted_direction_grad_el2 = torch.zeros(1,26,512, device=device)
342
- with torch.no_grad():
343
- styles_wanted_direction *= 0
344
-
345
- for i in range(math.ceil(len(seeds)/batch_size)):
346
- print(i*batch_size, "processed", time.time()-t1)
347
-
348
-
349
- styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
350
-
351
-
352
- seed = seeds[i]
353
-
354
- img2 = torch.tensor(temp_photos[i]).to(device)
355
-
356
- styles2 = styles + styles_wanted_direction
357
-
358
- styles_idx = 0
359
- x = img = None
360
- for k , (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)):
361
- block = getattr(G.synthesis, f'b{res}')
362
- if k>resolution_dict[resolution]:
363
- continue
364
-
365
- if res == 4:
366
- x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode)
367
- styles_idx += 2
368
- else:
369
- x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode)
370
- styles_idx += 3
371
-
372
- identity_loss, _ = id_loss(img, img2)
373
- identity_loss *= identity_coefficient
374
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
375
- img = (transf(img.permute(0, 3, 1, 2))/255).sub_(mean).div_(std)
376
- image_features = model.encode_image(img)
377
- cos_sim = -1*F.cosine_similarity(image_features, (text_features[text_counter]).unsqueeze(0))
378
- (identity_loss + cos_sim.sum()).backward(retain_graph=True)
379
-
380
-
381
- styles_wanted_direction.grad[:, [0, 1, 4, 7, 10, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], :] = 0
382
-
383
-
384
- if i%2==1:
385
- styles_wanted_direction.data = styles_wanted_direction - styles_wanted_direction.grad*5
386
- grads.append(styles_wanted_direction.grad.clone())
387
- styles_wanted_direction.grad.data.zero_()
388
-
389
- if i>3:
390
- styles_wanted_direction_grad_el2[grads[-2]*grads[-1]<0] += 1
391
-
392
-
393
- styles_wanted_direction_cpu = styles_wanted_direction.detach()
394
- styles_wanted_direction_cpu[styles_wanted_direction_grad_el2>(len(seeds)/batch_size)/4] = 0
395
- np.savez(f'{outdir}/direction_'+text_prompt.replace(" ", "_")+'.npz', s=styles_wanted_direction_cpu.cpu().numpy())
396
-
397
- print("time passed:", time.time()-t1)
398
- #----------------------------------------------------------------------------
399
-
400
- if __name__ == "__main__":
401
- generate_images() # pylint: disable=no-value-for-parameter
402
-
403
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate_w.py DELETED
@@ -1,148 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Generate images using pretrained network pickle."""
10
-
11
- import os
12
- import re
13
- from typing import List, Optional
14
- import torchvision
15
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
16
- import click
17
- import dnnlib
18
- import numpy as np
19
- import PIL.Image
20
- import torch
21
- from torch import linalg as LA
22
- import clip
23
- from PIL import Image
24
- import legacy
25
- import torch.nn.functional as F
26
- import cv2
27
- import matplotlib.pyplot as plt
28
- from torch_utils import misc
29
- from torch_utils import persistence
30
- from torch_utils.ops import conv2d_resample
31
- from torch_utils.ops import upfirdn2d
32
- from torch_utils.ops import bias_act
33
- from torch_utils.ops import fma
34
- import random
35
- import math
36
- import time
37
- import id_loss
38
-
39
-
40
- def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
41
- misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
42
- w_iter = iter(ws.unbind(dim=1))
43
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
44
- memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
45
- if fused_modconv is None:
46
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
47
- fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
48
-
49
- # Input.
50
- if self.in_channels == 0:
51
- x = self.const.to(dtype=dtype, memory_format=memory_format)
52
- x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
53
- else:
54
- misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
55
- x = x.to(dtype=dtype, memory_format=memory_format)
56
-
57
- # Main layers.
58
- if self.in_channels == 0:
59
- x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
60
- elif self.architecture == 'resnet':
61
- y = self.skip(x, gain=np.sqrt(0.5))
62
- x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
63
- x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
64
- x = y.add_(x)
65
- else:
66
- x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
67
- x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs)
68
-
69
- # ToRGB.
70
- if img is not None:
71
- misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
72
- img = upfirdn2d.upsample2d(img, self.resample_filter)
73
- if self.is_last or self.architecture == 'skip':
74
- y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv)
75
- y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
76
- img = img.add_(y) if img is not None else y
77
-
78
- assert x.dtype == dtype
79
- assert img is None or img.dtype == torch.float32
80
- return x, img
81
-
82
- def unravel_index(index, shape):
83
- out = []
84
- for dim in reversed(shape):
85
- out.append(index % dim)
86
- index = index // dim
87
- return tuple(reversed(out))
88
-
89
-
90
- def num_range(s: str) -> List[int]:
91
- '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
92
-
93
- range_re = re.compile(r'^(\d+)-(\d+)$')
94
- m = range_re.match(s)
95
- if m:
96
- return list(range(int(m.group(1)), int(m.group(2))+1))
97
- vals = s.split(',')
98
- return [int(x) for x in vals]
99
-
100
-
101
- @click.command()
102
- @click.pass_context
103
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
104
- @click.option('--seeds', type=num_range, help='List of random seeds')
105
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
106
- @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
107
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
108
- def generate_images(
109
- ctx: click.Context,
110
- network_pkl: str,
111
- seeds: Optional[List[int]],
112
- truncation_psi: float,
113
- noise_mode: str,
114
- class_idx: Optional[int],
115
- projected_w: Optional[str],
116
- projected_s: Optional[str]
117
- ):
118
-
119
- print('Loading networks from "%s"...' % network_pkl)
120
- # Use GPU if available
121
- if torch.cuda.is_available():
122
- device = torch.device("cuda")
123
- else:
124
- device = torch.device("cpu")
125
-
126
- with dnnlib.util.open_url(network_pkl) as f:
127
- G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
128
-
129
- if seeds is None:
130
- ctx.fail('--seeds option is required when not using --projected-w')
131
-
132
- # Labels.
133
- label = torch.zeros([1, G.c_dim], device=device).requires_grad_()
134
- if G.c_dim != 0:
135
- if class_idx is None:
136
- ctx.fail('Must specify class label with --class when using a conditional network')
137
- label[:, class_idx] = 1
138
- else:
139
- if class_idx is not None:
140
- print ('warn: --class=lbl ignored when running on an unconditional network')
141
-
142
- z = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device)
143
- ws = G.mapping(z, label, truncation_psi=truncation_psi)
144
- np.savez(f'encoder4editing/projected_w.npz', w=ws.detach().cpu().numpy())
145
-
146
-
147
- if __name__ == "__main__":
148
- generate_images()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
w_s_converter.py CHANGED
@@ -10,7 +10,7 @@
10
 
11
  import os
12
  import re
13
- from typing import List
14
 
15
  import numpy as np
16
  import torch
@@ -23,6 +23,13 @@ from torch_utils.ops import bias_act
23
  from torch_utils.ops import fma
24
 
25
 
 
 
 
 
 
 
 
26
  def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
27
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
28
  w_iter = iter(ws.unbind(dim=1))
@@ -66,13 +73,56 @@ def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None
66
  return x, img
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def unravel_index(index, shape):
70
  out = []
71
  for dim in reversed(shape):
72
  out.append(index % dim)
73
  index = index // dim
74
  return tuple(reversed(out))
75
-
76
 
77
  def w_to_s(
78
  G,
@@ -136,3 +186,75 @@ def w_to_s(
136
 
137
  styles = styles.detach()
138
  np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  import os
12
  import re
13
+ from typing import List, Optional
14
 
15
  import numpy as np
16
  import torch
 
23
  from torch_utils.ops import fma
24
 
25
 
26
+ import click
27
+
28
+ import PIL.Image
29
+ from torch import linalg as LA
30
+ import torch.nn.functional as F
31
+
32
+
33
  def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
34
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
35
  w_iter = iter(ws.unbind(dim=1))
 
73
  return x, img
74
 
75
 
76
+ def block_forward_from_style(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
77
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
78
+ w_iter = iter(ws.unbind(dim=1))
79
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
80
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
81
+ if fused_modconv is None:
82
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
83
+ fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
84
+
85
+ # Input.
86
+ if self.in_channels == 0:
87
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
88
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
89
+ else:
90
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
91
+ x = x.to(dtype=dtype, memory_format=memory_format)
92
+
93
+ # Main layers.
94
+ if self.in_channels == 0:
95
+ x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
96
+ elif self.architecture == 'resnet':
97
+ y = self.skip(x, gain=np.sqrt(0.5))
98
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
99
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
100
+ x = y.add_(x)
101
+ else:
102
+ x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
103
+ x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs)
104
+
105
+ # ToRGB.
106
+ if img is not None:
107
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
108
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
109
+ if self.is_last or self.architecture == 'skip':
110
+ y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv)
111
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
112
+ img = img.add_(y) if img is not None else y
113
+
114
+ assert x.dtype == dtype
115
+ assert img is None or img.dtype == torch.float32
116
+ return x, img
117
+
118
+
119
  def unravel_index(index, shape):
120
  out = []
121
  for dim in reversed(shape):
122
  out.append(index % dim)
123
  index = index // dim
124
  return tuple(reversed(out))
125
+
126
 
127
  def w_to_s(
128
  G,
 
186
 
187
  styles = styles.detach()
188
  np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy())
189
+
190
+
191
+ def generate_from_style(
192
+ G,
193
+ outdir: str,
194
+ s_input: str,
195
+ text_prompt: str,
196
+ change_power: int,
197
+ truncation_psi: float = 0.7,
198
+ noise_mode: str = "const",
199
+ ):
200
+ # Use GPU if available
201
+ if torch.cuda.is_available():
202
+ device = torch.device("cuda")
203
+ else:
204
+ device = torch.device("cpu")
205
+
206
+ os.makedirs(outdir, exist_ok=True)
207
+
208
+ # Generate images
209
+ for i in G.parameters():
210
+ i.requires_grad = False
211
+
212
+ temp_shapes = []
213
+ for res in G.synthesis.block_resolutions:
214
+ block = getattr(G.synthesis, f'b{res}')
215
+ if res == 4:
216
+ temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
217
+ block.conv1.affine = torch.nn.Identity()
218
+ block.torgb.affine = torch.nn.Identity()
219
+
220
+ else:
221
+ temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
222
+ block.conv0.affine = torch.nn.Identity()
223
+ block.conv1.affine = torch.nn.Identity()
224
+ block.torgb.affine = torch.nn.Identity()
225
+
226
+ temp_shapes.append(temp_shape)
227
+
228
+ if s_input is not None:
229
+ styles = np.load(s_input)['s']
230
+ styles_direction = np.load(f'{outdir}/direction_'+text_prompt.replace(" ", "_")+'.npz')['s']
231
+
232
+ styles_direction = torch.tensor(styles_direction, device=device)
233
+ styles = torch.tensor(styles, device=device)
234
+
235
+ with torch.no_grad():
236
+ imgs = []
237
+ grad_changes = [0, 0.25*change_power, 0.5*change_power, 0.75*change_power, change_power]
238
+
239
+ for grad_change in grad_changes:
240
+ styles += styles_direction*grad_change
241
+
242
+ styles_idx = 0
243
+ x = img = None
244
+ for k , res in enumerate(G.synthesis.block_resolutions):
245
+ block = getattr(G.synthesis, f'b{res}')
246
+
247
+ if res == 4:
248
+ x, img = block_forward_from_style(block, x, img, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
249
+ styles_idx += 2
250
+ else:
251
+ x, img = block_forward_from_style(block, x, img, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
252
+ styles_idx += 3
253
+
254
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
255
+ imgs.append(img[0].to(torch.uint8).cpu().numpy())
256
+
257
+ styles -= styles_direction*grad_change
258
+
259
+ img_filepath = f'{outdir}/'+text_prompt.replace(" ", "_")+'_'+str(change_power)+'.jpeg'
260
+ PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(img_filepath, quality=95)