Erwann Millon commited on
Commit
eac223c
1 Parent(s): 71b70df

cleanup and refactoring

Browse files
Files changed (11) hide show
  1. ImageState.py +0 -9
  2. animation.py +1 -6
  3. app.py +1 -1
  4. app_backend.py +21 -55
  5. configs.py +0 -7
  6. edit.py +4 -15
  7. img_processing.py +1 -1
  8. loaders.py +6 -22
  9. utils.py +1 -1
  10. vqgan_latent_ops.py +0 -14
  11. vqgan_only.pt +0 -3
ImageState.py CHANGED
@@ -63,24 +63,15 @@ class ImageState:
63
  def _decode_latent_to_pil(self, latent):
64
  current_im = self.vqgan.decode(latent.to(self.device))[0]
65
  return custom_to_pil(current_im)
66
- # def _get_current_vector_transforms(self):
67
- # current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
68
- # return (self.blend_latent, current_vector_transforms)
69
- # @cache
70
  def get_mask(self, img, mask=None):
71
  if img and "mask" in img and img["mask"] is not None:
72
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
73
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
74
- plt.imshow(attn_mask.detach().cpu(), cmap="Blues")
75
- plt.show()
76
- torch.save(attn_mask, "test_mask.pt")
77
  print("mask set successfully")
78
- # attn_mask = self.rescale_mask(attn_mask)
79
  print(type(attn_mask))
80
  print(attn_mask.shape)
81
  else:
82
  attn_mask = mask
83
- print("mask in apply ", get_resized_tensor(attn_mask), get_resized_tensor(attn_mask).shape)
84
  return attn_mask
85
  def set_mask(self, img):
86
  attn_mask = self.get_mask(img)
 
63
  def _decode_latent_to_pil(self, latent):
64
  current_im = self.vqgan.decode(latent.to(self.device))[0]
65
  return custom_to_pil(current_im)
 
 
 
 
66
  def get_mask(self, img, mask=None):
67
  if img and "mask" in img and img["mask"] is not None:
68
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
69
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
 
 
 
70
  print("mask set successfully")
 
71
  print(type(attn_mask))
72
  print(attn_mask.shape)
73
  else:
74
  attn_mask = mask
 
75
  return attn_mask
76
  def set_mask(self, img):
77
  attn_mask = self.get_mask(img)
animation.py CHANGED
@@ -8,7 +8,6 @@ def clear_img_dir():
8
  os.mkdir(img_dir)
9
  for filename in glob.glob(img_dir+"/*"):
10
  os.remove(filename)
11
-
12
 
13
  def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
14
  images = []
@@ -23,12 +22,8 @@ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="
23
  if file_name.endswith('.png'):
24
  file_path = os.path.join(folder, file_name)
25
  images.append(imageio.imread(file_path))
26
- # images[0] = images[0].set_meta_data({'duration': 1})
27
- # images[-1] = images[-1].set_meta_data({'duration': 1})
28
  imageio.mimsave(gif_name, images, duration=durations)
29
  return gif_name
30
 
31
  if __name__ == "__main__":
32
- # clear_img_dir()
33
- create_gif()
34
- # make_animation()
 
8
  os.mkdir(img_dir)
9
  for filename in glob.glob(img_dir+"/*"):
10
  os.remove(filename)
 
11
 
12
  def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
13
  images = []
 
22
  if file_name.endswith('.png'):
23
  file_path = os.path.join(folder, file_name)
24
  images.append(imageio.imread(file_path))
 
 
25
  imageio.mimsave(gif_name, images, duration=durations)
26
  return gif_name
27
 
28
  if __name__ == "__main__":
29
+ create_gif()
 
 
app.py CHANGED
@@ -4,7 +4,7 @@ import sys
4
 
5
  import wandb
6
 
7
- from configs import set_major_global, set_major_local, set_small_local
8
 
9
  sys.path.append("taming-transformers")
10
  import functools
 
4
 
5
  import wandb
6
 
7
+ from presets import set_major_global, set_major_local, set_small_local
8
 
9
  sys.path.append("taming-transformers")
10
  import functools
