Erwann Millon commited on
Commit
006354e
2 Parent(s): eac223c dc51a98

merge from hf

Browse files
Files changed (9) hide show
  1. ImageState.py +50 -22
  2. README.md +12 -0
  3. animation.py +3 -2
  4. app.py +121 -96
  5. app_backend.py → backend.py +5 -10
  6. configs.py +15 -0
  7. loaders.py +1 -0
  8. masking.py +1 -1
  9. presets.py +16 -0
ImageState.py CHANGED
@@ -1,9 +1,11 @@
1
  # from align import align_from_path
 
 
 
 
2
  from animation import clear_img_dir
3
- from app_backend import ImagePromptOptimizer, log
4
- from functools import cache
5
  import importlib
6
-
7
  import gradio as gr
8
  import matplotlib.pyplot as plt
9
  import torch
@@ -15,13 +17,13 @@ from torchvision.transforms.functional import resize
15
  from tqdm import tqdm
16
  from transformers import CLIPModel, CLIPProcessor
17
  import lpips
18
- from app_backend import get_resized_tensor
19
  from edit import blend_paths
20
  from img_processing import *
21
  from img_processing import custom_to_pil
22
  from loaders import load_default
23
-
24
  num = 0
 
25
  class PromptTransformHistory():
26
  def __init__(self, iterations) -> None:
27
  self.iterations = iterations
@@ -29,6 +31,7 @@ class PromptTransformHistory():
29
 
30
  class ImageState:
31
  def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
 
32
  self.vqgan = vqgan
33
  self.device = vqgan.device
34
  self.blend_latent = None
@@ -38,6 +41,8 @@ class ImageState:
38
  self.transform_history = []
39
  self.attn_mask = None
40
  self.prompt_optim = prompt_optimizer
 
 
41
  self._load_vectors()
42
  self.init_transforms()
43
  def _load_vectors(self):
@@ -45,6 +50,22 @@ class ImageState:
45
  self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
46
  self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
47
  self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def init_transforms(self):
49
  self.blue_eyes = torch.zeros_like(self.lip_vector)
50
  self.lip_size = torch.zeros_like(self.lip_vector)
@@ -54,7 +75,7 @@ class ImageState:
54
  def clear_transforms(self):
55
  global num
56
  self.init_transforms()
57
- clear_img_dir()
58
  num = 0
59
  return self._render_all_transformations()
60
  def _apply_vector(self, src, vector):
@@ -63,7 +84,7 @@ class ImageState:
63
  def _decode_latent_to_pil(self, latent):
64
  current_im = self.vqgan.decode(latent.to(self.device))[0]
65
  return custom_to_pil(current_im)
66
- def get_mask(self, img, mask=None):
67
  if img and "mask" in img and img["mask"] is not None:
68
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
69
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
@@ -74,7 +95,7 @@ class ImageState:
74
  attn_mask = mask
75
  return attn_mask
76
  def set_mask(self, img):
77
- attn_mask = self.get_mask(img)
78
  self.attn_mask = attn_mask
79
  # attn_mask = torch.ones_like(img, device=self.device)
80
  x = attn_mask.clone()
@@ -88,15 +109,21 @@ class ImageState:
88
  @torch.no_grad()
89
  def _render_all_transformations(self, return_twice=True):
90
  global num
 
 
 
 
91
  current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
92
  new_latent = self.blend_latent + sum(current_vector_transforms)
93
  if self.quant:
94
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
95
  image = self._decode_latent_to_pil(new_latent)
96
- img_dir = "./img_history"
 
 
97
  if not os.path.exists(img_dir):
98
  os.mkdir(img_dir)
99
- image.save(f"./img_history/img_{num:06}.png")
100
  num += 1
101
  return (image, image) if return_twice else image
102
  def apply_gp_vector(self, weight):
@@ -112,17 +139,21 @@ class ImageState:
112
  print(f"val = {val}")
113
  self.quant = val
114
  return self._render_all_transformations()
115
- def apply_gender_vector(self, weight):
116
  self.asian_transform = weight * self.asian_vector
