Erwann Millon commited on
Commit
e0f92a0
1 Parent(s): ec39fe8

refactoring and cleanup

Browse files
Files changed (12) hide show
  1. ImageState.py +118 -74
  2. animation.py +8 -6
  3. app.py +7 -7
  4. backend.py +104 -90
  5. edit.py +17 -12
  6. img_processing.py +40 -36
  7. loaders.py +20 -20
  8. masking.py +21 -23
  9. presets.py +30 -4
  10. prompts.py +31 -7
  11. unwrapped.yaml +0 -37
  12. utils.py +3 -1
ImageState.py CHANGED
@@ -1,183 +1,227 @@
1
- # from align import align_from_path
2
  import gc
 
3
  import imageio
4
  import glob
5
  import uuid
6
  from animation import clear_img_dir
7
- from backend import ImagePromptOptimizer, log
8
- import importlib
9
- import gradio as gr
10
- import matplotlib.pyplot as plt
11
  import torch
12
  import torchvision
13
  import wandb
14
- from icecream import ic
15
- from torch import nn
16
- from torchvision.transforms.functional import resize
17
- from tqdm import tqdm
18
- from transformers import CLIPModel, CLIPProcessor
19
- import lpips
20
- from backend import get_resized_tensor
21
  from edit import blend_paths
22
- from img_processing import *
23
  from img_processing import custom_to_pil
24
- from loaders import load_default
 
25
  num = 0
26
 
27
- class PromptTransformHistory():
 
28
  def __init__(self, iterations) -> None:
29
  self.iterations = iterations
30
  self.transforms = []
31
 
 
32
  class ImageState:
33
- def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
34
  self.vqgan = vqgan
35
  self.device = vqgan.device
36
  self.blend_latent = None
37
  self.quant = True
38
  self.path1 = None
39
  self.path2 = None
 
 
 
40
  self.transform_history = []
41
  self.attn_mask = None
42
  self.prompt_optim = prompt_optimizer
43
  self._load_vectors()
44
  self.init_transforms()
 
45
  def _load_vectors(self):
46
- self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device)
47
- self.blue_eyes_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
48
- self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
 
 
 
 
 
 
 
49
  def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
50
  images = []
51
- folder = self.state_id
52
  paths = glob.glob(folder + "/*")
53
  frame_duration = total_duration / len(paths)
54
  print(len(paths), "frame dur", frame_duration)
55
  durations = [frame_duration] * len(paths)
56
  if extend_frames:
57
- durations [0] = 1.5
58
- durations [-1] = 3
59
  for file_name in os.listdir(folder):
60
- if file_name.endswith('.png'):
61
  file_path = os.path.join(folder, file_name)
62
  images.append(imageio.imread(file_path))
63
  imageio.mimsave(gif_name, images, duration=durations)
64
  return gif_name
 
65
  def init_transforms(self):
66
  self.blue_eyes = torch.zeros_like(self.lip_vector)
67
  self.lip_size = torch.zeros_like(self.lip_vector)
68
  self.asian_transform = torch.zeros_like(self.lip_vector)
69
  self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
 
70
  def clear_transforms(self):
71
- global num
72
  self.init_transforms()
73
  clear_img_dir("./img_history")
74
- num = 0
75
  return self._render_all_transformations()
76
- def _apply_vector(self, src, vector):
77
- new_latent = torch.lerp(src, src + vector, 1)
78
- return new_latent
79
- def _decode_latent_to_pil(self, latent):
80
  current_im = self.vqgan.decode(latent.to(self.device))[0]
81
  return custom_to_pil(current_im)
 
82
  def _get_mask(self, img, mask=None):
83
  if img and "mask" in img and img["mask"] is not None:
84
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
85
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
86
  print("mask set successfully")
87
- print(type(attn_mask))
88
- print(attn_mask.shape)
89
  else:
90
  attn_mask = mask
91
  return attn_mask
 
92
  def set_mask(self, img):
93
  self.attn_mask = self._get_mask(img)
94
  x = self.attn_mask.clone()
95
  x = x.detach().cpu()
96
- x = torch.clamp(x, -1., 1.)
97
- x = (x + 1.)/2.
98
  x = x.numpy()
99
  x = (255 * x).astype(np.uint8)
100
  x = Image.fromarray(x, "L")
101
  return x
102
- @torch.no_grad()
 
103
  def _render_all_transformations(self, return_twice=True):
104
  global num
105
- if self.state_id is None:
106
- self.state_id = "./img_history/" + str(uuid.uuid4())
107
- print("redner all", self.state_id)
108
- current_vector_transforms = (self.blue_eyes, self.lip_size, self.asian_transform, sum(self.current_prompt_transforms))
 
 
109
  new_latent = self.blend_latent + sum(current_vector_transforms)
110
  if self.quant:
111
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
112
- image = self._decode_latent_to_pil(new_latent)
113
- img_dir = self.state_id
114
- if not os.path.exists("img_history"):
115
- os.mkdir("./img_history")
116
- if not os.path.exists(img_dir):
117
- os.mkdir(img_dir)
118
- image.save(f"{img_dir}/img_{num:06}.png")
119
  num += 1
120
  return (image, image) if return_twice else image
 
121
  def apply_rb_vector(self, weight):
122
  self.blue_eyes = weight * self.blue_eyes_vector
123
  return self._render_all_transformations()
 
124
  def apply_lip_vector(self, weight):
125
  self.lip_size = weight * self.lip_vector
126
  return self._render_all_transformations()
 
127
  def update_quant(self, val):
128
  self.quant = val
129
  return self._render_all_transformations()
 
130
  def apply_asian_vector(self, weight):
131
  self.asian_transform = weight * self.asian_vector
132
  return self._render_all_transformations()
 
133
  def update_images(self, path1, path2, blend_weight):
134
  if path1 is None and path2 is None:
135
  return None
136
- if path1 is None: path1 = path2
137
- if path2 is None: path2 = path1
 
 
 
 
 
138
  self.path1, self.path2 = path1, path2
139
- if self.state_id:
140
- clear_img_dir(self.state_id)
141
  return self.blend(blend_weight)
142
- @torch.no_grad()
 
143
  def blend(self, weight):
144
- _, latent = blend_paths(self.vqgan, self.path1, self.path2, weight=weight, show=False, device=self.device)
 
 
 
 
 
 
 
145
  self.blend_latent = latent
146
  return self._render_all_transformations()
147
- @torch.no_grad()
 
148
  def rewind(self, index):
149
  if not self.transform_history:
150
- print("no history")
151
  return self._render_all_transformations()
152
  prompt_transform = self.transform_history[-1]
153
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
154
  print(latent_index)
155
- self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index].to(self.device)
 
 
156
  return self._render_all_transformations()
157
- def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
158
- if self.state_id is None:
159
- self.state_id = "./img_history/" + str(uuid.uuid4())
160
- transform_log = PromptTransformHistory(iterations + reconstruction_steps)
161
- transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
162
- self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
 
 
 
 
 
 
 
 
 
 
 
 
163
  if log:
164
- wandb.init(reinit=True, project="face-editor")
165
- wandb.config.update({"Positive Prompts": positive_prompts})
166
- wandb.config.update({"Negative Prompts": negative_prompts})
167
- wandb.config.update(dict(
168
- lr=lr,
169
- iterations=iterations,
170
- lpips_weight=lpips_weight
171
- ))
 
 
172
  positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