app_backend.py CHANGED
@@ -81,7 +81,7 @@ class ImagePromptOptimizer(nn.Module):
81
  self.make_grid = make_grid
82
  self.return_val = return_val
83
  self.quantize = quantize
84
- self.disc = load_disc(self.device)
85
  self.lpips_weight = lpips_weight
86
  self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
87
  def disc_loss_fn(self, logits):
@@ -89,7 +89,7 @@ class ImagePromptOptimizer(nn.Module):
89
  def set_latent(self, latent):
90
  self.latent = latent.detach().to(self.device)
91
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
92
- self.attn_mask = attn_mask
93
  self.iterations = iterations
94
  self.lr = lr
95
  self.lpips_weight = lpips_weight
@@ -131,32 +131,29 @@ class ImagePromptOptimizer(nn.Module):
131
  else:
132
  plt.imshow(get_pil(processed_img[0]).detach().cpu())
133
  plt.show()
134
- def attn_masking(self, grad):
135
- # print("attnmask 1")
136
- # print(f"input grad.shape = {grad.shape}")
137
- # print(f"input grad = {get_resized_tensor(grad)}")
138
  newgrad = grad
139
- if self.attn_mask is not None:
140
- # print("masking mult")
141
- newgrad = grad * (self.attn_mask)
142
- # print("output grad, ", get_resized_tensor(newgrad))
143
- # print("end atn 1")
144
  return newgrad
145
- def attn_masking2(self, grad):
146
- # print("attnmask 2")
147
- # print(f"input grad.shape = {grad.shape}")
148
- # print(f"input grad = {get_resized_tensor(grad)}")
149
  newgrad = grad
150
- if self.attn_mask is not None:
151
- # print("masking mult")
152
- newgrad = grad * ((self.attn_mask - 1) * -1)
153
- # print("output grad, ", get_resized_tensor(newgrad))
154
- # print("end atn 2")
155
  return newgrad
 
 
 
 
 
 
 
 
 
 
156
 
157
  def optimize(self, latent, pos_prompts, neg_prompts):
158
  self.set_latent(latent)
159
- # self.make_grid=True
160
  transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
161
  original_img = loop_post_process(transformed_img)
162
  vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
@@ -167,27 +164,14 @@ class ImagePromptOptimizer(nn.Module):
167
  for i in tqdm(range(self.iterations)):
168
  optim.zero_grad()
169
  transformed_img = self(vector)
170
- processed_img = loop_post_process(transformed_img) #* self.attn_mask
171
- processed_img.retain_grad()
172
- lpips_input = processed_img.clone()
173
- lpips_input.register_hook(self.attn_masking2)
174
- lpips_input.retain_grad()
175
- clip_clone = processed_img.clone()
176
- clip_clone.register_hook(self.attn_masking)
177
- clip_clone.retain_grad()
178
  with torch.autocast("cuda"):
179
- clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
180
  print("CLIP loss", clip_loss)
181
  perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
182
  print("LPIPS loss: ", perceptual_loss)
183
- with torch.no_grad():
184
- disc_logits = self.disc(transformed_img)
185
- disc_loss = self.disc_loss_fn(disc_logits)
186
- print(f"disc_loss = {disc_loss}")
187
- disc_loss2 = self.disc(processed_img)
188
  if log:
189
  wandb.log({"Perceptual Loss": perceptual_loss})
190
- wandb.log({"Discriminator Loss": disc_loss})
191
  wandb.log({"CLIP Loss": clip_loss})
192
  clip_loss.backward(retain_graph=True)
193
  perceptual_loss.backward(retain_graph=True)
@@ -207,7 +191,7 @@ class ImagePromptOptimizer(nn.Module):
207
  processed_img = loop_post_process(transformed_img) #* self.attn_mask
208
  processed_img.retain_grad()
209
  lpips_input = processed_img.clone()
210
- lpips_input.register_hook(self.attn_masking2)
211
  lpips_input.retain_grad()
212
  with torch.autocast("cuda"):
213
  perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
@@ -216,28 +200,10 @@ class ImagePromptOptimizer(nn.Module):
216
  disc_loss = self.disc_loss_fn(disc_logits)
