erwann commited on
Commit
4e7a12f
1 Parent(s): 663705e

wip statewrapper chagne

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. app_backend.py +7 -19
  3. loaders.py +1 -0
app.py CHANGED
@@ -19,7 +19,7 @@ from loaders import load_default
19
  from animation import create_gif
20
  from prompts import get_random_prompts
21
 
22
- device = "cuda"
23
  vqgan = load_default(device)
24
  vqgan.eval()
25
  processor = ProcessorGradientFlow(device=device)
@@ -62,7 +62,7 @@ class StateWrapper:
62
  return state, *state[0].update_requant(*args, **kwargs)
63
 
64
  with gr.Blocks(css="styles.css") as demo:
65
- promptoptim = gr.State([ImagePromptOptimizer(vqgan, clip, processor, quantize=True)])
66
  state = gr.State([ImageState(vqgan, promptoptim)])
67
  with gr.Row():
68
  with gr.Column(scale=1):
 
19
  from animation import create_gif
20
  from prompts import get_random_prompts
21
 
22
+ device = "cpu"
23
  vqgan = load_default(device)
24
  vqgan.eval()
25
  processor = ProcessorGradientFlow(device=device)
 
62
  return state, *state[0].update_requant(*args, **kwargs)
63
 
64
  with gr.Blocks(css="styles.css") as demo:
65
+ promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
66
  state = gr.State([ImageState(vqgan, promptoptim)])
67
  with gr.Row():
68
  with gr.Column(scale=1):
app_backend.py CHANGED
@@ -174,19 +174,13 @@ class ImagePromptOptimizer(nn.Module):
174
  clip_clone = processed_img.clone()
175
  clip_clone.register_hook(self.attn_masking)
176
  clip_clone.retain_grad()
177
- # with torch.autocast("cuda"):
178
- clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
179
- print("CLIP loss", clip_loss)
180
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
181
- print("LPIPS loss: ", perceptual_loss)
182
- # with torch.no_grad():
183
- # disc_logits = self.disc(transformed_img)
184
- # disc_loss = self.disc_loss_fn(disc_logits)
185
- # print(f"disc_loss = {disc_loss}")
186
- # disc_loss2 = self.disc(processed_img)
187
  if log:
188
  wandb.log({"Perceptual Loss": perceptual_loss})
189
- # wandb.log({"Discriminator Loss": disc_loss})
190
  wandb.log({"CLIP Loss": clip_loss})
191
  clip_loss.backward(retain_graph=True)
192
  perceptual_loss.backward(retain_graph=True)
@@ -208,14 +202,8 @@ class ImagePromptOptimizer(nn.Module):
208
  lpips_input = processed_img.clone()
209
  lpips_input.register_hook(self.attn_masking2)
210
  lpips_input.retain_grad()
211
- # with torch.autocast("cuda"):
212
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
213
- # with torch.no_grad():
214
- # disc_logits = self.disc(transformed_img)
215
- # disc_loss = self.disc_loss_fn(disc_logits)
216
- # print(f"disc_loss = {disc_loss}")
217
- # disc_loss2 = self.disc(processed_img)
218
- # print(f"disc_loss2 = {disc_loss2}")
219
  if log:
220
  wandb.log({"Perceptual Loss": perceptual_loss})
221
  print("LPIPS loss: ", perceptual_loss)
 
174
  clip_clone = processed_img.clone()
175
  clip_clone.register_hook(self.attn_masking)
176
  clip_clone.retain_grad()
177
+ with torch.autocast("cuda"):
178
+ clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
179
+ print("CLIP loss", clip_loss)
180
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
181
+ print("LPIPS loss: ", perceptual_loss)
 
 
 
 
 
182
  if log:
183
  wandb.log({"Perceptual Loss": perceptual_loss})
 
184
  wandb.log({"CLIP Loss": clip_loss})
185
  clip_loss.backward(retain_graph=True)
186
  perceptual_loss.backward(retain_graph=True)
 
202
  lpips_input = processed_img.clone()
203
  lpips_input.register_hook(self.attn_masking2)
204
  lpips_input.retain_grad()
205
+ with torch.autocast("cuda"):
206
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
 
 
 
 
 
 
207
  if log:
208
  wandb.log({"Perceptual Loss": perceptual_loss})
209
  print("LPIPS loss: ", perceptual_loss)
loaders.py CHANGED
@@ -36,6 +36,7 @@ def load_default(device):
36
  sd = torch.load("./vqgan_only.pt", map_location=device)
37
  model.load_state_dict(sd, strict=True)
38
  model.to(device)
 
39
  return model
40
 
41
 
 
36
  sd = torch.load("./vqgan_only.pt", map_location=device)
37
  model.load_state_dict(sd, strict=True)
38
  model.to(device)
39
+ del sd
40
  return model
41
 
42