Erwann Millon commited on
Commit
ec39fe8
β€’
1 Parent(s): 006354e

refactoring and cleanup

Browse files
Files changed (4) hide show
  1. ImageState.py +8 -29
  2. animation.py +3 -0
  3. app.py +0 -12
  4. backend.py +1 -1
ImageState.py CHANGED
@@ -31,7 +31,6 @@ class PromptTransformHistory():
31
 
32
  class ImageState:
33
  def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
34
- # global vqgan
35
  self.vqgan = vqgan
36
  self.device = vqgan.device
37
  self.blend_latent = None
@@ -41,14 +40,11 @@ class ImageState:
41
  self.transform_history = []
42
  self.attn_mask = None
43
  self.prompt_optim = prompt_optimizer
44
- self.state_id = None
45
- print(self.state_id)
46
  self._load_vectors()
47
  self.init_transforms()
48
  def _load_vectors(self):
49
  self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device)
50
- self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
51
- self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
52
  self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
53
  def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
54
  images = []
@@ -71,7 +67,6 @@ class ImageState:
71
  self.lip_size = torch.zeros_like(self.lip_vector)
72
  self.asian_transform = torch.zeros_like(self.lip_vector)
73
  self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
74
- self.hair_gp = torch.zeros_like(self.lip_vector)
75
  def clear_transforms(self):
76
  global num
77
  self.init_transforms()
@@ -95,25 +90,22 @@ class ImageState:
95
  attn_mask = mask
96
  return attn_mask
97
  def set_mask(self, img):
98
- attn_mask = self._get_mask(img)
99
- self.attn_mask = attn_mask
100
- # attn_mask = torch.ones_like(img, device=self.device)
101
- x = attn_mask.clone()
102
  x = x.detach().cpu()
103
  x = torch.clamp(x, -1., 1.)
104
  x = (x + 1.)/2.
105
  x = x.numpy()
106
- x = (255*x).astype(np.uint8)
107
  x = Image.fromarray(x, "L")
108
  return x
109
  @torch.no_grad()
110
  def _render_all_transformations(self, return_twice=True):
111
  global num
112
- # global vqgan
113
  if self.state_id is None:
114
  self.state_id = "./img_history/" + str(uuid.uuid4())
115
  print("redner all", self.state_id)
116
- current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
117
  new_latent = self.blend_latent + sum(current_vector_transforms)
118
  if self.quant:
119
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
@@ -126,17 +118,13 @@ class ImageState:
126
  image.save(f"{img_dir}/img_{num:06}.png")
127
  num += 1
128
  return (image, image) if return_twice else image
129
- def apply_gp_vector(self, weight):
130
- self.hair_gp = weight * self.green_purple_vector
131
- return self._render_all_transformations()
132
  def apply_rb_vector(self, weight):
133
- self.blue_eyes = weight * self.red_blue_vector
134
  return self._render_all_transformations()
135
  def apply_lip_vector(self, weight):
136
  self.lip_size = weight * self.lip_vector
137
  return self._render_all_transformations()
138
- def update_requant(self, val):
139
- print(f"val = {val}")
140
  self.quant = val
141
  return self._render_all_transformations()
142
  def apply_asian_vector(self, weight):
@@ -144,11 +132,7 @@ class ImageState:
144
  return self._render_all_transformations()
145
  def update_images(self, path1, path2, blend_weight):
146
  if path1 is None and path2 is None:
147
- print("no paths")
148
  return None
149
- if path1 == path2:
150
- print("paths are the same")
151
- print(path1)
152
  if path1 is None: path1 = path2
153
  if path2 is None: path2 = path1
154
  self.path1, self.path2 = path1, path2
@@ -203,9 +187,4 @@ class ImageState:
203
  self.attn_mask = None
204
  self.transform_history.append(transform_log)
205
  gc.collect()
206
- torch.cuda.empty_cache()
207
- # transform = self.prompt_optim.optimize(self.blend_latent,
208
- # positive_prompts,
209
- # negative_prompts)
210
- # self.prompt_transforms = transform
211
- # return self._render_all_transformations()
 
31
 
32
  class ImageState:
33
  def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
 
34
  self.vqgan = vqgan
35
  self.device = vqgan.device
36
  self.blend_latent = None
 
40
  self.transform_history = []
41
  self.attn_mask = None
42
  self.prompt_optim = prompt_optimizer
 
 
43
  self._load_vectors()
44
  self.init_transforms()
45
  def _load_vectors(self):
46
  self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device)
47
+ self.blue_eyes_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
 
48
  self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
49
  def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
50
  images = []
 
67
  self.lip_size = torch.zeros_like(self.lip_vector)
68
  self.asian_transform = torch.zeros_like(self.lip_vector)
69
  self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
 
70
  def clear_transforms(self):
71
  global num
72
  self.init_transforms()
 
90
  attn_mask = mask
91
  return attn_mask
92
  def set_mask(self, img):
93
+ self.attn_mask = self._get_mask(img)
94
+ x = self.attn_mask.clone()
 
 
95
  x = x.detach().cpu()
96
  x = torch.clamp(x, -1., 1.)
97
  x = (x + 1.)/2.
98
  x = x.numpy()
99
+ x = (255 * x).astype(np.uint8)
100
  x = Image.fromarray(x, "L")