217
  print(f"disc_loss = {disc_loss}")
218
  disc_loss2 = self.disc(processed_img)
219
- # print(f"disc_loss2 = {disc_loss2}")
220
  if log:
221
  wandb.log({"Perceptual Loss": perceptual_loss})
222
  print("LPIPS loss: ", perceptual_loss)
223
  perceptual_loss.backward(retain_graph=True)
224
  optim.step()
225
  yield vector
226
- # torch.save(vector, "nose_vector.pt")
227
- # print("")
228
- # print("DISC STEPS")
229
- # print("*************")
230
- # for i in range(self.reconstruction_steps):
231
- # optim.zero_grad()
232
- # transformed_img = self(vector)
233
- # processed_img = loop_post_process(transformed_img) #* self.attn_mask
234
- # disc_logits = self.disc(transformed_img)
235
- # disc_loss = self.disc_loss_fn(disc_logits)
236
- # print(f"disc_loss = {disc_loss}")
237
- # if log:
238
- # wandb.log({"Disc Loss": disc_loss})
239
- # print("LPIPS loss: ", perceptual_loss)
240
- # disc_loss.backward(retain_graph=True)
241
- # optim.step()
242
- # yield vector
243
  yield vector if self.return_val == "vector" else self.latent + vector
 
81
  self.make_grid = make_grid
82
  self.return_val = return_val
83
  self.quantize = quantize
84
+ # self.disc = load_disc(self.device)
85
  self.lpips_weight = lpips_weight
86
  self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
87
  def disc_loss_fn(self, logits):
 
89
  def set_latent(self, latent):
90
  self.latent = latent.detach().to(self.device)
91
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
92
+ self._attn_mask = attn_mask
93
  self.iterations = iterations
94
  self.lr = lr
95
  self.lpips_weight = lpips_weight
 
131
  else:
132
  plt.imshow(get_pil(processed_img[0]).detach().cpu())
133
  plt.show()
134
+ def _attn_mask(self, grad):
 
 
 
135
  newgrad = grad
136
+ if self._attn_mask is not None:
137
+ newgrad = grad * (self._attn_mask)
 
 
 
138
  return newgrad
139
+ def _attn_mask_inverse(self, grad):
 
 
 
140
  newgrad = grad
141
+ if self._attn_mask is not None:
142
+ newgrad = grad * ((self._attn_mask - 1) * -1)
 
 
 
143
  return newgrad
144
+ def _get_next_inputs(self, transformed_img):
145
+ processed_img = loop_post_process(transformed_img) #* self.attn_mask
146
+ processed_img.retain_grad()
147
+ lpips_input = processed_img.clone()
148
+ lpips_input.register_hook(self._attn_mask_inverse)
149
+ lpips_input.retain_grad()
150
+ clip_input = processed_img.clone()
151
+ clip_input.register_hook(self._attn_mask)
152
+ clip_input.retain_grad()
153
+ return processed_img, lpips_input, clip_input
154
 
155
  def optimize(self, latent, pos_prompts, neg_prompts):
156
  self.set_latent(latent)
 
157
  transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
158
  original_img = loop_post_process(transformed_img)
159
  vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
 
164
  for i in tqdm(range(self.iterations)):
165
  optim.zero_grad()
166
  transformed_img = self(vector)
167
+ processed_img, lpips_input, clip_input = self._get_next_inputs(transformed_img)
 
 
 
 
 
 
 
168
  with torch.autocast("cuda"):
169
+ clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_input)
170
  print("CLIP loss", clip_loss)
171
  perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
172
  print("LPIPS loss: ", perceptual_loss)
 
 
 
 
 
173
  if log:
174
  wandb.log({"Perceptual Loss": perceptual_loss})
 
175
  wandb.log({"CLIP Loss": clip_loss})
176
  clip_loss.backward(retain_graph=True)
177
  perceptual_loss.backward(retain_graph=True)
 
191
  processed_img = loop_post_process(transformed_img) #* self.attn_mask
192
  processed_img.retain_grad()
193
  lpips_input = processed_img.clone()
