erwann commited on
Commit
d7fcb4c
1 Parent(s): 3b79bb4

reduce memory usage

Browse files
Files changed (4) hide show
  1. ImageState.py +7 -3
  2. app.py +18 -13
  3. app_backend.py +0 -230
  4. masking.py +1 -1
ImageState.py CHANGED
@@ -1,6 +1,6 @@
1
  # from align import align_from_path
2
  from animation import clear_img_dir
3
- from app_backend import ImagePromptOptimizer, log
4
  import importlib
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
@@ -13,12 +13,13 @@ from torchvision.transforms.functional import resize
13
  from tqdm import tqdm
14
  from transformers import CLIPModel, CLIPProcessor
15
  import lpips
16
- from app_backend import get_resized_tensor
17
  from edit import blend_paths
18
  from img_processing import *
19
  from img_processing import custom_to_pil
20
  from loaders import load_default
21
-
 
22
  num = 0
23
  class PromptTransformHistory():
24
  def __init__(self, iterations) -> None:
@@ -27,6 +28,7 @@ class PromptTransformHistory():
27
 
28
  class ImageState:
29
  def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
 
30
  self.vqgan = vqgan
31
  self.device = vqgan.device
32
  self.blend_latent = None
@@ -59,6 +61,7 @@ class ImageState:
59
  new_latent = torch.lerp(src, src + vector, 1)
60
  return new_latent
61
  def _decode_latent_to_pil(self, latent):
 
62
  current_im = self.vqgan.decode(latent.to(self.device))[0]
63
  return custom_to_pil(current_im)
64
  # def _get_current_vector_transforms(self):
@@ -95,6 +98,7 @@ class ImageState:
95
  @torch.no_grad()
96
  def _render_all_transformations(self, return_twice=True):
97
  global num
 
98
  current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
99
  new_latent = self.blend_latent + sum(current_vector_transforms)
100
  if self.quant:
 
1
  # from align import align_from_path
2
  from animation import clear_img_dir
3
+ from backend import ImagePromptOptimizer, log
4
  import importlib
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
 
13
  from tqdm import tqdm
14
  from transformers import CLIPModel, CLIPProcessor
15
  import lpips
16
+ from backend import get_resized_tensor
17
  from edit import blend_paths
18
  from img_processing import *
19
  from img_processing import custom_to_pil
20
  from loaders import load_default
21
+ # from app import vqgan
22
+ global vqgan
23
  num = 0
24
  class PromptTransformHistory():
25
  def __init__(self, iterations) -> None:
 
28
 
29
  class ImageState:
30
  def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None:
31
+ # global vqgan
32
  self.vqgan = vqgan
33
  self.device = vqgan.device
34
  self.blend_latent = None
 
61
  new_latent = torch.lerp(src, src + vector, 1)
62
  return new_latent
63
  def _decode_latent_to_pil(self, latent):
64
+ # global vqgan
65
  current_im = self.vqgan.decode(latent.to(self.device))[0]
66
  return custom_to_pil(current_im)
67
  # def _get_current_vector_transforms(self):
 
98
  @torch.no_grad()
99
  def _render_all_transformations(self, return_twice=True):
100
  global num
101
+ # global vqgan
102
  current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
103
  new_latent = self.blend_latent + sum(current_vector_transforms)
104
  if self.quant:
app.py CHANGED
@@ -3,29 +3,33 @@ import os
3
  import sys
4
 
5
  import wandb
 
6
 
7
  from configs import set_major_global, set_major_local, set_small_local
8
 
9
  sys.path.append("taming-transformers")
10
- import functools
11
 
12
  import gradio as gr
13
  from transformers import CLIPModel, CLIPProcessor
 
14
 
15
  import edit
16
- from app_backend import ImagePromptOptimizer, ProcessorGradientFlow
17
  from ImageState import ImageState
18
  from loaders import load_default
19
  from animation import create_gif
20
  from prompts import get_random_prompts
21
 
22
- device = "cpu"
 
 
23
  vqgan = load_default(device)
24
  vqgan.eval()
25
  processor = ProcessorGradientFlow(device=device)
26
- clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
27
- clip.to(device)
28
- promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
 
29
  def set_img_from_example(state, img):
30
  return state.update_images(img, img, 0)
31
  def get_cleared_mask():
@@ -40,7 +44,8 @@ class StateWrapper:
40
  def apply_lip_vector(state, *args, **kwargs):
41
  return state, *state[0].apply_lip_vector(*args, **kwargs)
42
  def apply_prompts(state, *args, **kwargs):
43
- return state, *state[0].apply_prompts(*args, **kwargs)
 
