EeeeeeeH commited on
Commit
2dccad5
1 Parent(s): 7f0bd74
Files changed (2) hide show
  1. app.py +59 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch, torchvision
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image, ImageColor
6
+ from diffusers import DDPMPipeline
7
+ from diffusers import DDIMScheduler
8
+
9
+ device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+ pipeline_name = 'johnowhitaker/sd-class-wikiart-from-bedrooms'
12
+ image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
13
+
14
+ scheduler = DDIMScheduler.from_pretrained(pipeline_name)
15
+ scheduler.set_timesteps(num_inference_steps=20)
16
+
17
+ def color_loss(images, target_color=(0.1, 0.9, 0.5)):
18
+ target = torch.tensor(target_color).to(images.device) * 2 - 1
19
+ target = target[None, :, None, None]
20
+ error = torch.abs(images - target).mean()
21
+ return error
22
+
23
+ def generate(color, guidance_loss_scale):
24
+ target_color = ImageColor.getcolor(color, "RGB")
25
+ target_color = [a / 255 for a in target_color]
26
+ x = torch.randn(1, 3, 256, 256).to(device)
27
+ for i, t in enumerate(scheduler.timesteps):
28
+ model_input = scheduler.scale_model_input(x, t)
29
+ with torch.no_grad():
30
+ noise_pred = image_pipe.unet(model_input, t)["sample"]
31
+ x = x.detach().requires_grad_()
32
+ x0 = scheduler.step(noise_pred, t, x).pred_original_sample
33
+ loss = color_loss(x0, target_color) * guidance_loss_scale
34
+ cond_grad = -torch.autograd.grad(loss, x)[0]
35
+ x = x.detach() + cond_grad
36
+ x = scheduler.step(noise_pred, t, x).prev_sample
37
+ grid = torchvision.utils.make_grid(x, nrow=4)
38
+ im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
39
+ im = Image.fromarray(np.array(im * 255).astype(np.uint8))
40
+ im.save("test.jpeg")
41
+ return im
42
+
43
+ inputs = [
44
+ gr.ColorPicker(label="Color", value = '55FFAA'),
45
+ gr.Slider(label="Guidance Loss Scale", minimum=0, maximum=30, value=1)
46
+ ]
47
+ outputs = gr.Image(label="result")
48
+
49
+ demo = gr.Interface(
50
+ fn=generate,
51
+ inputs=inputs,
52
+ outputs=outputs,
53
+ examples=[
54
+ ["#BB2266", 3],["#44CCAA", 5]
55
+ ],
56
+ )
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch(enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ diffusers
2
+ torch
3
+ numpy
4
+ torchvision
5
+ gradio
6
+ pillow