Spaces:
Sleeping
Sleeping
srikanthp07
commited on
Commit
•
f840fe7
1
Parent(s):
1d4d479
Update app.py
Browse files
app.py
CHANGED
@@ -90,15 +90,15 @@ def get_style_embeddings(style_file):
|
|
90 |
import torch
|
91 |
|
92 |
def sharpness_loss(image):
|
93 |
-
threshold = torch.tensor(0.
|
94 |
std_dev = torch.std(image, dim=[2, 3])
|
95 |
|
96 |
# Calculate the mean standard deviation across the batch
|
97 |
mean_std_dev = torch.mean(std_dev, dim=0)
|
98 |
|
99 |
# Check if the mean standard deviation is below the threshold
|
100 |
-
loss = torch.mean(mean_std_dev - threshold)*50
|
101 |
-
return
|
102 |
|
103 |
|
104 |
from torchvision.transforms import ToTensor
|
@@ -278,7 +278,7 @@ def image_generator(prompt = "dog", loss_function=None):
|
|
278 |
|
279 |
return display_images_in_rows(generated_sd_images, titles)
|
280 |
|
281 |
-
description = "Generate an image with a prompt
|
282 |
|
283 |
demo = gr.Interface(image_generator,
|
284 |
inputs=[gr.Textbox(label="prompt", type="text", value="cat fight"),
|
|
|
90 |
import torch
|
91 |
|
92 |
def sharpness_loss(image):
|
93 |
+
threshold = torch.tensor(0.9)
|
94 |
std_dev = torch.std(image, dim=[2, 3])
|
95 |
|
96 |
# Calculate the mean standard deviation across the batch
|
97 |
mean_std_dev = torch.mean(std_dev, dim=0)
|
98 |
|
99 |
# Check if the mean standard deviation is below the threshold
|
100 |
+
loss = torch.abs(torch.mean(mean_std_dev - threshold))*50
|
101 |
+
return loss
|
102 |
|
103 |
|
104 |
from torchvision.transforms import ToTensor
|
|
|
278 |
|
279 |
return display_images_in_rows(generated_sd_images, titles)
|
280 |
|
281 |
+
description = "Generate an image with a prompt"
|
282 |
|
283 |
demo = gr.Interface(image_generator,
|
284 |
inputs=[gr.Textbox(label="prompt", type="text", value="cat fight"),
|