mix060514 commited on
Commit
e1170e2
1 Parent(s): fa8bae4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # ## Gathered Notebook
3
+ #
4
+ # This notebook was generated by the Gather Extension. The intent is that it contains only the code and cells required to produce the same results as the cell originally selected for gathering. Please note that the Python analysis is quite conservative, so if it is unsure whether a line of code is necessary for execution, it will err on the side of including it.
5
+ #
6
+ # **Please let us know if you are satisfied with what was gathered [here](https://aka.ms/gatherfeedback).**
7
+ #
8
+ # Thanks
9
+
10
+ # %%
11
+ import gradio as gr
12
+ from PIL import Image, ImageColor
13
+ import numpy as np
14
+ import torch
15
+ import torchvision
16
+ from diffusers import DDPMPipeline, DDIMScheduler
17
+ from tqdm.auto import tqdm
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # %%
21
+ pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
22
+ image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
23
+ scheduler = DDIMScheduler.from_pretrained(pipeline_name)
24
+ scheduler.set_timesteps(num_inference_steps=40)
25
+
26
+ # %%
27
+ def color_loss(images, target_color=(.1, .9, .5)):
28
+ target = (torch.tensor(target_color).to(images.device) * 2 - 1)
29
+ target = target[None, :, None, None]
30
+ error = torch.abs(images - target).mean()
31
+ return error
32
+
33
+ # %%
34
+ def generate(color, guidance_loss_scale):
35
+ print(color, guidance_loss_scale)
36
+ target_color = ImageColor.getcolor(color, "RGB")
37
+ target_color = [a / 255 for a in target_color]
38
+ x = torch.randn(1, 3, 256, 256).to(device)
39
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
40
+ model_input = scheduler.scale_model_input(x, t)
41
+ with torch.no_grad():
42
+ noised_pred = image_pipe.unet(model_input, t)['sample']
43
+ x = x.detach().requires_grad_()
44
+ x0 = scheduler.step(noised_pred, t, x).pred_original_sample
45
+ loss = color_loss(x0, target_color) * guidance_loss_scale
46
+ cond_grad = - torch.autograd.grad(loss, x)[0]
47
+ x = x.detach() + cond_grad
48
+ x = scheduler.step(noised_pred, t, x).prev_sample
49
+ grid = torchvision.utils.make_grid(x, nrow=4)
50
+ im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * .5 + .5
51
+ im = Image.fromarray((im.numpy() * 255).astype(np.uint8))
52
+ return im
53
+ inputs = [
54
+ gr.ColorPicker(label = 'color',value="#55FFAA"),
55
+ gr.Slider(minimum=0, maximum=30, value=3, label="Guidance Loss Scale"),
56
+ ]
57
+ outputs = gr.Image(label="result")
58
+ demo = gr.Interface(
59
+ fn=generate,
60
+ inputs=inputs,
61
+ outputs=outputs,
62
+ examples=[
63
+ ["#BB2266", 3],
64
+ ["#44CCAA", 5],
65
+ ])
66
+ demo.launch(debug=True)
67
+
68
+
69
+