117
  return self._render_all_transformations()
118
  def update_images(self, path1, path2, blend_weight):
119
  if path1 is None and path2 is None:
 
120
  return None
 
 
 
121
  if path1 is None: path1 = path2
122
  if path2 is None: path2 = path1
123
  self.path1, self.path2 = path1, path2
124
- # self.aligned_path1 = align_from_path(path1)
125
- # self.aligned_path2 = align_from_path(path2)
126
  return self.blend(blend_weight)
127
  @torch.no_grad()
128
  def blend(self, weight):
@@ -137,16 +168,11 @@ class ImageState:
137
  prompt_transform = self.transform_history[-1]
138
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
139
  print(latent_index)
140
- self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
141
- # print(self.current_prompt_transform)
142
- # print(self.current_prompt_transforms.mean())
143
  return self._render_all_transformations()
144
- def rescale_mask(self, mask):
145
- rep = mask.clone()
146
- rep[mask < 0.03] = -1000000
147
- rep[mask >= 0.03] = 1
148
- return rep
149
  def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
 
 
150
  transform_log = PromptTransformHistory(iterations + reconstruction_steps)
151
  transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
152
  self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
@@ -165,7 +191,7 @@ class ImageState:
165
  for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
166
  positive_prompts,
167
  negative_prompts)):
168
- transform_log.transforms.append(transform.clone().detach())
169
  self.current_prompt_transforms[-1] = transform
170
  with torch.no_grad():
171
  image = self._render_all_transformations(return_twice=False)
@@ -176,6 +202,8 @@ class ImageState:
176
  wandb.finish()
177
  self.attn_mask = None
178
  self.transform_history.append(transform_log)
 
 
179
  # transform = self.prompt_optim.optimize(self.blend_latent,
180
  # positive_prompts,
181
  # negative_prompts)
 
1
  # from align import align_from_path
2
+ import gc
3
+ import imageio
4
+ import glob
5
+ import uuid
6
  from animation import clear_img_dir
7
+ from backend import ImagePromptOptimizer, log
 
8
  import importlib
 
9
  import gradio as gr
10
  import matplotlib.pyplot as plt
11
  import torch
 
17
  from tqdm import tqdm
18
  from transformers import CLIPModel, CLIPProcessor
19
  import lpips
20
+ from backend import get_resized_tensor
21
  from edit import blend_paths
22
  from img_processing import *
23
  from img_processing import custom_to_pil
24
  from loaders import load_default
 
25
  num = 0
26
+
27
  class PromptTransformHistory():
28
  def __init__(self, iterations) -> None:
29
  self.iterations = iterations
 
31
 
32
  class ImageState:
33
  def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
34
+ # global vqgan
35
  self.vqgan = vqgan
36
  self.device = vqgan.device
37
  self.blend_latent = None
 
41
  self.transform_history = []
42
  self.attn_mask = None
43
  self.prompt_optim = prompt_optimizer
44
+ self.state_id = None
45
+ print(self.state_id)
46
  self._load_vectors()
47
  self.init_transforms()
48
  def _load_vectors(self):
 
50
  self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
51
  self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
52
  self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
53
+ def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
54
+ images = []
55
+ folder = self.state_id
56
+ paths = glob.glob(folder + "/*")
57
+ frame_duration = total_duration / len(paths)
58
+ print(len(paths), "frame dur", frame_duration)
59
+ durations = [frame_duration] * len(paths)
60
+ if extend_frames:
61
+ durations [0] = 1.5
62
+ durations [-1] = 3
63
+ for file_name in os.listdir(folder):
64
+ if file_name.endswith('.png'):
65
+ file_path = os.path.join(folder, file_name)
66
+ images.append(imageio.imread(file_path))
67
+ imageio.mimsave(gif_name, images, duration=durations)
68
+ return gif_name
69
  def init_transforms(self):
70
  self.blue_eyes = torch.zeros_like(self.lip_vector)
71
  self.lip_size = torch.zeros_like(self.lip_vector)
 
