srikanthp07 commited on
Commit
f840fe7
1 Parent(s): 1d4d479

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -90,15 +90,15 @@ def get_style_embeddings(style_file):
90
  import torch
91
 
92
  def sharpness_loss(image):
93
- threshold = torch.tensor(0.5)
94
  std_dev = torch.std(image, dim=[2, 3])
95
 
96
  # Calculate the mean standard deviation across the batch
97
  mean_std_dev = torch.mean(std_dev, dim=0)
98
 
99
  # Check if the mean standard deviation is below the threshold
100
- loss = torch.mean(mean_std_dev - threshold)*50
101
- return -loss
102
 
103
 
104
  from torchvision.transforms import ToTensor
@@ -278,7 +278,7 @@ def image_generator(prompt = "dog", loss_function=None):
278
 
279
  return display_images_in_rows(generated_sd_images, titles)
280
 
281
- description = "Generate an image with a prompt (takes time)"
282
 
283
  demo = gr.Interface(image_generator,
284
  inputs=[gr.Textbox(label="prompt", type="text", value="cat fight"),
 
90
  import torch
91
 
92
  def sharpness_loss(image):
93
+ threshold = torch.tensor(0.9)
94
  std_dev = torch.std(image, dim=[2, 3])
95
 
96
  # Calculate the mean standard deviation across the batch
97
  mean_std_dev = torch.mean(std_dev, dim=0)
98
 
99
  # Check if the mean standard deviation is below the threshold
100
+ loss = torch.abs(torch.mean(mean_std_dev - threshold))*50
101
+ return loss
102
 
103
 
104
  from torchvision.transforms import ToTensor
 
278
 
279
  return display_images_in_rows(generated_sd_images, titles)
280
 
281
+ description = "Generate an image with a prompt"
282
 
283
  demo = gr.Interface(image_generator,
284
  inputs=[gr.Textbox(label="prompt", type="text", value="cat fight"),