erwann commited on
Commit
9902d5e
1 Parent(s): 71b70df

update config and rm discriminator

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. app_backend.py +18 -19
  3. configs.py +1 -1
app.py CHANGED
@@ -139,7 +139,7 @@ with gr.Blocks(css="styles.css") as demo:
139
  with gr.Column():
140
  major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
141
  iterations = gr.Slider(minimum=10,
142
- maximum=300,
143
  step=1,
144
  value=20,
145
  label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
 
139
  with gr.Column():
140
  major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
141
  iterations = gr.Slider(minimum=10,
142
+ maximum=60,
143
  step=1,
144
  value=20,
145
  label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
app_backend.py CHANGED
@@ -81,7 +81,6 @@ class ImagePromptOptimizer(nn.Module):
81
  self.make_grid = make_grid
82
  self.return_val = return_val
83
  self.quantize = quantize
84
- self.disc = load_disc(self.device)
85
  self.lpips_weight = lpips_weight
86
  self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
87
  def disc_loss_fn(self, logits):
@@ -175,19 +174,19 @@ class ImagePromptOptimizer(nn.Module):
175
  clip_clone = processed_img.clone()
176
  clip_clone.register_hook(self.attn_masking)
177
  clip_clone.retain_grad()
178
- with torch.autocast("cuda"):
179
- clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
180
- print("CLIP loss", clip_loss)
181
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
182
- print("LPIPS loss: ", perceptual_loss)
183
- with torch.no_grad():
184
- disc_logits = self.disc(transformed_img)
185
- disc_loss = self.disc_loss_fn(disc_logits)
186
- print(f"disc_loss = {disc_loss}")
187
- disc_loss2 = self.disc(processed_img)
188
  if log:
189
  wandb.log({"Perceptual Loss": perceptual_loss})
190
- wandb.log({"Discriminator Loss": disc_loss})
191
  wandb.log({"CLIP Loss": clip_loss})
192
  clip_loss.backward(retain_graph=True)
193
  perceptual_loss.backward(retain_graph=True)
@@ -209,13 +208,13 @@ class ImagePromptOptimizer(nn.Module):
209
  lpips_input = processed_img.clone()
210
  lpips_input.register_hook(self.attn_masking2)
211
  lpips_input.retain_grad()
212
- with torch.autocast("cuda"):
213
- perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
214
- with torch.no_grad():
215
- disc_logits = self.disc(transformed_img)
216
- disc_loss = self.disc_loss_fn(disc_logits)
217
- print(f"disc_loss = {disc_loss}")
218
- disc_loss2 = self.disc(processed_img)
219
  # print(f"disc_loss2 = {disc_loss2}")
220
  if log:
221
  wandb.log({"Perceptual Loss": perceptual_loss})
 
81
  self.make_grid = make_grid
82
  self.return_val = return_val
83
  self.quantize = quantize
 
84
  self.lpips_weight = lpips_weight
85
  self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
86
  def disc_loss_fn(self, logits):
 
174
  clip_clone = processed_img.clone()
175
  clip_clone.register_hook(self.attn_masking)
176
  clip_clone.retain_grad()
177
+ # with torch.autocast("cuda"):
178
+ clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
179
+ print("CLIP loss", clip_loss)
180
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
181
+ print("LPIPS loss: ", perceptual_loss)
182
+ # with torch.no_grad():
183
+ # disc_logits = self.disc(transformed_img)
184
+ # disc_loss = self.disc_loss_fn(disc_logits)
185
+ # print(f"disc_loss = {disc_loss}")
186
+ # disc_loss2 = self.disc(processed_img)
187
  if log:
188
  wandb.log({"Perceptual Loss": perceptual_loss})
189
+ # wandb.log({"Discriminator Loss": disc_loss})
190
  wandb.log({"CLIP Loss": clip_loss})
191
  clip_loss.backward(retain_graph=True)
192
  perceptual_loss.backward(retain_graph=True)
 
208
  lpips_input = processed_img.clone()
209
  lpips_input.register_hook(self.attn_masking2)
210
  lpips_input.retain_grad()
211
+ # with torch.autocast("cuda"):
212
+ perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
213
+ # with torch.no_grad():
214
+ # disc_logits = self.disc(transformed_img)
215
+ # disc_loss = self.disc_loss_fn(disc_logits)
216
+ # print(f"disc_loss = {disc_loss}")
217
+ # disc_loss2 = self.disc(processed_img)
218
  # print(f"disc_loss2 = {disc_loss2}")
219
  if log:
220
  wandb.log({"Perceptual Loss": perceptual_loss})
configs.py CHANGED
@@ -2,6 +2,6 @@ import gradio as gr
2
  def set_small_local():
3
  return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
4
  def set_major_local():
5
- return (gr.Slider.update(value=25), gr.Slider.update(value=0.25), gr.Slider.update(value=35), gr.Slider.update(value=10))
6
  def set_major_global():
7
  return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
 
2
  def set_small_local():
3
  return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
4
  def set_major_local():
5
+ return (gr.Slider.update(value=25), gr.Slider.update(value=0.2), gr.Slider.update(value=36.6), gr.Slider.update(value=10))
6
  def set_major_global():
7
  return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))