75
  def clear_transforms(self):
76
  global num
77
  self.init_transforms()
78
+ clear_img_dir("./img_history")
79
  num = 0
80
  return self._render_all_transformations()
81
  def _apply_vector(self, src, vector):
 
84
  def _decode_latent_to_pil(self, latent):
85
  current_im = self.vqgan.decode(latent.to(self.device))[0]
86
  return custom_to_pil(current_im)
87
+ def _get_mask(self, img, mask=None):
88
  if img and "mask" in img and img["mask"] is not None:
89
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
90
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
 
95
  attn_mask = mask
96
  return attn_mask
97
  def set_mask(self, img):
98
+ attn_mask = self._get_mask(img)
99
  self.attn_mask = attn_mask
100
  # attn_mask = torch.ones_like(img, device=self.device)
101
  x = attn_mask.clone()
 
109
  @torch.no_grad()
110
  def _render_all_transformations(self, return_twice=True):
111
  global num
112
+ # global vqgan
113
+ if self.state_id is None:
114
+ self.state_id = "./img_history/" + str(uuid.uuid4())
115
+ print("redner all", self.state_id)
116
  current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
117
  new_latent = self.blend_latent + sum(current_vector_transforms)
118
  if self.quant:
119
  new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
120
  image = self._decode_latent_to_pil(new_latent)
121
+ img_dir = self.state_id
122
+ if not os.path.exists("img_history"):
123
+ os.mkdir("./img_history")
124
  if not os.path.exists(img_dir):
125
  os.mkdir(img_dir)
126
+ image.save(f"{img_dir}/img_{num:06}.png")
127
  num += 1
128
  return (image, image) if return_twice else image
129
  def apply_gp_vector(self, weight):
 
139
  print(f"val = {val}")
140
  self.quant = val
141
  return self._render_all_transformations()
142
+ def apply_asian_vector(self, weight):
143
  self.asian_transform = weight * self.asian_vector
144
  return self._render_all_transformations()
145
  def update_images(self, path1, path2, blend_weight):
146
  if path1 is None and path2 is None:
147
+ print("no paths")
148
  return None
149
+ if path1 == path2:
150
+ print("paths are the same")
151
+ print(path1)
152
  if path1 is None: path1 = path2
153
  if path2 is None: path2 = path1
154
  self.path1, self.path2 = path1, path2
155
+ if self.state_id:
156
+ clear_img_dir(self.state_id)
157
  return self.blend(blend_weight)
158
  @torch.no_grad()
159
  def blend(self, weight):
 
168
  prompt_transform = self.transform_history[-1]
169
  latent_index = int(index / 100 * (prompt_transform.iterations - 1))
170
  print(latent_index)
171
+ self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index].to(self.device)
 
 
172
  return self._render_all_transformations()
 
 
 
 
 
173
  def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
174
+ if self.state_id is None:
175
+ self.state_id = "./img_history/" + str(uuid.uuid4())
176
  transform_log = PromptTransformHistory(iterations + reconstruction_steps)
177
  transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
178
  self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
 
191
  for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
192
  positive_prompts,
193
  negative_prompts)):
194
+ transform_log.transforms.append(transform.detach().cpu())
195
  self.current_prompt_transforms[-1] = transform
196
  with torch.no_grad():
197
  image = self._render_all_transformations(return_twice=False)
 
202
  wandb.finish()
203
  self.attn_mask = None
204
  self.transform_history.append(transform_log)
205
+ gc.collect()
206
+ torch.cuda.empty_cache()
207
  # transform = self.prompt_optim.optimize(self.blend_latent,
208
  # positive_prompts,
209
  # negative_prompts)
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Face Editor
3
+ emoji: 🪞
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.14.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
animation.py CHANGED
@@ -2,8 +2,9 @@ import imageio
2
  import glob
3
  import os
4
 
5
- def clear_img_dir():
6
- img_dir = "./img_history"
 
7
  if not os.path.exists(img_dir):
8
  os.mkdir(img_dir)
