srikanthp07 commited on
Commit
1d4d479
1 Parent(s): a840b20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -89,16 +89,16 @@ def get_style_embeddings(style_file):
89
 
90
  import torch
91
 
92
- def vibrance_loss(image):
93
- # Calculate the standard deviation of color channels
94
- std_dev = torch.std(image, dim=(2, 3)) # Compute standard deviation over height and width
 
95
  # Calculate the mean standard deviation across the batch
96
- mean_std_dev = torch.mean(std_dev)
97
- # You can adjust a scale factor to control the strength of vibrance regularization
98
- scale_factor = 100.0
99
- # Calculate the vibrance loss
100
- loss = -scale_factor * mean_std_dev
101
- return loss
102
 
103
 
104
  from torchvision.transforms import ToTensor
@@ -183,6 +183,7 @@ def generate_with_embs(text_embeddings, max_length, random_seed, loss_fn = None)
183
  if loss_fn is not None:
184
  if i%2 == 0:
185
  latents, custom_loss = additional_guidance(latents, scheduler, noise_pred, t, sigma, loss_fn)
 
186
 
187
  # compute the previous noisy sample x_t -> x_t-1
188
  latents = scheduler.step(noise_pred, t, latents).prev_sample
@@ -255,7 +256,7 @@ def image_generator(prompt = "dog", loss_function=None):
255
  images_without_loss = []
256
  images_with_loss = []
257
  if loss_function == "Yes":
258
- loss_function = vibrance_loss
259
  else:
260
  loss_function = None
261
 
@@ -277,7 +278,7 @@ def image_generator(prompt = "dog", loss_function=None):
277
 
278
  return display_images_in_rows(generated_sd_images, titles)
279
 
280
- description = "Generate an image with a prompt"
281
 
282
  demo = gr.Interface(image_generator,
283
  inputs=[gr.Textbox(label="prompt", type="text", value="cat fight"),
 
89
 
90
  import torch
91
 
92
+ def sharpness_loss(image):
93
+ threshold = torch.tensor(0.5)
94
+ std_dev = torch.std(image, dim=[2, 3])
95
+
96
  # Calculate the mean standard deviation across the batch
97
+ mean_std_dev = torch.mean(std_dev, dim=0)
98
+
99
+ # Check if the mean standard deviation is below the threshold
100
+ loss = torch.mean(mean_std_dev - threshold)*50
101
+ return -loss
 
102
 
103
 
104
  from torchvision.transforms import ToTensor
 
183
  if loss_fn is not None:
184
  if i%2 == 0:
185
  latents, custom_loss = additional_guidance(latents, scheduler, noise_pred, t, sigma, loss_fn)
186
+ print('loss: ',custom_loss.item())
187
 
188
  # compute the previous noisy sample x_t -> x_t-1
189
  latents = scheduler.step(noise_pred, t, latents).prev_sample
 
256
  images_without_loss = []
257
  images_with_loss = []
258
  if loss_function == "Yes":
259
+ loss_function = sharpness_loss
260
  else:
261
  loss_function = None
262
 
 
278
 
279
  return display_images_in_rows(generated_sd_images, titles)
280
 
281
+ description = "Generate an image with a prompt (takes time)"
282
 
283
  demo = gr.Interface(image_generator,
284
  inputs=[gr.Textbox(label="prompt", type="text", value="cat fight"),