mix060514's picture
Update app.py
8b04042 verified
# %% [markdown]
# ## Gathered Notebook
#
# This notebook was generated by the Gather Extension. The intent is that it contains only the code and cells required to produce the same results as the cell originally selected for gathering. Please note that the Python analysis is quite conservative, so if it is unsure whether a line of code is necessary for execution, it will err on the side of including it.
#
# **Please let us know if you are satisfied with what was gathered [here](https://aka.ms/gatherfeedback).**
#
# Thanks
# %%
import gradio as gr
from PIL import Image, ImageColor
import numpy as np
import torch
import torchvision
from diffusers import DDPMPipeline, DDIMScheduler
from tqdm.auto import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)
# %%
def color_loss(images, target_color=(.1, .9, .5)):
target = (torch.tensor(target_color).to(images.device) * 2 - 1)
target = target[None, :, None, None]
error = torch.abs(images - target).mean()
return error
# %%
def generate(color, guidance_loss_scale):
print(color, guidance_loss_scale)
target_color = ImageColor.getcolor(color, "RGB")
target_color = [a / 255 for a in target_color]
x = torch.randn(1, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noised_pred = image_pipe.unet(model_input, t)['sample']
x = x.detach().requires_grad_()
x0 = scheduler.step(noised_pred, t, x).pred_original_sample
loss = color_loss(x0, target_color) * guidance_loss_scale
cond_grad = - torch.autograd.grad(loss, x)[0]
x = x.detach() + cond_grad
x = scheduler.step(noised_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * .5 + .5
im = Image.fromarray((im.numpy() * 255).astype(np.uint8))
return im
inputs = [
gr.ColorPicker(label = 'color',value="#55FFAA"),
gr.Slider(minimum=0, maximum=30, value=3, label="Guidance Loss Scale"),
]
outputs = gr.Image(label="result")
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
examples=[
# ["#BB2266", 3],
# ["#44CCAA", 5],
])
demo.launch(debug=True)