9
  for filename in glob.glob(img_dir+"/*"):
 
2
  import glob
3
  import os
4
 
5
+ def clear_img_dir(img_dir):
6
+ if not os.path.exists("img_history"):
7
+ os.mkdir("img_history")
8
  if not os.path.exists(img_dir):
9
  os.mkdir(img_dir)
10
  for filename in glob.glob(img_dir+"/*"):
app.py CHANGED
@@ -3,40 +3,106 @@ import os
3
  import sys
4
 
5
  import wandb
 
6
 
7
  from presets import set_major_global, set_major_local, set_small_local
8
 
9
  sys.path.append("taming-transformers")
10
- import functools
11
 
12
  import gradio as gr
13
  from transformers import CLIPModel, CLIPProcessor
 
14
 
15
  import edit
16
- # import importlib
17
- # importlib.reload(edit)
18
- from app_backend import ImagePromptOptimizer, ProcessorGradientFlow
19
  from ImageState import ImageState
20
  from loaders import load_default
21
- from animation import create_gif
22
  from prompts import get_random_prompts
23
 
24
- device = "cuda"
 
 
25
  vqgan = load_default(device)
26
  vqgan.eval()
27
  processor = ProcessorGradientFlow(device=device)
28
- clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
29
- clip.to(device)
30
- promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
31
- state = ImageState(vqgan, promptoptim)
32
- def set_img_from_example(img):
33
  return state.update_images(img, img, 0)
34
  def get_cleared_mask():
35
  return gr.Image.update(value=None)
36
  # mask.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  with gr.Blocks(css="styles.css") as demo:
 
 
38
  with gr.Row():
39
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  blue_eyes = gr.Slider(
41
  label="Blue Eyes",
42
  minimum=-.8,
@@ -76,120 +142,79 @@ with gr.Blocks(css="styles.css") as demo:
76
  maximum=2.,
77
  step=0.07,
78
  )
79
- with gr.Row():
80
- with gr.Column():
81
- gr.Markdown(value="""## Image Upload
82
- For best results, crop the photos like in the example pictures""", show_label=False)
83
- with gr.Row():
84
- base_img = gr.Image(label="Base Image", type="filepath")
85
- blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
86
- # gr.Markdown("## Image Examples")
87
- with gr.Accordion(label="Add Mask", open=False):
88
- mask = gr.Image(tool="sketch", interactive=True)
89
- gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio bug)")
90
- set_mask = gr.Button(value="Set mask")
91
- gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
92
- testim = gr.Image()
93
- clear_mask = gr.Button(value="Clear mask")
94
- clear_mask.click(get_cleared_mask, outputs=mask)
95
- with gr.Row():
96
- gr.Examples(
97
- examples=glob.glob("test_pics/*"),
98
- inputs=base_img,
99
- outputs=blend_img,
100
- fn=set_img_from_example,
101
- # cache_examples=True,
102
- )
103
- with gr.Column(scale=1):
104
- out = gr.Image()
105
- rewind = gr.Slider(value=100,
106
- label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
107
- minimum=0,
108
- maximum=100)
109
-
110
- apply_prompts = gr.Button(value="Apply Prompts", elem_id="apply")
111
- clear = gr.Button(value="Clear all transformations (irreversible)", elem_id="warning")
112
- with gr.Accordion(label="Save Animation", open=False):
113
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
114
  duration = gr.Number(value=10, label="Duration of the animation in seconds")
115
  extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
116
  gif = gr.File(interactive=False)
117
  create_animation = gr.Button(value="Create Animation")
118
- create_animation.click(create_gif, inputs=[duration, extend_frames], outputs=gif)
119
 
120
  with gr.Column(scale=1):
121
- gr.Markdown(value="""## Text Prompting
122
- See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits. Negative prompts are highly recommended""", show_label=False)
123
  positive_prompts = gr.Textbox(label="Positive prompts",
124
- value="a picture of a woman with a very big nose | a picture of a woman with a large wide nose | a woman with an extremely prominent nose")
125
  negative_prompts = gr.Textbox(label="Negative prompts",
126
- value="a picture of a person with a tiny nose | a picture of a person with a very thin nose")
127
  gen_prompts = gr.Button(value="🎲 Random prompts")