173
  negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
174
- self.prompt_optim.set_params(lr, iterations, lpips_weight, attn_mask=self.attn_mask, reconstruction_steps=reconstruction_steps)
175
- for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
176
- positive_prompts,
177
- negative_prompts)):
 
 
 
 
 
 
 
 
 
178
  transform_log.transforms.append(transform.detach().cpu())
179
  self.current_prompt_transforms[-1] = transform
180
- with torch.no_grad():
181
  image = self._render_all_transformations(return_twice=False)
182
  if log:
183
  wandb.log({"image": wandb.Image(image)})
@@ -187,4 +231,4 @@ class ImageState:
187
  self.attn_mask = None
188
  self.transform_history.append(transform_log)
189
  gc.collect()
190
- torch.cuda.empty_cache()
 
1
+ import numpy as np
2
  import gc
3
+ import os
4
  import imageio
5
  import glob
6
  import uuid
7
  from animation import clear_img_dir
8
+ from backend import ImagePromptEditor, log
 
 
 
9
  import torch
10
  import torchvision
11
  import wandb
 
 
 
 
 
 
 
12
  from edit import blend_paths
 
13
  from img_processing import custom_to_pil
14
+ from PIL import Image
15
+
16
  num = 0
17
 
18
+
19
+ class PromptTransformHistory:
20
  def __init__(self, iterations) -> None:
21
  self.iterations = iterations
22
  self.transforms = []
23
 
24
+
25
  class ImageState:
26
+ def __init__(self, vqgan, prompt_optimizer: ImagePromptEditor) -> None:
27
  self.vqgan = vqgan
28
  self.device = vqgan.device
29
  self.blend_latent = None
30
  self.quant = True
31
  self.path1 = None
32
  self.path2 = None
33
+ self.img_dir = "./img_history"
34
+ if not os.path.exists(self.img_dir):
35
+ os.mkdir(self.img_dir)
36
  self.transform_history = []
37
  self.attn_mask = None
38
  self.prompt_optim = prompt_optimizer
39
  self._load_vectors()
40
  self.init_transforms()
41
+
42
  def _load_vectors(self):
43
+ self.lip_vector = torch.load(
44
+ "./latent_vectors/lipvector.pt", map_location=self.device
45
+ )
46
+ self.blue_eyes_vector = torch.load(
47
+ "./latent_vectors/2blue_eyes.pt", map_location=self.device
48
+ )
49
+ self.asian_vector = torch.load(
50
+ "./latent_vectors/asian10.pt", map_location=self.device
51
+ )
52
+
53
  def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
54
  images = []
55
+ folder = self.img_dir
56
  paths = glob.glob(folder + "/*")
57
  frame_duration = total_duration / len(paths)
58
  print(len(paths), "frame dur", frame_duration)
59
  durations = [frame_duration] * len(paths)
60
  if extend_frames:
61
+ durations[0] = 1.5
62
+ durations[-1] = 3
63
  for file_name in os.listdir(folder):
64
+ if file_name.endswith(".png"):
65
  file_path = os.path.join(folder, file_name)
66
  images.append(imageio.imread(file_path))
67
  imageio.mimsave(gif_name, images, duration=durations)
68
  return gif_name
69
+
70
  def init_transforms(self):
71
  self.blue_eyes = torch.zeros_like(self.lip_vector)
72
  self.lip_size = torch.zeros_like(self.lip_vector)
73
  self.asian_transform = torch.zeros_like(self.lip_vector)
74
  self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
75
+
76
  def clear_transforms(self):
 
77
  self.init_transforms()
78
  clear_img_dir("./img_history")
 
79
  return self._render_all_transformations()
80
+
81
+ def _latent_to_pil(self, latent):
 
 
82
  current_im = self.vqgan.decode(latent.to(self.device))[0]
83
  return custom_to_pil(current_im)
84
+
85
  def _get_mask(self, img, mask=None):
86
  if img and "mask" in img and img["mask"] is not None:
87
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
88
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
89
  print("mask set successfully")
 
 
90
  else:
91
  attn_mask = mask
92
  return attn_mask
93
+
94
  def set_mask(self, img):
95
  self.attn_mask = self._get_mask(img)
96
  x = self.attn_mask.clone()
97
  x = x.detach().cpu()
98
+ x = torch.clamp(x, -1.0, 1.0)
99
+ x = (x + 1.0) / 2.0
100
  x = x.numpy()
101
  x = (255 * x).astype(np.uint8)
102
  x = Image.fromarray(x, "L")
103
  return x
104
+
105
+ @torch.inference_mode()
106
  def _render_all_transformations(self, return_twice=True):
107
  global num
108
+ current_vector_transforms = (
109
+ self.blue_eyes,
110
+ self.lip_size,
111
+ self.asian_transform,
112
+ sum(self.current_prompt_transforms),
113
+ )
114
  new_latent = self.blend_latent + sum(current_vector_transforms)
115
  if self.quant:
116
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
117
+ image = self._latent_to_pil(new_latent)
118
+ image.save(f"{self.img_dir}/img_{num:06}.png")
 
 
 
 
 
119
  num += 1
120
  return (image, image) if return_twice else image
121
+
122
  def apply_rb_vector(self, weight):
123
  self.blue_eyes = weight * self.blue_eyes_vector
124
  return self._render_all_transformations()
125
+
126
  def apply_lip_vector(self, weight):
127
  self.lip_size = weight * self.lip_vector
128
  return self._render_all_transformations()
129
+
130
  def update_quant(self, val):
131
  self.quant = val
132
  return self._render_all_transformations()
133
+
134
  def apply_asian_vector(self, weight):
135
  self.asian_transform = weight * self.asian_vector
136
  return self._render_all_transformations()
137
+
138
  def update_images(self, path1, path2, blend_weight):
139
  if path1 is None and path2 is None:
140
  return None
141
+
142
+ # Duplicate paths if one is empty
143
+ if path1 is None:
144
+ path1 = path2
145
+ if path2 is None:
146
+ path2 = path1
147
+
148
  self.path1, self.path2 = path1, path2
149
+ if self.img_dir:
150
+ clear_img_dir(self.img_dir)
151
  return self.blend(blend_weight)
152
+
153
+ @torch.inference_mode()
154
  def blend(self, weight):
155
+ _, latent = blend_paths(
156
+ self.vqgan,
157
+ self.path1,
158
+ self.path2,
159
+ weight=weight,
160
+ show=False,
161
+ device=self.device,
162
+ )
163
  self.blend_latent = latent
164
  return self._render_all_transformations()
165
+
166
+ @torch.inference_mode()
167
  def rewind(self, index):
168
  if not self.transform_history:
169
+ print("No history")
170
  return self._render_all_transformations()
171
  prompt_transform = self.transform_history[-1]
172
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
173
  print(latent_index)
174
+ self.current_prompt_transforms[-1] = prompt_transform.transforms[
175
+ latent_index
176
+ ].to(self.device)
177
  return self._render_all_transformations()
