erwann commited on
Commit
29bbf75
1 Parent(s): 0be9cd5

refactoring optimization loop

Browse files
Files changed (3) hide show
  1. ImageState.py +4 -4
  2. animation.py +0 -4
  3. backend.py +49 -47
ImageState.py CHANGED
@@ -102,7 +102,7 @@ class ImageState:
102
  x = Image.fromarray(x, "L")
103
  return x
104
 
105
- @torch.inference_mode()
106
  def _render_all_transformations(self, return_twice=True):
107
  global num
108
  current_vector_transforms = (
@@ -150,7 +150,7 @@ class ImageState:
150
  clear_img_dir(self.img_dir)
151
  return self.blend(blend_weight)
152
 
153
- @torch.inference_mode()
154
  def blend(self, weight):
155
  _, latent = blend_paths(
156
  self.vqgan,
@@ -163,7 +163,7 @@ class ImageState:
163
  self.blend_latent = latent
164
  return self._render_all_transformations()
165
 
166
- @torch.inference_mode()
167
  def rewind(self, index):
168
  if not self.transform_history:
169
  print("No history")
@@ -221,7 +221,7 @@ class ImageState:
221
  ):
222
  transform_log.transforms.append(transform.detach().cpu())
223
  self.current_prompt_transforms[-1] = transform
224
- with torch.inference_mode():
225
  image = self._render_all_transformations(return_twice=False)
226
  if log:
227
  wandb.log({"image": wandb.Image(image)})
 
102
  x = Image.fromarray(x, "L")
103
  return x
104
 
105
+ @torch.no_grad()
106
  def _render_all_transformations(self, return_twice=True):
107
  global num
108
  current_vector_transforms = (
 
150
  clear_img_dir(self.img_dir)
151
  return self.blend(blend_weight)
152
 
153
+ @torch.no_grad()
154
  def blend(self, weight):
155
  _, latent = blend_paths(
156
  self.vqgan,
 
163
  self.blend_latent = latent
164
  return self._render_all_transformations()
165
 
166
+ @torch.no_grad()
167
  def rewind(self, index):
168
  if not self.transform_history:
169
  print("No history")
 
221
  ):
222
  transform_log.transforms.append(transform.detach().cpu())
223
  self.current_prompt_transforms[-1] = transform
224
+ with torch.no_grad():
225
  image = self._render_all_transformations(return_twice=False)
226
  if log:
227
  wandb.log({"image": wandb.Image(image)})
animation.py CHANGED
@@ -4,10 +4,6 @@ import os
4
 
5
 
6
  def clear_img_dir(img_dir):
7
- if not os.path.exists("img_history"):
8
- os.mkdir("img_history")
9
- if not os.path.exists(img_dir):
10
- os.mkdir(img_dir)
11
  for filename in glob.glob(img_dir + "/*"):
12
  os.remove(filename)
13
 
 
4
 
5
 
6
  def clear_img_dir(img_dir):
 
 
 
 
7
  for filename in glob.glob(img_dir + "/*"):
8
  os.remove(filename)
9
 
backend.py CHANGED
@@ -140,7 +140,7 @@ class ImagePromptEditor(nn.Module):
140
  return newgrad
141
 
142
  def _get_next_inputs(self, transformed_img):
143
- processed_img = loop_post_process(transformed_img) # * self.attn_mask
144
  processed_img.retain_grad()
145
 
146
  lpips_input = processed_img.clone()
@@ -154,51 +154,53 @@ class ImagePromptEditor(nn.Module):
154
  return (processed_img, lpips_input, clip_input)
155
 
156
  def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts):
157
- optim.zero_grad()
158
- transformed_img = self(vector)
159
- processed_img, lpips_input, clip_input = self._get_next_inputs(
160
- transformed_img
161
- )
162
- with torch.autocast("cuda"):
163
- clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
164
- print("CLIP loss", clip_loss)
165
- perceptual_loss = (
166
- self.perceptual_loss(lpips_input, original_img.clone())
167
- * self.lpips_weight
168
  )
169
- print("LPIPS loss: ", perceptual_loss)
170
- print("Sum Loss", perceptual_loss + clip_loss)
171
- if log:
172
- wandb.log({"Perceptual Loss": perceptual_loss})
173
- wandb.log({"CLIP Loss": clip_loss})
174
-
175
- # These gradients will be masked if attn_mask has been set
176
- clip_loss.backward(retain_graph=True)
177
- perceptual_loss.backward(retain_graph=True)
178
-
179
- optim.step()
180
- yield vector
 
 
 
 
 
 
 
181
 
182
  def _optimize_LPIPS(self, vector, original_img, optim):
183
- optim.zero_grad()
184
- transformed_img = self(vector)
185
- processed_img = loop_post_process(transformed_img) # * self.attn_mask
186
- processed_img.retain_grad()
187
-
188
- lpips_input = processed_img.clone()
189
- lpips_input.register_hook(self._attn_mask_inverse)
190
- lpips_input.retain_grad()
191
- with torch.autocast("cuda"):
192
- perceptual_loss = (
193
- self.perceptual_loss(lpips_input, original_img.clone())
194
- * self.lpips_weight
195
- )
196
- if log:
197
- wandb.log({"Perceptual Loss": perceptual_loss})
198
- print("LPIPS loss: ", perceptual_loss)
199
- perceptual_loss.backward(retain_graph=True)
200
- optim.step()
201
- yield vector
 
202
 
203
  def optimize(self, latent, pos_prompts, neg_prompts):
204
  self.set_latent(latent)
@@ -209,10 +211,10 @@ class ImagePromptEditor(nn.Module):
209
  vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
210
  optim = torch.optim.Adam([vector], lr=self.lr)
211
 
212
- for i in tqdm(range(self.iterations)):
213
- yield self._optimize_CLIP_LPIPS(optim, original_img, vector, pos_prompts, neg_prompts)
214
 
215
  print("Running LPIPS optim only")
216
- for i in range(self.reconstruction_steps):
217
- yield self._optimize_LPIPS(vector, original_img, transformed_img, optim)
218
  yield vector if self.return_val == "vector" else self.latent + vector
 
140
  return newgrad
141
 
142
  def _get_next_inputs(self, transformed_img):
143
+ processed_img = loop_post_process(transformed_img)
144
  processed_img.retain_grad()
145
 
146
  lpips_input = processed_img.clone()
 
154
  return (processed_img, lpips_input, clip_input)
155
 
156
  def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts):
