adirik commited on
Commit
e90f2c5
1 Parent(s): 3576406

add wrapper

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.gitattributes CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  filter=lfs diff=lfs merge=lfs -text
37
  *.pkl* filter=lfs diff=lfs merge=lfs -text
38
  filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  filter=lfs diff=lfs merge=lfs -text
37
  *.pkl* filter=lfs diff=lfs merge=lfs -text
38
  filter=lfs diff=lfs merge=lfs -text
39
+ *.dat* filter=lfs diff=lfs merge=lfs -text
40
+ *.pt* filter=lfs diff=lfs merge=lfs -text
find_direction.py CHANGED
@@ -20,7 +20,7 @@ from PIL import Image
20
  from torch_utils import misc
21
  from torch_utils.ops import upfirdn2d
22
  import id_loss
23
-
24
 
25
  def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
26
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
@@ -72,13 +72,14 @@ def unravel_index(index, shape):
72
  return tuple(reversed(out))
73
 
74
  def find_direction(
75
- G,
76
  text_prompt: str,
77
  truncation_psi: float = 0.7,
78
  noise_mode: str = "const",
79
  resolution: int = 256,
80
  identity_power: float = 0.5,
81
  ):
 
82
  seeds=np.random.randint(0, 1000, 128)
83
 
84
  batch_size=1