44
  def apply_rb_vector(state, *args, **kwargs):
45
  return state, *state[0].apply_rb_vector(*args, **kwargs)
46
  def blend(state, *args, **kwargs):
@@ -56,7 +61,7 @@ class StateWrapper:
56
  def rewind(state, *args, **kwargs):
57
  return state, *state[0].rewind(*args, **kwargs)
58
  def set_mask(state, *args, **kwargs):
59
- return state, *state[0].set_mask(*args, **kwargs)
60
  def update_images(state, *args, **kwargs):
61
  return state, *state[0].update_images(*args, **kwargs)
62
  def update_requant(state, *args, **kwargs):
@@ -191,7 +196,7 @@ with gr.Blocks(css="styles.css") as demo:
191
  # step=1,
192
  # value=0,
193
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
194
- # clear.click(state.clear_transforms, inputs=[state], outputs=[state, out, mask])
195
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
196
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
197
  # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
@@ -200,11 +205,11 @@ with gr.Blocks(css="styles.css") as demo:
200
  # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
201
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
202
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
203
- small_local.click(set_small_local, outputs=[state, iterations, learning_rate, lpips_weight, reconstruction_steps])
204
- major_local.click(set_major_local, outputs=[state, iterations, learning_rate, lpips_weight, reconstruction_steps])
205
- major_global.click(set_major_global, outputs=[state, iterations, learning_rate, lpips_weight, reconstruction_steps])
206
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
207
  rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
208
- set_mask.click(StateWrapper.set_mask, inputs=mask, outputs=testim)
209
  demo.queue()
210
  demo.launch(debug=True, enable_queue=True)
 
3
  import sys
4
 
5
  import wandb
6
+ import torch
7
 
8
  from configs import set_major_global, set_major_local, set_small_local
9
 
10
  sys.path.append("taming-transformers")
 
11
 
12
  import gradio as gr
13
  from transformers import CLIPModel, CLIPProcessor
14
+ from lpips import LPIPS
15
 
16
  import edit
17
+ from backend import ImagePromptOptimizer, ProcessorGradientFlow
18
  from ImageState import ImageState
19
  from loaders import load_default
20
  from animation import create_gif
21
  from prompts import get_random_prompts
22
 
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ global vqgan
26
  vqgan = load_default(device)
27
  vqgan.eval()
28
  processor = ProcessorGradientFlow(device=device)