128
  gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
129
  with gr.Row():
130
  with gr.Column():
131
- gr.Text(value="Prompt Editing Configuration", show_label=False)
132
  with gr.Row():
133
- gr.Markdown(value="## Preset Configs", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
134
  with gr.Row():
135
- with gr.Column():
136
- small_local = gr.Button(value="Small Masked Changes (e.g. add lipstick)", elem_id="small_local").style(full_width=False)
137
- with gr.Column():
138
- major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
139
- with gr.Column():
140
- major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
141
  iterations = gr.Slider(minimum=10,
142
- maximum=300,
143
  step=1,
144
  value=20,
145
  label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
146
- learning_rate = gr.Slider(minimum=1e-3,
147
- maximum=6e-1,
148
- value=1e-2,
149
  label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
150
- with gr.Accordion(label="Advanced Prompt Editing Options", open=False):
151
  lpips_weight = gr.Slider(minimum=0,
152
  maximum=50,
153
  value=1,
154
- label="Perceptual similarity weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
155
  reconstruction_steps = gr.Slider(minimum=0,
156
  maximum=50,
157
- value=15,
158
  step=1,
159
- label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that will 'pull' the image back towards the original identity")
160
  # discriminator_steps = gr.Slider(minimum=0,
161
  # maximum=50,
162
  # step=1,
163
  # value=0,
164
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
165
- clear.click(state.clear_transforms, outputs=[out, mask])
166
- asian_weight.change(state.apply_gender_vector, inputs=[asian_weight], outputs=[out, mask])
167
- lip_size.change(state.apply_lip_vector, inputs=[lip_size], outputs=[out, mask])
168
- # hair_green_purple.change(state.apply_gp_vector, inputs=[hair_green_purple], outputs=[out, mask])
169
- blue_eyes.change(state.apply_rb_vector, inputs=[blue_eyes], outputs=[out, mask])
170
-
171
- blend_weight.change(state.blend, inputs=[blend_weight], outputs=[out, mask])
172
- # requantize.change(state.update_requant, inputs=[requantize], outputs=[out, mask])
173
-
174
-
175
- base_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
176
- blend_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
177
-
178
- small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
179
- major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
180
- small_local.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
181
- apply_prompts.click(state.apply_prompts, inputs=[positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[out, mask])
182
- rewind.change(state.rewind, inputs=[rewind], outputs=[out, mask])
183
- set_mask.click(state.set_mask, inputs=mask, outputs=testim)
184
  demo.queue()
185
- demo.launch(debug=True, inbrowser=True)
186
- # if __name__ == "__main__":
187
- # import argparse
188
- # parser = argparse.ArgumentParser()
189
- # parser.add_argument('--debug', action='store_true', default=False, help='Enable debugging output')
190
- # args = parser.parse_args()
191
- # # if args.debug:
192
- # # state=None
193
- # # promptoptim=None
194
- # # else:
195
- # main()
 
3
  import sys
4
 
5
  import wandb
6
+ import torch
7
 
8
  from presets import set_major_global, set_major_local, set_small_local
9
 
10
  sys.path.append("taming-transformers")
 
11
 
12
  import gradio as gr
13
  from transformers import CLIPModel, CLIPProcessor
14
+ from lpips import LPIPS
15
 
16
  import edit
17
+ from backend import ImagePromptOptimizer, ProcessorGradientFlow
 
 
18
  from ImageState import ImageState
19
  from loaders import load_default
20
+ # from animation import create_gif
21
  from prompts import get_random_prompts
22
 
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ global vqgan
26
  vqgan = load_default(device)
27
  vqgan.eval()
28
  processor = ProcessorGradientFlow(device=device)
29
+ # clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
30
+ lpips_fn = LPIPS(net='vgg').to(device)
31
+ clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
32
+ promptoptim = ImagePromptOptimizer(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
33
+ def set_img_from_example(state, img):
34
  return state.update_images(img, img, 0)
35
  def get_cleared_mask():
36
  return gr.Image.update(value=None)
37
  # mask.clear()
38
+
39
+ class StateWrapper:
40
+ def create_gif(state, *args, **kwargs):
41
+ return state, state[0].create_gif(*args, **kwargs)
42
+ def apply_asian_vector(state, *args, **kwargs):
43
+ return state, *state[0].apply_asian_vector(*args, **kwargs)
44
+ def apply_gp_vector(state, *args, **kwargs):
45
+ return state, *state[0].apply_gp_vector(*args, **kwargs)
46
+ def apply_lip_vector(state, *args, **kwargs):
47
+ return state, *state[0].apply_lip_vector(*args, **kwargs)
48
+ def apply_prompts(state, *args, **kwargs):
49
+ print(state[1])
50
+ for image in state[0].apply_prompts(*args, **kwargs):
51
+ yield state, *image
52
+ def apply_rb_vector(state, *args, **kwargs):
53
+ return state, *state[0].apply_rb_vector(*args, **kwargs)
54
+ def blend(state, *args, **kwargs):
55
+ return state, *state[0].blend(*args, **kwargs)
56
+ def clear_transforms(state, *args, **kwargs):
57
+ return state, *state[0].clear_transforms(*args, **kwargs)
58
+ def init_transforms(state, *args, **kwargs):
59
+ return state, *state[0].init_transforms(*args, **kwargs)
60
+ def prompt_optim(state, *args, **kwargs):
61
+ return state, *state[0].prompt_optim(*args, **kwargs)
62
+ def rescale_mask(state, *args, **kwargs):
63
+ return state, *state[0].rescale_mask(*args, **kwargs)
64
+ def rewind(state, *args, **kwargs):
65
+ return state, *state[0].rewind(*args, **kwargs)
66
+ def set_mask(state, *args, **kwargs):
67
+ return state, state[0].set_mask(*args, **kwargs)
68
+ def update_images(state, *args, **kwargs):
69
+ return state, *state[0].update_images(*args, **kwargs)
70
+ def update_requant(state, *args, **kwargs):
71
+ return state, *state[0].update_requant(*args, **kwargs)
72
  with gr.Blocks(css="styles.css") as demo:
73
+ # id = gr.State(str(uuid.uuid4()))
74
+ state = gr.State([ImageState(vqgan, promptoptim), str(uuid.uuid4())])
75
  with gr.Row():
76
  with gr.Column(scale=1):
77
+ with gr.Row():
78
+ with gr.Column():
79
+ gr.Markdown(value="""## Image Upload
80
+ For best results, crop the photos like in the example pictures""", show_label=False)
81
+ with gr.Row():
82
+ base_img = gr.Image(label="Base Image", type="filepath")
83
+ blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
84
+ with gr.Accordion(label="Add Mask", open=False):
85
+ mask = gr.Image(tool="sketch", interactive=True)
86
+ gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio issue)")
87
+ set_mask = gr.Button(value="Set mask")
88
+ gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
89
+ testim = gr.Image()
90
+ with gr.Row():
91
+ gr.Examples(
92
+ examples=glob.glob("test_pics/*"),
93
+ inputs=base_img,
94
+ outputs=blend_img,
95
+ fn=set_img_from_example,
96
+ )
97
+ with gr.Column(scale=1):
98
+ out = gr.Image()
99
+ rewind = gr.Slider(value=100,
100
+ label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
101
+ minimum=0,
102
+ maximum=100)
103
+
104
+ apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
105
+ clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
106
  blue_eyes = gr.Slider(
107
  label="Blue Eyes",
108
  minimum=-.8,
 
142
  maximum=2.,
143
  step=0.07,
144
  )
145
+ with gr.Accordion(label="💾 Save Animation", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
147
  duration = gr.Number(value=10, label="Duration of the animation in seconds")
148
  extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
149
  gif = gr.File(interactive=False)
150
  create_animation = gr.Button(value="Create Animation")
151
+ create_animation.click(StateWrapper.create_gif, inputs=[state, duration, extend_frames], outputs=[state, gif])
152
 
153
  with gr.Column(scale=1):
154
+ gr.Markdown(value="""## ✍️ Prompt Editing
155
+ See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits (Remember to click Set Mask!). Negative prompts are highly recommended""", show_label=False)
156
  positive_prompts = gr.Textbox(label="Positive prompts",
157
+ value="A picture of a handsome man | a picture of a masculine man",)
158
  negative_prompts = gr.Textbox(label="Negative prompts",
159
+ value="a picture of a woman | a picture of a feminine person")
160
  gen_prompts = gr.Button(value="🎲 Random prompts")
161
  gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
162
  with gr.Row():
163
  with gr.Column():
 
164
  with gr.Row():
165
+ gr.Markdown(value="## Prompt Editing Config", show_label=False)
166
+ with gr.Accordion(label="Config Tutorial", open=False):
167
+ gr.Markdown(value="""
168
+ - If results are not changing enough, increase the learning rate or decrease the perceptual loss weight
169
+ - To make local edits, use the 'Add Mask' section
170
+ - If using a mask and the image is changing too much outside of the masked area, try increasing the perceptual loss weight or lowering the learning rate
171
+ - Use the rewind slider to scroll through the iterations of your prompt transformation, you can resume editing from any point in the history.
172
+ - I recommend starting prompts with 'a picture of a'
173
+ - To avoid shifts in gender, you can use 'a person' instead of 'a man' or 'a woman', especially in the negative prompts.
174
+ - The more 'out-of-domain' the prompts are, the more you need to increase the learning rate and decrease the perceptual loss weight. For example, trying to make a black person have platinum blond hair is more out-of-domain than the same transformation on a caucasian person.
175
+ - Example: Higher config values, like learning rate: 0.7, perceptual loss weight: 35 can be used to make major out-of-domain changes.
176
+ """)
177
  with gr.Row():
178
+ # with gr.Column():
179
+ presets = gr.Dropdown(value="Select a preset", label="Preset Configs", choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"])
 
 
 
 
180
  iterations = gr.Slider(minimum=10,
181
+ maximum=60,
182
  step=1,
183
  value=20,
184
  label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
185
+ learning_rate = gr.Slider(minimum=4e-3,
186
+ maximum=1,
187
+ value=1e-1,
188
  label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
 
189
  lpips_weight = gr.Slider(minimum=0,
190
  maximum=50,
191
  value=1,
192
+ label="Perceptual Loss weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
193
  reconstruction_steps = gr.Slider(minimum=0,
194
  maximum=50,
195
+ value=3,
196
  step=1,
197
+ label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that 'pull' the image back towards the original identity")
198
  # discriminator_steps = gr.Slider(minimum=0,
199
  # maximum=50,
200
  # step=1,
201
  # value=0,
202
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
203
+ clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
204
+ asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
205
+ lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
206
+ # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
207
+ blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
208
+ blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
209
+ # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
210
+ base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
211
+ blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
212
+ # small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
213
+ # major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
214
+ # major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
215
+ apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
216
+ rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
217
+ set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
218
+ presets.change(set_preset, inputs=[presets], outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
 
 
 
219
  demo.queue()
220
+ demo.launch(debug=True, enable_queue=True)
 
 
 
 
 
 
 
 
 
 
app_backend.py → backend.py RENAMED
@@ -17,7 +17,9 @@ from img_processing import *
17
  from img_processing import custom_to_pil
18
  from loaders import load_default
19
  import glob
20
- # global log
 
 
21
  log=False
22
 
23
  # ic.disable()
@@ -61,6 +63,7 @@ class ImagePromptOptimizer(nn.Module):
61
  vqgan,
62
  clip,
63
  clip_preprocessor,
 
64
  iterations=100,
65
  lr = 0.01,
66
  save_vector=True,
@@ -81,11 +84,8 @@ class ImagePromptOptimizer(nn.Module):
81
  self.make_grid = make_grid
82
  self.return_val = return_val
83
  self.quantize = quantize
84
- # self.disc = load_disc(self.device)
85
  self.lpips_weight = lpips_weight
86
- self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
87
- def disc_loss_fn(self, logits):
88
- return -torch.mean(logits)
89
  def set_latent(self, latent):
90
  self.latent = latent.detach().to(self.device)
91
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
@@ -195,11 +195,6 @@ class ImagePromptOptimizer(nn.Module):
195
  lpips_input.retain_grad()
196
  with torch.autocast("cuda"):
197
  perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
198
- with torch.no_grad():
199
- disc_logits = self.disc(transformed_img)
200
- disc_loss = self.disc_loss_fn(disc_logits)
201
- print(f"disc_loss = {disc_loss}")
202
- disc_loss2 = self.disc(processed_img)
203
  if log:
204
  wandb.log({"Perceptual Loss": perceptual_loss})
205
  print("LPIPS loss: ", perceptual_loss)
 
17
  from img_processing import custom_to_pil
18
  from loaders import load_default
19
  import glob
20
+ import gc
21
+
22
+ global log
23
  log=False
24
 
25
  # ic.disable()
 
63
  vqgan,
64
  clip,
65
  clip_preprocessor,
66
+ lpips_fn,
67
  iterations=100,
68
  lr = 0.01,
69
  save_vector=True,
 
84
  self.make_grid = make_grid
85
  self.return_val = return_val
86
  self.quantize = quantize
 
87
  self.lpips_weight = lpips_weight
88
+ self.perceptual_loss = lpips_fn
 
 
89
  def set_latent(self, latent):
90
  self.latent = latent.detach().to(self.device)
91
  def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
 
195
  lpips_input.retain_grad()
196
  with torch.autocast("cuda"):
197
  perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
 
 
 
 
 
198
  if log:
199
  wandb.log({"Perceptual Loss": perceptual_loss})
200
  print("LPIPS loss: ", perceptual_loss)
configs.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ def set_small_local():
3
+ return (gr.Slider.update(value=18), gr.Slider.update(value=0.15), gr.Slider.update(value=5), gr.Slider.update(value=4))
4
+ def set_major_local():
5
+ return (gr.Slider.update(value=25), gr.Slider.update(value=0.187), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
6
+ def set_major_global():
7
+ return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=1), gr.Slider.update(value=1))
8
+ def set_preset(config_str):
9
+ choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"]
10
+ if config_str == choices[0]:
11
+ return set_small_local()
12
+ elif config_str == choices[1]:
13
+ return set_major_local()
14
+ elif config_str == choices[2]:
15
+ return set_major_global()
loaders.py CHANGED
@@ -23,6 +23,7 @@ def load_default(device):
23
  sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
24
  model.load_state_dict(sd, strict=True)
25
  model.to(device)
 
26
  return model
27
 
28
 
 
23
  sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
24
  model.load_state_dict(sd, strict=True)
25
  model.to(device)
26
+ del sd
27
  return model
28
 
29
 
masking.py CHANGED
@@ -13,7 +13,7 @@ from transformers import CLIPModel, CLIPProcessor
13
  import edit
14
  # import importlib
15
  # importlib.reload(edit)
16
- from app_backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
17
  from loaders import load_default
18
 
19
  device = "cuda"
 
13
  import edit
14
  # import importlib
15
  # importlib.reload(edit)
16
+ from backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
17
  from loaders import load_default
18
 
19
  device = "cuda"
presets.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def set_preset(config_str):
4
+ choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"]
5
+ if config_str == choices[0]:
6
+ return set_small_local()
7
+ elif config_str == choices[1]:
8
+ return set_major_local()
9
+ elif config_str == choices[2]:
10
+ return set_major_global()
11
+ def set_small_local():
12
+ return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
13
+ def set_major_local():
14
+ return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
15
+ def set_major_global():
16
+ return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))