alkzar90 commited on
Commit
1a635ad
1 Parent(s): e021f7b

Create app.py with Ukiyo postal generator service!

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open_clip
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+
7
+ from tqdm.auto import tqdm
8
+ from PIL import Image, ImageColor
9
+ from torchvision import transforms
10
+ from diffusers import DDIMScheduler, DDPMPipeline
11
+
12
+
13
+ device = (
14
+ "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cuda"
17
+ if torch.cuda.is_available()
18
+ else "cpu"
19
+ )
20
+
21
+ # Load the pretrained pipeline
22
+ pipeline_name = "alkzar90/sd-class-ukiyo-e-256"
23
+ image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
24
+
25
+ # Sample some images with a DDIM Scheduler over 40 steps
26
+ scheduler = DDIMScheduler.from_pretrained(pipeline_name)
27
+ scheduler.set_timesteps(num_inference_steps=40)
28
+
29
+
30
+ # Color guidance
31
+ #-------------------------------------------------------------------------------
32
+ # Color guidance function
33
+ def color_loss(images, target_color=(0.1, 0.9, 0.5)):
34
+ """Given a target color (R, G, B) return a loss for how far away on average
35
+ the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
36
+ target = (
37
+ torch.tensor(target_color).to(images.device) * 2 - 1
38
+ ) # Map target color to (-1, 1)
39
+ target = target[
40
+ None, :, None, None
41
+ ] # Get shape right to work with the images (b, c, h, w)
42
+ error = torch.abs(
43
+ images - target
44
+ ).mean() # Mean absolute difference between the image pixels and the target color
45
+ return error
46
+
47
+
48
+
49
+ # CLIP guidance
50
+ #-------------------------------------------------------------------------------
51
+ clip_model, _, preprocess = open_clip.create_model_and_transforms(
52
+ "ViT-B-32", pretrained="openai"
53
+ )
54
+ clip_model.to(device)
55
+
56
+ # Transforms to resize and augment an image + normalize to match CLIP's training data
57
+ tfms = transforms.Compose(
58
+ [
59
+ transforms.RandomResizedCrop(224), # Random CROP each time
60
+ transforms.RandomAffine(
61
+ 5
62
+ ), # One possible random augmentation: skews the image
63
+ transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like
64
+ transforms.Normalize(
65
+ mean=(0.48145466, 0.4578275, 0.40821073),
66
+ std=(0.26862954, 0.26130258, 0.27577711),
67
+ ),
68
+ ]
69
+ )
70
+
71
+
72
+ # CLIP guidance function
73
+ def clip_loss(image, text_features):
74
+ image_features = clip_model.encode_image(
75
+ tfms(image)
76
+ ) # Note: applies the above transforms
77
+ input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
78
+ embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
79
+ dists = (
80
+ input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
81
+ ) # Squared Great Circle Distance
82
+ return dists.mean()
83
+
84
+
85
+
86
+ # Sample generator loop
87
+ #-------------------------------------------------------------------------------
88
+ def generate(color,
89
+ color_loss_scale,
90
+ num_examples=4,
91
+ seed=None,
92
+ prompt=None,
93
+ prompt_loss_scale=None,
94
+ prompt_n_cuts=None,
95
+ inference_steps=50,
96
+ ):
97
+ scheduler.set_timesteps(num_inference_steps=inference_steps)
98
+
99
+ if seed:
100
+ torch.manual_seed(seed)
101
+
102
+ if prompt:
103
+ text = open_clip.tokenize([prompt]).to(device)
104
+ with torch.no_grad(), torch.cuda.amp.autocast():
105
+ text_features = clip_model.encode_text(text)
106
+
107
+ target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
108
+ target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1)
109
+
110
+ x = torch.randn(num_examples, 3, 256, 256).to(device)
111
+
112
+ for i, t in tqdm(enumerate(scheduler.timesteps)):
113
+ model_input = scheduler.scale_model_input(x, t)
114
+ with torch.no_grad():
115
+ noise_pred = image_pipe.unet(model_input, t)["sample"]
116
+ x = x.detach().requires_grad_()
117
+ x0 = scheduler.step(noise_pred, t, x).pred_original_sample
118
+
119
+ # color loss
120
+ loss = color_loss(x0, target_color) * color_loss_scale
121
+ cond_color_grad = -torch.autograd.grad(loss, x)[0]
122
+ # Modify x based solely on the color gradient -> x_cond
123
+ x_cond = x.detach() + cond_color_grad
124
+
125
+ # prompt loss (modify x_cond with cond_prompt_grad) based on
126
+ # the original x (not modifified previously with cond_color_grad)
127
+ if prompt:
128
+ cond_prompt_grad = 0
129
+ for cut in range(prompt_n_cuts):
130
+ # Set requires grad on x
131
+ x = x.detach().requires_grad_()
132
+ # Get the predicted x0:
133
+ x0 = scheduler.step(noise_pred, t, x).pred_original_sample
134
+ # Calculate loss
135
+ prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale
136
+ # Get gradient (scale by n_cuts since we want the average)
137
+ cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts
138
+ # Modify x based on this gradient
139
+ alpha_bar = scheduler.alphas_cumprod[i]
140
+ x_cond = (
141
+ x_cond + cond_prompt_grad * alpha_bar.sqrt()
142
+ ) # Note the additional scaling factor here!
143
+
144
+
145
+ x = scheduler.step(noise_pred, t, x_cond).prev_sample
146
+ grid = torchvision.utils.make_grid(x, nrow=4)
147
+ im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
148
+ im = Image.fromarray(np.array(im * 255).astype(np.uint8))
149
+ im.save("test.jpeg")
150
+ return im
151
+
152
+
153
+
154
+ # GRADIO Interface
155
+ #-------------------------------------------------------------------------------
156
+ TITLE="Ukiyo-e postal generator service 🎴!"
157
+ DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo"
158
+ CSS = ".output-image, .input-image, .image-preview {height: 250px !important}"
159
+
160
+ # See the gradio docs for the types of inputs and outputs available
161
+ inputs = [
162
+ gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), # Add any inputs you need here
163
+ gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7),
164
+ gr.Slider(label="num_examples (# images generated)", minimum=4, maximum=12, value=8, step=4),
165
+ gr.Number(label="seed (reproducibility and experimentation)", value=666),
166
+ gr.Text(label="Text prompt (optional)", value=None),
167
+ gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10),
168
+ gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4),
169
+ gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", mimimum=40, maximum=60, value=40, step=1),
170
+ ]
171
+
172
+ outputs = gr.Image(label="result")
173
+
174
+ # And the minimal interface
175
+ demo = gr.Interface(
176
+ fn=generate,
177
+ inputs=inputs,
178
+ outputs=outputs,
179
+ css=CSS,
180
+ examples=[
181
+ ["#DF5C16", 6.7, 12, 666, None, None, None, 40],
182
+ ["#C01660", 13.5, 12, 1990, None, None, None, 40],
183
+ ["#44CCAA", 8.9, 12, 1512, None, None, None, 40],
184
+ ["#39A291", 5.0, 12, 666, "A sakura tree", 60, 8, 52],
185
+ ["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52],
186
+ ["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 8, 47],
187
+ ],
188
+ title=TITLE,
189
+ description=DESCRIPTION,
190
+ )
191
+
192
+ if __name__ == "__main__":
193
+ demo.launch(enable_queue=True)