|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
from PIL import Image, ImageColor |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from diffusers import DDPMPipeline, DDIMScheduler |
|
from tqdm.auto import tqdm |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms" |
|
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) |
|
scheduler = DDIMScheduler.from_pretrained(pipeline_name) |
|
scheduler.set_timesteps(num_inference_steps=40) |
|
|
|
|
|
def color_loss(images, target_color=(.1, .9, .5)): |
|
target = (torch.tensor(target_color).to(images.device) * 2 - 1) |
|
target = target[None, :, None, None] |
|
error = torch.abs(images - target).mean() |
|
return error |
|
|
|
|
|
def generate(color, guidance_loss_scale): |
|
print(color, guidance_loss_scale) |
|
target_color = ImageColor.getcolor(color, "RGB") |
|
target_color = [a / 255 for a in target_color] |
|
x = torch.randn(1, 3, 256, 256).to(device) |
|
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): |
|
model_input = scheduler.scale_model_input(x, t) |
|
with torch.no_grad(): |
|
noised_pred = image_pipe.unet(model_input, t)['sample'] |
|
x = x.detach().requires_grad_() |
|
x0 = scheduler.step(noised_pred, t, x).pred_original_sample |
|
loss = color_loss(x0, target_color) * guidance_loss_scale |
|
cond_grad = - torch.autograd.grad(loss, x)[0] |
|
x = x.detach() + cond_grad |
|
x = scheduler.step(noised_pred, t, x).prev_sample |
|
grid = torchvision.utils.make_grid(x, nrow=4) |
|
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * .5 + .5 |
|
im = Image.fromarray((im.numpy() * 255).astype(np.uint8)) |
|
return im |
|
inputs = [ |
|
gr.ColorPicker(label = 'color',value="#55FFAA"), |
|
gr.Slider(minimum=0, maximum=30, value=3, label="Guidance Loss Scale"), |
|
] |
|
outputs = gr.Image(label="result") |
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=inputs, |
|
outputs=outputs, |
|
examples=[ |
|
|
|
|
|
]) |
|
demo.launch(debug=True) |
|
|
|
|
|
|
|
|