101
  return x
102
  @torch.no_grad()
103
  def _render_all_transformations(self, return_twice=True):
104
  global num
 
105
  if self.state_id is None:
106
  self.state_id = "./img_history/" + str(uuid.uuid4())
107
  print("redner all", self.state_id)
108
+ current_vector_transforms = (self.blue_eyes, self.lip_size, self.asian_transform, sum(self.current_prompt_transforms))
109
  new_latent = self.blend_latent + sum(current_vector_transforms)
110
  if self.quant:
111
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
 
118
  image.save(f"{img_dir}/img_{num:06}.png")
119
  num += 1
120
  return (image, image) if return_twice else image
 
 
 
121
  def apply_rb_vector(self, weight):
122
+ self.blue_eyes = weight * self.blue_eyes_vector
123
  return self._render_all_transformations()
124
  def apply_lip_vector(self, weight):
125
  self.lip_size = weight * self.lip_vector
126
  return self._render_all_transformations()
127
+ def update_quant(self, val):
 
128
  self.quant = val
129
  return self._render_all_transformations()
130
  def apply_asian_vector(self, weight):
 
132
  return self._render_all_transformations()
133
  def update_images(self, path1, path2, blend_weight):
134
  if path1 is None and path2 is None:
 
135
  return None
 
 
 
136
  if path1 is None: path1 = path2
137
  if path2 is None: path2 = path1
138
  self.path1, self.path2 = path1, path2
 
187
  self.attn_mask = None
188
  self.transform_history.append(transform_log)
189
  gc.collect()
190
+ torch.cuda.empty_cache()
 
 
 
 
 
animation.py CHANGED
@@ -2,6 +2,7 @@ import imageio
2
  import glob
3
  import os
4
 
 
5
  def clear_img_dir(img_dir):
6
  if not os.path.exists("img_history"):
7
  os.mkdir("img_history")
@@ -10,6 +11,7 @@ def clear_img_dir(img_dir):
10
  for filename in glob.glob(img_dir+"/*"):
11
  os.remove(filename)
12
 
 
13
  def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
14
  images = []
15
  paths = glob.glob(folder + "/*")
@@ -26,5 +28,6 @@ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="
26
  imageio.mimsave(gif_name, images, duration=durations)
27
  return gif_name
28
 
 
29
  if __name__ == "__main__":
30
  create_gif()
 
2
  import glob
3
  import os
4
 
5
+
6
  def clear_img_dir(img_dir):
7
  if not os.path.exists("img_history"):
8
  os.mkdir("img_history")
 
11
  for filename in glob.glob(img_dir+"/*"):
12
  os.remove(filename)
13
 
14
+
15
  def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
16
  images = []
17
  paths = glob.glob(folder + "/*")
 
28
  imageio.mimsave(gif_name, images, duration=durations)
29
  return gif_name
30
 
31
+
32
  if __name__ == "__main__":
33
  create_gif()
app.py CHANGED
@@ -100,7 +100,6 @@ with gr.Blocks(css="styles.css") as demo:
100
  label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
101
  minimum=0,
102
  maximum=100)
103
-
104
  apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
105
  clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
106
  blue_eyes = gr.Slider(
@@ -110,13 +109,6 @@ with gr.Blocks(css="styles.css") as demo:
110
  value=0,
111
  step=0.1,
112
  )
113
- # hair_green_purple = gr.Slider(
114
- # label="hair green<->purple ",
115
- # minimum=-.8,
116
- # maximum=.8,
117
- # value=0,
118
- # step=0.1,
119
- # )
120
  lip_size = gr.Slider(
121
  label="Lip Size",
122
  minimum=-1.9,
@@ -131,10 +123,6 @@ with gr.Blocks(css="styles.css") as demo:
131
  maximum=1.,
132
  step=0.1,
133
  )
134
- # requantize = gr.Checkbox(
135
- # label="Requantize Latents (necessary using text prompts)",
136
- # value=True,
137
- # )
138
  asian_weight = gr.Slider(
139
  minimum=-2.,
140
  value=0,
 
100
  label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
101
  minimum=0,
102
  maximum=100)
 
103
  apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
104
  clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
105
  blue_eyes = gr.Slider(
 
109
  value=0,
110
  step=0.1,
111
  )
 
 
 
 
 
 
 
112
  lip_size = gr.Slider(
113
  label="Lip Size",
114
  minimum=-1.9,
 
123
  maximum=1.,
124
  step=0.1,
125
  )
 
 
 
 
126
  asian_weight = gr.Slider(
127
  minimum=-2.,
128
  value=0,
backend.py CHANGED
@@ -33,7 +33,7 @@ def get_resized_tensor(x):
33
  class ProcessorGradientFlow():
34
  """
35
  This wraps the huggingface CLIP processor to allow backprop through the image processing step.
36
- The original processor forces conversion to PIL images, which breaks gradient flow.
37
  """
38
  def __init__(self, device="cuda") -> None:
39
  self.device = device
 
33
  class ProcessorGradientFlow():
34
  """
35
  This wraps the huggingface CLIP processor to allow backprop through the image processing step.
36
+ The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
37
  """
38
  def __init__(self, device="cuda") -> None:
39
  self.device = device