194
+ lpips_input.register_hook(self._attn_mask_inverse)
195
  lpips_input.retain_grad()
196
  with torch.autocast("cuda"):
197
  perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
 
200
  disc_loss = self.disc_loss_fn(disc_logits)
201
  print(f"disc_loss = {disc_loss}")
202
  disc_loss2 = self.disc(processed_img)
 
203
  if log:
204
  wandb.log({"Perceptual Loss": perceptual_loss})
205
  print("LPIPS loss: ", perceptual_loss)
206
  perceptual_loss.backward(retain_graph=True)
207
  optim.step()
208
  yield vector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  yield vector if self.return_val == "vector" else self.latent + vector
configs.py DELETED
@@ -1,7 +0,0 @@
1
- import gradio as gr
2
- def set_small_local():
3
- return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
4
- def set_major_local():
5
- return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
6
- def set_major_global():
7
- return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
 
 
 
 
 
 
 
 
edit.py CHANGED
@@ -17,13 +17,13 @@ from utils import get_device
17
 
18
 
19
  def get_embedding(model, path=None, img=None, device="cpu"):
20
- assert path is None or img is None, "Input either path or tensor"
21
  if img is not None:
22
  raise NotImplementedError
23
  x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
24
  x_processed = preprocess_vqgan(x)
25
- x_latent, _, [_, _, indices] = model.encode(x_processed)
26
- return x_latent
27
 
28
 
29
  def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
@@ -47,23 +47,12 @@ def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, devi
47
 
48
  if __name__ == "__main__":
49
  device = get_device()
50
- # conf_path = "logs/2021-04-23T18-11-19_celebahq_transformer/configs/2021-04-23T18-11-19-project.yaml"
51
  ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
52
- # ckpt_path = "./faceshq/faceshq.pt"
53
  conf_path = "./unwrapped.yaml"
54
- # conf_path = "./faceshq/faceshq.yaml"
55
  config = load_config(conf_path, display=False)
56
  model = taming.models.vqgan.VQModel(**config.model.params)
57
  sd = torch.load("./vqgan_only.pt", map_location="mps")
58
  model.load_state_dict(sd, strict=True)
59
  model.to(device)
60
  blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
61
- plt.show()
62
-
63
- demo = gr.Interface(
64
- get_image,
65
- inputs=gr.inputs.Image(label="UploadZz a black and white face", type="filepath"),
66
- outputs="image",
67
- title="Upload a black and white face and get a colorized image!",
68
- )
69
-
 
17
 
18
 
19
  def get_embedding(model, path=None, img=None, device="cpu"):
20
+ assert path or img, "Input either path or tensor"
21
  if img is not None:
22
  raise NotImplementedError
23
  x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
24
  x_processed = preprocess_vqgan(x)
25
+ z, _, [_, _, indices] = model.encode(x_processed)
26
+ return z
27
 
28
 
29
  def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
 
47
 
48
  if __name__ == "__main__":
49
  device = get_device()
 
50
  ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
 
51
  conf_path = "./unwrapped.yaml"
 
52
  config = load_config(conf_path, display=False)
53
  model = taming.models.vqgan.VQModel(**config.model.params)
54
  sd = torch.load("./vqgan_only.pt", map_location="mps")
55
  model.load_state_dict(sd, strict=True)
56
  model.to(device)
57
  blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
58
+ plt.show()
 
 
 
 
 
 
 
 
img_processing.py CHANGED
@@ -32,7 +32,7 @@ def preprocess(img, target_image_size=256, map_dalle=False):
32
  return img
33
 
34
  def preprocess_vqgan(x):
35
- x = 2.*x - 1.
36
  return x
37
 
38
  def custom_to_pil(x, process=True, mode="RGB"):
 
32
  return img
33
 
34
  def preprocess_vqgan(x):
35
+ x = 2. * x - 1.
36
  return x
37
 
38
  def custom_to_pil(x, process=True, mode="RGB"):
loaders.py CHANGED
@@ -1,5 +1,4 @@
1
  import importlib
2
-
3
  import numpy as np
4
  import taming
5
  import torch
@@ -7,9 +6,8 @@ import yaml
7
  from omegaconf import OmegaConf
