Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %% [markdown]
|
2 |
+
# ## Gathered Notebook
|
3 |
+
#
|
4 |
+
# This notebook was generated by the Gather Extension. The intent is that it contains only the code and cells required to produce the same results as the cell originally selected for gathering. Please note that the Python analysis is quite conservative, so if it is unsure whether a line of code is necessary for execution, it will err on the side of including it.
|
5 |
+
#
|
6 |
+
# **Please let us know if you are satisfied with what was gathered [here](https://aka.ms/gatherfeedback).**
|
7 |
+
#
|
8 |
+
# Thanks
|
9 |
+
|
10 |
+
# %%
|
11 |
+
import gradio as gr
|
12 |
+
from PIL import Image, ImageColor
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torchvision
|
16 |
+
from diffusers import DDPMPipeline, DDIMScheduler
|
17 |
+
from tqdm.auto import tqdm
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
# %%
|
21 |
+
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
|
22 |
+
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
|
23 |
+
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
|
24 |
+
scheduler.set_timesteps(num_inference_steps=40)
|
25 |
+
|
26 |
+
# %%
|
27 |
+
def color_loss(images, target_color=(.1, .9, .5)):
|
28 |
+
target = (torch.tensor(target_color).to(images.device) * 2 - 1)
|
29 |
+
target = target[None, :, None, None]
|
30 |
+
error = torch.abs(images - target).mean()
|
31 |
+
return error
|
32 |
+
|
33 |
+
# %%
|
34 |
+
def generate(color, guidance_loss_scale):
|
35 |
+
print(color, guidance_loss_scale)
|
36 |
+
target_color = ImageColor.getcolor(color, "RGB")
|
37 |
+
target_color = [a / 255 for a in target_color]
|
38 |
+
x = torch.randn(1, 3, 256, 256).to(device)
|
39 |
+
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
|
40 |
+
model_input = scheduler.scale_model_input(x, t)
|
41 |
+
with torch.no_grad():
|
42 |
+
noised_pred = image_pipe.unet(model_input, t)['sample']
|
43 |
+
x = x.detach().requires_grad_()
|
44 |
+
x0 = scheduler.step(noised_pred, t, x).pred_original_sample
|
45 |
+
loss = color_loss(x0, target_color) * guidance_loss_scale
|
46 |
+
cond_grad = - torch.autograd.grad(loss, x)[0]
|
47 |
+
x = x.detach() + cond_grad
|
48 |
+
x = scheduler.step(noised_pred, t, x).prev_sample
|
49 |
+
grid = torchvision.utils.make_grid(x, nrow=4)
|
50 |
+
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * .5 + .5
|
51 |
+
im = Image.fromarray((im.numpy() * 255).astype(np.uint8))
|
52 |
+
return im
|
53 |
+
inputs = [
|
54 |
+
gr.ColorPicker(label = 'color',value="#55FFAA"),
|
55 |
+
gr.Slider(minimum=0, maximum=30, value=3, label="Guidance Loss Scale"),
|
56 |
+
]
|
57 |
+
outputs = gr.Image(label="result")
|
58 |
+
demo = gr.Interface(
|
59 |
+
fn=generate,
|
60 |
+
inputs=inputs,
|
61 |
+
outputs=outputs,
|
62 |
+
examples=[
|
63 |
+
["#BB2266", 3],
|
64 |
+
["#44CCAA", 5],
|
65 |
+
])
|
66 |
+
demo.launch(debug=True)
|
67 |
+
|
68 |
+
|
69 |
+
|