adirik commited on
Commit
27fc598
1 Parent(s): b5e8b97

update w_to_s converter

Browse files
Files changed (1) hide show
  1. w_s_converter.py +24 -74
w_s_converter.py CHANGED
@@ -10,24 +10,10 @@
10
 
11
  import os
12
  import re
13
- import random
14
- import math
15
- import time
16
- import click
17
- import legacy
18
- from typing import List, Optional
19
-
20
- import cv2
21
- import clip
22
- import dnnlib
23
  import numpy as np
24
- import torchvision
25
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
26
  import torch
27
- from torch import linalg as LA
28
- import torch.nn.functional as F
29
- from PIL import Image
30
- import matplotlib.pyplot as plt
31
 
32
  from torch_utils import misc
33
  from torch_utils import persistence
@@ -86,56 +72,27 @@ def unravel_index(index, shape):
86
  out.append(index % dim)
87
  index = index // dim
88
  return tuple(reversed(out))
 
89
 
90
-
91
- def num_range(s: str) -> List[int]:
92
- '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
93
-
94
- range_re = re.compile(r'^(\d+)-(\d+)$')
95
- m = range_re.match(s)
96
- if m:
97
- return list(range(int(m.group(1)), int(m.group(2))+1))
98
- vals = s.split(',')
99
- return [int(x) for x in vals]
100
-
101
-
102
- @click.command()
103
- @click.pass_context
104
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
105
- @click.option('--seeds', type=num_range, help='List of random seeds')
106
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
107
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
108
- @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
109
- @click.option('--projected_s', help='Projection result file', type=str, metavar='FILE')
110
- @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
111
-
112
- def generate_images(
113
- ctx: click.Context,
114
- network_pkl: str,
115
- seeds: Optional[List[int]],
116
- truncation_psi: float,
117
- noise_mode: str,
118
  outdir: str,
119
- class_idx: Optional[int],
120
- projected_w: Optional[str],
121
- projected_s: Optional[str]
122
  ):
123
 
124
- print('Loading networks from "%s"...' % network_pkl)
125
  # Use GPU if available
126
  if torch.cuda.is_available():
127
  device = torch.device("cuda")
128
  else:
129
  device = torch.device("cpu")
130
 
131
- with dnnlib.util.open_url(network_pkl) as f:
132
- G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
133
-
134
  os.makedirs(outdir, exist_ok=True)
135
 
136
  # Generate images.
137
  for i in G.parameters():
138
- i.requires_grad = True
139
 
140
  ws = np.load(projected_w)['w']
141
  ws = torch.tensor(ws, device=device)
@@ -145,14 +102,12 @@ def generate_images(
145
  misc.assert_shape(ws, [None, G.synthesis.num_ws, G.synthesis.w_dim])
146
  ws = ws.to(torch.float32)
147
 
148
-
149
  w_idx = 0
150
  for res in G.synthesis.block_resolutions:
151
  block = getattr(G.synthesis, f'b{res}')
152
  block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
153
  w_idx += block.num_conv
154
 
155
-
156
  styles = torch.zeros(1,26,512, device=device)
157
  styles_idx = 0
158
  temp_shapes = []
@@ -160,29 +115,24 @@ def generate_images(
160
  block = getattr(G.synthesis, f'b{res}')
161
 
162
  if res == 4:
163
- temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
164
- styles[0,:1,:] = block.conv1.affine(cur_ws[0,:1,:])
165
- styles[0,1:2,:] = block.torgb.affine(cur_ws[0,1:2,:])
166
 
167
- block.conv1.affine = torch.nn.Identity()
168
- block.torgb.affine = torch.nn.Identity()
169
- styles_idx += 2
170
  else:
171
- temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
172
- styles[0,styles_idx:styles_idx+1,:temp_shape[0]] = block.conv0.affine(cur_ws[0,:1,:])
173
- styles[0,styles_idx+1:styles_idx+2,:temp_shape[1]] = block.conv1.affine(cur_ws[0,1:2,:])
174
- styles[0,styles_idx+2:styles_idx+3,:temp_shape[2]] = block.torgb.affine(cur_ws[0,2:3,:])
175
-
176
- block.conv0.affine = torch.nn.Identity()
177
- block.conv1.affine = torch.nn.Identity()
178
- block.torgb.affine = torch.nn.Identity()
179
- styles_idx += 3
180
  temp_shapes.append(temp_shape)
181
 
182
-
183
  styles = styles.detach()
184
  np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy())
185
-
186
-
187
- if __name__ == "__main__":
188
- generate_images()
 
10
 
11
  import os
12
  import re
13
+ from typing import List
14
+
 
 
 
 
 
 
 
 
15
  import numpy as np
 
 
16
  import torch
 
 
 
 
17
 
18
  from torch_utils import misc
19
  from torch_utils import persistence
 
72
  out.append(index % dim)
73
  index = index // dim
74
  return tuple(reversed(out))
75
+
76
 
77
+ def w_to_s(
78
+ G,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  outdir: str,
80
+ projected_w: str,
81
+ truncation_psi: float = 0.7,
82
+ noise_mode: str = "const",
83
  ):
84
 
 
85
  # Use GPU if available
86
  if torch.cuda.is_available():
87
  device = torch.device("cuda")
88
  else:
89
  device = torch.device("cpu")
90
 
 
 
 
91
  os.makedirs(outdir, exist_ok=True)
92
 
93
  # Generate images.
94
  for i in G.parameters():
95
+ i.requires_grad = True
96
 
97
  ws = np.load(projected_w)['w']
98
  ws = torch.tensor(ws, device=device)
 
102
  misc.assert_shape(ws, [None, G.synthesis.num_ws, G.synthesis.w_dim])
103
  ws = ws.to(torch.float32)
104
 
 
105
  w_idx = 0
106
  for res in G.synthesis.block_resolutions:
107
  block = getattr(G.synthesis, f'b{res}')
108
  block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
109
  w_idx += block.num_conv
110
 
 
111
  styles = torch.zeros(1,26,512, device=device)
112
  styles_idx = 0
113
  temp_shapes = []
 
115
  block = getattr(G.synthesis, f'b{res}')
116
 
117
  if res == 4:
118
+ temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
119
+ styles[0,:1,:] = block.conv1.affine(cur_ws[0,:1,:])
120
+ styles[0,1:2,:] = block.torgb.affine(cur_ws[0,1:2,:])
121
 
122
+ block.conv1.affine = torch.nn.Identity()
123
+ block.torgb.affine = torch.nn.Identity()
124
+ styles_idx += 2
125
  else:
126
+ temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
127
+ styles[0,styles_idx:styles_idx+1,:temp_shape[0]] = block.conv0.affine(cur_ws[0,:1,:])
128
+ styles[0,styles_idx+1:styles_idx+2,:temp_shape[1]] = block.conv1.affine(cur_ws[0,1:2,:])
129
+ styles[0,styles_idx+2:styles_idx+3,:temp_shape[2]] = block.torgb.affine(cur_ws[0,2:3,:])
130
+
131
+ block.conv0.affine = torch.nn.Identity()
132
+ block.conv1.affine = torch.nn.Identity()
133
+ block.torgb.affine = torch.nn.Identity()
134
+ styles_idx += 3
135
  temp_shapes.append(temp_shape)
136
 
 
137
  styles = styles.detach()
138
  np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy())