erwann commited on
Commit
bea83f6
β€’
1 Parent(s): d7fcb4c

update animation creation

Browse files
Files changed (4) hide show
  1. ImageState.py +31 -9
  2. animation.py +2 -3
  3. app.py +14 -12
  4. configs.py +1 -1
ImageState.py CHANGED
@@ -1,4 +1,7 @@
1
  # from align import align_from_path
 
 
 
2
  from animation import clear_img_dir
3
  from backend import ImagePromptOptimizer, log
4
  import importlib
@@ -38,6 +41,9 @@ class ImageState:
38
  self.transform_history = []
39
  self.attn_mask = None
40
  self.prompt_optim = prompt_optimizer
 
 
 
41
  self._load_vectors()
42
  self.init_transforms()
43
  def _load_vectors(self):
@@ -45,6 +51,24 @@ class ImageState:
45
  self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
46
  self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
47
  self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def init_transforms(self):
49
  self.blue_eyes = torch.zeros_like(self.lip_vector)
50
  self.lip_size = torch.zeros_like(self.lip_vector)
@@ -104,10 +128,10 @@ class ImageState:
104
  if self.quant:
105
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
106
  image = self._decode_latent_to_pil(new_latent)
107
- img_dir = "./img_history"
108
  if not os.path.exists(img_dir):
109
  os.mkdir(img_dir)
110
- image.save(f"./img_history/img_{num:06}.png")
111
  num += 1
112
  return (image, image) if return_twice else image
113
  def apply_gp_vector(self, weight):
@@ -149,14 +173,12 @@ class ImageState:
149
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
150
  print(latent_index)
151
  self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
152
- # print(self.current_prompt_transform)
153
- # print(self.current_prompt_transforms.mean())
154
  return self._render_all_transformations()
155
- def rescale_mask(self, mask):
156
- rep = mask.clone()
157
- rep[mask < 0.03] = -1000000
158
- rep[mask >= 0.03] = 1
159
- return rep
160
  def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
161
  transform_log = PromptTransformHistory(iterations + reconstruction_steps)
162
  transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
 
1
  # from align import align_from_path
2
+ import imageio
3
+ import glob
4
+ import uuid
5
  from animation import clear_img_dir
6
  from backend import ImagePromptOptimizer, log
7
  import importlib
 
41
  self.transform_history = []
42
  self.attn_mask = None
43
  self.prompt_optim = prompt_optimizer
44
+ self.state_id = "./" + str(uuid.uuid4())
45
+ print("NEW INSTANCE")
46
+ print(self.state_id)
47
  self._load_vectors()
48
  self.init_transforms()
49
  def _load_vectors(self):
 
51
  self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
52
  self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
53
  self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
54
+ def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
55
+ images = []
56
+ folder = self.state_id
57
+ paths = glob.glob(folder + "/*")
58
+ frame_duration = total_duration / len(paths)
59
+ print(len(paths), "frame dur", frame_duration)
60
+ durations = [frame_duration] * len(paths)
61
+ if extend_frames:
62
+ durations [0] = 1.5
63
+ durations [-1] = 3
64
+ for file_name in os.listdir(folder):
65
+ if file_name.endswith('.png'):
66
+ file_path = os.path.join(folder, file_name)
67
+ images.append(imageio.imread(file_path))
68
+ # images[0] = images[0].set_meta_data({'duration': 1})
69
+ # images[-1] = images[-1].set_meta_data({'duration': 1})
70
+ imageio.mimsave(gif_name, images, duration=durations)
71
+ return gif_name
72
  def init_transforms(self):
73
  self.blue_eyes = torch.zeros_like(self.lip_vector)
74
  self.lip_size = torch.zeros_like(self.lip_vector)
 
128
  if self.quant:
129
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
130
  image = self._decode_latent_to_pil(new_latent)
131
+ img_dir = self.state_id
132
  if not os.path.exists(img_dir):
133
  os.mkdir(img_dir)
134
+ image.save(f"{img_dir}/img_{num:06}.png")
135
  num += 1
136
  return (image, image) if return_twice else image
137
  def apply_gp_vector(self, weight):
 
173
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
174
  print(latent_index)
175
  self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
 
 
176
  return self._render_all_transformations()
177
+ # def rescale_mask(self, mask):
178
+ # rep = mask.clone()
179
+ # rep[mask < 0.03] = -1000000
180
+ # rep[mask >= 0.03] = 1
181
+ # return rep
182
  def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
183
  transform_log = PromptTransformHistory(iterations + reconstruction_steps)
184
  transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
animation.py CHANGED
@@ -2,15 +2,14 @@ import imageio
2
  import glob
3
  import os
4
 
5
- def clear_img_dir():
6
- img_dir = "./img_history"
7
  if not os.path.exists(img_dir):
8
  os.mkdir(img_dir)
9
  for filename in glob.glob(img_dir+"/*"):
10
  os.remove(filename)
11
 
12
 
13
- def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"):
14
  images = []
15
  paths = glob.glob(folder + "/*")
16
  frame_duration = total_duration / len(paths)
 
2
  import glob
3
  import os
4
 
5
+ def clear_img_dir(img_dir):
 
6
  if not os.path.exists(img_dir):
7
  os.mkdir(img_dir)
8
  for filename in glob.glob(img_dir+"/*"):
9
  os.remove(filename)
10
 
11
 
12
+ def create_gif(total_duration, extend_frames, folder, gif_name="face_edit.gif"):
13
  images = []
14
  paths = glob.glob(folder + "/*")
15
  frame_duration = total_duration / len(paths)
