Spaces:
Configuration error
Configuration error
Merge remote-tracking branch 'hf/feature' into HEAD
Browse files- ImageState.py +118 -74
- animation.py +8 -6
- app.py +7 -7
- backend.py +104 -90
- edit.py +17 -12
- img_processing.py +40 -36
- loaders.py +20 -20
- masking.py +21 -23
- presets.py +30 -4
- prompts.py +31 -7
- unwrapped.yaml +0 -37
- utils.py +3 -1
ImageState.py
CHANGED
@@ -1,183 +1,227 @@
|
|
1 |
-
|
2 |
import gc
|
|
|
3 |
import imageio
|
4 |
import glob
|
5 |
import uuid
|
6 |
from animation import clear_img_dir
|
7 |
-
from backend import
|
8 |
-
import importlib
|
9 |
-
import gradio as gr
|
10 |
-
import matplotlib.pyplot as plt
|
11 |
import torch
|
12 |
import torchvision
|
13 |
import wandb
|
14 |
-
from icecream import ic
|
15 |
-
from torch import nn
|
16 |
-
from torchvision.transforms.functional import resize
|
17 |
-
from tqdm import tqdm
|
18 |
-
from transformers import CLIPModel, CLIPProcessor
|
19 |
-
import lpips
|
20 |
-
from backend import get_resized_tensor
|
21 |
from edit import blend_paths
|
22 |
-
from img_processing import *
|
23 |
from img_processing import custom_to_pil
|
24 |
-
from
|
|
|
25 |
num = 0
|
26 |
|
27 |
-
|
|
|
28 |
def __init__(self, iterations) -> None:
|
29 |
self.iterations = iterations
|
30 |
self.transforms = []
|
31 |
|
|
|
32 |
class ImageState:
|
33 |
-
def __init__(self, vqgan, prompt_optimizer:
|
34 |
self.vqgan = vqgan
|
35 |
self.device = vqgan.device
|
36 |
self.blend_latent = None
|
37 |
self.quant = True
|
38 |
self.path1 = None
|
39 |
self.path2 = None
|
|
|
|
|
|
|
40 |
self.transform_history = []
|
41 |
self.attn_mask = None
|
42 |
self.prompt_optim = prompt_optimizer
|
43 |
self._load_vectors()
|
44 |
self.init_transforms()
|
|
|
45 |
def _load_vectors(self):
|
46 |
-
self.lip_vector = torch.load(
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
50 |
images = []
|
51 |
-
folder = self.
|
52 |
paths = glob.glob(folder + "/*")
|
53 |
frame_duration = total_duration / len(paths)
|
54 |
print(len(paths), "frame dur", frame_duration)
|
55 |
durations = [frame_duration] * len(paths)
|
56 |
if extend_frames:
|
57 |
-
durations
|
58 |
-
durations
|
59 |
for file_name in os.listdir(folder):
|
60 |
-
if file_name.endswith(
|
61 |
file_path = os.path.join(folder, file_name)
|
62 |
images.append(imageio.imread(file_path))
|
63 |
imageio.mimsave(gif_name, images, duration=durations)
|
64 |
return gif_name
|
|
|
65 |
def init_transforms(self):
|
66 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
67 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
68 |
self.asian_transform = torch.zeros_like(self.lip_vector)
|
69 |
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
|
|
|
70 |
def clear_transforms(self):
|
71 |
-
global num
|
72 |
self.init_transforms()
|
73 |
clear_img_dir("./img_history")
|
74 |
-
num = 0
|
75 |
return self._render_all_transformations()
|
76 |
-
|
77 |
-
|
78 |
-
return new_latent
|
79 |
-
def _decode_latent_to_pil(self, latent):
|
80 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
81 |
return custom_to_pil(current_im)
|
|
|
82 |
def _get_mask(self, img, mask=None):
|
83 |
if img and "mask" in img and img["mask"] is not None:
|
84 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
85 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
86 |
print("mask set successfully")
|
87 |
-
print(type(attn_mask))
|
88 |
-
print(attn_mask.shape)
|
89 |
else:
|
90 |
attn_mask = mask
|
91 |
return attn_mask
|
|
|
92 |
def set_mask(self, img):
|
93 |
self.attn_mask = self._get_mask(img)
|
94 |
x = self.attn_mask.clone()
|
95 |
x = x.detach().cpu()
|
96 |
-
x = torch.clamp(x, -1
|
97 |
-
x = (x + 1.)/2.
|
98 |
x = x.numpy()
|
99 |
x = (255 * x).astype(np.uint8)
|
100 |
x = Image.fromarray(x, "L")
|
101 |
return x
|
102 |
-
|
|
|
103 |
def _render_all_transformations(self, return_twice=True):
|
104 |
global num
|
105 |
-
|
106 |
-
self.
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
110 |
if self.quant:
|
111 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
112 |
-
image = self.
|
113 |
-
|
114 |
-
if not os.path.exists("img_history"):
|
115 |
-
os.mkdir("./img_history")
|
116 |
-
if not os.path.exists(img_dir):
|
117 |
-
os.mkdir(img_dir)
|
118 |
-
image.save(f"{img_dir}/img_{num:06}.png")
|
119 |
num += 1
|
120 |
return (image, image) if return_twice else image
|
|
|
121 |
def apply_rb_vector(self, weight):
|
122 |
self.blue_eyes = weight * self.blue_eyes_vector
|
123 |
return self._render_all_transformations()
|
|
|
124 |
def apply_lip_vector(self, weight):
|
125 |
self.lip_size = weight * self.lip_vector
|
126 |
return self._render_all_transformations()
|
|
|
127 |
def update_quant(self, val):
|
128 |
self.quant = val
|
129 |
return self._render_all_transformations()
|
|
|
130 |
def apply_asian_vector(self, weight):
|
131 |
self.asian_transform = weight * self.asian_vector
|
132 |
return self._render_all_transformations()
|
|
|
133 |
def update_images(self, path1, path2, blend_weight):
|
134 |
if path1 is None and path2 is None:
|
135 |
return None
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
self.path1, self.path2 = path1, path2
|
139 |
-
if self.
|
140 |
-
clear_img_dir(self.
|
141 |
return self.blend(blend_weight)
|
142 |
-
|
|
|
143 |
def blend(self, weight):
|
144 |
-
_, latent = blend_paths(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
self.blend_latent = latent
|
146 |
return self._render_all_transformations()
|
147 |
-
|
|
|
148 |
def rewind(self, index):
|
149 |
if not self.transform_history:
|
150 |
-
print("
|
151 |
return self._render_all_transformations()
|
152 |
prompt_transform = self.transform_history[-1]
|
153 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
154 |
print(latent_index)
|
155 |
-
self.current_prompt_transforms[-1] = prompt_transform.transforms[
|
|
|
|
|
156 |
return self._render_all_transformations()
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if log:
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
172 |
positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
|
173 |
negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
|
174 |
-
self.prompt_optim.set_params(
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
transform_log.transforms.append(transform.detach().cpu())
|
179 |
self.current_prompt_transforms[-1] = transform
|
180 |
-
with torch.
|
181 |
image = self._render_all_transformations(return_twice=False)
|
182 |
if log:
|
183 |
wandb.log({"image": wandb.Image(image)})
|
@@ -187,4 +231,4 @@ class ImageState:
|
|
187 |
self.attn_mask = None
|
188 |
self.transform_history.append(transform_log)
|
189 |
gc.collect()
|
190 |
-
torch.cuda.empty_cache()
|
|
|
1 |
+
import numpy as np
|
2 |
import gc
|
3 |
+
import os
|
4 |
import imageio
|
5 |
import glob
|
6 |
import uuid
|
7 |
from animation import clear_img_dir
|
8 |
+
from backend import ImagePromptEditor, log
|
|
|
|
|
|
|
9 |
import torch
|
10 |
import torchvision
|
11 |
import wandb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from edit import blend_paths
|
|
|
13 |
from img_processing import custom_to_pil
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
num = 0
|
17 |
|
18 |
+
|
19 |
+
class PromptTransformHistory:
|
20 |
def __init__(self, iterations) -> None:
|
21 |
self.iterations = iterations
|
22 |
self.transforms = []
|
23 |
|
24 |
+
|
25 |
class ImageState:
|
26 |
+
def __init__(self, vqgan, prompt_optimizer: ImagePromptEditor) -> None:
|
27 |
self.vqgan = vqgan
|
28 |
self.device = vqgan.device
|
29 |
self.blend_latent = None
|
30 |
self.quant = True
|
31 |
self.path1 = None
|
32 |
self.path2 = None
|
33 |
+
self.img_dir = "./img_history"
|
34 |
+
if not os.path.exists(self.img_dir):
|
35 |
+
os.mkdir(self.img_dir)
|
36 |
self.transform_history = []
|
37 |
self.attn_mask = None
|
38 |
self.prompt_optim = prompt_optimizer
|
39 |
self._load_vectors()
|
40 |
self.init_transforms()
|
41 |
+
|
42 |
def _load_vectors(self):
|
43 |
+
self.lip_vector = torch.load(
|
44 |
+
"./latent_vectors/lipvector.pt", map_location=self.device
|
45 |
+
)
|
46 |
+
self.blue_eyes_vector = torch.load(
|
47 |
+
"./latent_vectors/2blue_eyes.pt", map_location=self.device
|
48 |
+
)
|
49 |
+
self.asian_vector = torch.load(
|
50 |
+
"./latent_vectors/asian10.pt", map_location=self.device
|
51 |
+
)
|
52 |
+
|
53 |
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
54 |
images = []
|
55 |
+
folder = self.img_dir
|
56 |
paths = glob.glob(folder + "/*")
|
57 |
frame_duration = total_duration / len(paths)
|
58 |
print(len(paths), "frame dur", frame_duration)
|
59 |
durations = [frame_duration] * len(paths)
|
60 |
if extend_frames:
|
61 |
+
durations[0] = 1.5
|
62 |
+
durations[-1] = 3
|
63 |
for file_name in os.listdir(folder):
|
64 |
+
if file_name.endswith(".png"):
|
65 |
file_path = os.path.join(folder, file_name)
|
66 |
images.append(imageio.imread(file_path))
|
67 |
imageio.mimsave(gif_name, images, duration=durations)
|
68 |
return gif_name
|
69 |
+
|
70 |
def init_transforms(self):
|
71 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
72 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
73 |
self.asian_transform = torch.zeros_like(self.lip_vector)
|
74 |
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
|
75 |
+
|
76 |
def clear_transforms(self):
|
|
|
77 |
self.init_transforms()
|
78 |
clear_img_dir("./img_history")
|
|
|
79 |
return self._render_all_transformations()
|
80 |
+
|
81 |
+
def _latent_to_pil(self, latent):
|
|
|
|
|
82 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
83 |
return custom_to_pil(current_im)
|
84 |
+
|
85 |
def _get_mask(self, img, mask=None):
|
86 |
if img and "mask" in img and img["mask"] is not None:
|
87 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
88 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
89 |
print("mask set successfully")
|
|
|
|
|
90 |
else:
|
91 |
attn_mask = mask
|
92 |
return attn_mask
|
93 |
+
|
94 |
def set_mask(self, img):
|
95 |
self.attn_mask = self._get_mask(img)
|
96 |
x = self.attn_mask.clone()
|
97 |
x = x.detach().cpu()
|
98 |
+
x = torch.clamp(x, -1.0, 1.0)
|
99 |
+
x = (x + 1.0) / 2.0
|
100 |
x = x.numpy()
|
101 |
x = (255 * x).astype(np.uint8)
|
102 |
x = Image.fromarray(x, "L")
|
103 |
return x
|
104 |
+
|
105 |
+
@torch.inference_mode()
|
106 |
def _render_all_transformations(self, return_twice=True):
|
107 |
global num
|
108 |
+
current_vector_transforms = (
|
109 |
+
self.blue_eyes,
|
110 |
+
self.lip_size,
|
111 |
+
self.asian_transform,
|
112 |
+
sum(self.current_prompt_transforms),
|
113 |
+
)
|
114 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
115 |
if self.quant:
|
116 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
117 |
+
image = self._latent_to_pil(new_latent)
|
118 |
+
image.save(f"{self.img_dir}/img_{num:06}.png")
|
|
|
|
|
|
|
|
|
|
|
119 |
num += 1
|
120 |
return (image, image) if return_twice else image
|
121 |
+
|
122 |
def apply_rb_vector(self, weight):
|
123 |
self.blue_eyes = weight * self.blue_eyes_vector
|
124 |
return self._render_all_transformations()
|
125 |
+
|
126 |
def apply_lip_vector(self, weight):
|
127 |
self.lip_size = weight * self.lip_vector
|
128 |
return self._render_all_transformations()
|
129 |
+
|
130 |
def update_quant(self, val):
|
131 |
self.quant = val
|
132 |
return self._render_all_transformations()
|
133 |
+
|
134 |
def apply_asian_vector(self, weight):
|
135 |
self.asian_transform = weight * self.asian_vector
|
136 |
return self._render_all_transformations()
|
137 |
+
|
138 |
def update_images(self, path1, path2, blend_weight):
|
139 |
if path1 is None and path2 is None:
|
140 |
return None
|
141 |
+
|
142 |
+
# Duplicate paths if one is empty
|
143 |
+
if path1 is None:
|
144 |
+
path1 = path2
|
145 |
+
if path2 is None:
|
146 |
+
path2 = path1
|
147 |
+
|
148 |
self.path1, self.path2 = path1, path2
|
149 |
+
if self.img_dir:
|
150 |
+
clear_img_dir(self.img_dir)
|
151 |
return self.blend(blend_weight)
|
152 |
+
|
153 |
+
@torch.inference_mode()
|
154 |
def blend(self, weight):
|
155 |
+
_, latent = blend_paths(
|
156 |
+
self.vqgan,
|
157 |
+
self.path1,
|
158 |
+
self.path2,
|
159 |
+
weight=weight,
|
160 |
+
show=False,
|
161 |
+
device=self.device,
|
162 |
+
)
|
163 |
self.blend_latent = latent
|
164 |
return self._render_all_transformations()
|
165 |
+
|
166 |
+
@torch.inference_mode()
|
167 |
def rewind(self, index):
|
168 |
if not self.transform_history:
|
169 |
+
print("No history")
|
170 |
return self._render_all_transformations()
|
171 |
prompt_transform = self.transform_history[-1]
|
172 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
173 |
print(latent_index)
|
174 |
+
self.current_prompt_transforms[-1] = prompt_transform.transforms[
|
175 |
+
latent_index
|
176 |
+
].to(self.device)
|
177 |
return self._render_all_transformations()
|
178 |
+
|
179 |
+
def _init_logging(lr, iterations, lpips_weight, positive_prompts, negative_prompts):
|
180 |
+
wandb.init(reinit=True, project="face-editor")
|
181 |
+
wandb.config.update({"Positive Prompts": positive_prompts})
|
182 |
+
wandb.config.update({"Negative Prompts": negative_prompts})
|
183 |
+
wandb.config.update(
|
184 |
+
dict(lr=lr, iterations=iterations, lpips_weight=lpips_weight)
|
185 |
+
)
|
186 |
+
|
187 |
+
def apply_prompts(
|
188 |
+
self,
|
189 |
+
positive_prompts,
|
190 |
+
negative_prompts,
|
191 |
+
lr,
|
192 |
+
iterations,
|
193 |
+
lpips_weight,
|
194 |
+
reconstruction_steps,
|
195 |
+
):
|
196 |
if log:
|
197 |
+
self._init_logging(
|
198 |
+
lr, iterations, lpips_weight, positive_prompts, negative_prompts
|
199 |
+
)
|
200 |
+
transform_log = PromptTransformHistory(iterations + reconstruction_steps)
|
201 |
+
transform_log.transforms.append(
|
202 |
+
torch.zeros_like(self.blend_latent, requires_grad=False)
|
203 |
+
)
|
204 |
+
self.current_prompt_transforms.append(
|
205 |
+
torch.zeros_like(self.blend_latent, requires_grad=False)
|
206 |
+
)
|
207 |
positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
|
208 |
negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
|
209 |
+
self.prompt_optim.set_params(
|
210 |
+
lr,
|
211 |
+
iterations,
|
212 |
+
lpips_weight,
|
213 |
+
attn_mask=self.attn_mask,
|
214 |
+
reconstruction_steps=reconstruction_steps,
|
215 |
+
)
|
216 |
+
|
217 |
+
for i, transform in enumerate(
|
218 |
+
self.prompt_optim.optimize(
|
219 |
+
self.blend_latent, positive_prompts, negative_prompts
|
220 |
+
)
|
221 |
+
):
|
222 |
transform_log.transforms.append(transform.detach().cpu())
|
223 |
self.current_prompt_transforms[-1] = transform
|
224 |
+
with torch.inference_mode():
|
225 |
image = self._render_all_transformations(return_twice=False)
|
226 |
if log:
|
227 |
wandb.log({"image": wandb.Image(image)})
|
|
|
231 |
self.attn_mask = None
|
232 |
self.transform_history.append(transform_log)
|
233 |
gc.collect()
|
234 |
+
torch.cuda.empty_cache()
|
animation.py
CHANGED
@@ -8,21 +8,23 @@ def clear_img_dir(img_dir):
|
|
8 |
os.mkdir("img_history")
|
9 |
if not os.path.exists(img_dir):
|
10 |
os.mkdir(img_dir)
|
11 |
-
for filename in glob.glob(img_dir+"/*"):
|
12 |
os.remove(filename)
|
13 |
|
14 |
|
15 |
-
def create_gif(
|
|
|
|
|
16 |
images = []
|
17 |
paths = glob.glob(folder + "/*")
|
18 |
frame_duration = total_duration / len(paths)
|
19 |
print(len(paths), "frame dur", frame_duration)
|
20 |
durations = [frame_duration] * len(paths)
|
21 |
if extend_frames:
|
22 |
-
durations
|
23 |
-
durations
|
24 |
for file_name in os.listdir(folder):
|
25 |
-
if file_name.endswith(
|
26 |
file_path = os.path.join(folder, file_name)
|
27 |
images.append(imageio.imread(file_path))
|
28 |
imageio.mimsave(gif_name, images, duration=durations)
|
@@ -30,4 +32,4 @@ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="
|
|
30 |
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
-
create_gif()
|
|
|
8 |
os.mkdir("img_history")
|
9 |
if not os.path.exists(img_dir):
|
10 |
os.mkdir(img_dir)
|
11 |
+
for filename in glob.glob(img_dir + "/*"):
|
12 |
os.remove(filename)
|
13 |
|
14 |
|
15 |
+
def create_gif(
|
16 |
+
total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"
|
17 |
+
):
|
18 |
images = []
|
19 |
paths = glob.glob(folder + "/*")
|
20 |
frame_duration = total_duration / len(paths)
|
21 |
print(len(paths), "frame dur", frame_duration)
|
22 |
durations = [frame_duration] * len(paths)
|
23 |
if extend_frames:
|
24 |
+
durations[0] = 1.5
|
25 |
+
durations[-1] = 3
|
26 |
for file_name in os.listdir(folder):
|
27 |
+
if file_name.endswith(".png"):
|
28 |
file_path = os.path.join(folder, file_name)
|
29 |
images.append(imageio.imread(file_path))
|
30 |
imageio.mimsave(gif_name, images, duration=durations)
|
|
|
32 |
|
33 |
|
34 |
if __name__ == "__main__":
|
35 |
+
create_gif()
|
app.py
CHANGED
@@ -14,7 +14,7 @@ from transformers import CLIPModel, CLIPProcessor
|
|
14 |
from lpips import LPIPS
|
15 |
|
16 |
import edit
|
17 |
-
from backend import
|
18 |
from ImageState import ImageState
|
19 |
from loaders import load_default
|
20 |
# from animation import create_gif
|
@@ -29,14 +29,14 @@ processor = ProcessorGradientFlow(device=device)
|
|
29 |
# clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
30 |
lpips_fn = LPIPS(net='vgg').to(device)
|
31 |
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
32 |
-
promptoptim =
|
|
|
33 |
def set_img_from_example(state, img):
|
34 |
return state.update_images(img, img, 0)
|
35 |
def get_cleared_mask():
|
36 |
return gr.Image.update(value=None)
|
37 |
-
# mask.clear()
|
38 |
-
|
39 |
class StateWrapper:
|
|
|
40 |
def create_gif(state, *args, **kwargs):
|
41 |
return state, state[0].create_gif(*args, **kwargs)
|
42 |
def apply_asian_vector(state, *args, **kwargs):
|
@@ -46,7 +46,6 @@ class StateWrapper:
|
|
46 |
def apply_lip_vector(state, *args, **kwargs):
|
47 |
return state, *state[0].apply_lip_vector(*args, **kwargs)
|
48 |
def apply_prompts(state, *args, **kwargs):
|
49 |
-
print(state[1])
|
50 |
for image in state[0].apply_prompts(*args, **kwargs):
|
51 |
yield state, *image
|
52 |
def apply_rb_vector(state, *args, **kwargs):
|
@@ -69,9 +68,10 @@ class StateWrapper:
|
|
69 |
return state, *state[0].update_images(*args, **kwargs)
|
70 |
def update_requant(state, *args, **kwargs):
|
71 |
return state, *state[0].update_requant(*args, **kwargs)
|
|
|
|
|
72 |
with gr.Blocks(css="styles.css") as demo:
|
73 |
-
|
74 |
-
state = gr.State([ImageState(vqgan, promptoptim), str(uuid.uuid4())])
|
75 |
with gr.Row():
|
76 |
with gr.Column(scale=1):
|
77 |
with gr.Row():
|
|
|
14 |
from lpips import LPIPS
|
15 |
|
16 |
import edit
|
17 |
+
from backend import ImagePromptEditor, ProcessorGradientFlow
|
18 |
from ImageState import ImageState
|
19 |
from loaders import load_default
|
20 |
# from animation import create_gif
|
|
|
29 |
# clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
30 |
lpips_fn = LPIPS(net='vgg').to(device)
|
31 |
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
32 |
+
promptoptim = ImagePromptEditor(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
|
33 |
+
|
34 |
def set_img_from_example(state, img):
|
35 |
return state.update_images(img, img, 0)
|
36 |
def get_cleared_mask():
|
37 |
return gr.Image.update(value=None)
|
|
|
|
|
38 |
class StateWrapper:
|
39 |
+
"""This extremely ugly code is a hacky fix to allow con"""
|
40 |
def create_gif(state, *args, **kwargs):
|
41 |
return state, state[0].create_gif(*args, **kwargs)
|
42 |
def apply_asian_vector(state, *args, **kwargs):
|
|
|
46 |
def apply_lip_vector(state, *args, **kwargs):
|
47 |
return state, *state[0].apply_lip_vector(*args, **kwargs)
|
48 |
def apply_prompts(state, *args, **kwargs):
|
|
|
49 |
for image in state[0].apply_prompts(*args, **kwargs):
|
50 |
yield state, *image
|
51 |
def apply_rb_vector(state, *args, **kwargs):
|
|
|
68 |
return state, *state[0].update_images(*args, **kwargs)
|
69 |
def update_requant(state, *args, **kwargs):
|
70 |
return state, *state[0].update_requant(*args, **kwargs)
|
71 |
+
|
72 |
+
|
73 |
with gr.Blocks(css="styles.css") as demo:
|
74 |
+
state = gr.State([ImageState(vqgan, promptoptim)])
|
|
|
75 |
with gr.Row():
|
76 |
with gr.Column(scale=1):
|
77 |
with gr.Row():
|
backend.py
CHANGED
@@ -1,77 +1,65 @@
|
|
1 |
-
# from functools import cache
|
2 |
-
import importlib
|
3 |
-
|
4 |
-
import gradio as gr
|
5 |
import matplotlib.pyplot as plt
|
6 |
import torch
|
7 |
import torchvision
|
8 |
import wandb
|
9 |
-
from icecream import ic
|
10 |
from torch import nn
|
11 |
-
from torchvision.transforms.functional import resize
|
12 |
from tqdm import tqdm
|
13 |
-
from transformers import
|
14 |
-
import
|
15 |
-
|
16 |
-
from img_processing import *
|
17 |
-
from img_processing import custom_to_pil
|
18 |
-
from loaders import load_default
|
19 |
-
import glob
|
20 |
-
import gc
|
21 |
|
22 |
global log
|
23 |
-
log=False
|
24 |
-
|
25 |
-
|
26 |
-
# ic.enable()
|
27 |
-
def get_resized_tensor(x):
|
28 |
-
if len(x.shape) == 2:
|
29 |
-
re = x.unsqueeze(0)
|
30 |
-
else: re = x
|
31 |
-
re = resize(re, (10, 10))
|
32 |
-
return re
|
33 |
-
class ProcessorGradientFlow():
|
34 |
"""
|
35 |
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
36 |
-
The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
|
37 |
"""
|
|
|
38 |
def __init__(self, device="cuda") -> None:
|
39 |
self.device = device
|
40 |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
41 |
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
42 |
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
43 |
self.normalize = torchvision.transforms.Normalize(
|
44 |
-
self.image_mean,
|
45 |
-
self.image_std
|
46 |
)
|
47 |
self.resize = torchvision.transforms.Resize(224)
|
48 |
self.center_crop = torchvision.transforms.CenterCrop(224)
|
|
|
49 |
def preprocess_img(self, images):
|
50 |
images = self.center_crop(images)
|
51 |
images = self.resize(images)
|
52 |
images = self.center_crop(images)
|
53 |
images = self.normalize(images)
|
54 |
return images
|
|
|
55 |
def __call__(self, images=[], **kwargs):
|
56 |
processed_inputs = self.processor(**kwargs)
|
57 |
processed_inputs["pixel_values"] = self.preprocess_img(images)
|
58 |
-
processed_inputs = {
|
|
|
|
|
59 |
return processed_inputs
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
75 |
super().__init__()
|
76 |
self.latent = None
|
77 |
self.device = vqgan.device
|
@@ -86,14 +74,17 @@ class ImagePromptOptimizer(nn.Module):
|
|
86 |
self.quantize = quantize
|
87 |
self.lpips_weight = lpips_weight
|
88 |
self.perceptual_loss = lpips_fn
|
|
|
89 |
def set_latent(self, latent):
|
90 |
self.latent = latent.detach().to(self.device)
|
|
|
91 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
92 |
self._attn_mask = attn_mask
|
93 |
self.iterations = iterations
|
94 |
self.lr = lr
|
95 |
self.lpips_weight = lpips_weight
|
96 |
self.reconstruction_steps = reconstruction_steps
|
|
|
97 |
def forward(self, vector):
|
98 |
base_latent = self.latent.detach().requires_grad_()
|
99 |
trans_latent = base_latent + vector
|
@@ -103,19 +94,22 @@ class ImagePromptOptimizer(nn.Module):
|
|
103 |
z_q = trans_latent
|
104 |
dec = self.vqgan.decode(z_q)
|
105 |
return dec
|
|
|
106 |
def _get_clip_similarity(self, prompts, image, weights=None):
|
107 |
if isinstance(prompts, str):
|
108 |
prompts = [prompts]
|
109 |
elif not isinstance(prompts, list):
|
110 |
raise TypeError("Provide prompts as string or list of strings")
|
111 |
-
clip_inputs = self.clip_preprocessor(
|
112 |
-
images=image, return_tensors="pt", padding=True
|
|
|
113 |
clip_outputs = self.clip(**clip_inputs)
|
114 |
similarity_logits = clip_outputs.logits_per_image
|
115 |
if weights:
|
116 |
similarity_logits *= weights
|
117 |
return similarity_logits.sum()
|
118 |
-
|
|
|
119 |
pos_logits = self._get_clip_similarity(pos_prompts, image)
|
120 |
if neg_prompts:
|
121 |
neg_logits = self._get_clip_similarity(neg_prompts, image)
|
@@ -123,6 +117,7 @@ class ImagePromptOptimizer(nn.Module):
|
|
123 |
neg_logits = torch.tensor([1], device=self.device)
|
124 |
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
125 |
return loss
|
|
|
126 |
def visualize(self, processed_img):
|
127 |
if self.make_grid:
|
128 |
self.index += 1
|
@@ -131,74 +126,93 @@ class ImagePromptOptimizer(nn.Module):
|
|
131 |
else:
|
132 |
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
133 |
plt.show()
|
|
|
134 |
def _attn_mask(self, grad):
|
135 |
newgrad = grad
|
136 |
if self._attn_mask is not None:
|
137 |
newgrad = grad * (self._attn_mask)
|
138 |
return newgrad
|
|
|
139 |
def _attn_mask_inverse(self, grad):
|
140 |
newgrad = grad
|
141 |
if self._attn_mask is not None:
|
142 |
newgrad = grad * ((self._attn_mask - 1) * -1)
|
143 |
return newgrad
|
|
|
144 |
def _get_next_inputs(self, transformed_img):
|
145 |
-
processed_img = loop_post_process(transformed_img)
|
146 |
processed_img.retain_grad()
|
|
|
147 |
lpips_input = processed_img.clone()
|
148 |
lpips_input.register_hook(self._attn_mask_inverse)
|
149 |
lpips_input.retain_grad()
|
|
|
150 |
clip_input = processed_img.clone()
|
151 |
clip_input.register_hook(self._attn_mask)
|
152 |
clip_input.retain_grad()
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
156 |
self.set_latent(latent)
|
157 |
-
transformed_img = self(
|
|
|
|
|
158 |
original_img = loop_post_process(transformed_img)
|
159 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
160 |
optim = torch.optim.Adam([vector], lr=self.lr)
|
161 |
-
|
162 |
-
plt.figure(figsize=(35, 25))
|
163 |
-
self.index = 1
|
164 |
for i in tqdm(range(self.iterations)):
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
with torch.autocast("cuda"):
|
169 |
-
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_input)
|
170 |
-
print("CLIP loss", clip_loss)
|
171 |
-
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
172 |
-
print("LPIPS loss: ", perceptual_loss)
|
173 |
-
if log:
|
174 |
-
wandb.log({"Perceptual Loss": perceptual_loss})
|
175 |
-
wandb.log({"CLIP Loss": clip_loss})
|
176 |
-
clip_loss.backward(retain_graph=True)
|
177 |
-
perceptual_loss.backward(retain_graph=True)
|
178 |
-
p2 = processed_img.grad
|
179 |
-
print("Sum Loss", perceptual_loss + clip_loss)
|
180 |
-
optim.step()
|
181 |
-
# if i % self.iterations // 10 == 0:
|
182 |
-
# self.visualize(transformed_img)
|
183 |
-
yield vector
|
184 |
-
if self.make_grid:
|
185 |
-
plt.savefig(f"plot {pos_prompts[0]}.png")
|
186 |
-
plt.show()
|
187 |
-
print("lpips solo op")
|
188 |
for i in range(self.reconstruction_steps):
|
189 |
-
|
190 |
-
transformed_img = self(vector)
|
191 |
-
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
192 |
-
processed_img.retain_grad()
|
193 |
-
lpips_input = processed_img.clone()
|
194 |
-
lpips_input.register_hook(self._attn_mask_inverse)
|
195 |
-
lpips_input.retain_grad()
|
196 |
-
with torch.autocast("cuda"):
|
197 |
-
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
198 |
-
if log:
|
199 |
-
wandb.log({"Perceptual Loss": perceptual_loss})
|
200 |
-
print("LPIPS loss: ", perceptual_loss)
|
201 |
-
perceptual_loss.backward(retain_graph=True)
|
202 |
-
optim.step()
|
203 |
-
yield vector
|
204 |
yield vector if self.return_val == "vector" else self.latent + vector
|
|
|
|
|
|
|
|
|
|
|
1 |
import matplotlib.pyplot as plt
|
2 |
import torch
|
3 |
import torchvision
|
4 |
import wandb
|
|
|
5 |
from torch import nn
|
|
|
6 |
from tqdm import tqdm
|
7 |
+
from transformers import CLIPProcessor
|
8 |
+
from img_processing import get_pil, loop_post_process
|
9 |
+
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
global log
|
12 |
+
log = False
|
13 |
+
|
14 |
+
class ProcessorGradientFlow:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"""
|
16 |
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
17 |
+
The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
|
18 |
"""
|
19 |
+
|
20 |
def __init__(self, device="cuda") -> None:
|
21 |
self.device = device
|
22 |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
23 |
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
24 |
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
25 |
self.normalize = torchvision.transforms.Normalize(
|
26 |
+
self.image_mean, self.image_std
|
|
|
27 |
)
|
28 |
self.resize = torchvision.transforms.Resize(224)
|
29 |
self.center_crop = torchvision.transforms.CenterCrop(224)
|
30 |
+
|
31 |
def preprocess_img(self, images):
|
32 |
images = self.center_crop(images)
|
33 |
images = self.resize(images)
|
34 |
images = self.center_crop(images)
|
35 |
images = self.normalize(images)
|
36 |
return images
|
37 |
+
|
38 |
def __call__(self, images=[], **kwargs):
|
39 |
processed_inputs = self.processor(**kwargs)
|
40 |
processed_inputs["pixel_values"] = self.preprocess_img(images)
|
41 |
+
processed_inputs = {
|
42 |
+
key: value.to(self.device) for (key, value) in processed_inputs.items()
|
43 |
+
}
|
44 |
return processed_inputs
|
45 |
|
46 |
+
|
47 |
+
class ImagePromptEditor(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
vqgan,
|
51 |
+
clip,
|
52 |
+
clip_preprocessor,
|
53 |
+
lpips_fn,
|
54 |
+
iterations=100,
|
55 |
+
lr=0.01,
|
56 |
+
save_vector=True,
|
57 |
+
return_val="vector",
|
58 |
+
quantize=True,
|
59 |
+
make_grid=False,
|
60 |
+
lpips_weight=6.2,
|
61 |
+
) -> None:
|
62 |
+
|
63 |
super().__init__()
|
64 |
self.latent = None
|
65 |
self.device = vqgan.device
|
|
|
74 |
self.quantize = quantize
|
75 |
self.lpips_weight = lpips_weight
|
76 |
self.perceptual_loss = lpips_fn
|
77 |
+
|
78 |
def set_latent(self, latent):
|
79 |
self.latent = latent.detach().to(self.device)
|
80 |
+
|
81 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
82 |
self._attn_mask = attn_mask
|
83 |
self.iterations = iterations
|
84 |
self.lr = lr
|
85 |
self.lpips_weight = lpips_weight
|
86 |
self.reconstruction_steps = reconstruction_steps
|
87 |
+
|
88 |
def forward(self, vector):
|
89 |
base_latent = self.latent.detach().requires_grad_()
|
90 |
trans_latent = base_latent + vector
|
|
|
94 |
z_q = trans_latent
|
95 |
dec = self.vqgan.decode(z_q)
|
96 |
return dec
|
97 |
+
|
98 |
def _get_clip_similarity(self, prompts, image, weights=None):
|
99 |
if isinstance(prompts, str):
|
100 |
prompts = [prompts]
|
101 |
elif not isinstance(prompts, list):
|
102 |
raise TypeError("Provide prompts as string or list of strings")
|
103 |
+
clip_inputs = self.clip_preprocessor(
|
104 |
+
text=prompts, images=image, return_tensors="pt", padding=True
|
105 |
+
)
|
106 |
clip_outputs = self.clip(**clip_inputs)
|
107 |
similarity_logits = clip_outputs.logits_per_image
|
108 |
if weights:
|
109 |
similarity_logits *= weights
|
110 |
return similarity_logits.sum()
|
111 |
+
|
112 |
+
def _get_CLIP_loss(self, pos_prompts, neg_prompts, image):
|
113 |
pos_logits = self._get_clip_similarity(pos_prompts, image)
|
114 |
if neg_prompts:
|
115 |
neg_logits = self._get_clip_similarity(neg_prompts, image)
|
|
|
117 |
neg_logits = torch.tensor([1], device=self.device)
|
118 |
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
119 |
return loss
|
120 |
+
|
121 |
def visualize(self, processed_img):
|
122 |
if self.make_grid:
|
123 |
self.index += 1
|
|
|
126 |
else:
|
127 |
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
128 |
plt.show()
|
129 |
+
|
130 |
def _attn_mask(self, grad):
|
131 |
newgrad = grad
|
132 |
if self._attn_mask is not None:
|
133 |
newgrad = grad * (self._attn_mask)
|
134 |
return newgrad
|
135 |
+
|
136 |
def _attn_mask_inverse(self, grad):
|
137 |
newgrad = grad
|
138 |
if self._attn_mask is not None:
|
139 |
newgrad = grad * ((self._attn_mask - 1) * -1)
|
140 |
return newgrad
|
141 |
+
|
142 |
def _get_next_inputs(self, transformed_img):
|
143 |
+
processed_img = loop_post_process(transformed_img) # * self.attn_mask
|
144 |
processed_img.retain_grad()
|
145 |
+
|
146 |
lpips_input = processed_img.clone()
|
147 |
lpips_input.register_hook(self._attn_mask_inverse)
|
148 |
lpips_input.retain_grad()
|
149 |
+
|
150 |
clip_input = processed_img.clone()
|
151 |
clip_input.register_hook(self._attn_mask)
|
152 |
clip_input.retain_grad()
|
153 |
+
|
154 |
+
return (processed_img, lpips_input, clip_input)
|
155 |
+
|
156 |
+
def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts):
|
157 |
+
optim.zero_grad()
|
158 |
+
transformed_img = self(vector)
|
159 |
+
processed_img, lpips_input, clip_input = self._get_next_inputs(
|
160 |
+
transformed_img
|
161 |
+
)
|
162 |
+
with torch.autocast("cuda"):
|
163 |
+
clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
|
164 |
+
print("CLIP loss", clip_loss)
|
165 |
+
perceptual_loss = (
|
166 |
+
self.perceptual_loss(lpips_input, original_img.clone())
|
167 |
+
* self.lpips_weight
|
168 |
+
)
|
169 |
+
print("LPIPS loss: ", perceptual_loss)
|
170 |
+
print("Sum Loss", perceptual_loss + clip_loss)
|
171 |
+
if log:
|
172 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
173 |
+
wandb.log({"CLIP Loss": clip_loss})
|
174 |
+
|
175 |
+
# These gradients will be masked if attn_mask has been set
|
176 |
+
clip_loss.backward(retain_graph=True)
|
177 |
+
perceptual_loss.backward(retain_graph=True)
|
178 |
+
|
179 |
+
optim.step()
|
180 |
+
yield vector
|
181 |
+
|
182 |
+
def _optimize_LPIPS(self, vector, original_img, optim):
|
183 |
+
optim.zero_grad()
|
184 |
+
transformed_img = self(vector)
|
185 |
+
processed_img = loop_post_process(transformed_img) # * self.attn_mask
|
186 |
+
processed_img.retain_grad()
|
187 |
+
|
188 |
+
lpips_input = processed_img.clone()
|
189 |
+
lpips_input.register_hook(self._attn_mask_inverse)
|
190 |
+
lpips_input.retain_grad()
|
191 |
+
with torch.autocast("cuda"):
|
192 |
+
perceptual_loss = (
|
193 |
+
self.perceptual_loss(lpips_input, original_img.clone())
|
194 |
+
* self.lpips_weight
|
195 |
+
)
|
196 |
+
if log:
|
197 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
198 |
+
print("LPIPS loss: ", perceptual_loss)
|
199 |
+
perceptual_loss.backward(retain_graph=True)
|
200 |
+
optim.step()
|
201 |
+
yield vector
|
202 |
|
203 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
204 |
self.set_latent(latent)
|
205 |
+
transformed_img = self(
|
206 |
+
torch.zeros_like(self.latent, requires_grad=True, device=self.device)
|
207 |
+
)
|
208 |
original_img = loop_post_process(transformed_img)
|
209 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
210 |
optim = torch.optim.Adam([vector], lr=self.lr)
|
211 |
+
|
|
|
|
|
212 |
for i in tqdm(range(self.iterations)):
|
213 |
+
yield self._optimize_CLIP_LPIPS(optim, original_img, vector, pos_prompts, neg_prompts)
|
214 |
+
|
215 |
+
print("Running LPIPS optim only")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
for i in range(self.reconstruction_steps):
|
217 |
+
yield self._optimize_LPIPS(vector, original_img, transformed_img, optim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
yield vector if self.return_val == "vector" else self.latent + vector
|
edit.py
CHANGED
@@ -12,7 +12,7 @@ import PIL
|
|
12 |
import taming
|
13 |
import torch
|
14 |
|
15 |
-
from loaders import load_config
|
16 |
from utils import get_device
|
17 |
|
18 |
|
@@ -25,11 +25,14 @@ def get_embedding(model, path=None, img=None, device="cpu"):
|
|
25 |
z, _, [_, _, indices] = model.encode(x_processed)
|
26 |
return z
|
27 |
|
28 |
-
|
29 |
-
def blend_paths(
|
|
|
|
|
30 |
x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
|
31 |
y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
|
32 |
-
x_latent
|
|
|
33 |
z = torch.lerp(x_latent, y_latent, weight)
|
34 |
if quantize:
|
35 |
z = model.quantize(z)[0]
|
@@ -45,14 +48,16 @@ def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, devi
|
|
45 |
plt.show()
|
46 |
return custom_to_pil(decoded), z
|
47 |
|
|
|
48 |
if __name__ == "__main__":
|
49 |
device = get_device()
|
50 |
-
|
51 |
-
conf_path = "./unwrapped.yaml"
|
52 |
-
config = load_config(conf_path, display=False)
|
53 |
-
model = taming.models.vqgan.VQModel(**config.model.params)
|
54 |
-
sd = torch.load("./vqgan_only.pt", map_location="mps")
|
55 |
-
model.load_state_dict(sd, strict=True)
|
56 |
model.to(device)
|
57 |
-
blend_paths(
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import taming
|
13 |
import torch
|
14 |
|
15 |
+
from loaders import load_config, load_default
|
16 |
from utils import get_device
|
17 |
|
18 |
|
|
|
25 |
z, _, [_, _, indices] = model.encode(x_processed)
|
26 |
return z
|
27 |
|
28 |
+
|
29 |
+
def blend_paths(
|
30 |
+
model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"
|
31 |
+
):
|
32 |
x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
|
33 |
y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
|
34 |
+
x_latent = get_embedding(model, path=path1, device=device)
|
35 |
+
y_latent = get_embedding(model, path=path2, device=device)
|
36 |
z = torch.lerp(x_latent, y_latent, weight)
|
37 |
if quantize:
|
38 |
z = model.quantize(z)[0]
|
|
|
48 |
plt.show()
|
49 |
return custom_to_pil(decoded), z
|
50 |
|
51 |
+
|
52 |
if __name__ == "__main__":
|
53 |
device = get_device()
|
54 |
+
model = load_default(device)
|
|
|
|
|
|
|
|
|
|
|
55 |
model.to(device)
|
56 |
+
blend_paths(
|
57 |
+
model,
|
58 |
+
"./test_data/face.jpeg",
|
59 |
+
"./test_data/face2.jpeg",
|
60 |
+
quantize=False,
|
61 |
+
weight=0.5,
|
62 |
+
)
|
63 |
+
plt.show()
|
img_processing.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
import io
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
|
5 |
import numpy as np
|
6 |
import PIL
|
7 |
import requests
|
8 |
import torch
|
9 |
-
import torch.nn.functional as F
|
10 |
import torchvision.transforms as T
|
11 |
import torchvision.transforms.functional as TF
|
12 |
from PIL import Image, ImageDraw, ImageFont
|
@@ -20,10 +17,10 @@ def download_image(url):
|
|
20 |
|
21 |
def preprocess(img, target_image_size=256, map_dalle=False):
|
22 |
s = min(img.size)
|
23 |
-
|
24 |
if s < target_image_size:
|
25 |
-
raise ValueError(f
|
26 |
-
|
27 |
r = target_image_size / s
|
28 |
s = (round(r * img.size[1]), round(r * img.size[0]))
|
29 |
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
|
@@ -31,42 +28,49 @@ def preprocess(img, target_image_size=256, map_dalle=False):
|
|
31 |
img = torch.unsqueeze(T.ToTensor()(img), 0)
|
32 |
return img
|
33 |
|
|
|
34 |
def preprocess_vqgan(x):
|
35 |
-
|
36 |
-
|
|
|
37 |
|
38 |
def custom_to_pil(x, process=True, mode="RGB"):
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
def get_pil(x):
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
def loop_post_process(x):
|
58 |
-
|
59 |
-
|
|
|
60 |
|
61 |
def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
1 |
import io
|
|
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import PIL
|
5 |
import requests
|
6 |
import torch
|
|
|
7 |
import torchvision.transforms as T
|
8 |
import torchvision.transforms.functional as TF
|
9 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
17 |
|
18 |
def preprocess(img, target_image_size=256, map_dalle=False):
|
19 |
s = min(img.size)
|
20 |
+
|
21 |
if s < target_image_size:
|
22 |
+
raise ValueError(f"min dim for image {s} < {target_image_size}")
|
23 |
+
|
24 |
r = target_image_size / s
|
25 |
s = (round(r * img.size[1]), round(r * img.size[0]))
|
26 |
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
|
|
|
28 |
img = torch.unsqueeze(T.ToTensor()(img), 0)
|
29 |
return img
|
30 |
|
31 |
+
|
32 |
def preprocess_vqgan(x):
|
33 |
+
x = 2.0 * x - 1.0
|
34 |
+
return x
|
35 |
+
|
36 |
|
37 |
def custom_to_pil(x, process=True, mode="RGB"):
|
38 |
+
x = x.detach().cpu()
|
39 |
+
if process:
|
40 |
+
x = torch.clamp(x, -1.0, 1.0)
|
41 |
+
x = (x + 1.0) / 2.0
|
42 |
+
x = x.permute(1, 2, 0).numpy()
|
43 |
+
if process:
|
44 |
+
x = (255 * x).astype(np.uint8)
|
45 |
+
x = Image.fromarray(x)
|
46 |
+
if not x.mode == mode:
|
47 |
+
x = x.convert(mode)
|
48 |
+
return x
|
49 |
+
|
50 |
|
51 |
def get_pil(x):
|
52 |
+
x = torch.clamp(x, -1.0, 1.0)
|
53 |
+
x = (x + 1.0) / 2.0
|
54 |
+
x = x.permute(1, 2, 0)
|
55 |
+
return x
|
56 |
+
|
57 |
|
58 |
def loop_post_process(x):
|
59 |
+
x = get_pil(x.squeeze())
|
60 |
+
return x.permute(2, 0, 1).unsqueeze(0)
|
61 |
+
|
62 |
|
63 |
def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
|
64 |
+
assert input.size == x1.size == x2.size == x3.size
|
65 |
+
w, h = input.size[0], input.size[1]
|
66 |
+
img = Image.new("RGB", (5 * w, h))
|
67 |
+
img.paste(input, (0, 0))
|
68 |
+
img.paste(x0, (1 * w, 0))
|
69 |
+
img.paste(x1, (2 * w, 0))
|
70 |
+
img.paste(x2, (3 * w, 0))
|
71 |
+
img.paste(x3, (4 * w, 0))
|
72 |
+
for i, title in enumerate(titles):
|
73 |
+
ImageDraw.Draw(img).text(
|
74 |
+
(i * w, 0), f"{title}", (255, 255, 255), font=font
|
75 |
+
) # coordinates, text, color, font
|
76 |
+
return img
|
loaders.py
CHANGED
@@ -10,17 +10,17 @@ from utils import get_device
|
|
10 |
|
11 |
|
12 |
def load_config(config_path, display=False):
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
17 |
|
18 |
def load_default(device):
|
19 |
-
|
20 |
-
conf_path = "./unwrapped.yaml"
|
21 |
config = load_config(conf_path, display=False)
|
22 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
23 |
-
sd = torch.load("./
|
24 |
model.load_state_dict(sd, strict=True)
|
25 |
model.to(device)
|
26 |
del sd
|
@@ -34,17 +34,14 @@ def load_vqgan(config, ckpt_path=None, is_gumbel=False):
|
|
34 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
35 |
return model.eval()
|
36 |
|
37 |
-
def load_ffhq():
|
38 |
-
conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
|
39 |
-
ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt"
|
40 |
-
vqgan = load_model(load_config(conf), ckpt, True, True)[0]
|
41 |
|
42 |
def reconstruct_with_vqgan(x, model):
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
48 |
def get_obj_from_str(string, reload=False):
|
49 |
module, cls = string.rsplit(".", 1)
|
50 |
if reload:
|
@@ -52,12 +49,13 @@ def get_obj_from_str(string, reload=False):
|
|
52 |
importlib.reload(module_imp)
|
53 |
return getattr(importlib.import_module(module, package=None), cls)
|
54 |
|
55 |
-
def instantiate_from_config(config):
|
56 |
|
57 |
-
|
|
|
58 |
raise KeyError("Expected key `target` to instantiate.")
|
59 |
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
60 |
|
|
|
61 |
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
62 |
model = instantiate_from_config(config)
|
63 |
if sd is not None:
|
@@ -78,5 +76,7 @@ def load_model(config, ckpt, gpu, eval_mode):
|
|
78 |
else:
|
79 |
pl_sd = {"state_dict": None}
|
80 |
global_step = None
|
81 |
-
model = load_model_from_config(
|
82 |
-
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def load_config(config_path, display=False):
|
13 |
+
config = OmegaConf.load(config_path)
|
14 |
+
if display:
|
15 |
+
print(yaml.dump(OmegaConf.to_container(config)))
|
16 |
+
return config
|
17 |
+
|
18 |
|
19 |
def load_default(device):
|
20 |
+
conf_path = "./celeba_vqgan/unwrapped.yaml"
|
|
|
21 |
config = load_config(conf_path, display=False)
|
22 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
23 |
+
sd = torch.load("./celeba_vqgan/vqgan_only.pt", map_location=device)
|
24 |
model.load_state_dict(sd, strict=True)
|
25 |
model.to(device)
|
26 |
del sd
|
|
|
34 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
35 |
return model.eval()
|
36 |
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def reconstruct_with_vqgan(x, model):
|
39 |
+
z, _, [_, _, indices] = model.encode(x)
|
40 |
+
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
|
41 |
+
xrec = model.decode(z)
|
42 |
+
return xrec
|
43 |
+
|
44 |
+
|
45 |
def get_obj_from_str(string, reload=False):
|
46 |
module, cls = string.rsplit(".", 1)
|
47 |
if reload:
|
|
|
49 |
importlib.reload(module_imp)
|
50 |
return getattr(importlib.import_module(module, package=None), cls)
|
51 |
|
|
|
52 |
|
53 |
+
def instantiate_from_config(config):
|
54 |
+
if "target" not in config:
|
55 |
raise KeyError("Expected key `target` to instantiate.")
|
56 |
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
57 |
|
58 |
+
|
59 |
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
60 |
model = instantiate_from_config(config)
|
61 |
if sd is not None:
|
|
|
76 |
else:
|
77 |
pl_sd = {"state_dict": None}
|
78 |
global_step = None
|
79 |
+
model = load_model_from_config(
|
80 |
+
config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode
|
81 |
+
)["model"]
|
82 |
+
return model, global_step
|
masking.py
CHANGED
@@ -3,30 +3,28 @@ import sys
|
|
3 |
|
4 |
import matplotlib.pyplot as plt
|
5 |
import torch
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
from backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
|
17 |
-
from loaders import load_default
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
plt.imshow(x)
|
30 |
-
plt.show()
|
31 |
-
state.apply_prompts("a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask)
|
32 |
-
print('done')
|
|
|
3 |
|
4 |
import matplotlib.pyplot as plt
|
5 |
import torch
|
6 |
+
from backend import ImagePromptEditor, ImageState, ProcessorGradientFlow
|
7 |
+
from loaders import load_default
|
8 |
+
from transformers import CLIPModel
|
9 |
|
10 |
+
if __name__ == "__main__":
|
11 |
+
sys.path.append("taming-transformers")
|
12 |
+
device = "cuda"
|
13 |
|
14 |
+
vqgan = load_default(device)
|
15 |
+
vqgan.eval()
|
16 |
|
17 |
+
processor = ProcessorGradientFlow(device=device)
|
18 |
+
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
19 |
+
clip.to(device)
|
|
|
|
|
20 |
|
21 |
+
promptoptim = ImagePromptEditor(vqgan, clip, processor, quantize=True)
|
22 |
+
state = ImageState(vqgan, promptoptim)
|
23 |
+
mask = torch.load("eyebrow_mask.pt")
|
24 |
+
x = state.blend("./test_data/face.jpeg", "./test_data/face2.jpeg", 0.5)
|
25 |
+
plt.imshow(x)
|
26 |
+
plt.show()
|
27 |
+
state.apply_prompts(
|
28 |
+
"a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask
|
29 |
+
)
|
30 |
+
print("done")
|
|
|
|
|
|
|
|
presets.py
CHANGED
@@ -1,16 +1,42 @@
|
|
1 |
import gradio as gr
|
2 |
|
|
|
3 |
def set_preset(config_str):
|
4 |
-
choices=[
|
|
|
|
|
|
|
|
|
5 |
if config_str == choices[0]:
|
6 |
return set_small_local()
|
7 |
elif config_str == choices[1]:
|
8 |
return set_major_local()
|
9 |
elif config_str == choices[2]:
|
10 |
return set_major_global()
|
|
|
|
|
11 |
def set_small_local():
|
12 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def set_major_local():
|
14 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def set_major_global():
|
16 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
|
4 |
def set_preset(config_str):
|
5 |
+
choices = [
|
6 |
+
"Small Masked Changes (e.g. add lipstick)",
|
7 |
+
"Major Masked Changes (e.g. change hair color or nose size)",
|
8 |
+
"Major Global Changes (e.g. change race / gender",
|
9 |
+
]
|
10 |
if config_str == choices[0]:
|
11 |
return set_small_local()
|
12 |
elif config_str == choices[1]:
|
13 |
return set_major_local()
|
14 |
elif config_str == choices[2]:
|
15 |
return set_major_global()
|
16 |
+
|
17 |
+
|
18 |
def set_small_local():
|
19 |
+
return (
|
20 |
+
gr.Slider.update(value=25),
|
21 |
+
gr.Slider.update(value=0.15),
|
22 |
+
gr.Slider.update(value=1),
|
23 |
+
gr.Slider.update(value=4),
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
def set_major_local():
|
28 |
+
return (
|
29 |
+
gr.Slider.update(value=25),
|
30 |
+
gr.Slider.update(value=0.25),
|
31 |
+
gr.Slider.update(value=35),
|
32 |
+
gr.Slider.update(value=10),
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
def set_major_global():
|
37 |
+
return (
|
38 |
+
gr.Slider.update(value=30),
|
39 |
+
gr.Slider.update(value=0.1),
|
40 |
+
gr.Slider.update(value=2),
|
41 |
+
gr.Slider.update(value=0.2),
|
42 |
+
)
|
prompts.py
CHANGED
@@ -1,17 +1,41 @@
|
|
1 |
import random
|
|
|
|
|
2 |
class PromptSet:
|
3 |
def __init__(self, pos, neg, config=None):
|
4 |
self.positive = pos
|
5 |
self.negative = neg
|
6 |
self.config = config
|
|
|
|
|
7 |
example_prompts = (
|
8 |
-
PromptSet(
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
PromptSet(
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
)
|
|
|
|
|
15 |
def get_random_prompts():
|
16 |
prompt = random.choice(example_prompts)
|
17 |
-
return prompt.positive, prompt.negative
|
|
|
1 |
import random
|
2 |
+
|
3 |
+
|
4 |
class PromptSet:
|
5 |
def __init__(self, pos, neg, config=None):
|
6 |
self.positive = pos
|
7 |
self.negative = neg
|
8 |
self.config = config
|
9 |
+
|
10 |
+
|
11 |
example_prompts = (
|
12 |
+
PromptSet(
|
13 |
+
"a picture of a woman with light blonde hair",
|
14 |
+
"a picture of a person with dark hair | a picture of a person with brown hair",
|
15 |
+
),
|
16 |
+
PromptSet(
|
17 |
+
"A picture of a woman with very thick eyebrows",
|
18 |
+
"a picture of a person with very thin eyebrows | a picture of a person with no eyebrows",
|
19 |
+
),
|
20 |
+
PromptSet(
|
21 |
+
"A picture of a woman wearing bright red lipstick",
|
22 |
+
"a picture of a person wearing no lipstick | a picture of a person wearing dark lipstick",
|
23 |
+
),
|
24 |
+
PromptSet(
|
25 |
+
"A picture of a beautiful chinese woman | a picture of a Japanese woman | a picture of an Asian woman",
|
26 |
+
"a picture of a white woman | a picture of an Indian woman | a picture of a black woman",
|
27 |
+
),
|
28 |
+
PromptSet(
|
29 |
+
"A picture of a handsome man | a picture of a masculine man",
|
30 |
+
"a picture of a woman | a picture of a feminine person",
|
31 |
+
),
|
32 |
+
PromptSet(
|
33 |
+
"A picture of a woman with a very big nose",
|
34 |
+
"a picture of a person with a small nose | a picture of a person with a normal nose",
|
35 |
+
),
|
36 |
)
|
37 |
+
|
38 |
+
|
39 |
def get_random_prompts():
|
40 |
prompt = random.choice(example_prompts)
|
41 |
+
return prompt.positive, prompt.negative
|
unwrapped.yaml
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
target: taming.models.vqgan.VQModel
|
3 |
-
params:
|
4 |
-
embed_dim: 256
|
5 |
-
n_embed: 1024
|
6 |
-
ddconfig:
|
7 |
-
double_z: false
|
8 |
-
z_channels: 256
|
9 |
-
resolution: 256
|
10 |
-
in_channels: 3
|
11 |
-
out_ch: 3
|
12 |
-
ch: 128
|
13 |
-
ch_mult:
|
14 |
-
- 1
|
15 |
-
- 1
|
16 |
-
- 2
|
17 |
-
- 2
|
18 |
-
- 4
|
19 |
-
num_res_blocks: 2
|
20 |
-
attn_resolutions:
|
21 |
-
- 16
|
22 |
-
dropout: 0.0
|
23 |
-
lossconfig:
|
24 |
-
target: taming.modules.losses.vqperceptual.DummyLoss
|
25 |
-
data:
|
26 |
-
target: cutlit.DataModuleFromConfig
|
27 |
-
params:
|
28 |
-
batch_size: 24
|
29 |
-
num_workers: 24
|
30 |
-
train:
|
31 |
-
target: taming.data.faceshq.CelebAHQTrain
|
32 |
-
params:
|
33 |
-
size: 256
|
34 |
-
validation:
|
35 |
-
target: taming.data.faceshq.CelebAHQValidation
|
36 |
-
params:
|
37 |
-
size: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -7,9 +7,11 @@ import torch.nn.functional as F
|
|
7 |
from skimage.color import lab2rgb, rgb2lab
|
8 |
from torch import nn
|
9 |
|
|
|
10 |
def freeze_module(module):
|
11 |
for param in module.parameters():
|
12 |
-
|
|
|
13 |
|
14 |
def get_device():
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
7 |
from skimage.color import lab2rgb, rgb2lab
|
8 |
from torch import nn
|
9 |
|
10 |
+
|
11 |
def freeze_module(module):
|
12 |
for param in module.parameters():
|
13 |
+
param.requires_grad = False
|
14 |
+
|
15 |
|
16 |
def get_device():
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|