178
+
179
+ def _init_logging(lr, iterations, lpips_weight, positive_prompts, negative_prompts):
180
+ wandb.init(reinit=True, project="face-editor")
181
+ wandb.config.update({"Positive Prompts": positive_prompts})
182
+ wandb.config.update({"Negative Prompts": negative_prompts})
183
+ wandb.config.update(
184
+ dict(lr=lr, iterations=iterations, lpips_weight=lpips_weight)
185
+ )
186
+
187
+ def apply_prompts(
188
+ self,
189
+ positive_prompts,
190
+ negative_prompts,
191
+ lr,
192
+ iterations,
193
+ lpips_weight,
194
+ reconstruction_steps,
195
+ ):
196
  if log:
197
+ self._init_logging(
198
+ lr, iterations, lpips_weight, positive_prompts, negative_prompts
199
+ )
200
+ transform_log = PromptTransformHistory(iterations + reconstruction_steps)
201
+ transform_log.transforms.append(
202
+ torch.zeros_like(self.blend_latent, requires_grad=False)
203
+ )
204
+ self.current_prompt_transforms.append(
205
+ torch.zeros_like(self.blend_latent, requires_grad=False)
206
+ )
207
  positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
208
  negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
209
+ self.prompt_optim.set_params(
210
+ lr,
211
+ iterations,
212
+ lpips_weight,
213
+ attn_mask=self.attn_mask,
214
+ reconstruction_steps=reconstruction_steps,
215
+ )
216
+
217
+ for i, transform in enumerate(
218
+ self.prompt_optim.optimize(
219
+ self.blend_latent, positive_prompts, negative_prompts
220
+ )
221
+ ):
222
  transform_log.transforms.append(transform.detach().cpu())
223
  self.current_prompt_transforms[-1] = transform
224
+ with torch.inference_mode():
225
  image = self._render_all_transformations(return_twice=False)
226
  if log:
227
  wandb.log({"image": wandb.Image(image)})
 
231
  self.attn_mask = None
232
  self.transform_history.append(transform_log)
233
  gc.collect()
234
+ torch.cuda.empty_cache()
animation.py CHANGED
@@ -8,21 +8,23 @@ def clear_img_dir(img_dir):
8
  os.mkdir("img_history")
9
  if not os.path.exists(img_dir):
10
  os.mkdir(img_dir)
11
- for filename in glob.glob(img_dir+"/*"):
12
  os.remove(filename)
13
 
14
 
15
- def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
 
 
16
  images = []
17
  paths = glob.glob(folder + "/*")
18
  frame_duration = total_duration / len(paths)
19
  print(len(paths), "frame dur", frame_duration)
20
  durations = [frame_duration] * len(paths)
21
  if extend_frames:
22
- durations [0] = 1.5
23
- durations [-1] = 3
24
  for file_name in os.listdir(folder):
25
- if file_name.endswith('.png'):
26
  file_path = os.path.join(folder, file_name)
27
  images.append(imageio.imread(file_path))
28
  imageio.mimsave(gif_name, images, duration=durations)
@@ -30,4 +32,4 @@ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="
30
 
31
 
32
  if __name__ == "__main__":
33
- create_gif()
 
8
  os.mkdir("img_history")
9
  if not os.path.exists(img_dir):
10
  os.mkdir(img_dir)
11
+ for filename in glob.glob(img_dir + "/*"):
12
  os.remove(filename)
13
 
14
 
15
+ def create_gif(
16
+ total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"
17
+ ):
18
  images = []
19
  paths = glob.glob(folder + "/*")
20
  frame_duration = total_duration / len(paths)
21
  print(len(paths), "frame dur", frame_duration)
22
  durations = [frame_duration] * len(paths)
23
  if extend_frames:
24
+ durations[0] = 1.5
25
+ durations[-1] = 3
26
  for file_name in os.listdir(folder):
27
+ if file_name.endswith(".png"):
28
  file_path = os.path.join(folder, file_name)
29
  images.append(imageio.imread(file_path))
30
  imageio.mimsave(gif_name, images, duration=durations)
 
32
 
33
 
34
  if __name__ == "__main__":
35
+ create_gif()
app.py CHANGED
@@ -14,7 +14,7 @@ from transformers import CLIPModel, CLIPProcessor
14
  from lpips import LPIPS
15
 
16
  import edit
17
- from backend import ImagePromptOptimizer, ProcessorGradientFlow
18
  from ImageState import ImageState
19
  from loaders import load_default
20
  # from animation import create_gif
@@ -29,14 +29,14 @@ processor = ProcessorGradientFlow(device=device)
29
  # clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
30
  lpips_fn = LPIPS(net='vgg').to(device)
31
  clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
32
- promptoptim = ImagePromptOptimizer(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
 
33
  def set_img_from_example(state, img):
34
  return state.update_images(img, img, 0)
35
  def get_cleared_mask():
36
  return gr.Image.update(value=None)
37
- # mask.clear()
38
-
39
  class StateWrapper:
 
40
  def create_gif(state, *args, **kwargs):
41
  return state, state[0].create_gif(*args, **kwargs)
42
  def apply_asian_vector(state, *args, **kwargs):
@@ -46,7 +46,6 @@ class StateWrapper:
46
  def apply_lip_vector(state, *args, **kwargs):
47
  return state, *state[0].apply_lip_vector(*args, **kwargs)
48
  def apply_prompts(state, *args, **kwargs):
49
- print(state[1])
50
  for image in state[0].apply_prompts(*args, **kwargs):
51
  yield state, *image
52
  def apply_rb_vector(state, *args, **kwargs):
@@ -69,9 +68,10 @@ class StateWrapper:
69
  return state, *state[0].update_images(*args, **kwargs)
70
  def update_requant(state, *args, **kwargs):
71
  return state, *state[0].update_requant(*args, **kwargs)
 
 
72
  with gr.Blocks(css="styles.css") as demo:
73
- # id = gr.State(str(uuid.uuid4()))
74
- state = gr.State([ImageState(vqgan, promptoptim), str(uuid.uuid4())])
75
  with gr.Row():
76
  with gr.Column(scale=1):
77
  with gr.Row():
 
14
  from lpips import LPIPS
15
 
16
  import edit
17
+ from backend import ImagePromptEditor, ProcessorGradientFlow
18
  from ImageState import ImageState
19
  from loaders import load_default
20
  # from animation import create_gif
 
29
  # clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
30
  lpips_fn = LPIPS(net='vgg').to(device)
31
  clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
32
+ promptoptim = ImagePromptEditor(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
33
+
34
  def set_img_from_example(state, img):
35
  return state.update_images(img, img, 0)
36
  def get_cleared_mask():
37
  return gr.Image.update(value=None)
 
 
38
  class StateWrapper:
39
+ """This extremely ugly code is a hacky fix to allow con"""
40
  def create_gif(state, *args, **kwargs):
41
  return state, state[0].create_gif(*args, **kwargs)
42
  def apply_asian_vector(state, *args, **kwargs):
 
46
  def apply_lip_vector(state, *args, **kwargs):
47
  return state, *state[0].apply_lip_vector(*args, **kwargs)
48
  def apply_prompts(state, *args, **kwargs):
 
49
  for image in state[0].apply_prompts(*args, **kwargs):
50
  yield state, *image
51
  def apply_rb_vector(state, *args, **kwargs):
 
68
  return state, *state[0].update_images(*args, **kwargs)
69
  def update_requant(state, *args, **kwargs):
70
  return state, *state[0].update_requant(*args, **kwargs)
71
+
72
+
73
  with gr.Blocks(css="styles.css") as demo:
74
+ state = gr.State([ImageState(vqgan, promptoptim)])
 
75
  with gr.Row():
76
  with gr.Column(scale=1):
77
  with gr.Row():
backend.py CHANGED
@@ -1,77 +1,65 @@
1
- # from functools import cache
2
- import importlib
3
-
4
- import gradio as gr
5
  import matplotlib.pyplot as plt
6
  import torch
7
  import torchvision
8
  import wandb
9
- from icecream import ic
10
  from torch import nn
11
- from torchvision.transforms.functional import resize
12
  from tqdm import tqdm
13
- from transformers import CLIPModel, CLIPProcessor
14
- import lpips
15
- from edit import blend_paths
16
- from img_processing import *
17
- from img_processing import custom_to_pil
18
- from loaders import load_default
19
- import glob
20
- import gc
21
 
22
  global log
23
- log=False
24
-
25
- # ic.disable()
26
- # ic.enable()
27
- def get_resized_tensor(x):
28
- if len(x.shape) == 2:
29
- re = x.unsqueeze(0)
30
- else: re = x
31
- re = resize(re, (10, 10))
32
- return re
33
- class ProcessorGradientFlow():
34
  """
35
  This wraps the huggingface CLIP processor to allow backprop through the image processing step.
36
- The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
37
  """
 
38
  def __init__(self, device="cuda") -> None:
39
  self.device = device
40
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
41
  self.image_mean = [0.48145466, 0.4578275, 0.40821073]
42
  self.image_std = [0.26862954, 0.26130258, 0.27577711]
43
  self.normalize = torchvision.transforms.Normalize(
44
- self.image_mean,
45
- self.image_std
46
  )
47
  self.resize = torchvision.transforms.Resize(224)
48
  self.center_crop = torchvision.transforms.CenterCrop(224)
 
49
  def preprocess_img(self, images):
50
  images = self.center_crop(images)
51
  images = self.resize(images)
52
  images = self.center_crop(images)
53
  images = self.normalize(images)
54
  return images
 
55
  def __call__(self, images=[], **kwargs):
56
  processed_inputs = self.processor(**kwargs)
57
  processed_inputs["pixel_values"] = self.preprocess_img(images)
58
- processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()}
 
 
59
  return processed_inputs
60
 
61
- class ImagePromptOptimizer(nn.Module):
62
- def __init__(self,
63
- vqgan,
64
- clip,
65
- clip_preprocessor,
66
- lpips_fn,
67
- iterations=100,
68
- lr = 0.01,
69
- save_vector=True,
70
- return_val="vector",
71
- quantize=True,
72
- make_grid=False,
73
- lpips_weight = 6.2) -> None:
74
-
 
 
 
75
  super().__init__()
76
  self.latent = None
77
  self.device = vqgan.device
@@ -86,14 +74,17 @@ class ImagePromptOptimizer(nn.Module):
86
  self.quantize = quantize
87
  self.lpips_weight = lpips_weight
88
  self.perceptual_loss = lpips_fn
 
89
  def set_latent(self, latent):
90
  self.latent = latent.detach().to(self.device)
 
91
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
92
  self._attn_mask = attn_mask
93
  self.iterations = iterations
94
  self.lr = lr
95
  self.lpips_weight = lpips_weight
96
  self.reconstruction_steps = reconstruction_steps
 
97
  def forward(self, vector):
98
  base_latent = self.latent.detach().requires_grad_()
99
  trans_latent = base_latent + vector
@@ -103,19 +94,22 @@ class ImagePromptOptimizer(nn.Module):
103
  z_q = trans_latent
104
  dec = self.vqgan.decode(z_q)
105
  return dec
 
106
  def _get_clip_similarity(self, prompts, image, weights=None):
107
  if isinstance(prompts, str):
108
  prompts = [prompts]
109
  elif not isinstance(prompts, list):
110
  raise TypeError("Provide prompts as string or list of strings")
111
- clip_inputs = self.clip_preprocessor(text=prompts,
112
- images=image, return_tensors="pt", padding=True)
 
113
  clip_outputs = self.clip(**clip_inputs)
114
  similarity_logits = clip_outputs.logits_per_image
115
  if weights:
116
  similarity_logits *= weights
117
  return similarity_logits.sum()
118
- def get_similarity_loss(self, pos_prompts, neg_prompts, image):
 
119
  pos_logits = self._get_clip_similarity(pos_prompts, image)
120
  if neg_prompts:
121
  neg_logits = self._get_clip_similarity(neg_prompts, image)
@@ -123,6 +117,7 @@ class ImagePromptOptimizer(nn.Module):
123
  neg_logits = torch.tensor([1], device=self.device)
124
  loss = -torch.log(pos_logits) + torch.log(neg_logits)
125
  return loss
 
126
  def visualize(self, processed_img):
127
  if self.make_grid:
128
  self.index += 1
@@ -131,74 +126,93 @@ class ImagePromptOptimizer(nn.Module):
131
  else:
132
  plt.imshow(get_pil(processed_img[0]).detach().cpu())
133
  plt.show()
 
134
  def _attn_mask(self, grad):
135
  newgrad = grad
136
  if self._attn_mask is not None:
137
  newgrad = grad * (self._attn_mask)
138
  return newgrad
 
139
  def _attn_mask_inverse(self, grad):
140
  newgrad = grad
141
  if self._attn_mask is not None:
142
  newgrad = grad * ((self._attn_mask - 1) * -1)
143
  return newgrad
 
144
  def _get_next_inputs(self, transformed_img):
145
- processed_img = loop_post_process(transformed_img) #* self.attn_mask
146
  processed_img.retain_grad()
 
147
  lpips_input = processed_img.clone()
148
  lpips_input.register_hook(self._attn_mask_inverse)
149
  lpips_input.retain_grad()
 
150
  clip_input = processed_img.clone()
151
  clip_input.register_hook(self._attn_mask)
152
  clip_input.retain_grad()
153
- return processed_img, lpips_input, clip_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def optimize(self, latent, pos_prompts, neg_prompts):
156
  self.set_latent(latent)
157
- transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
 
 
158
  original_img = loop_post_process(transformed_img)
159
  vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
160
  optim = torch.optim.Adam([vector], lr=self.lr)
161
- if self.make_grid:
162
- plt.figure(figsize=(35, 25))
163
- self.index = 1
164
  for i in tqdm(range(self.iterations)):
165
- optim.zero_grad()
166
- transformed_img = self(vector)
167
- processed_img, lpips_input, clip_input = self._get_next_inputs(transformed_img)
168
- with torch.autocast("cuda"):
169
- clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_input)
170
- print("CLIP loss", clip_loss)
171
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
172
- print("LPIPS loss: ", perceptual_loss)
173
- if log:
174
- wandb.log({"Perceptual Loss": perceptual_loss})
175
- wandb.log({"CLIP Loss": clip_loss})
176
- clip_loss.backward(retain_graph=True)
177
- perceptual_loss.backward(retain_graph=True)
178
- p2 = processed_img.grad
179
- print("Sum Loss", perceptual_loss + clip_loss)
180
- optim.step()
181
- # if i % self.iterations // 10 == 0:
182
- # self.visualize(transformed_img)
183
- yield vector
184
- if self.make_grid:
185
- plt.savefig(f"plot {pos_prompts[0]}.png")
186
- plt.show()
187
- print("lpips solo op")
188
  for i in range(self.reconstruction_steps):
