Spaces:
Configuration error
Configuration error
merge from hf
Browse files- ImageState.py +50 -22
- README.md +12 -0
- animation.py +3 -2
- app.py +121 -96
- app_backend.py → backend.py +5 -10
- configs.py +15 -0
- loaders.py +1 -0
- masking.py +1 -1
- presets.py +16 -0
ImageState.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
# from align import align_from_path
|
|
|
|
|
|
|
|
|
2 |
from animation import clear_img_dir
|
3 |
-
from
|
4 |
-
from functools import cache
|
5 |
import importlib
|
6 |
-
|
7 |
import gradio as gr
|
8 |
import matplotlib.pyplot as plt
|
9 |
import torch
|
@@ -15,13 +17,13 @@ from torchvision.transforms.functional import resize
|
|
15 |
from tqdm import tqdm
|
16 |
from transformers import CLIPModel, CLIPProcessor
|
17 |
import lpips
|
18 |
-
from
|
19 |
from edit import blend_paths
|
20 |
from img_processing import *
|
21 |
from img_processing import custom_to_pil
|
22 |
from loaders import load_default
|
23 |
-
|
24 |
num = 0
|
|
|
25 |
class PromptTransformHistory():
|
26 |
def __init__(self, iterations) -> None:
|
27 |
self.iterations = iterations
|
@@ -29,6 +31,7 @@ class PromptTransformHistory():
|
|
29 |
|
30 |
class ImageState:
|
31 |
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
|
|
|
32 |
self.vqgan = vqgan
|
33 |
self.device = vqgan.device
|
34 |
self.blend_latent = None
|
@@ -38,6 +41,8 @@ class ImageState:
|
|
38 |
self.transform_history = []
|
39 |
self.attn_mask = None
|
40 |
self.prompt_optim = prompt_optimizer
|
|
|
|
|
41 |
self._load_vectors()
|
42 |
self.init_transforms()
|
43 |
def _load_vectors(self):
|
@@ -45,6 +50,22 @@ class ImageState:
|
|
45 |
self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
|
46 |
self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
|
47 |
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def init_transforms(self):
|
49 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
50 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
@@ -54,7 +75,7 @@ class ImageState:
|
|
54 |
def clear_transforms(self):
|
55 |
global num
|
56 |
self.init_transforms()
|
57 |
-
clear_img_dir()
|
58 |
num = 0
|
59 |
return self._render_all_transformations()
|
60 |
def _apply_vector(self, src, vector):
|
@@ -63,7 +84,7 @@ class ImageState:
|
|
63 |
def _decode_latent_to_pil(self, latent):
|
64 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
65 |
return custom_to_pil(current_im)
|
66 |
-
def
|
67 |
if img and "mask" in img and img["mask"] is not None:
|
68 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
69 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
@@ -74,7 +95,7 @@ class ImageState:
|
|
74 |
attn_mask = mask
|
75 |
return attn_mask
|
76 |
def set_mask(self, img):
|
77 |
-
attn_mask = self.
|
78 |
self.attn_mask = attn_mask
|
79 |
# attn_mask = torch.ones_like(img, device=self.device)
|
80 |
x = attn_mask.clone()
|
@@ -88,15 +109,21 @@ class ImageState:
|
|
88 |
@torch.no_grad()
|
89 |
def _render_all_transformations(self, return_twice=True):
|
90 |
global num
|
|
|
|
|
|
|
|
|
91 |
current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
|
92 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
93 |
if self.quant:
|
94 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
95 |
image = self._decode_latent_to_pil(new_latent)
|
96 |
-
img_dir =
|
|
|
|
|
97 |
if not os.path.exists(img_dir):
|
98 |
os.mkdir(img_dir)
|
99 |
-
image.save(f"
|
100 |
num += 1
|
101 |
return (image, image) if return_twice else image
|
102 |
def apply_gp_vector(self, weight):
|
@@ -112,17 +139,21 @@ class ImageState:
|
|
112 |
print(f"val = {val}")
|
113 |
self.quant = val
|
114 |
return self._render_all_transformations()
|
115 |
-
def
|
116 |
self.asian_transform = weight * self.asian_vector
|
117 |
return self._render_all_transformations()
|
118 |
def update_images(self, path1, path2, blend_weight):
|
119 |
if path1 is None and path2 is None:
|
|
|
120 |
return None
|
|
|
|
|
|
|
121 |
if path1 is None: path1 = path2
|
122 |
if path2 is None: path2 = path1
|
123 |
self.path1, self.path2 = path1, path2
|
124 |
-
|
125 |
-
|
126 |
return self.blend(blend_weight)
|
127 |
@torch.no_grad()
|
128 |
def blend(self, weight):
|
@@ -137,16 +168,11 @@ class ImageState:
|
|
137 |
prompt_transform = self.transform_history[-1]
|
138 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
139 |
print(latent_index)
|
140 |
-
self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index]
|
141 |
-
# print(self.current_prompt_transform)
|
142 |
-
# print(self.current_prompt_transforms.mean())
|
143 |
return self._render_all_transformations()
|
144 |
-
def rescale_mask(self, mask):
|
145 |
-
rep = mask.clone()
|
146 |
-
rep[mask < 0.03] = -1000000
|
147 |
-
rep[mask >= 0.03] = 1
|
148 |
-
return rep
|
149 |
def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
|
|
|
|
|
150 |
transform_log = PromptTransformHistory(iterations + reconstruction_steps)
|
151 |
transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
152 |
self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
@@ -165,7 +191,7 @@ class ImageState:
|
|
165 |
for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
|
166 |
positive_prompts,
|
167 |
negative_prompts)):
|
168 |
-
transform_log.transforms.append(transform.
|
169 |
self.current_prompt_transforms[-1] = transform
|
170 |
with torch.no_grad():
|
171 |
image = self._render_all_transformations(return_twice=False)
|
@@ -176,6 +202,8 @@ class ImageState:
|
|
176 |
wandb.finish()
|
177 |
self.attn_mask = None
|
178 |
self.transform_history.append(transform_log)
|
|
|
|
|
179 |
# transform = self.prompt_optim.optimize(self.blend_latent,
|
180 |
# positive_prompts,
|
181 |
# negative_prompts)
|
|
|
1 |
# from align import align_from_path
|
2 |
+
import gc
|
3 |
+
import imageio
|
4 |
+
import glob
|
5 |
+
import uuid
|
6 |
from animation import clear_img_dir
|
7 |
+
from backend import ImagePromptOptimizer, log
|
|
|
8 |
import importlib
|
|
|
9 |
import gradio as gr
|
10 |
import matplotlib.pyplot as plt
|
11 |
import torch
|
|
|
17 |
from tqdm import tqdm
|
18 |
from transformers import CLIPModel, CLIPProcessor
|
19 |
import lpips
|
20 |
+
from backend import get_resized_tensor
|
21 |
from edit import blend_paths
|
22 |
from img_processing import *
|
23 |
from img_processing import custom_to_pil
|
24 |
from loaders import load_default
|
|
|
25 |
num = 0
|
26 |
+
|
27 |
class PromptTransformHistory():
|
28 |
def __init__(self, iterations) -> None:
|
29 |
self.iterations = iterations
|
|
|
31 |
|
32 |
class ImageState:
|
33 |
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
|
34 |
+
# global vqgan
|
35 |
self.vqgan = vqgan
|
36 |
self.device = vqgan.device
|
37 |
self.blend_latent = None
|
|
|
41 |
self.transform_history = []
|
42 |
self.attn_mask = None
|
43 |
self.prompt_optim = prompt_optimizer
|
44 |
+
self.state_id = None
|
45 |
+
print(self.state_id)
|
46 |
self._load_vectors()
|
47 |
self.init_transforms()
|
48 |
def _load_vectors(self):
|
|
|
50 |
self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device)
|
51 |
self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device)
|
52 |
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device)
|
53 |
+
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
54 |
+
images = []
|
55 |
+
folder = self.state_id
|
56 |
+
paths = glob.glob(folder + "/*")
|
57 |
+
frame_duration = total_duration / len(paths)
|
58 |
+
print(len(paths), "frame dur", frame_duration)
|
59 |
+
durations = [frame_duration] * len(paths)
|
60 |
+
if extend_frames:
|
61 |
+
durations [0] = 1.5
|
62 |
+
durations [-1] = 3
|
63 |
+
for file_name in os.listdir(folder):
|
64 |
+
if file_name.endswith('.png'):
|
65 |
+
file_path = os.path.join(folder, file_name)
|
66 |
+
images.append(imageio.imread(file_path))
|
67 |
+
imageio.mimsave(gif_name, images, duration=durations)
|
68 |
+
return gif_name
|
69 |
def init_transforms(self):
|
70 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
71 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
|
|
75 |
def clear_transforms(self):
|
76 |
global num
|
77 |
self.init_transforms()
|
78 |
+
clear_img_dir("./img_history")
|
79 |
num = 0
|
80 |
return self._render_all_transformations()
|
81 |
def _apply_vector(self, src, vector):
|
|
|
84 |
def _decode_latent_to_pil(self, latent):
|
85 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
86 |
return custom_to_pil(current_im)
|
87 |
+
def _get_mask(self, img, mask=None):
|
88 |
if img and "mask" in img and img["mask"] is not None:
|
89 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
90 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
|
|
95 |
attn_mask = mask
|
96 |
return attn_mask
|
97 |
def set_mask(self, img):
|
98 |
+
attn_mask = self._get_mask(img)
|
99 |
self.attn_mask = attn_mask
|
100 |
# attn_mask = torch.ones_like(img, device=self.device)
|
101 |
x = attn_mask.clone()
|
|
|
109 |
@torch.no_grad()
|
110 |
def _render_all_transformations(self, return_twice=True):
|
111 |
global num
|
112 |
+
# global vqgan
|
113 |
+
if self.state_id is None:
|
114 |
+
self.state_id = "./img_history/" + str(uuid.uuid4())
|
115 |
+
print("redner all", self.state_id)
|
116 |
current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
|
117 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
118 |
if self.quant:
|
119 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
120 |
image = self._decode_latent_to_pil(new_latent)
|
121 |
+
img_dir = self.state_id
|
122 |
+
if not os.path.exists("img_history"):
|
123 |
+
os.mkdir("./img_history")
|
124 |
if not os.path.exists(img_dir):
|
125 |
os.mkdir(img_dir)
|
126 |
+
image.save(f"{img_dir}/img_{num:06}.png")
|
127 |
num += 1
|
128 |
return (image, image) if return_twice else image
|
129 |
def apply_gp_vector(self, weight):
|
|
|
139 |
print(f"val = {val}")
|
140 |
self.quant = val
|
141 |
return self._render_all_transformations()
|
142 |
+
def apply_asian_vector(self, weight):
|
143 |
self.asian_transform = weight * self.asian_vector
|
144 |
return self._render_all_transformations()
|
145 |
def update_images(self, path1, path2, blend_weight):
|
146 |
if path1 is None and path2 is None:
|
147 |
+
print("no paths")
|
148 |
return None
|
149 |
+
if path1 == path2:
|
150 |
+
print("paths are the same")
|
151 |
+
print(path1)
|
152 |
if path1 is None: path1 = path2
|
153 |
if path2 is None: path2 = path1
|
154 |
self.path1, self.path2 = path1, path2
|
155 |
+
if self.state_id:
|
156 |
+
clear_img_dir(self.state_id)
|
157 |
return self.blend(blend_weight)
|
158 |
@torch.no_grad()
|
159 |
def blend(self, weight):
|
|
|
168 |
prompt_transform = self.transform_history[-1]
|
169 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
170 |
print(latent_index)
|
171 |
+
self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index].to(self.device)
|
|
|
|
|
172 |
return self._render_all_transformations()
|
|
|
|
|
|
|
|
|
|
|
173 |
def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
|
174 |
+
if self.state_id is None:
|
175 |
+
self.state_id = "./img_history/" + str(uuid.uuid4())
|
176 |
transform_log = PromptTransformHistory(iterations + reconstruction_steps)
|
177 |
transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
178 |
self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False))
|
|
|
191 |
for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent,
|
192 |
positive_prompts,
|
193 |
negative_prompts)):
|
194 |
+
transform_log.transforms.append(transform.detach().cpu())
|
195 |
self.current_prompt_transforms[-1] = transform
|
196 |
with torch.no_grad():
|
197 |
image = self._render_all_transformations(return_twice=False)
|
|
|
202 |
wandb.finish()
|
203 |
self.attn_mask = None
|
204 |
self.transform_history.append(transform_log)
|
205 |
+
gc.collect()
|
206 |
+
torch.cuda.empty_cache()
|
207 |
# transform = self.prompt_optim.optimize(self.blend_latent,
|
208 |
# positive_prompts,
|
209 |
# negative_prompts)
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Face Editor
|
3 |
+
emoji: 🪞
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.14.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
animation.py
CHANGED
@@ -2,8 +2,9 @@ import imageio
|
|
2 |
import glob
|
3 |
import os
|
4 |
|
5 |
-
def clear_img_dir():
|
6 |
-
|
|
|
7 |
if not os.path.exists(img_dir):
|
8 |
os.mkdir(img_dir)
|
9 |
for filename in glob.glob(img_dir+"/*"):
|
|
|
2 |
import glob
|
3 |
import os
|
4 |
|
5 |
+
def clear_img_dir(img_dir):
|
6 |
+
if not os.path.exists("img_history"):
|
7 |
+
os.mkdir("img_history")
|
8 |
if not os.path.exists(img_dir):
|
9 |
os.mkdir(img_dir)
|
10 |
for filename in glob.glob(img_dir+"/*"):
|
app.py
CHANGED
@@ -3,40 +3,106 @@ import os
|
|
3 |
import sys
|
4 |
|
5 |
import wandb
|
|
|
6 |
|
7 |
from presets import set_major_global, set_major_local, set_small_local
|
8 |
|
9 |
sys.path.append("taming-transformers")
|
10 |
-
import functools
|
11 |
|
12 |
import gradio as gr
|
13 |
from transformers import CLIPModel, CLIPProcessor
|
|
|
14 |
|
15 |
import edit
|
16 |
-
|
17 |
-
# importlib.reload(edit)
|
18 |
-
from app_backend import ImagePromptOptimizer, ProcessorGradientFlow
|
19 |
from ImageState import ImageState
|
20 |
from loaders import load_default
|
21 |
-
from animation import create_gif
|
22 |
from prompts import get_random_prompts
|
23 |
|
24 |
-
device = "cuda"
|
|
|
|
|
25 |
vqgan = load_default(device)
|
26 |
vqgan.eval()
|
27 |
processor = ProcessorGradientFlow(device=device)
|
28 |
-
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
def set_img_from_example(img):
|
33 |
return state.update_images(img, img, 0)
|
34 |
def get_cleared_mask():
|
35 |
return gr.Image.update(value=None)
|
36 |
# mask.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
with gr.Blocks(css="styles.css") as demo:
|
|
|
|
|
38 |
with gr.Row():
|
39 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
blue_eyes = gr.Slider(
|
41 |
label="Blue Eyes",
|
42 |
minimum=-.8,
|
@@ -76,120 +142,79 @@ with gr.Blocks(css="styles.css") as demo:
|
|
76 |
maximum=2.,
|
77 |
step=0.07,
|
78 |
)
|
79 |
-
with gr.
|
80 |
-
with gr.Column():
|
81 |
-
gr.Markdown(value="""## Image Upload
|
82 |
-
For best results, crop the photos like in the example pictures""", show_label=False)
|
83 |
-
with gr.Row():
|
84 |
-
base_img = gr.Image(label="Base Image", type="filepath")
|
85 |
-
blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
|
86 |
-
# gr.Markdown("## Image Examples")
|
87 |
-
with gr.Accordion(label="Add Mask", open=False):
|
88 |
-
mask = gr.Image(tool="sketch", interactive=True)
|
89 |
-
gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio bug)")
|
90 |
-
set_mask = gr.Button(value="Set mask")
|
91 |
-
gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
|
92 |
-
testim = gr.Image()
|
93 |
-
clear_mask = gr.Button(value="Clear mask")
|
94 |
-
clear_mask.click(get_cleared_mask, outputs=mask)
|
95 |
-
with gr.Row():
|
96 |
-
gr.Examples(
|
97 |
-
examples=glob.glob("test_pics/*"),
|
98 |
-
inputs=base_img,
|
99 |
-
outputs=blend_img,
|
100 |
-
fn=set_img_from_example,
|
101 |
-
# cache_examples=True,
|
102 |
-
)
|
103 |
-
with gr.Column(scale=1):
|
104 |
-
out = gr.Image()
|
105 |
-
rewind = gr.Slider(value=100,
|
106 |
-
label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
|
107 |
-
minimum=0,
|
108 |
-
maximum=100)
|
109 |
-
|
110 |
-
apply_prompts = gr.Button(value="Apply Prompts", elem_id="apply")
|
111 |
-
clear = gr.Button(value="Clear all transformations (irreversible)", elem_id="warning")
|
112 |
-
with gr.Accordion(label="Save Animation", open=False):
|
113 |
gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
|
114 |
duration = gr.Number(value=10, label="Duration of the animation in seconds")
|
115 |
extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
|
116 |
gif = gr.File(interactive=False)
|
117 |
create_animation = gr.Button(value="Create Animation")
|
118 |
-
create_animation.click(create_gif, inputs=[duration, extend_frames], outputs=gif)
|
119 |
|
120 |
with gr.Column(scale=1):
|
121 |
-
gr.Markdown(value="""##
|
122 |
-
See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits. Negative prompts are highly recommended""", show_label=False)
|
123 |
positive_prompts = gr.Textbox(label="Positive prompts",
|
124 |
-
value="
|
125 |
negative_prompts = gr.Textbox(label="Negative prompts",
|
126 |
-
value="a picture of a
|
127 |
gen_prompts = gr.Button(value="🎲 Random prompts")
|
128 |
gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
|
129 |
with gr.Row():
|
130 |
with gr.Column():
|
131 |
-
gr.Text(value="Prompt Editing Configuration", show_label=False)
|
132 |
with gr.Row():
|
133 |
-
gr.Markdown(value="##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
with gr.Row():
|
135 |
-
with gr.Column():
|
136 |
-
|
137 |
-
with gr.Column():
|
138 |
-
major_local = gr.Button(value="Major Masked Changes (e.g. change hair color or nose size)").style(full_width=False)
|
139 |
-
with gr.Column():
|
140 |
-
major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
|
141 |
iterations = gr.Slider(minimum=10,
|
142 |
-
maximum=
|
143 |
step=1,
|
144 |
value=20,
|
145 |
label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
|
146 |
-
learning_rate = gr.Slider(minimum=
|
147 |
-
maximum=
|
148 |
-
value=1e-
|
149 |
label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
|
150 |
-
with gr.Accordion(label="Advanced Prompt Editing Options", open=False):
|
151 |
lpips_weight = gr.Slider(minimum=0,
|
152 |
maximum=50,
|
153 |
value=1,
|
154 |
-
label="Perceptual
|
155 |
reconstruction_steps = gr.Slider(minimum=0,
|
156 |
maximum=50,
|
157 |
-
value=
|
158 |
step=1,
|
159 |
-
label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that
|
160 |
# discriminator_steps = gr.Slider(minimum=0,
|
161 |
# maximum=50,
|
162 |
# step=1,
|
163 |
# value=0,
|
164 |
# label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
|
165 |
-
clear.click(
|
166 |
-
asian_weight.change(
|
167 |
-
lip_size.change(
|
168 |
-
# hair_green_purple.change(
|
169 |
-
blue_eyes.change(
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
apply_prompts.click(state.apply_prompts, inputs=[positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[out, mask])
|
182 |
-
rewind.change(state.rewind, inputs=[rewind], outputs=[out, mask])
|
183 |
-
set_mask.click(state.set_mask, inputs=mask, outputs=testim)
|
184 |
demo.queue()
|
185 |
-
demo.launch(debug=True,
|
186 |
-
# if __name__ == "__main__":
|
187 |
-
# import argparse
|
188 |
-
# parser = argparse.ArgumentParser()
|
189 |
-
# parser.add_argument('--debug', action='store_true', default=False, help='Enable debugging output')
|
190 |
-
# args = parser.parse_args()
|
191 |
-
# # if args.debug:
|
192 |
-
# # state=None
|
193 |
-
# # promptoptim=None
|
194 |
-
# # else:
|
195 |
-
# main()
|
|
|
3 |
import sys
|
4 |
|
5 |
import wandb
|
6 |
+
import torch
|
7 |
|
8 |
from presets import set_major_global, set_major_local, set_small_local
|
9 |
|
10 |
sys.path.append("taming-transformers")
|
|
|
11 |
|
12 |
import gradio as gr
|
13 |
from transformers import CLIPModel, CLIPProcessor
|
14 |
+
from lpips import LPIPS
|
15 |
|
16 |
import edit
|
17 |
+
from backend import ImagePromptOptimizer, ProcessorGradientFlow
|
|
|
|
|
18 |
from ImageState import ImageState
|
19 |
from loaders import load_default
|
20 |
+
# from animation import create_gif
|
21 |
from prompts import get_random_prompts
|
22 |
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
|
25 |
+
global vqgan
|
26 |
vqgan = load_default(device)
|
27 |
vqgan.eval()
|
28 |
processor = ProcessorGradientFlow(device=device)
|
29 |
+
# clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
30 |
+
lpips_fn = LPIPS(net='vgg').to(device)
|
31 |
+
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
32 |
+
promptoptim = ImagePromptOptimizer(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
|
33 |
+
def set_img_from_example(state, img):
|
34 |
return state.update_images(img, img, 0)
|
35 |
def get_cleared_mask():
|
36 |
return gr.Image.update(value=None)
|
37 |
# mask.clear()
|
38 |
+
|
39 |
+
class StateWrapper:
|
40 |
+
def create_gif(state, *args, **kwargs):
|
41 |
+
return state, state[0].create_gif(*args, **kwargs)
|
42 |
+
def apply_asian_vector(state, *args, **kwargs):
|
43 |
+
return state, *state[0].apply_asian_vector(*args, **kwargs)
|
44 |
+
def apply_gp_vector(state, *args, **kwargs):
|
45 |
+
return state, *state[0].apply_gp_vector(*args, **kwargs)
|
46 |
+
def apply_lip_vector(state, *args, **kwargs):
|
47 |
+
return state, *state[0].apply_lip_vector(*args, **kwargs)
|
48 |
+
def apply_prompts(state, *args, **kwargs):
|
49 |
+
print(state[1])
|
50 |
+
for image in state[0].apply_prompts(*args, **kwargs):
|
51 |
+
yield state, *image
|
52 |
+
def apply_rb_vector(state, *args, **kwargs):
|
53 |
+
return state, *state[0].apply_rb_vector(*args, **kwargs)
|
54 |
+
def blend(state, *args, **kwargs):
|
55 |
+
return state, *state[0].blend(*args, **kwargs)
|
56 |
+
def clear_transforms(state, *args, **kwargs):
|
57 |
+
return state, *state[0].clear_transforms(*args, **kwargs)
|
58 |
+
def init_transforms(state, *args, **kwargs):
|
59 |
+
return state, *state[0].init_transforms(*args, **kwargs)
|
60 |
+
def prompt_optim(state, *args, **kwargs):
|
61 |
+
return state, *state[0].prompt_optim(*args, **kwargs)
|
62 |
+
def rescale_mask(state, *args, **kwargs):
|
63 |
+
return state, *state[0].rescale_mask(*args, **kwargs)
|
64 |
+
def rewind(state, *args, **kwargs):
|
65 |
+
return state, *state[0].rewind(*args, **kwargs)
|
66 |
+
def set_mask(state, *args, **kwargs):
|
67 |
+
return state, state[0].set_mask(*args, **kwargs)
|
68 |
+
def update_images(state, *args, **kwargs):
|
69 |
+
return state, *state[0].update_images(*args, **kwargs)
|
70 |
+
def update_requant(state, *args, **kwargs):
|
71 |
+
return state, *state[0].update_requant(*args, **kwargs)
|
72 |
with gr.Blocks(css="styles.css") as demo:
|
73 |
+
# id = gr.State(str(uuid.uuid4()))
|
74 |
+
state = gr.State([ImageState(vqgan, promptoptim), str(uuid.uuid4())])
|
75 |
with gr.Row():
|
76 |
with gr.Column(scale=1):
|
77 |
+
with gr.Row():
|
78 |
+
with gr.Column():
|
79 |
+
gr.Markdown(value="""## Image Upload
|
80 |
+
For best results, crop the photos like in the example pictures""", show_label=False)
|
81 |
+
with gr.Row():
|
82 |
+
base_img = gr.Image(label="Base Image", type="filepath")
|
83 |
+
blend_img = gr.Image(label="Image for face blending (optional)", type="filepath")
|
84 |
+
with gr.Accordion(label="Add Mask", open=False):
|
85 |
+
mask = gr.Image(tool="sketch", interactive=True)
|
86 |
+
gr.Markdown(value="Note: You must clear the mask using the rewind button every time you want to change the mask (this is a gradio issue)")
|
87 |
+
set_mask = gr.Button(value="Set mask")
|
88 |
+
gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
|
89 |
+
testim = gr.Image()
|
90 |
+
with gr.Row():
|
91 |
+
gr.Examples(
|
92 |
+
examples=glob.glob("test_pics/*"),
|
93 |
+
inputs=base_img,
|
94 |
+
outputs=blend_img,
|
95 |
+
fn=set_img_from_example,
|
96 |
+
)
|
97 |
+
with gr.Column(scale=1):
|
98 |
+
out = gr.Image()
|
99 |
+
rewind = gr.Slider(value=100,
|
100 |
+
label="Rewind back through a prompt transform: Use this to scroll through the iterations of your prompt transformation.",
|
101 |
+
minimum=0,
|
102 |
+
maximum=100)
|
103 |
+
|
104 |
+
apply_prompts = gr.Button(variant="primary", value="🎨 Apply Prompts", elem_id="apply")
|
105 |
+
clear = gr.Button(value="❌ Clear all transformations (irreversible)", elem_id="warning")
|
106 |
blue_eyes = gr.Slider(
|
107 |
label="Blue Eyes",
|
108 |
minimum=-.8,
|
|
|
142 |
maximum=2.,
|
143 |
step=0.07,
|
144 |
)
|
145 |
+
with gr.Accordion(label="💾 Save Animation", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
|
147 |
duration = gr.Number(value=10, label="Duration of the animation in seconds")
|
148 |
extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
|
149 |
gif = gr.File(interactive=False)
|
150 |
create_animation = gr.Button(value="Create Animation")
|
151 |
+
create_animation.click(StateWrapper.create_gif, inputs=[state, duration, extend_frames], outputs=[state, gif])
|
152 |
|
153 |
with gr.Column(scale=1):
|
154 |
+
gr.Markdown(value="""## ✍️ Prompt Editing
|
155 |
+
See readme for a prompting guide. Use the '|' symbol to separate prompts. Use the "Add mask" section to make local edits (Remember to click Set Mask!). Negative prompts are highly recommended""", show_label=False)
|
156 |
positive_prompts = gr.Textbox(label="Positive prompts",
|
157 |
+
value="A picture of a handsome man | a picture of a masculine man",)
|
158 |
negative_prompts = gr.Textbox(label="Negative prompts",
|
159 |
+
value="a picture of a woman | a picture of a feminine person")
|
160 |
gen_prompts = gr.Button(value="🎲 Random prompts")
|
161 |
gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
|
162 |
with gr.Row():
|
163 |
with gr.Column():
|
|
|
164 |
with gr.Row():
|
165 |
+
gr.Markdown(value="## ⚙ Prompt Editing Config", show_label=False)
|
166 |
+
with gr.Accordion(label="Config Tutorial", open=False):
|
167 |
+
gr.Markdown(value="""
|
168 |
+
- If results are not changing enough, increase the learning rate or decrease the perceptual loss weight
|
169 |
+
- To make local edits, use the 'Add Mask' section
|
170 |
+
- If using a mask and the image is changing too much outside of the masked area, try increasing the perceptual loss weight or lowering the learning rate
|
171 |
+
- Use the rewind slider to scroll through the iterations of your prompt transformation, you can resume editing from any point in the history.
|
172 |
+
- I recommend starting prompts with 'a picture of a'
|
173 |
+
- To avoid shifts in gender, you can use 'a person' instead of 'a man' or 'a woman', especially in the negative prompts.
|
174 |
+
- The more 'out-of-domain' the prompts are, the more you need to increase the learning rate and decrease the perceptual loss weight. For example, trying to make a black person have platinum blond hair is more out-of-domain than the same transformation on a caucasian person.
|
175 |
+
- Example: Higher config values, like learning rate: 0.7, perceptual loss weight: 35 can be used to make major out-of-domain changes.
|
176 |
+
""")
|
177 |
with gr.Row():
|
178 |
+
# with gr.Column():
|
179 |
+
presets = gr.Dropdown(value="Select a preset", label="Preset Configs", choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"])
|
|
|
|
|
|
|
|
|
180 |
iterations = gr.Slider(minimum=10,
|
181 |
+
maximum=60,
|
182 |
step=1,
|
183 |
value=20,
|
184 |
label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
|
185 |
+
learning_rate = gr.Slider(minimum=4e-3,
|
186 |
+
maximum=1,
|
187 |
+
value=1e-1,
|
188 |
label="Learning Rate: How strong the change in each step will be (you should raise this for bigger changes (for example, changing hair color), and lower it for more minor changes. Raise if changes aren't strong enough")
|
|
|
189 |
lpips_weight = gr.Slider(minimum=0,
|
190 |
maximum=50,
|
191 |
value=1,
|
192 |
+
label="Perceptual Loss weight (Keeps areas outside of the mask looking similar to the original. Increase if the rest of the image is changing too much while you're trying to change make a localized edit")
|
193 |
reconstruction_steps = gr.Slider(minimum=0,
|
194 |
maximum=50,
|
195 |
+
value=3,
|
196 |
step=1,
|
197 |
+
label="Steps to run at the end of the optimization, optimizing only the masked perceptual loss. If the edit is changing the identity too much, this setting will run steps at the end that 'pull' the image back towards the original identity")
|
198 |
# discriminator_steps = gr.Slider(minimum=0,
|
199 |
# maximum=50,
|
200 |
# step=1,
|
201 |
# value=0,
|
202 |
# label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
|
203 |
+
clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
|
204 |
+
asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
|
205 |
+
lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
|
206 |
+
# hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
|
207 |
+
blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
|
208 |
+
blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
|
209 |
+
# requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
|
210 |
+
base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
|
211 |
+
blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
|
212 |
+
# small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
213 |
+
# major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
214 |
+
# major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
215 |
+
apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
|
216 |
+
rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
|
217 |
+
set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
|
218 |
+
presets.change(set_preset, inputs=[presets], outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
|
|
|
|
|
|
|
219 |
demo.queue()
|
220 |
+
demo.launch(debug=True, enable_queue=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_backend.py → backend.py
RENAMED
@@ -17,7 +17,9 @@ from img_processing import *
|
|
17 |
from img_processing import custom_to_pil
|
18 |
from loaders import load_default
|
19 |
import glob
|
20 |
-
|
|
|
|
|
21 |
log=False
|
22 |
|
23 |
# ic.disable()
|
@@ -61,6 +63,7 @@ class ImagePromptOptimizer(nn.Module):
|
|
61 |
vqgan,
|
62 |
clip,
|
63 |
clip_preprocessor,
|
|
|
64 |
iterations=100,
|
65 |
lr = 0.01,
|
66 |
save_vector=True,
|
@@ -81,11 +84,8 @@ class ImagePromptOptimizer(nn.Module):
|
|
81 |
self.make_grid = make_grid
|
82 |
self.return_val = return_val
|
83 |
self.quantize = quantize
|
84 |
-
# self.disc = load_disc(self.device)
|
85 |
self.lpips_weight = lpips_weight
|
86 |
-
self.perceptual_loss =
|
87 |
-
def disc_loss_fn(self, logits):
|
88 |
-
return -torch.mean(logits)
|
89 |
def set_latent(self, latent):
|
90 |
self.latent = latent.detach().to(self.device)
|
91 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
@@ -195,11 +195,6 @@ class ImagePromptOptimizer(nn.Module):
|
|
195 |
lpips_input.retain_grad()
|
196 |
with torch.autocast("cuda"):
|
197 |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
198 |
-
with torch.no_grad():
|
199 |
-
disc_logits = self.disc(transformed_img)
|
200 |
-
disc_loss = self.disc_loss_fn(disc_logits)
|
201 |
-
print(f"disc_loss = {disc_loss}")
|
202 |
-
disc_loss2 = self.disc(processed_img)
|
203 |
if log:
|
204 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
205 |
print("LPIPS loss: ", perceptual_loss)
|
|
|
17 |
from img_processing import custom_to_pil
|
18 |
from loaders import load_default
|
19 |
import glob
|
20 |
+
import gc
|
21 |
+
|
22 |
+
global log
|
23 |
log=False
|
24 |
|
25 |
# ic.disable()
|
|
|
63 |
vqgan,
|
64 |
clip,
|
65 |
clip_preprocessor,
|
66 |
+
lpips_fn,
|
67 |
iterations=100,
|
68 |
lr = 0.01,
|
69 |
save_vector=True,
|
|
|
84 |
self.make_grid = make_grid
|
85 |
self.return_val = return_val
|
86 |
self.quantize = quantize
|
|
|
87 |
self.lpips_weight = lpips_weight
|
88 |
+
self.perceptual_loss = lpips_fn
|
|
|
|
|
89 |
def set_latent(self, latent):
|
90 |
self.latent = latent.detach().to(self.device)
|
91 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
|
|
195 |
lpips_input.retain_grad()
|
196 |
with torch.autocast("cuda"):
|
197 |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
|
|
|
|
|
|
|
|
|
|
198 |
if log:
|
199 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
200 |
print("LPIPS loss: ", perceptual_loss)
|
configs.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
def set_small_local():
|
3 |
+
return (gr.Slider.update(value=18), gr.Slider.update(value=0.15), gr.Slider.update(value=5), gr.Slider.update(value=4))
|
4 |
+
def set_major_local():
|
5 |
+
return (gr.Slider.update(value=25), gr.Slider.update(value=0.187), gr.Slider.update(value=36.6), gr.Slider.update(value=6))
|
6 |
+
def set_major_global():
|
7 |
+
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=1), gr.Slider.update(value=1))
|
8 |
+
def set_preset(config_str):
|
9 |
+
choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"]
|
10 |
+
if config_str == choices[0]:
|
11 |
+
return set_small_local()
|
12 |
+
elif config_str == choices[1]:
|
13 |
+
return set_major_local()
|
14 |
+
elif config_str == choices[2]:
|
15 |
+
return set_major_global()
|
loaders.py
CHANGED
@@ -23,6 +23,7 @@ def load_default(device):
|
|
23 |
sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
|
24 |
model.load_state_dict(sd, strict=True)
|
25 |
model.to(device)
|
|
|
26 |
return model
|
27 |
|
28 |
|
|
|
23 |
sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
|
24 |
model.load_state_dict(sd, strict=True)
|
25 |
model.to(device)
|
26 |
+
del sd
|
27 |
return model
|
28 |
|
29 |
|
masking.py
CHANGED
@@ -13,7 +13,7 @@ from transformers import CLIPModel, CLIPProcessor
|
|
13 |
import edit
|
14 |
# import importlib
|
15 |
# importlib.reload(edit)
|
16 |
-
from
|
17 |
from loaders import load_default
|
18 |
|
19 |
device = "cuda"
|
|
|
13 |
import edit
|
14 |
# import importlib
|
15 |
# importlib.reload(edit)
|
16 |
+
from backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
|
17 |
from loaders import load_default
|
18 |
|
19 |
device = "cuda"
|
presets.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def set_preset(config_str):
|
4 |
+
choices=["Small Masked Changes (e.g. add lipstick)", "Major Masked Changes (e.g. change hair color or nose size)", "Major Global Changes (e.g. change race / gender"]
|
5 |
+
if config_str == choices[0]:
|
6 |
+
return set_small_local()
|
7 |
+
elif config_str == choices[1]:
|
8 |
+
return set_major_local()
|
9 |
+
elif config_str == choices[2]:
|
10 |
+
return set_major_global()
|
11 |
+
def set_small_local():
|
12 |
+
return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
|
13 |
+
def set_major_local():
|
14 |
+
return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
|
15 |
+
def set_major_global():
|
16 |
+
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
|