8
  from PIL import Image
9
  from taming.models.vqgan import VQModel
10
-
11
  from utils import get_device
12
- # import discriminator
13
 
14
  def load_config(config_path, display=False):
15
  config = OmegaConf.load(config_path)
@@ -17,37 +15,23 @@ def load_config(config_path, display=False):
17
  print(yaml.dump(OmegaConf.to_container(config)))
18
  return config
19
 
20
- # def load_disc(device):
21
- # dconf = load_config("disc_config.yaml")
22
- # sd = torch.load("disc.pt", map_location=device)
23
- # # print(sd.keys())
24
- # model = discriminator.NLayerDiscriminator()
25
- # model.load_state_dict(sd, strict=True)
26
- # model.to(device)
27
- # return model
28
- # print(dconf.keys())
29
-
30
  def load_default(device):
31
- # device = get_device()
32
  ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
33
  conf_path = "./unwrapped.yaml"
34
  config = load_config(conf_path, display=False)
35
  model = taming.models.vqgan.VQModel(**config.model.params)
36
- sd = torch.load("./vqgan_only.pt", map_location=device)
37
  model.load_state_dict(sd, strict=True)
38
  model.to(device)
39
  return model
40
 
41
 
42
  def load_vqgan(config, ckpt_path=None, is_gumbel=False):
43
- if is_gumbel:
44
- model = GumbelVQ(**config.model.params)
45
- else:
46
  model = VQModel(**config.model.params)
47
- if ckpt_path is not None:
48
- sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
49
- missing, unexpected = model.load_state_dict(sd, strict=False)
50
- return model.eval()
51
 
52
  def load_ffhq():
53
  conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
 
1
  import importlib
 
2
  import numpy as np
3
  import taming
4
  import torch
 
6
  from omegaconf import OmegaConf
7
  from PIL import Image
8
  from taming.models.vqgan import VQModel
 
9
  from utils import get_device
10
+
11
 
12
  def load_config(config_path, display=False):
13
  config = OmegaConf.load(config_path)
 
15
  print(yaml.dump(OmegaConf.to_container(config)))
16
  return config
17
 
 
 
 
 
 
 
 
 
 
 
18
  def load_default(device):
 
19
  ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
20
  conf_path = "./unwrapped.yaml"
21
  config = load_config(conf_path, display=False)
22
  model = taming.models.vqgan.VQModel(**config.model.params)
23
+ sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
24
  model.load_state_dict(sd, strict=True)
25
  model.to(device)
26
  return model
27
 
28
 
29
  def load_vqgan(config, ckpt_path=None, is_gumbel=False):
 
 
 
30
  model = VQModel(**config.model.params)
31
+ if ckpt_path is not None:
32
+ sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
33
+ missing, unexpected = model.load_state_dict(sd, strict=False)
34
+ return model.eval()
35
 
36
  def load_ffhq():
37
  conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
utils.py CHANGED
@@ -7,10 +7,10 @@ import torch.nn.functional as F
7
  from skimage.color import lab2rgb, rgb2lab
8
  from torch import nn
9
 
10
-
11
  def freeze_module(module):
12
  for param in module.parameters():
13
  param.requires_grad = False
 
14
  def get_device():
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  if torch.backends.mps.is_available() and torch.backends.mps.is_built():
 
7
  from skimage.color import lab2rgb, rgb2lab
8
  from torch import nn
9
 
 
10
  def freeze_module(module):
11
  for param in module.parameters():
12
  param.requires_grad = False
13
+
14
  def get_device():
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  if torch.backends.mps.is_available() and torch.backends.mps.is_built():
vqgan_latent_ops.py DELETED
@@ -1,14 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from gradient_flow_ops import ReplaceGrad
6
-
7
- replace_grad = ReplaceGrad.apply
8
-
9
- def vector_quantize(x, codebook):
10
-
11
- d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
12
- indices = d.argmin(-1)
13
- x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
14
- return replace_grad(x_q, x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vqgan_only.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8e39472bae4489764c0ffc70ba84ec7815f245781020ce55cc2e7adc60e580e4
3
- size 288690579