189
- optim.zero_grad()
190
- transformed_img = self(vector)
191
- processed_img = loop_post_process(transformed_img) #* self.attn_mask
192
- processed_img.retain_grad()
193
- lpips_input = processed_img.clone()
194
- lpips_input.register_hook(self._attn_mask_inverse)
195
- lpips_input.retain_grad()
196
- with torch.autocast("cuda"):
197
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
198
- if log:
199
- wandb.log({"Perceptual Loss": perceptual_loss})
200
- print("LPIPS loss: ", perceptual_loss)
201
- perceptual_loss.backward(retain_graph=True)
202
- optim.step()
203
- yield vector
204
  yield vector if self.return_val == "vector" else self.latent + vector
 
 
 
 
 
1
  import matplotlib.pyplot as plt
2
  import torch
3
  import torchvision
4
  import wandb
 
5
  from torch import nn
 
6
  from tqdm import tqdm
7
+ from transformers import CLIPProcessor
8
+ from img_processing import get_pil, loop_post_process
9
+
 
 
 
 
 
10
 
11
  global log
12
+ log = False
13
+
14
+ class ProcessorGradientFlow:
 
 
 
 
 
 
 
 
15
  """
16
  This wraps the huggingface CLIP processor to allow backprop through the image processing step.
17
+ The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
18
  """
19
+
20
  def __init__(self, device="cuda") -> None:
21
  self.device = device
22
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
23
  self.image_mean = [0.48145466, 0.4578275, 0.40821073]
24
  self.image_std = [0.26862954, 0.26130258, 0.27577711]
25
  self.normalize = torchvision.transforms.Normalize(
26
+ self.image_mean, self.image_std
 
27
  )
28
  self.resize = torchvision.transforms.Resize(224)
29
  self.center_crop = torchvision.transforms.CenterCrop(224)
30
+
31
  def preprocess_img(self, images):
32
  images = self.center_crop(images)
33
  images = self.resize(images)
34
  images = self.center_crop(images)
35
  images = self.normalize(images)
36
  return images
37
+
38
  def __call__(self, images=[], **kwargs):
39
  processed_inputs = self.processor(**kwargs)
40
  processed_inputs["pixel_values"] = self.preprocess_img(images)
41
+ processed_inputs = {
42
+ key: value.to(self.device) for (key, value) in processed_inputs.items()
43
+ }
44
  return processed_inputs
45
 
46
+
47
+ class ImagePromptEditor(nn.Module):
48
+ def __init__(
49
+ self,
50
+ vqgan,
51
+ clip,
52
+ clip_preprocessor,
53
+ lpips_fn,
54
+ iterations=100,
55
+ lr=0.01,
56
+ save_vector=True,
57
+ return_val="vector",
58
+ quantize=True,
59
+ make_grid=False,
60
+ lpips_weight=6.2,
61
+ ) -> None:
62
+
63
  super().__init__()
64
  self.latent = None
65
  self.device = vqgan.device
 
74
  self.quantize = quantize
75
  self.lpips_weight = lpips_weight
76
  self.perceptual_loss = lpips_fn
77
+
78
  def set_latent(self, latent):
79
  self.latent = latent.detach().to(self.device)
80
+
81
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
82
  self._attn_mask = attn_mask
83
  self.iterations = iterations
84
  self.lr = lr
85
  self.lpips_weight = lpips_weight
86
  self.reconstruction_steps = reconstruction_steps
87
+
88
  def forward(self, vector):
89
  base_latent = self.latent.detach().requires_grad_()
90
  trans_latent = base_latent + vector
 
94
  z_q = trans_latent
95
  dec = self.vqgan.decode(z_q)
96
  return dec
97
+
98
  def _get_clip_similarity(self, prompts, image, weights=None):
99
  if isinstance(prompts, str):
100
  prompts = [prompts]
101
  elif not isinstance(prompts, list):
102
  raise TypeError("Provide prompts as string or list of strings")
103
+ clip_inputs = self.clip_preprocessor(
104
+ text=prompts, images=image, return_tensors="pt", padding=True
105
+ )
106
  clip_outputs = self.clip(**clip_inputs)
107
  similarity_logits = clip_outputs.logits_per_image
108
  if weights:
109
  similarity_logits *= weights
110
  return similarity_logits.sum()
111
+
112
+ def _get_CLIP_loss(self, pos_prompts, neg_prompts, image):
113
  pos_logits = self._get_clip_similarity(pos_prompts, image)
114
  if neg_prompts:
115
  neg_logits = self._get_clip_similarity(neg_prompts, image)
 
117
  neg_logits = torch.tensor([1], device=self.device)
118
  loss = -torch.log(pos_logits) + torch.log(neg_logits)
119
  return loss
120
+
121
  def visualize(self, processed_img):
122
  if self.make_grid:
123
  self.index += 1
 
126
  else:
127
  plt.imshow(get_pil(processed_img[0]).detach().cpu())
128
  plt.show()
129
+
130
  def _attn_mask(self, grad):
131
  newgrad = grad
132
  if self._attn_mask is not None:
133
  newgrad = grad * (self._attn_mask)
134
  return newgrad
135
+
136
  def _attn_mask_inverse(self, grad):
137
  newgrad = grad
138
  if self._attn_mask is not None:
139
  newgrad = grad * ((self._attn_mask - 1) * -1)
140
  return newgrad
141
+
142
  def _get_next_inputs(self, transformed_img):
143
+ processed_img = loop_post_process(transformed_img) # * self.attn_mask
144
  processed_img.retain_grad()
145
+
146
  lpips_input = processed_img.clone()
147
  lpips_input.register_hook(self._attn_mask_inverse)
148
  lpips_input.retain_grad()
149
+
150
  clip_input = processed_img.clone()
151
  clip_input.register_hook(self._attn_mask)
152
  clip_input.retain_grad()
153
+
154
+ return (processed_img, lpips_input, clip_input)
155
+
156
+ def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts):
157
+ optim.zero_grad()
158
+ transformed_img = self(vector)
159
+ processed_img, lpips_input, clip_input = self._get_next_inputs(
160
+ transformed_img
161
+ )
162
+ with torch.autocast("cuda"):
163
+ clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
164
+ print("CLIP loss", clip_loss)
165
+ perceptual_loss = (
166
+ self.perceptual_loss(lpips_input, original_img.clone())
167
+ * self.lpips_weight
168
+ )
169
+ print("LPIPS loss: ", perceptual_loss)
170
+ print("Sum Loss", perceptual_loss + clip_loss)
171
+ if log:
172
+ wandb.log({"Perceptual Loss": perceptual_loss})
173
+ wandb.log({"CLIP Loss": clip_loss})
174
+
175
+ # These gradients will be masked if attn_mask has been set
176
+ clip_loss.backward(retain_graph=True)
177
+ perceptual_loss.backward(retain_graph=True)
178
+
179
+ optim.step()
180
+ yield vector
181
+
182
+ def _optimize_LPIPS(self, vector, original_img, optim):
183
+ optim.zero_grad()
184
+ transformed_img = self(vector)
185
+ processed_img = loop_post_process(transformed_img) # * self.attn_mask
186
+ processed_img.retain_grad()
187
+
188
+ lpips_input = processed_img.clone()
189
+ lpips_input.register_hook(self._attn_mask_inverse)
190
+ lpips_input.retain_grad()
191
+ with torch.autocast("cuda"):
192
+ perceptual_loss = (
193
+ self.perceptual_loss(lpips_input, original_img.clone())
194
+ * self.lpips_weight
195
+ )
196
+ if log:
197
+ wandb.log({"Perceptual Loss": perceptual_loss})
198
+ print("LPIPS loss: ", perceptual_loss)
199
+ perceptual_loss.backward(retain_graph=True)
200
+ optim.step()
201
+ yield vector
202
 