29
+ # clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
30
+ lpips_fn = LPIPS(net='vgg').to(device)
31
+ clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
32
+ promptoptim = ImagePromptOptimizer(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
33
  def set_img_from_example(state, img):
34
  return state.update_images(img, img, 0)
35
  def get_cleared_mask():
 
44
  def apply_lip_vector(state, *args, **kwargs):
45
  return state, *state[0].apply_lip_vector(*args, **kwargs)
46
  def apply_prompts(state, *args, **kwargs):
47
+ for image in state[0].apply_prompts(*args, **kwargs):
48
+ yield state, *image
49
  def apply_rb_vector(state, *args, **kwargs):
50
  return state, *state[0].apply_rb_vector(*args, **kwargs)
51
  def blend(state, *args, **kwargs):
 
61
  def rewind(state, *args, **kwargs):
62
  return state, *state[0].rewind(*args, **kwargs)
63
  def set_mask(state, *args, **kwargs):
64
+ return state, state[0].set_mask(*args, **kwargs)
65
  def update_images(state, *args, **kwargs):
66
  return state, *state[0].update_images(*args, **kwargs)
67
  def update_requant(state, *args, **kwargs):
 
196
  # step=1,
197
  # value=0,
198
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
199
+ clear.click(state.clear_transforms, inputs=[state], outputs=[state, out, mask])
200
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
201
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
202
  # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
 
205
  # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
206
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
207
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
208
+ small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
209
+ major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
210
+ major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
211
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
212
  rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
213
+ set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
214
  demo.queue()
215
  demo.launch(debug=True, enable_queue=True)
app_backend.py DELETED
@@ -1,230 +0,0 @@
1
- # from functools import cache
2
- import importlib
3
-
4
- import gradio as gr
5
- import matplotlib.pyplot as plt
6
- import torch
7
- import torchvision
8
- import wandb
9
- from icecream import ic
10
- from torch import nn
11
- from torchvision.transforms.functional import resize
12
- from tqdm import tqdm
13
- from transformers import CLIPModel, CLIPProcessor
14
- import lpips
15
- from edit import blend_paths
16
- from img_processing import *
17
- from img_processing import custom_to_pil
18
- from loaders import load_default
19
- import glob
20
- # global log
21
- log=False
22
-
23
- # ic.disable()
24
- # ic.enable()
25
- def get_resized_tensor(x):
26
- if len(x.shape) == 2:
27
- re = x.unsqueeze(0)
28
- else: re = x
29
- re = resize(re, (10, 10))
30
- return re
31
- class ProcessorGradientFlow():
32
- """
33
- This wraps the huggingface CLIP processor to allow backprop through the image processing step.
34
- The original processor forces conversion to PIL images, which breaks gradient flow.
35
- """
36
- def __init__(self, device="cuda") -> None:
37
- self.device = device
38
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
39
- self.image_mean = [0.48145466, 0.4578275, 0.40821073]
40
- self.image_std = [0.26862954, 0.26130258, 0.27577711]
41
- self.normalize = torchvision.transforms.Normalize(
42
- self.image_mean,
43
- self.image_std
44
- )
45
- self.resize = torchvision.transforms.Resize(224)
46
- self.center_crop = torchvision.transforms.CenterCrop(224)
47
- def preprocess_img(self, images):
48
- images = self.center_crop(images)
49
- images = self.resize(images)
50
- images = self.center_crop(images)
51
- images = self.normalize(images)
52
- return images
53
- def __call__(self, images=[], **kwargs):
54
- processed_inputs = self.processor(**kwargs)
55
- processed_inputs["pixel_values"] = self.preprocess_img(images)
56
- processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()}
57
- return processed_inputs
58
-
59
- class ImagePromptOptimizer(nn.Module):
60
- def __init__(self,
61
- vqgan,
62
- clip,
63
- clip_preprocessor,
64
- iterations=100,
65
- lr = 0.01,
66
- save_vector=True,
67
- return_val="vector",
68
- quantize=True,
69
- make_grid=False,
70
- lpips_weight = 6.2) -> None:
71
-
72
- super().__init__()
73
- self.latent = None
74
- self.device = vqgan.device
75
- vqgan.eval()
76
- self.vqgan = vqgan
77
- self.clip = clip
78
- self.iterations = iterations
79
- self.lr = lr
80
- self.clip_preprocessor = clip_preprocessor
81
- self.make_grid = make_grid
82
- self.return_val = return_val
83
- self.quantize = quantize
84
- self.lpips_weight = lpips_weight
85
- self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
86
- def disc_loss_fn(self, logits):
87
- return -torch.mean(logits)
88
- def set_latent(self, latent):
89
- self.latent = latent.detach().to(self.device)
90
- def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
91
- self.attn_mask = attn_mask
92
- self.iterations = iterations
93
- self.lr = lr
94
- self.lpips_weight = lpips_weight
95
- self.reconstruction_steps = reconstruction_steps
96
- def forward(self, vector):
97
- base_latent = self.latent.detach().requires_grad_()
98
- trans_latent = base_latent + vector
99
- if self.quantize:
100
- z_q, *_ = self.vqgan.quantize(trans_latent)
101
- else:
102
- z_q = trans_latent
103
- dec = self.vqgan.decode(z_q)
104
- return dec
105
- def _get_clip_similarity(self, prompts, image, weights=None):
106
- if isinstance(prompts, str):
107
- prompts = [prompts]
108
- elif not isinstance(prompts, list):
109
- raise TypeError("Provide prompts as string or list of strings")
110
- clip_inputs = self.clip_preprocessor(text=prompts,
111
- images=image, return_tensors="pt", padding=True)
112
- clip_outputs = self.clip(**clip_inputs)
113
- similarity_logits = clip_outputs.logits_per_image
114
- if weights:
115
- similarity_logits *= weights
116
- return similarity_logits.sum()
117
- def get_similarity_loss(self, pos_prompts, neg_prompts, image):
118
- pos_logits = self._get_clip_similarity(pos_prompts, image)
119
- if neg_prompts:
120
- neg_logits = self._get_clip_similarity(neg_prompts, image)
121
- else:
122
- neg_logits = torch.tensor([1], device=self.device)
123
- loss = -torch.log(pos_logits) + torch.log(neg_logits)
124
- return loss
125
- def visualize(self, processed_img):
126
- if self.make_grid:
127
- self.index += 1
128
- plt.subplot(1, 13, self.index)
129
- plt.imshow(get_pil(processed_img[0]).detach().cpu())
130
- else:
131
- plt.imshow(get_pil(processed_img[0]).detach().cpu())
132
- plt.show()
133
- def attn_masking(self, grad):
134
- # print("attnmask 1")
135
- # print(f"input grad.shape = {grad.shape}")
136
- # print(f"input grad = {get_resized_tensor(grad)}")
137
- newgrad = grad
138
- if self.attn_mask is not None:
139
- # print("masking mult")
140
- newgrad = grad * (self.attn_mask)
141
- # print("output grad, ", get_resized_tensor(newgrad))
142
- # print("end atn 1")
143
- return newgrad
144
- def attn_masking2(self, grad):
145
- # print("attnmask 2")
146
- # print(f"input grad.shape = {grad.shape}")
147
- # print(f"input grad = {get_resized_tensor(grad)}")
148
- newgrad = grad
149
- if self.attn_mask is not None:
150
- # print("masking mult")
151
- newgrad = grad * ((self.attn_mask - 1) * -1)
152
- # print("output grad, ", get_resized_tensor(newgrad))
153
- # print("end atn 2")
154
- return newgrad
155
-
156
- def optimize(self, latent, pos_prompts, neg_prompts):
157
- self.set_latent(latent)
158
- # self.make_grid=True
159
- transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device))
160
- original_img = loop_post_process(transformed_img)
161
- vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
162
- optim = torch.optim.Adam([vector], lr=self.lr)
163
- if self.make_grid:
164
- plt.figure(figsize=(35, 25))
165
- self.index = 1
166
- for i in tqdm(range(self.iterations)):
167
- optim.zero_grad()
168
- transformed_img = self(vector)
169
- processed_img = loop_post_process(transformed_img) #* self.attn_mask
170
- processed_img.retain_grad()
171
- lpips_input = processed_img.clone()
172
- lpips_input.register_hook(self.attn_masking2)
173
- lpips_input.retain_grad()
174
- clip_clone = processed_img.clone()
175
- clip_clone.register_hook(self.attn_masking)
176
- clip_clone.retain_grad()
177
- with torch.autocast("cuda"):
178
- clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
179
- print("CLIP loss", clip_loss)
180
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
181
- print("LPIPS loss: ", perceptual_loss)
182
- if log:
183
- wandb.log({"Perceptual Loss": perceptual_loss})
184
- wandb.log({"CLIP Loss": clip_loss})
185
- clip_loss.backward(retain_graph=True)
186
- perceptual_loss.backward(retain_graph=True)
187
- p2 = processed_img.grad
188
- print("Sum Loss", perceptual_loss + clip_loss)
189
- optim.step()
190
- # if i % self.iterations // 10 == 0:
191
- # self.visualize(transformed_img)
192
- yield vector
193
- if self.make_grid:
194
- plt.savefig(f"plot {pos_prompts[0]}.png")
195
- plt.show()
196
- print("lpips solo op")
197
- for i in range(self.reconstruction_steps):
198
- optim.zero_grad()
199
- transformed_img = self(vector)
200
- processed_img = loop_post_process(transformed_img) #* self.attn_mask
201
- processed_img.retain_grad()
202
- lpips_input = processed_img.clone()
203
- lpips_input.register_hook(self.attn_masking2)
204
- lpips_input.retain_grad()
205
- with torch.autocast("cuda"):
206
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
207
- if log:
208
- wandb.log({"Perceptual Loss": perceptual_loss})
209
- print("LPIPS loss: ", perceptual_loss)
210
- perceptual_loss.backward(retain_graph=True)
211
- optim.step()
212
- yield vector
213
- # torch.save(vector, "nose_vector.pt")
214
- # print("")
215
- # print("DISC STEPS")
216
- # print("*************")
217
- # for i in range(self.reconstruction_steps):
218
- # optim.zero_grad()
219
- # transformed_img = self(vector)
220
- # processed_img = loop_post_process(transformed_img) #* self.attn_mask
221
- # disc_logits = self.disc(transformed_img)
222
- # disc_loss = self.disc_loss_fn(disc_logits)
223
- # print(f"disc_loss = {disc_loss}")
224
- # if log:
225
- # wandb.log({"Disc Loss": disc_loss})
226
- # print("LPIPS loss: ", perceptual_loss)
227
- # disc_loss.backward(retain_graph=True)
228
- # optim.step()
229
- # yield vector
230
- yield vector if self.return_val == "vector" else self.latent + vector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
masking.py CHANGED
@@ -13,7 +13,7 @@ from transformers import CLIPModel, CLIPProcessor
13
  import edit
14
  # import importlib
15
  # importlib.reload(edit)
16
- from app_backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
17
  from loaders import load_default
18
 
19
  device = "cuda"
 
13
  import edit
14
  # import importlib
15
  # importlib.reload(edit)
16
+ from backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
17
  from loaders import load_default
18
 
19
  device = "cuda"