app.py CHANGED
@@ -6,7 +6,8 @@ import wandb
6
  import torch
7
 
8
  from configs import set_major_global, set_major_local, set_small_local
9
-
 
10
  sys.path.append("taming-transformers")
11
 
12
  import gradio as gr
@@ -37,6 +38,8 @@ def get_cleared_mask():
37
  # mask.clear()
38
 
39
  class StateWrapper:
 
 
40
  def apply_asian_vector(state, *args, **kwargs):
41
  return state, *state[0].apply_asian_vector(*args, **kwargs)
42
  def apply_gp_vector(state, *args, **kwargs):
@@ -141,7 +144,7 @@ with gr.Blocks(css="styles.css") as demo:
141
  minimum=0,
142
  maximum=100)
143
 
144
- apply_prompts = gr.Button(value="🎨 Apply Prompts", elem_id="apply")
145
  clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
146
  with gr.Accordion(label="πŸ’Ύ Save Animation", open=False):
147
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
@@ -149,7 +152,7 @@ with gr.Blocks(css="styles.css") as demo:
149
  extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
150
  gif = gr.File(interactive=False)
151
  create_animation = gr.Button(value="Create Animation")
152
- create_animation.click(create_gif, inputs=[duration, extend_frames], outputs=gif)
153
 
154
  with gr.Column(scale=1):
155
  gr.Markdown(value="""## Text Prompting
@@ -166,12 +169,12 @@ with gr.Blocks(css="styles.css") as demo:
166
  with gr.Row():
167
  gr.Markdown(value="## Preset Configs", show_label=False)
168
  with gr.Row():
169
- with gr.Column():
170
- small_local = gr.Button(value="Small Masked Changes (e.g. add lipstick)", elem_id="small_local").style(full_width=False)
171
- with gr.Column():
172
- major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
173
- with gr.Column():
174
- major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
175
  iterations = gr.Slider(minimum=10,
176
  maximum=60,
177
  step=1,
@@ -181,14 +184,13 @@ with gr.Blocks(css="styles.css") as demo:
181
  maximum=7e-1,
182
  value=1e-1,
183
  label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
184
- with gr.Accordion(label="Advanced Prompt Editing Options", open=False):
185
  lpips_weight = gr.Slider(minimum=0,
186
  maximum=50,
187
  value=1,
188
  label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
189
  reconstruction_steps = gr.Slider(minimum=0,
190
  maximum=50,
191
- value=15,
192
  step=1,
193
  label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that will 'pull' the image back towards the original identity")
194
  # discriminator_steps = gr.Slider(minimum=0,
@@ -196,7 +198,7 @@ with gr.Blocks(css="styles.css") as demo:
196
  # step=1,
197
  # value=0,
198
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
199
- clear.click(state.clear_transforms, inputs=[state], outputs=[state, out, mask])
200
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
201
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
202
  # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
 
6
  import torch
7
 
8
  from configs import set_major_global, set_major_local, set_small_local
9
+ import uuid
10
+ # print()'
11
  sys.path.append("taming-transformers")
12
 
13
  import gradio as gr
 
38
  # mask.clear()
39
 
40
  class StateWrapper:
41
+ def create_gif(state, *args, **kwargs):
42
+ return state, state[0].create_gif(*args, **kwargs)
43
  def apply_asian_vector(state, *args, **kwargs):
44
  return state, *state[0].apply_asian_vector(*args, **kwargs)
45
  def apply_gp_vector(state, *args, **kwargs):
 
144
  minimum=0,
145
  maximum=100)
146
 
147
+ apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
148
  clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
149
  with gr.Accordion(label="πŸ’Ύ Save Animation", open=False):
150
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
 
152
  extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
153
  gif = gr.File(interactive=False)
154
  create_animation = gr.Button(value="Create Animation")
155
+ create_animation.click(StateWrapper.create_gif, inputs=[state, duration, extend_frames], outputs=[state, gif])
156
 
157
  with gr.Column(scale=1):
158
  gr.Markdown(value="""## Text Prompting
 
169
  with gr.Row():
170
  gr.Markdown(value="## Preset Configs", show_label=False)
171
  with gr.Row():
172
+ # with gr.Column():
173
+ small_local = gr.Button(value="Small Masked Changes (e.g. add lipstick)", elem_id="small_local").style(full_width=False)
174
+ # with gr.Column():
175
+ major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
176
+ # with gr.Column():
177
+ major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
178
  iterations = gr.Slider(minimum=10,
179
  maximum=60,
180
  step=1,
 
184
  maximum=7e-1,
185
  value=1e-1,
186
  label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
 
187
  lpips_weight = gr.Slider(minimum=0,
188
  maximum=50,
189
  value=1,
190
  label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
191
  reconstruction_steps = gr.Slider(minimum=0,
192
  maximum=50,
193
+ value=3,
194
  step=1,
195
  label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that will 'pull' the image back towards the original identity")
196
  # discriminator_steps = gr.Slider(minimum=0,
 
198
  # step=1,
199
  # value=0,
200
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
201
+ clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
202
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
203
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
204
  # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
configs.py CHANGED
@@ -4,4 +4,4 @@ def set_small_local():
4
  def set_major_local():
5
  return (gr.Slider.update(value=25), gr.Slider.update(value=0.2), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
6
  def set_major_global():
7
- return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
 
4
  def set_major_local():
5
  return (gr.Slider.update(value=25), gr.Slider.update(value=0.2), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
6
  def set_major_global():
7
+ return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=1), gr.Slider.update(value=1))