203
  def optimize(self, latent, pos_prompts, neg_prompts):
204
  self.set_latent(latent)
205
+ transformed_img = self(
206
+ torch.zeros_like(self.latent, requires_grad=True, device=self.device)
207
+ )
208
  original_img = loop_post_process(transformed_img)
209
  vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
210
  optim = torch.optim.Adam([vector], lr=self.lr)
211
+
 
 
212
  for i in tqdm(range(self.iterations)):
213
+ yield self._optimize_CLIP_LPIPS(optim, original_img, vector, pos_prompts, neg_prompts)
214
+
215
+ print("Running LPIPS optim only")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  for i in range(self.reconstruction_steps):
217
+ yield self._optimize_LPIPS(vector, original_img, transformed_img, optim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  yield vector if self.return_val == "vector" else self.latent + vector
edit.py CHANGED
@@ -12,7 +12,7 @@ import PIL
12
  import taming
13
  import torch
14
 
15
- from loaders import load_config
16
  from utils import get_device
17
 
18
 
@@ -25,11 +25,14 @@ def get_embedding(model, path=None, img=None, device="cpu"):
25
  z, _, [_, _, indices] = model.encode(x_processed)
26
  return z
27
 
28
-
29
- def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
 
 
30
  x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
31
  y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
32
- x_latent, y_latent = get_embedding(model, path=path1, device=device), get_embedding(model, path=path2, device=device)
 
33
  z = torch.lerp(x_latent, y_latent, weight)
34
  if quantize:
35
  z = model.quantize(z)[0]
@@ -45,14 +48,16 @@ def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, devi
45
  plt.show()
46
  return custom_to_pil(decoded), z
47
 
 
48
  if __name__ == "__main__":
49
  device = get_device()
50
- ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
51
- conf_path = "./unwrapped.yaml"
52
- config = load_config(conf_path, display=False)
53
- model = taming.models.vqgan.VQModel(**config.model.params)
54
- sd = torch.load("./vqgan_only.pt", map_location="mps")
55
- model.load_state_dict(sd, strict=True)
56
  model.to(device)
57
- blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
58
- plt.show()
 
 
 
 
 
 
 
12
  import taming
13
  import torch
14
 
15
+ from loaders import load_config, load_default
16
  from utils import get_device
17
 
18
 
 
25
  z, _, [_, _, indices] = model.encode(x_processed)
26
  return z
27
 
28
+
29
+ def blend_paths(
30
+ model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"
31
+ ):
32
  x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
33
  y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
34
+ x_latent = get_embedding(model, path=path1, device=device)
35
+ y_latent = get_embedding(model, path=path2, device=device)
36
  z = torch.lerp(x_latent, y_latent, weight)
37
  if quantize:
38
  z = model.quantize(z)[0]
 
48
  plt.show()
49
  return custom_to_pil(decoded), z
50
 
51
+
52
  if __name__ == "__main__":
53
  device = get_device()
54
+ model = load_default(device)
 
 
 
 
 
55
  model.to(device)
56
+ blend_paths(
57
+ model,
58
+ "./test_data/face.jpeg",
59
+ "./test_data/face2.jpeg",
60
+ quantize=False,
61
+ weight=0.5,
62
+ )
63
+ plt.show()
img_processing.py CHANGED
@@ -1,12 +1,9 @@
1
  import io
2
- import os
3
- import sys
4
 
5
  import numpy as np
6
  import PIL
7
  import requests
8
  import torch
9
- import torch.nn.functional as F
10
  import torchvision.transforms as T
11
  import torchvision.transforms.functional as TF
12
  from PIL import Image, ImageDraw, ImageFont
@@ -20,10 +17,10 @@ def download_image(url):
20
 
21
  def preprocess(img, target_image_size=256, map_dalle=False):
22
  s = min(img.size)
23
-
24
  if s < target_image_size:
25
- raise ValueError(f'min dim for image {s} < {target_image_size}')
26
-
27
  r = target_image_size / s
28
  s = (round(r * img.size[1]), round(r * img.size[0]))
29
  img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
@@ -31,42 +28,49 @@ def preprocess(img, target_image_size=256, map_dalle=False):
31
  img = torch.unsqueeze(T.ToTensor()(img), 0)
32
  return img
33
 
 
34
  def preprocess_vqgan(x):
35
- x = 2. * x - 1.
36
- return x
 
37
 
38
  def custom_to_pil(x, process=True, mode="RGB"):
39
- x = x.detach().cpu()
40
- if process:
41
- x = torch.clamp(x, -1., 1.)
42
- x = (x + 1.)/2.
43
- x = x.permute(1,2,0).numpy()
44
- if process:
45
- x = (255*x).astype(np.uint8)
46
- x = Image.fromarray(x)
47
- if not x.mode == mode:
48
- x = x.convert(mode)
49
- return x
 
50
 
51
  def get_pil(x):
52
- x = torch.clamp(x, -1., 1.)
53
- x = (x + 1.)/2.
54
- x = x.permute(1,2,0)
55
- return x
 
56
 
57
  def loop_post_process(x):
58
- x = get_pil(x.squeeze())
59
- return x.permute(2, 0, 1).unsqueeze(0)
 
60
 
61
  def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
62
- assert input.size == x1.size == x2.size == x3.size
63
- w, h = input.size[0], input.size[1]
64
- img = Image.new("RGB", (5*w, h))
65
- img.paste(input, (0,0))
66
- img.paste(x0, (1*w,0))
67
- img.paste(x1, (2*w,0))
68
- img.paste(x2, (3*w,0))
69
- img.paste(x3, (4*w,0))
70
- for i, title in enumerate(titles):
71
- ImageDraw.Draw(img).text((i*w, 0), f'{title}', (255, 255, 255), font=font) # coordinates, text, color, font
72
- return img
 
 
 
1
  import io
 
 
2
 
3
  import numpy as np
4
  import PIL
5
  import requests
6
  import torch
 
7
  import torchvision.transforms as T
8
  import torchvision.transforms.functional as TF
9
  from PIL import Image, ImageDraw, ImageFont
 
17
 
18
  def preprocess(img, target_image_size=256, map_dalle=False):
19
  s = min(img.size)
20
+
21
  if s < target_image_size:
22
+ raise ValueError(f"min dim for image {s} < {target_image_size}")
23
+
24
  r = target_image_size / s
25
  s = (round(r * img.size[1]), round(r * img.size[0]))
26
  img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
 
28
  img = torch.unsqueeze(T.ToTensor()(img), 0)
29
  return img
30
 
31
+
32
  def preprocess_vqgan(x):
33
+ x = 2.0 * x - 1.0
34
+ return x
35
+
36
 
37
  def custom_to_pil(x, process=True, mode="RGB"):
38
+ x = x.detach().cpu()
39
+ if process:
40
+ x = torch.clamp(x, -1.0, 1.0)
41
+ x = (x + 1.0) / 2.0
42
+ x = x.permute(1, 2, 0).numpy()
43
+ if process:
44
+ x = (255 * x).astype(np.uint8)
45
+ x = Image.fromarray(x)
46
+ if not x.mode == mode:
47
+ x = x.convert(mode)
48
+ return x
49
+
50
 
51
  def get_pil(x):
52
+ x = torch.clamp(x, -1.0, 1.0)
53
+ x = (x + 1.0) / 2.0
54
+ x = x.permute(1, 2, 0)
55
+ return x
56
+
57
 
58
  def loop_post_process(x):
59
+ x = get_pil(x.squeeze())
60
+ return x.permute(2, 0, 1).unsqueeze(0)
61
+
62
 
63
  def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
64
+ assert input.size == x1.size == x2.size == x3.size
65
+ w, h = input.size[0], input.size[1]
66
+ img = Image.new("RGB", (5 * w, h))
67
+ img.paste(input, (0, 0))
68
+ img.paste(x0, (1 * w, 0))
69
+ img.paste(x1, (2 * w, 0))
70
+ img.paste(x2, (3 * w, 0))
71
+ img.paste(x3, (4 * w, 0))
72
+ for i, title in enumerate(titles):
73
+ ImageDraw.Draw(img).text(
74
+ (i * w, 0), f"{title}", (255, 255, 255), font=font
75
+ ) # coordinates, text, color, font
76
+ return img
loaders.py CHANGED
@@ -10,17 +10,17 @@ from utils import get_device
10
 
11
 
12
  def load_config(config_path, display=False):
13
- config = OmegaConf.load(config_path)
14
- if display:
15
- print(yaml.dump(OmegaConf.to_container(config)))
16
- return config
 
17
 
18
  def load_default(device):
19
- ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
20
- conf_path = "./unwrapped.yaml"
21
  config = load_config(conf_path, display=False)
22
  model = taming.models.vqgan.VQModel(**config.model.params)
23
- sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
24
  model.load_state_dict(sd, strict=True)
25
  model.to(device)
26
  del sd
@@ -34,17 +34,14 @@ def load_vqgan(config, ckpt_path=None, is_gumbel=False):
34
  missing, unexpected = model.load_state_dict(sd, strict=False)
35
  return model.eval()
36
 
37
- def load_ffhq():
38
- conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
39
- ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt"
40
- vqgan = load_model(load_config(conf), ckpt, True, True)[0]
41
 
42
  def reconstruct_with_vqgan(x, model):
43
- # could also use model(x) for reconstruction but use explicit encoding and decoding here
44
- z, _, [_, _, indices] = model.encode(x)
45
- print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
46
- xrec = model.decode(z)
47
- return xrec
 
48
  def get_obj_from_str(string, reload=False):
49
  module, cls = string.rsplit(".", 1)
50
  if reload:
@@ -52,12 +49,13 @@ def get_obj_from_str(string, reload=False):
52
  importlib.reload(module_imp)
53
  return getattr(importlib.import_module(module, package=None), cls)
54
 
55
- def instantiate_from_config(config):
56
 
57
- if not "target" in config:
 
58
  raise KeyError("Expected key `target` to instantiate.")
59
  return get_obj_from_str(config["target"])(**config.get("params", dict()))
60
 
 
61
  def load_model_from_config(config, sd, gpu=True, eval_mode=True):
62
  model = instantiate_from_config(config)
63
  if sd is not None:
@@ -78,5 +76,7 @@ def load_model(config, ckpt, gpu, eval_mode):
78
  else:
79
  pl_sd = {"state_dict": None}
80
  global_step = None
81
- model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
82
- return model, global_step
 
 
 
10
 
11
 
12
  def load_config(config_path, display=False):
13
+ config = OmegaConf.load(config_path)
14
+ if display:
15
+ print(yaml.dump(OmegaConf.to_container(config)))
16
+ return config
17
+
18
 
19
  def load_default(device):
20
+ conf_path = "./celeba_vqgan/unwrapped.yaml"
 
21
  config = load_config(conf_path, display=False)
22
  model = taming.models.vqgan.VQModel(**config.model.params)
23
+ sd = torch.load("./celeba_vqgan/vqgan_only.pt", map_location=device)
24
  model.load_state_dict(sd, strict=True)
25
  model.to(device)
26
  del sd
 
34
  missing, unexpected = model.load_state_dict(sd, strict=False)
35
  return model.eval()
36
 
 
 
 
 
37
 
38
  def reconstruct_with_vqgan(x, model):
39
+ z, _, [_, _, indices] = model.encode(x)
40
+ print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
41
+ xrec = model.decode(z)
42
+ return xrec
43
+
44
+
45
  def get_obj_from_str(string, reload=False):
46
  module, cls = string.rsplit(".", 1)
47
  if reload:
 
49
  importlib.reload(module_imp)
50
  return getattr(importlib.import_module(module, package=None), cls)
51
 
 
52
 
53
+ def instantiate_from_config(config):
54
+ if "target" not in config:
55
  raise KeyError("Expected key `target` to instantiate.")
56
  return get_obj_from_str(config["target"])(**config.get("params", dict()))
57
 
58
+
59
  def load_model_from_config(config, sd, gpu=True, eval_mode=True):
60
  model = instantiate_from_config(config)
61
  if sd is not None:
 
76
  else:
77
  pl_sd = {"state_dict": None}
78
  global_step = None
79
+ model = load_model_from_config(
80
+ config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode
81
+ )["model"]
82
+ return model, global_step
masking.py CHANGED
@@ -3,30 +3,28 @@ import sys
3
 
4
  import matplotlib.pyplot as plt
5
  import torch
 
 
 
6
 
7
- sys.path.append("taming-transformers")
8
- import functools
 
9
 
10
- import gradio as gr
11
- from transformers import CLIPModel, CLIPProcessor
12
 
13
- import edit
14
- # import importlib
15
- # importlib.reload(edit)
16
- from backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
17
- from loaders import load_default
18
 
19
- device = "cuda"
20
- vqgan = load_default(device)
21
- vqgan.eval()
22
- processor = ProcessorGradientFlow(device=device)
23
- clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
24
- clip.to(device)
25
- promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
26
- state = ImageState(vqgan, promptoptim)
27
- mask = torch.load("eyebrow_mask.pt")
28
- x = state.blend("./test_data/face.jpeg", "./test_data/face2.jpeg", 0.5)
29
- plt.imshow(x)
30
- plt.show()
31
- state.apply_prompts("a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask)
32
- print('done')
 
3
 
4
  import matplotlib.pyplot as plt
5
  import torch
6
+ from backend import ImagePromptEditor, ImageState, ProcessorGradientFlow
7
+ from loaders import load_default
8
+ from transformers import CLIPModel
9
 
10
+ if __name__ == "__main__":
11
+ sys.path.append("taming-transformers")
12
+ device = "cuda"
13
 
14
+ vqgan = load_default(device)
15
+ vqgan.eval()
16
 
17
+ processor = ProcessorGradientFlow(device=device)
18
+ clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
19
+ clip.to(device)
 
 
20
 
21
+ promptoptim = ImagePromptEditor(vqgan, clip, processor, quantize=True)
22
+ state = ImageState(vqgan, promptoptim)
23
+ mask = torch.load("eyebrow_mask.pt")
24
+ x = state.blend("./test_data/face.jpeg", "./test_data/face2.jpeg", 0.5)
25
+ plt.imshow(x)
26
+ plt.show()
27
+ state.apply_prompts(
28
+ "a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask
29
+ )
30
+ print("done")
 
 
 
 
presets.py CHANGED
@@ -1,16 +1,42 @@
1
  import gradio as gr
2
 
 
3
  def set_preset(config_str):
4
- choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"]
 
 
 
 
5
  if config_str == choices[0]:
6
  return set_small_local()
7
  elif config_str == choices[1]:
8
  return set_major_local()
9
  elif config_str == choices[2]:
10
  return set_major_global()
 
 
11
  def set_small_local():
12
- return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
 
 
 
 
 
 
 
13
  def set_major_local():
14
- return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
 
 
 
 
 
 
 
15
  def set_major_global():
16
- return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+
4
  def set_preset(config_str):
5
+ choices = [
6
+ "Small Masked Changes (e.g. add lipstick)",
7
+ "Major Masked Changes (e.g. change hair color or nose size)",
8
+ "Major Global Changes (e.g. change race / gender",
9
+ ]
10
  if config_str == choices[0]:
11
  return set_small_local()
12
  elif config_str == choices[1]:
13
  return set_major_local()
14
  elif config_str == choices[2]:
15
  return set_major_global()
16
+
17
+
18
  def set_small_local():
19
+ return (
20
+ gr.Slider.update(value=25),
21
+ gr.Slider.update(value=0.15),
22
+ gr.Slider.update(value=1),
23
+ gr.Slider.update(value=4),
24
+ )
25
+
26
+
27
  def set_major_local():
28
+ return (
29
+ gr.Slider.update(value=25),
30
+ gr.Slider.update(value=0.25),
31
+ gr.Slider.update(value=35),
32
+ gr.Slider.update(value=10),
33
+ )
34
+
35
+
36
  def set_major_global():
37
+ return (
38
+ gr.Slider.update(value=30),
39
+ gr.Slider.update(value=0.1),
40
+ gr.Slider.update(value=2),
41
+ gr.Slider.update(value=0.2),
42
+ )
prompts.py CHANGED
@@ -1,17 +1,41 @@
1
  import random
 
 
2
  class PromptSet:
3
  def __init__(self, pos, neg, config=None):
4
  self.positive = pos
5
  self.negative = neg
6
  self.config = config
 
 
7
  example_prompts = (
8
- PromptSet("a picture of a woman with light blonde hair", "a picture of a person with dark hair | a picture of a person with brown hair"),
9
- PromptSet("A picture of a woman with very thick eyebrows", "a picture of a person with very thin eyebrows | a picture of a person with no eyebrows"),
10
- PromptSet("A picture of a woman wearing bright red lipstick", "a picture of a person wearing no lipstick | a picture of a person wearing dark lipstick"),
11
- PromptSet("A picture of a beautiful chinese woman | a picture of a Japanese woman | a picture of an Asian woman", "a picture of a white woman | a picture of an Indian woman | a picture of a black woman"),
12
- PromptSet("A picture of a handsome man | a picture of a masculine man", "a picture of a woman | a picture of a feminine person"),
13
- PromptSet("A picture of a woman with a very big nose", "a picture of a person with a small nose | a picture of a person with a normal nose"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
 
 
15
  def get_random_prompts():
16
  prompt = random.choice(example_prompts)
17
- return prompt.positive, prompt.negative
 
1
  import random
2
+
3
+
4
  class PromptSet:
5
  def __init__(self, pos, neg, config=None):
6
  self.positive = pos
7
  self.negative = neg
8
  self.config = config
9
+
10
+
11
  example_prompts = (
12
+ PromptSet(
13
+ "a picture of a woman with light blonde hair",
14
+ "a picture of a person with dark hair | a picture of a person with brown hair",
15
+ ),
16
+ PromptSet(
17
+ "A picture of a woman with very thick eyebrows",
18
+ "a picture of a person with very thin eyebrows | a picture of a person with no eyebrows",
19
+ ),
20
+ PromptSet(
21
+ "A picture of a woman wearing bright red lipstick",
22
+ "a picture of a person wearing no lipstick | a picture of a person wearing dark lipstick",
23
+ ),
24
+ PromptSet(
25
+ "A picture of a beautiful chinese woman | a picture of a Japanese woman | a picture of an Asian woman",
26
+ "a picture of a white woman | a picture of an Indian woman | a picture of a black woman",
27
+ ),
28
+ PromptSet(
29
+ "A picture of a handsome man | a picture of a masculine man",
30
+ "a picture of a woman | a picture of a feminine person",
31
+ ),
32
+ PromptSet(
33
+ "A picture of a woman with a very big nose",
34
+ "a picture of a person with a small nose | a picture of a person with a normal nose",
35
+ ),
36
  )
37
+
38
+
39
  def get_random_prompts():
40
  prompt = random.choice(example_prompts)
41
+ return prompt.positive, prompt.negative
unwrapped.yaml DELETED
@@ -1,37 +0,0 @@
1
- model:
2
- target: taming.models.vqgan.VQModel
3
- params:
4
- embed_dim: 256
5
- n_embed: 1024
6
- ddconfig:
7
- double_z: false
8
- z_channels: 256
9
- resolution: 256
10
- in_channels: 3
11
- out_ch: 3
12
- ch: 128
13
- ch_mult:
14
- - 1
15
- - 1
16
- - 2
17
- - 2
18
- - 4
19
- num_res_blocks: 2
20
- attn_resolutions:
21
- - 16
22
- dropout: 0.0
23
- lossconfig:
24
- target: taming.modules.losses.vqperceptual.DummyLoss
25
- data:
26
- target: cutlit.DataModuleFromConfig
27
- params:
28
- batch_size: 24
29
- num_workers: 24
30
- train:
31
- target: taming.data.faceshq.CelebAHQTrain
32
- params:
33
- size: 256
34
- validation:
35
- target: taming.data.faceshq.CelebAHQValidation
36
- params:
37
- size: 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -7,9 +7,11 @@ import torch.nn.functional as F
7
  from skimage.color import lab2rgb, rgb2lab
8
  from torch import nn
9
 
 
10
  def freeze_module(module):
11
  for param in module.parameters():
12
- param.requires_grad = False
 
13
 
14
  def get_device():
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
7
  from skimage.color import lab2rgb, rgb2lab
8
  from torch import nn
9
 
10
+
11
  def freeze_module(module):
12
  for param in module.parameters():
13
+ param.requires_grad = False
14
+
15
 
16
  def get_device():
17
  device = "cuda" if torch.cuda.is_available() else "cpu"