157
+ for i in (range(self.iterations)):
158
+ optim.zero_grad()
159
+ transformed_img = self(vector)
160
+ processed_img, lpips_input, clip_input = self._get_next_inputs(
161
+ transformed_img
 
 
 
 
 
 
162
  )
163
+ with torch.autocast("cuda"):
164
+ clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
165
+ print("CLIP loss", clip_loss)
166
+ perceptual_loss = (
167
+ self.perceptual_loss(lpips_input, original_img.clone())
168
+ * self.lpips_weight
169
+ )
170
+ print("LPIPS loss: ", perceptual_loss)
171
+ print("Sum Loss", perceptual_loss + clip_loss)
172
+ if log:
173
+ wandb.log({"Perceptual Loss": perceptual_loss})
174
+ wandb.log({"CLIP Loss": clip_loss})
175
+
176
+ # These gradients will be masked if attn_mask has been set
177
+ clip_loss.backward(retain_graph=True)
178
+ perceptual_loss.backward(retain_graph=True)
179
+
180
+ optim.step()
181
+ yield vector
182
 
183
  def _optimize_LPIPS(self, vector, original_img, optim):
184
+ for i in range(self.reconstruction_steps):
185
+ optim.zero_grad()
186
+ transformed_img = self(vector)
187
+ processed_img = loop_post_process(transformed_img)
188
+ processed_img.retain_grad()
189
+
190
+ lpips_input = processed_img.clone()
191
+ lpips_input.register_hook(self._attn_mask_inverse)
192
+ lpips_input.retain_grad()
193
+ with torch.autocast("cuda"):
194
+ perceptual_loss = (
195
+ self.perceptual_loss(lpips_input, original_img.clone())
196
+ * self.lpips_weight
197
+ )
198
+ if log:
199
+ wandb.log({"Perceptual Loss": perceptual_loss})
200
+ print("LPIPS loss: ", perceptual_loss)
201
+ perceptual_loss.backward(retain_graph=True)
202
+ optim.step()
203
+ yield vector
204
 
205
  def optimize(self, latent, pos_prompts, neg_prompts):
206
  self.set_latent(latent)
 
211
  vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
212
  optim = torch.optim.Adam([vector], lr=self.lr)
213
 
214
+ for transform in self._optimize_CLIP_LPIPS(optim, original_img, vector, pos_prompts, neg_prompts):
215
+ yield transform
216
 
217
  print("Running LPIPS optim only")
218
+ for transform in self._optimize_LPIPS(vector, original_img, optim):
219
+ yield transform
220
  yield vector if self.return_val == "vector" else self.latent + vector