@@ -160,8 +161,9 @@ def find_direction(
160
  styles_direction_grad_el2 = torch.zeros(1, 26, 512, device=device)
161
  styles_direction.requires_grad_()
162
 
163
- global id_loss
164
- id_loss = id_loss.IDLoss("a").to(device).eval()
 
165
 
166
  temp_photos = []
167
  grads = []
@@ -205,7 +207,7 @@ def find_direction(
205
  x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
206
  styles_idx += 3
207
 
208
- identity_loss, _ = id_loss(img, img2)
209
  identity_loss *= id_coeff
210
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
211
  img = (transf(img.permute(0, 3, 1, 2)) / 255).sub_(mean).div_(std)
@@ -238,7 +240,7 @@ def find_direction(
238
  x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
239
  styles_idx += 3
240
 
241
- identity_loss, _ = id_loss(img, img2)
242
  identity_loss *= id_coeff
243
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
244
  img = (transf(img.permute(0, 3, 1, 2)) / 255).sub_(mean).div_(std)
 
20
  from torch_utils import misc
21
  from torch_utils.ops import upfirdn2d
22
  import id_loss
23
+ from copy import deepcopy
24
 
25
  def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
26
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
 
72
  return tuple(reversed(out))
73
 
74
  def find_direction(
75
+ GIn,
76
  text_prompt: str,
77
  truncation_psi: float = 0.7,
78
  noise_mode: str = "const",
79
  resolution: int = 256,
80
  identity_power: float = 0.5,
81
  ):
82
+ G = deepcopy(GIn)
83
  seeds=np.random.randint(0, 1000, 128)
84
 
85
  batch_size=1
 
161
  styles_direction_grad_el2 = torch.zeros(1, 26, 512, device=device)
162
  styles_direction.requires_grad_()
163
 
164
+ global id_loss2
165
+ #id_loss = id_loss.IDLoss("a").to(device).eval()
166
+ id_loss2 = id_loss.IDLoss("a").to(device).eval()
167
 
168
  temp_photos = []
169
  grads = []
 
207
  x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
208
  styles_idx += 3
209
 
210
+ identity_loss, _ = id_loss2(img, img2)
211
  identity_loss *= id_coeff
212
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
213
  img = (transf(img.permute(0, 3, 1, 2)) / 255).sub_(mean).div_(std)
 
240
  x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True)
241
  styles_idx += 3
242
 
243
+ identity_loss, _ = id_loss2(img, img2)
244
  identity_loss *= id_coeff
245
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
246
  img = (transf(img.permute(0, 3, 1, 2)) / 255).sub_(mean).div_(std)
generator.py CHANGED
@@ -21,6 +21,7 @@ from torch_utils.ops import conv2d_resample
21
  from torch_utils.ops import upfirdn2d
22
  from torch_utils.ops import bias_act
23
  from torch_utils.ops import fma
 
24
 
25
 
26
  import click
@@ -125,13 +126,14 @@ def unravel_index(index, shape):
125
 
126
 
127
  def w_to_s(
128
- G,
129
- outdir: str,
130
- projected_w: str,
131
  truncation_psi: float = 0.7,
132
  noise_mode: str = "const",
133
  ):
134
-
 
135
  # Use GPU if available
136
  if torch.cuda.is_available():
137
  device = torch.device("cuda")
@@ -144,8 +146,8 @@ def w_to_s(
144
  for i in G.parameters():
145
  i.requires_grad = True
146
 
147
- ws = np.load(projected_w)['w']
148
- ws = torch.tensor(ws, device=device)
149
 
150
  block_ws = []
151
  with torch.autograd.profiler.record_function('split_ws'):
@@ -186,17 +188,18 @@ def w_to_s(
186
 
187
  styles = styles.detach()
188
  np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy())
189
-
190
 
191
  def generate_from_style(
192
- G,
 
 
193
  outdir: str,
194
- s_input: str,
195
- text_prompt: str,
196
  change_power: int,
197
  truncation_psi: float = 0.7,
198
  noise_mode: str = "const",
199
  ):
 
200
  # Use GPU if available
201
  if torch.cuda.is_available():
202
  device = torch.device("cuda")
@@ -225,16 +228,12 @@ def generate_from_style(
225
 
226
  temp_shapes.append(temp_shape)
227
 
228
- if s_input is not None:
229
- styles = np.load(s_input)['s']
230
- styles_direction = np.load(f'{outdir}/direction_'+text_prompt.replace(" ", "_")+'.npz')['s']
231
-
232
- styles_direction = torch.tensor(styles_direction, device=device)
233
- styles = torch.tensor(styles, device=device)
234
 
235
  with torch.no_grad():
236
  imgs = []
237
- grad_changes = [0, 0.25*change_power, 0.5*change_power, 0.75*change_power, change_power]
238
 
239
  for grad_change in grad_changes:
240
  styles += styles_direction*grad_change
@@ -256,5 +255,6 @@ def generate_from_style(
256
 
257
  styles -= styles_direction*grad_change
258
 
259
- img_filepath = f'{outdir}/'+text_prompt.replace(" ", "_")+'_'+str(change_power)+'.jpeg'
260
- PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(img_filepath, quality=95)
 
 
21
  from torch_utils.ops import upfirdn2d
22
  from torch_utils.ops import bias_act
23
  from torch_utils.ops import fma
24
+ from copy import deepcopy
25
 
26
 
27
  import click
 
126
 
127
 
128
  def w_to_s(
129
+ GIn,
130
+ wsIn:np.ndarray,
131
+ outdir: str ="s_out",
132
  truncation_psi: float = 0.7,
133
  noise_mode: str = "const",
134
  ):
135
+ G=deepcopy(GIn)
136
+
137
  # Use GPU if available
138
  if torch.cuda.is_available():
139
  device = torch.device("cuda")
 
146
  for i in G.parameters():
147
  i.requires_grad = True
148
 
149
+ # ws = np.load(projected_w)['w']
150
+ ws = torch.tensor(wsIn, device=device)
151
 
152
  block_ws = []
153
  with torch.autograd.profiler.record_function('split_ws'):
 
188
 
189
  styles = styles.detach()
190
  np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy())
191
+ return styles.cpu().numpy()
192
 
193
  def generate_from_style(
194
+ GIn,
195
+ styles: np.ndarray,
196
+ styles_direction: np.ndarray,
197
  outdir: str,
 
 
198
  change_power: int,
199
  truncation_psi: float = 0.7,
200
  noise_mode: str = "const",
201
  ):
202
+ G=deepcopy(GIn)
203
  # Use GPU if available
204
  if torch.cuda.is_available():
205
  device = torch.device("cuda")
 
228
 
229
  temp_shapes.append(temp_shape)
230
 
231
+ styles_direction = torch.tensor(styles_direction, device=device)
232
+ styles = torch.tensor(styles, device=device)
 
 
 
 
233
 
234
  with torch.no_grad():
235
  imgs = []
236
+ grad_changes = [change_power]
237
 
238
  for grad_change in grad_changes:
239
  styles += styles_direction*grad_change
 
255
 
256
  styles -= styles_direction*grad_change
257
 
258
+ output_image = PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB')
259
+ output_image.save(os.path.join(outdir, 'final_out.png'), quality=95)
260
+ return output_image
pretrained/.DS_Store CHANGED
Binary files a/pretrained/.DS_Store and b/pretrained/.DS_Store differ
 
pretrained/e4e_ffhq_encode.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ace1d9a8c05c10a399bcd500b8dda118f759ff1aac89dbdab7435f2136a0999
3
+ size 1201649680
pretrained/shape_predictor_68_face_landmarks.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
3
+ size 99693937
psp_wrapper.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import sys
3
+
4
+ sys.path.append(".")
5
+ sys.path.append("..")
6
+ sys.path.append("./encoder4editing")
7
+
8
+ from PIL import Image
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ import dlib
12
+ from utils.alignment import align_face
13
+
14
+
15
+ from utils.common import tensor2im
16
+ from models.psp import pSp # we use the pSp framework to load the e4e encoder.
17
+ experiment_type = 'ffhq_encode'
18
+
19
+ EXPERIMENT_DATA_ARGS = {
20
+ "ffhq_encode": {
21
+ "model_path": "encoder4editing/e4e_ffhq_encode.pt",
22
+ "image_path": "notebooks/images/input_img.jpg"
23
+ },
24
+ }
25
+ # Setup required image transformations
26
+ EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
27
+ EXPERIMENT_ARGS['transform'] = transforms.Compose([
28
+ transforms.Resize((256, 256)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
31
+
32
+ class psp_encoder:
33
+ def __init__(self, model_path: str, shape_predictor_path: str):
34
+ self.ckpt = torch.load(model_path, map_location='cpu')
35
+ self.opts = self.ckpt['opts']
36
+ # update the training options
37
+ self.opts['checkpoint_path'] = model_path
38
+ self.opts= Namespace(**self.opts)
39
+ self.net = pSp(self.opts)
40
+ self.net.eval()
41
+ self.net.cuda()
42
+ self.shape_predictor = dlib.shape_predictor(shape_predictor_path)
43
+
44
+ def get_w(self, image_path):
45
+ original_image = Image.open(image_path)
46
+ original_image = original_image.convert("RGB")
47
+ input_image = align_face(filepath=image_path, predictor=self.shape_predictor)
48
+ resize_dims = (256, 256)
49
+ input_image.resize(resize_dims)
50
+ img_transforms = EXPERIMENT_ARGS['transform']
51
+ transformed_image = img_transforms(input_image)
52
+ with torch.no_grad():
53
+ _, latents = self.net(transformed_image.unsqueeze(0).to("cuda").float(), randomize_noise=False, return_latents=True)
54
+ return latents.cpu().numpy()