Spaces:
Runtime error
Runtime error
Create app.py with Ukiyo postal generator service!
Browse files
app.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import open_clip
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from PIL import Image, ImageColor
|
9 |
+
from torchvision import transforms
|
10 |
+
from diffusers import DDIMScheduler, DDPMPipeline
|
11 |
+
|
12 |
+
|
13 |
+
device = (
|
14 |
+
"mps"
|
15 |
+
if torch.backends.mps.is_available()
|
16 |
+
else "cuda"
|
17 |
+
if torch.cuda.is_available()
|
18 |
+
else "cpu"
|
19 |
+
)
|
20 |
+
|
21 |
+
# Load the pretrained pipeline
|
22 |
+
pipeline_name = "alkzar90/sd-class-ukiyo-e-256"
|
23 |
+
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
|
24 |
+
|
25 |
+
# Sample some images with a DDIM Scheduler over 40 steps
|
26 |
+
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
|
27 |
+
scheduler.set_timesteps(num_inference_steps=40)
|
28 |
+
|
29 |
+
|
30 |
+
# Color guidance
|
31 |
+
#-------------------------------------------------------------------------------
|
32 |
+
# Color guidance function
|
33 |
+
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
|
34 |
+
"""Given a target color (R, G, B) return a loss for how far away on average
|
35 |
+
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
|
36 |
+
target = (
|
37 |
+
torch.tensor(target_color).to(images.device) * 2 - 1
|
38 |
+
) # Map target color to (-1, 1)
|
39 |
+
target = target[
|
40 |
+
None, :, None, None
|
41 |
+
] # Get shape right to work with the images (b, c, h, w)
|
42 |
+
error = torch.abs(
|
43 |
+
images - target
|
44 |
+
).mean() # Mean absolute difference between the image pixels and the target color
|
45 |
+
return error
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
# CLIP guidance
|
50 |
+
#-------------------------------------------------------------------------------
|
51 |
+
clip_model, _, preprocess = open_clip.create_model_and_transforms(
|
52 |
+
"ViT-B-32", pretrained="openai"
|
53 |
+
)
|
54 |
+
clip_model.to(device)
|
55 |
+
|
56 |
+
# Transforms to resize and augment an image + normalize to match CLIP's training data
|
57 |
+
tfms = transforms.Compose(
|
58 |
+
[
|
59 |
+
transforms.RandomResizedCrop(224), # Random CROP each time
|
60 |
+
transforms.RandomAffine(
|
61 |
+
5
|
62 |
+
), # One possible random augmentation: skews the image
|
63 |
+
transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like
|
64 |
+
transforms.Normalize(
|
65 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
66 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
67 |
+
),
|
68 |
+
]
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
# CLIP guidance function
|
73 |
+
def clip_loss(image, text_features):
|
74 |
+
image_features = clip_model.encode_image(
|
75 |
+
tfms(image)
|
76 |
+
) # Note: applies the above transforms
|
77 |
+
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
|
78 |
+
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
|
79 |
+
dists = (
|
80 |
+
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
|
81 |
+
) # Squared Great Circle Distance
|
82 |
+
return dists.mean()
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
# Sample generator loop
|
87 |
+
#-------------------------------------------------------------------------------
|
88 |
+
def generate(color,
|
89 |
+
color_loss_scale,
|
90 |
+
num_examples=4,
|
91 |
+
seed=None,
|
92 |
+
prompt=None,
|
93 |
+
prompt_loss_scale=None,
|
94 |
+
prompt_n_cuts=None,
|
95 |
+
inference_steps=50,
|
96 |
+
):
|
97 |
+
scheduler.set_timesteps(num_inference_steps=inference_steps)
|
98 |
+
|
99 |
+
if seed:
|
100 |
+
torch.manual_seed(seed)
|
101 |
+
|
102 |
+
if prompt:
|
103 |
+
text = open_clip.tokenize([prompt]).to(device)
|
104 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
105 |
+
text_features = clip_model.encode_text(text)
|
106 |
+
|
107 |
+
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
|
108 |
+
target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1)
|
109 |
+
|
110 |
+
x = torch.randn(num_examples, 3, 256, 256).to(device)
|
111 |
+
|
112 |
+
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
113 |
+
model_input = scheduler.scale_model_input(x, t)
|
114 |
+
with torch.no_grad():
|
115 |
+
noise_pred = image_pipe.unet(model_input, t)["sample"]
|
116 |
+
x = x.detach().requires_grad_()
|
117 |
+
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
|
118 |
+
|
119 |
+
# color loss
|
120 |
+
loss = color_loss(x0, target_color) * color_loss_scale
|
121 |
+
cond_color_grad = -torch.autograd.grad(loss, x)[0]
|
122 |
+
# Modify x based solely on the color gradient -> x_cond
|
123 |
+
x_cond = x.detach() + cond_color_grad
|
124 |
+
|
125 |
+
# prompt loss (modify x_cond with cond_prompt_grad) based on
|
126 |
+
# the original x (not modifified previously with cond_color_grad)
|
127 |
+
if prompt:
|
128 |
+
cond_prompt_grad = 0
|
129 |
+
for cut in range(prompt_n_cuts):
|
130 |
+
# Set requires grad on x
|
131 |
+
x = x.detach().requires_grad_()
|
132 |
+
# Get the predicted x0:
|
133 |
+
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
|
134 |
+
# Calculate loss
|
135 |
+
prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale
|
136 |
+
# Get gradient (scale by n_cuts since we want the average)
|
137 |
+
cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts
|
138 |
+
# Modify x based on this gradient
|
139 |
+
alpha_bar = scheduler.alphas_cumprod[i]
|
140 |
+
x_cond = (
|
141 |
+
x_cond + cond_prompt_grad * alpha_bar.sqrt()
|
142 |
+
) # Note the additional scaling factor here!
|
143 |
+
|
144 |
+
|
145 |
+
x = scheduler.step(noise_pred, t, x_cond).prev_sample
|
146 |
+
grid = torchvision.utils.make_grid(x, nrow=4)
|
147 |
+
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
|
148 |
+
im = Image.fromarray(np.array(im * 255).astype(np.uint8))
|
149 |
+
im.save("test.jpeg")
|
150 |
+
return im
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
# GRADIO Interface
|
155 |
+
#-------------------------------------------------------------------------------
|
156 |
+
TITLE="Ukiyo-e postal generator service 🎴!"
|
157 |
+
DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo"
|
158 |
+
CSS = ".output-image, .input-image, .image-preview {height: 250px !important}"
|
159 |
+
|
160 |
+
# See the gradio docs for the types of inputs and outputs available
|
161 |
+
inputs = [
|
162 |
+
gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), # Add any inputs you need here
|
163 |
+
gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7),
|
164 |
+
gr.Slider(label="num_examples (# images generated)", minimum=4, maximum=12, value=8, step=4),
|
165 |
+
gr.Number(label="seed (reproducibility and experimentation)", value=666),
|
166 |
+
gr.Text(label="Text prompt (optional)", value=None),
|
167 |
+
gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10),
|
168 |
+
gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4),
|
169 |
+
gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", mimimum=40, maximum=60, value=40, step=1),
|
170 |
+
]
|
171 |
+
|
172 |
+
outputs = gr.Image(label="result")
|
173 |
+
|
174 |
+
# And the minimal interface
|
175 |
+
demo = gr.Interface(
|
176 |
+
fn=generate,
|
177 |
+
inputs=inputs,
|
178 |
+
outputs=outputs,
|
179 |
+
css=CSS,
|
180 |
+
examples=[
|
181 |
+
["#DF5C16", 6.7, 12, 666, None, None, None, 40],
|
182 |
+
["#C01660", 13.5, 12, 1990, None, None, None, 40],
|
183 |
+
["#44CCAA", 8.9, 12, 1512, None, None, None, 40],
|
184 |
+
["#39A291", 5.0, 12, 666, "A sakura tree", 60, 8, 52],
|
185 |
+
["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52],
|
186 |
+
["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 8, 47],
|
187 |
+
],
|
188 |
+
title=TITLE,
|
189 |
+
description=DESCRIPTION,
|
190 |
+
)
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
demo.launch(enable_queue=True)
|