Spaces:
Sleeping
Sleeping
muneebable
commited on
Commit
•
740be46
1
Parent(s):
f345a6e
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
|
|
3 |
import torch.nn.functional as F
|
4 |
import numpy as np
|
5 |
from PIL import Image, ImageColor
|
@@ -27,24 +28,25 @@ def color_loss(images, target_color=(0.1, 0.9, 0.5)):
|
|
27 |
|
28 |
# And the core function to generate an image given the relevant inputs
|
29 |
def generate(color, guidance_loss_scale):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
48 |
|
49 |
# See the gradio docs for the types of inputs and outputs available
|
50 |
inputs = [
|
@@ -65,5 +67,4 @@ demo = gr.Interface(
|
|
65 |
|
66 |
# And launching
|
67 |
if __name__ == "__main__":
|
68 |
-
demo.launch(enable_queue=True
|
69 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
import torch.nn.functional as F
|
5 |
import numpy as np
|
6 |
from PIL import Image, ImageColor
|
|
|
28 |
|
29 |
# And the core function to generate an image given the relevant inputs
|
30 |
def generate(color, guidance_loss_scale):
|
31 |
+
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
|
32 |
+
target_color = [a/255 for a in target_color] # Rescale from (0, 255) to (0, 1)
|
33 |
+
x = torch.randn(1, 3, 256, 256).to(device)
|
34 |
+
for i, t in enumerate(scheduler.timesteps):
|
35 |
+
model_input = scheduler.scale_model_input(x, t)
|
36 |
+
with torch.no_grad():
|
37 |
+
noise_pred = image_pipe.unet(model_input, t)["sample"]
|
38 |
+
x = x.detach().requires_grad_()
|
39 |
+
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
|
40 |
+
loss = color_loss(x0, target_color) * guidance_loss_scale
|
41 |
+
cond_grad = -torch.autograd.grad(loss, x)[0]
|
42 |
+
x = x.detach() + cond_grad
|
43 |
+
x = scheduler.step(noise_pred, t, x).prev_sample
|
44 |
+
grid = torchvision.utils.make_grid(x, nrow=4)
|
45 |
+
im = grid.permute(1, 2, 0).cpu().clip(-1, 1)*0.5 + 0.5
|
46 |
+
im = (im * 255).byte().numpy() # Convert to uint8 numpy array
|
47 |
+
im = Image.fromarray(im)
|
48 |
+
im.save('test.jpeg')
|
49 |
+
return im
|
50 |
|
51 |
# See the gradio docs for the types of inputs and outputs available
|
52 |
inputs = [
|
|
|
67 |
|
68 |
# And launching
|
69 |
if __name__ == "__main__":
|
70 |
+
demo.launch() # Removed enable_queue=True
|
|