wjh
commited on
Commit
•
67e6974
1
Parent(s):
c464026
init
Browse files- app.py +432 -0
- interpolation.py +388 -0
- pipeline_interpolated_sdxl.py +0 -0
- pipeline_interpolated_stable_diffusion.py +584 -0
- prior.py +168 -0
- requirements.txt +65 -0
- style.css +71 -0
- utils.py +189 -0
app.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import user_history
|
9 |
+
from PIL import Image
|
10 |
+
from scipy.stats import beta as beta_distribution
|
11 |
+
|
12 |
+
from pipeline_interpolated_sdxl import InterpolationStableDiffusionXLPipeline
|
13 |
+
from pipeline_interpolated_stable_diffusion import InterpolationStableDiffusionPipeline
|
14 |
+
|
15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
16 |
+
|
17 |
+
title = r"""
|
18 |
+
<h1 align="center">PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</h1>
|
19 |
+
"""
|
20 |
+
|
21 |
+
description = r"""
|
22 |
+
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</b></a>.<br>
|
23 |
+
How to use:<br>
|
24 |
+
1. Input prompt 1 and prompt 2.
|
25 |
+
2. (Optional) Input the guidance prompt and negative prompt.
|
26 |
+
3. (Optional) Change the interpolation parameters and check the Beta distribution.
|
27 |
+
4. Click the <b>Generate</b> button to begin generating images.
|
28 |
+
5. Enjoy! 😊"""
|
29 |
+
|
30 |
+
article = r"""
|
31 |
+
---
|
32 |
+
✒️ **Citation**
|
33 |
+
<br>
|
34 |
+
If you found this demo/our paper useful, please consider citing:
|
35 |
+
```bibtex
|
36 |
+
@article{he024paid,
|
37 |
+
title={PAID:(Prompt-guided) Attention Interpolation of Text-to-Image Diffusion},
|
38 |
+
author={He, Qiyuan and Wang, Jinghao and Liu, Ziwei and Angle, Yao},
|
39 |
+
journal={},
|
40 |
+
year={2024}
|
41 |
+
}
|
42 |
+
```
|
43 |
+
📧 **Contact**
|
44 |
+
<br>
|
45 |
+
If you have any questions, please feel free to open an issue in our <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>Github Repo</b></a> or directly reach us out at <b>qhe@u.nus.edu.sg</b>.
|
46 |
+
"""
|
47 |
+
|
48 |
+
MAX_SEED = np.iinfo(np.int32).max
|
49 |
+
CACHE_EXAMPLES = False
|
50 |
+
USE_TORCH_COMPILE = False
|
51 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
52 |
+
PREVIEW_IMAGES = False
|
53 |
+
|
54 |
+
dtype = torch.float32
|
55 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
+
pipeline = InterpolationStableDiffusionPipeline(
|
57 |
+
repo_name="runwayml/stable-diffusion-v1-5",
|
58 |
+
guidance_scale=10.0,
|
59 |
+
scheduler_name="unipc",
|
60 |
+
)
|
61 |
+
pipeline.to(device, dtype=dtype)
|
62 |
+
|
63 |
+
|
64 |
+
def change_model_fn(model_name: str) -> None:
|
65 |
+
global pipeline
|
66 |
+
name_mapping = {
|
67 |
+
"SD1.4-521": "CompVis/stable-diffusion-v1-4",
|
68 |
+
"SD1.5-512": "runwayml/stable-diffusion-v1-5",
|
69 |
+
"SD2.1-768": "stabilityai/stable-diffusion-2-1",
|
70 |
+
"SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
|
71 |
+
}
|
72 |
+
if "XL" not in model_name:
|
73 |
+
pipeline = InterpolationStableDiffusionPipeline(
|
74 |
+
repo_name=name_mapping[model_name],
|
75 |
+
guidance_scale=10.0,
|
76 |
+
scheduler_name="unipc",
|
77 |
+
)
|
78 |
+
pipeline.to(device, dtype=dtype)
|
79 |
+
else:
|
80 |
+
pipeline = InterpolationStableDiffusionXLPipeline.from_pretrained(
|
81 |
+
name_mapping[model_name]
|
82 |
+
)
|
83 |
+
pipeline.to(device, dtype=dtype)
|
84 |
+
|
85 |
+
|
86 |
+
def save_image(img, index):
|
87 |
+
unique_name = f"{index}.png"
|
88 |
+
img = Image.fromarray(img)
|
89 |
+
img.save(unique_name)
|
90 |
+
return unique_name
|
91 |
+
|
92 |
+
|
93 |
+
def generate_beta_tensor(
|
94 |
+
size: int, alpha: float = 3.0, beta: float = 3.0
|
95 |
+
) -> torch.FloatTensor:
|
96 |
+
prob_values = [i / (size - 1) for i in range(size)]
|
97 |
+
inverse_cdf_values = beta_distribution.ppf(prob_values, alpha, beta)
|
98 |
+
return inverse_cdf_values
|
99 |
+
|
100 |
+
|
101 |
+
def plot_gemma_fn(alpha: float, beta: float, size: int) -> pd.DataFrame:
|
102 |
+
beta_ppf = generate_beta_tensor(size=size, alpha=int(alpha), beta=int(beta))
|
103 |
+
return pd.DataFrame(
|
104 |
+
{
|
105 |
+
"interpolation index": [i for i in range(size)],
|
106 |
+
"coefficient": beta_ppf.tolist(),
|
107 |
+
}
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def get_example() -> list:
|
112 |
+
case = [
|
113 |
+
[
|
114 |
+
"A photo of dog, best quality, extremely detailed",
|
115 |
+
"A photo of car, best quality, extremely detailed",
|
116 |
+
3,
|
117 |
+
6,
|
118 |
+
3,
|
119 |
+
"A photo of a dog driving a car, logical, best quality, extremely detailed",
|
120 |
+
"monochrome, lowres, bad anatomy, worst quality, low quality",
|
121 |
+
"SD1.5-512",
|
122 |
+
6.1 / 50,
|
123 |
+
10,
|
124 |
+
50,
|
125 |
+
"fused_inner",
|
126 |
+
"self",
|
127 |
+
1002,
|
128 |
+
True,
|
129 |
+
]
|
130 |
+
]
|
131 |
+
return case
|
132 |
+
|
133 |
+
|
134 |
+
def dynamic_gallery_fn(interpolation_size: int):
|
135 |
+
|
136 |
+
return gr.Gallery(
|
137 |
+
label="Result", show_label=False, rows=1, columns=interpolation_size
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def generate(
|
143 |
+
prompt1: str,
|
144 |
+
prompt2: str,
|
145 |
+
guidance_prompt: Optional[str] = None,
|
146 |
+
negative_prompt: str = "",
|
147 |
+
warmup_ratio: int = 8,
|
148 |
+
guidance_scale: float = 10,
|
149 |
+
early: str = "fused_outer",
|
150 |
+
late: str = "self",
|
151 |
+
alpha: float = 4.0,
|
152 |
+
beta: float = 4.0,
|
153 |
+
interpolation_size: int = 3,
|
154 |
+
seed: int = 0,
|
155 |
+
same_latent: bool = True,
|
156 |
+
num_inference_steps: int = 50,
|
157 |
+
progress=gr.Progress(),
|
158 |
+
) -> np.ndarray:
|
159 |
+
global pipeline
|
160 |
+
generator = (
|
161 |
+
torch.cuda.manual_seed(seed)
|
162 |
+
if torch.cuda.is_available()
|
163 |
+
else torch.manual_seed(seed)
|
164 |
+
)
|
165 |
+
latent1 = pipeline.generate_latent(generator=generator)
|
166 |
+
latent1 = latent1.to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
|
167 |
+
if same_latent:
|
168 |
+
latent2 = latent1.clone()
|
169 |
+
else:
|
170 |
+
latent2 = pipeline.generate_latent(generator=generator)
|
171 |
+
latent2 = latent2.to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
|
172 |
+
betas = generate_beta_tensor(size=interpolation_size, alpha=alpha, beta=beta)
|
173 |
+
for i in progress.tqdm(
|
174 |
+
range(interpolation_size - 2),
|
175 |
+
desc=(
|
176 |
+
f"Generating {interpolation_size-2} images"
|
177 |
+
if interpolation_size > 3
|
178 |
+
else "Generating 1 image"
|
179 |
+
),
|
180 |
+
):
|
181 |
+
it = betas[i + 1].item()
|
182 |
+
images = pipeline.interpolate_single(
|
183 |
+
it,
|
184 |
+
latent_start=latent1,
|
185 |
+
latent_end=latent2,
|
186 |
+
prompt_start=prompt1,
|
187 |
+
prompt_end=prompt2,
|
188 |
+
guide_prompt=guidance_prompt,
|
189 |
+
num_inference_steps=num_inference_steps,
|
190 |
+
warmup_ratio=warmup_ratio,
|
191 |
+
early=early,
|
192 |
+
late=late,
|
193 |
+
negative_prompt=negative_prompt,
|
194 |
+
guidance_scale=guidance_scale,
|
195 |
+
)
|
196 |
+
if interpolation_size == 3:
|
197 |
+
final_images = images
|
198 |
+
break
|
199 |
+
if i == 0:
|
200 |
+
final_images = images[:2]
|
201 |
+
elif i == interpolation_size - 3:
|
202 |
+
final_images = np.concatenate([final_images, images[1:]], axis=0)
|
203 |
+
else:
|
204 |
+
final_images = np.concatenate([final_images, images[1:2]], axis=0)
|
205 |
+
return final_images
|
206 |
+
|
207 |
+
|
208 |
+
interpolation_size = None
|
209 |
+
|
210 |
+
with gr.Blocks() as demo:
|
211 |
+
gr.Markdown(title)
|
212 |
+
gr.Markdown(description)
|
213 |
+
with gr.Group():
|
214 |
+
prompt1 = gr.Text(
|
215 |
+
label="Prompt 1",
|
216 |
+
max_lines=3,
|
217 |
+
placeholder="Enter the First Prompt",
|
218 |
+
interactive=True,
|
219 |
+
value="A photo of dog, best quality, extremely detailed",
|
220 |
+
)
|
221 |
+
prompt2 = gr.Text(
|
222 |
+
label="Prompt 2",
|
223 |
+
max_lines=3,
|
224 |
+
placeholder="Enter the Second prompt",
|
225 |
+
interactive=True,
|
226 |
+
value="A photo of car, best quality, extremely detaile",
|
227 |
+
)
|
228 |
+
result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3)
|
229 |
+
generate_button = gr.Button("Generate", variant="primary")
|
230 |
+
with gr.Accordion("Advanced options", open=True):
|
231 |
+
with gr.Group():
|
232 |
+
with gr.Row():
|
233 |
+
with gr.Column():
|
234 |
+
interpolation_size = gr.Slider(
|
235 |
+
label="Interpolation Size",
|
236 |
+
minimum=3,
|
237 |
+
maximum=15,
|
238 |
+
step=1,
|
239 |
+
value=3,
|
240 |
+
info="Interpolation size includes the start and end images",
|
241 |
+
)
|
242 |
+
alpha = gr.Slider(
|
243 |
+
label="alpha",
|
244 |
+
minimum=1,
|
245 |
+
maximum=50,
|
246 |
+
step=0.1,
|
247 |
+
value=6.0,
|
248 |
+
)
|
249 |
+
beta = gr.Slider(
|
250 |
+
label="beta",
|
251 |
+
minimum=1,
|
252 |
+
maximum=50,
|
253 |
+
step=0.1,
|
254 |
+
value=3.0,
|
255 |
+
)
|
256 |
+
gamma_plot = gr.LinePlot(
|
257 |
+
x="interpolation index",
|
258 |
+
y="coefficient",
|
259 |
+
title="Beta Distribution with Sampled Points",
|
260 |
+
height=500,
|
261 |
+
width=400,
|
262 |
+
overlay_point=True,
|
263 |
+
tooltip=["coefficient", "interpolation index"],
|
264 |
+
interactive=False,
|
265 |
+
show_label=False,
|
266 |
+
)
|
267 |
+
gamma_plot.change(
|
268 |
+
plot_gemma_fn,
|
269 |
+
inputs=[
|
270 |
+
alpha,
|
271 |
+
beta,
|
272 |
+
interpolation_size,
|
273 |
+
],
|
274 |
+
outputs=gamma_plot,
|
275 |
+
)
|
276 |
+
with gr.Group():
|
277 |
+
guidance_prompt = gr.Text(
|
278 |
+
label="Guidance prompt",
|
279 |
+
max_lines=3,
|
280 |
+
placeholder="Enter a Guidance Prompt",
|
281 |
+
interactive=True,
|
282 |
+
value="A photo of a dog driving a car, logical, best quality, extremely detailed",
|
283 |
+
)
|
284 |
+
negative_prompt = gr.Text(
|
285 |
+
label="Negative prompt",
|
286 |
+
max_lines=3,
|
287 |
+
placeholder="Enter a Negative Prompt",
|
288 |
+
interactive=True,
|
289 |
+
value="monochrome, lowres, bad anatomy, worst quality, low quality",
|
290 |
+
)
|
291 |
+
with gr.Row():
|
292 |
+
with gr.Column():
|
293 |
+
warmup_ratio = gr.Slider(
|
294 |
+
label="Warmup Ratio",
|
295 |
+
minimum=0.02,
|
296 |
+
maximum=1,
|
297 |
+
step=0.01,
|
298 |
+
value=0.122,
|
299 |
+
interactive=True,
|
300 |
+
)
|
301 |
+
guidance_scale = gr.Slider(
|
302 |
+
label="Guidance Scale",
|
303 |
+
minimum=0,
|
304 |
+
maximum=50,
|
305 |
+
step=0.1,
|
306 |
+
value=10,
|
307 |
+
interactive=True,
|
308 |
+
)
|
309 |
+
with gr.Column():
|
310 |
+
early = gr.Dropdown(
|
311 |
+
label="Early stage attention type",
|
312 |
+
choices=[
|
313 |
+
"pure_inner",
|
314 |
+
"fused_inner",
|
315 |
+
"pure_outer",
|
316 |
+
"fused_outer",
|
317 |
+
"self",
|
318 |
+
],
|
319 |
+
value="fused_outer",
|
320 |
+
type="value",
|
321 |
+
interactive=True,
|
322 |
+
)
|
323 |
+
late = gr.Dropdown(
|
324 |
+
label="Late stage attention type",
|
325 |
+
choices=[
|
326 |
+
"pure_inner",
|
327 |
+
"fused_inner",
|
328 |
+
"pure_outer",
|
329 |
+
"fused_outer",
|
330 |
+
"self",
|
331 |
+
],
|
332 |
+
value="self",
|
333 |
+
type="value",
|
334 |
+
interactive=True,
|
335 |
+
)
|
336 |
+
num_inference_steps = gr.Slider(
|
337 |
+
label="Inference Steps",
|
338 |
+
minimum=25,
|
339 |
+
maximum=50,
|
340 |
+
step=1,
|
341 |
+
value=50,
|
342 |
+
interactive=True,
|
343 |
+
)
|
344 |
+
with gr.Row():
|
345 |
+
model_choice = gr.Dropdown(
|
346 |
+
["SD1.4-521", "SD1.5-512", "SD2.1-768", "SDXL-1024"],
|
347 |
+
label="Model",
|
348 |
+
value="SD1.5-512",
|
349 |
+
interactive=True,
|
350 |
+
)
|
351 |
+
with gr.Column():
|
352 |
+
seed = gr.Slider(
|
353 |
+
label="Seed",
|
354 |
+
minimum=0,
|
355 |
+
maximum=MAX_SEED,
|
356 |
+
step=1,
|
357 |
+
value=1002,
|
358 |
+
)
|
359 |
+
same_latent = gr.Checkbox(
|
360 |
+
label="Same latent",
|
361 |
+
value=True,
|
362 |
+
info="Use the same latent for start and end images",
|
363 |
+
show_label=True,
|
364 |
+
)
|
365 |
+
|
366 |
+
gr.Examples(
|
367 |
+
examples=get_example(),
|
368 |
+
inputs=[
|
369 |
+
prompt1,
|
370 |
+
prompt2,
|
371 |
+
interpolation_size,
|
372 |
+
alpha,
|
373 |
+
beta,
|
374 |
+
guidance_prompt,
|
375 |
+
negative_prompt,
|
376 |
+
model_choice,
|
377 |
+
warmup_ratio,
|
378 |
+
guidance_scale,
|
379 |
+
num_inference_steps,
|
380 |
+
early,
|
381 |
+
late,
|
382 |
+
seed,
|
383 |
+
same_latent,
|
384 |
+
],
|
385 |
+
outputs=result,
|
386 |
+
fn=generate,
|
387 |
+
cache_examples=CACHE_EXAMPLES,
|
388 |
+
)
|
389 |
+
|
390 |
+
alpha.change(
|
391 |
+
fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
|
392 |
+
)
|
393 |
+
beta.change(
|
394 |
+
fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
|
395 |
+
)
|
396 |
+
interpolation_size.change(
|
397 |
+
fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
|
398 |
+
)
|
399 |
+
model_choice.change(fn=change_model_fn, inputs=model_choice)
|
400 |
+
inputs = [
|
401 |
+
prompt1,
|
402 |
+
prompt2,
|
403 |
+
guidance_prompt,
|
404 |
+
negative_prompt,
|
405 |
+
warmup_ratio,
|
406 |
+
guidance_scale,
|
407 |
+
early,
|
408 |
+
late,
|
409 |
+
alpha,
|
410 |
+
beta,
|
411 |
+
interpolation_size,
|
412 |
+
seed,
|
413 |
+
same_latent,
|
414 |
+
num_inference_steps,
|
415 |
+
]
|
416 |
+
generate_button.click(
|
417 |
+
fn=dynamic_gallery_fn,
|
418 |
+
inputs=interpolation_size,
|
419 |
+
outputs=result,
|
420 |
+
).then(
|
421 |
+
fn=generate,
|
422 |
+
inputs=inputs,
|
423 |
+
outputs=result,
|
424 |
+
)
|
425 |
+
gr.Markdown(article)
|
426 |
+
|
427 |
+
with gr.Blocks(css="style.css") as demo_with_history:
|
428 |
+
with gr.Tab("App"):
|
429 |
+
demo.render()
|
430 |
+
|
431 |
+
if __name__ == "__main__":
|
432 |
+
demo_with_history.queue(max_size=20).launch()
|
interpolation.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import FloatTensor, LongTensor, Size, Tensor
|
5 |
+
|
6 |
+
from prior import generate_beta_tensor
|
7 |
+
|
8 |
+
|
9 |
+
class OuterInterpolatedAttnProcessor:
|
10 |
+
r"""
|
11 |
+
Personalized processor for performing outer attention interpolation.
|
12 |
+
|
13 |
+
The attention output of interpolated image is obtained by:
|
14 |
+
(1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m;
|
15 |
+
If fused with self-attention:
|
16 |
+
(1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t];
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
t: Optional[float] = None,
|
22 |
+
size: int = 7,
|
23 |
+
is_fused: bool = False,
|
24 |
+
alpha: float = 1,
|
25 |
+
beta: float = 1,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
t: float, interpolation point between 0 and 1, if specified, size is set to 3
|
29 |
+
"""
|
30 |
+
if t is None:
|
31 |
+
ts = generate_beta_tensor(size, alpha=alpha, beta=beta)
|
32 |
+
ts[0], ts[-1] = 0, 1
|
33 |
+
else:
|
34 |
+
assert t > 0 and t < 1, "t must be between 0 and 1"
|
35 |
+
ts = [0, t, 1]
|
36 |
+
ts = torch.tensor(ts)
|
37 |
+
size = 3
|
38 |
+
|
39 |
+
self.size = size
|
40 |
+
self.coef = ts
|
41 |
+
self.is_fused = is_fused
|
42 |
+
|
43 |
+
def __call__(
|
44 |
+
self,
|
45 |
+
attn,
|
46 |
+
hidden_states: torch.FloatTensor,
|
47 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
48 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
49 |
+
temb: Optional[torch.FloatTensor] = None,
|
50 |
+
) -> torch.Tensor:
|
51 |
+
residual = hidden_states
|
52 |
+
|
53 |
+
if attn.spatial_norm is not None:
|
54 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
55 |
+
|
56 |
+
input_ndim = hidden_states.ndim
|
57 |
+
|
58 |
+
if input_ndim == 4:
|
59 |
+
batch_size, channel, height, width = hidden_states.shape
|
60 |
+
hidden_states = hidden_states.view(
|
61 |
+
batch_size, channel, height * width
|
62 |
+
).transpose(1, 2)
|
63 |
+
|
64 |
+
batch_size, sequence_length, _ = (
|
65 |
+
hidden_states.shape
|
66 |
+
if encoder_hidden_states is None
|
67 |
+
else encoder_hidden_states.shape
|
68 |
+
)
|
69 |
+
attention_mask = attn.prepare_attention_mask(
|
70 |
+
attention_mask, sequence_length, batch_size
|
71 |
+
)
|
72 |
+
|
73 |
+
if attn.group_norm is not None:
|
74 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
75 |
+
1, 2
|
76 |
+
)
|
77 |
+
|
78 |
+
query = attn.to_q(hidden_states)
|
79 |
+
query = attn.head_to_batch_dim(query)
|
80 |
+
|
81 |
+
if encoder_hidden_states is None:
|
82 |
+
encoder_hidden_states = hidden_states
|
83 |
+
elif attn.norm_cross:
|
84 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
85 |
+
encoder_hidden_states
|
86 |
+
)
|
87 |
+
|
88 |
+
key = attn.to_k(encoder_hidden_states)
|
89 |
+
value = attn.to_v(encoder_hidden_states)
|
90 |
+
|
91 |
+
# Specify the first and last key and value
|
92 |
+
key_begin = key[0:1]
|
93 |
+
key_end = key[-1:]
|
94 |
+
value_begin = value[0:1]
|
95 |
+
value_end = value[-1:]
|
96 |
+
|
97 |
+
key_begin = torch.cat([key_begin] * (self.size))
|
98 |
+
key_end = torch.cat([key_end] * (self.size))
|
99 |
+
value_begin = torch.cat([value_begin] * (self.size))
|
100 |
+
value_end = torch.cat([value_end] * (self.size))
|
101 |
+
|
102 |
+
key_begin = attn.head_to_batch_dim(key_begin)
|
103 |
+
value_begin = attn.head_to_batch_dim(value_begin)
|
104 |
+
key_end = attn.head_to_batch_dim(key_end)
|
105 |
+
value_end = attn.head_to_batch_dim(value_end)
|
106 |
+
|
107 |
+
# Fused with self-attention
|
108 |
+
if self.is_fused:
|
109 |
+
key = attn.head_to_batch_dim(key)
|
110 |
+
value = attn.head_to_batch_dim(value)
|
111 |
+
key_end = torch.cat([key, key_end], dim=-2)
|
112 |
+
value_end = torch.cat([value, value_end], dim=-2)
|
113 |
+
key_begin = torch.cat([key, key_begin], dim=-2)
|
114 |
+
value_begin = torch.cat([value, value_begin], dim=-2)
|
115 |
+
|
116 |
+
attention_probs_end = attn.get_attention_scores(query, key_end, attention_mask)
|
117 |
+
hidden_states_end = torch.bmm(attention_probs_end, value_end)
|
118 |
+
hidden_states_end = attn.batch_to_head_dim(hidden_states_end)
|
119 |
+
|
120 |
+
attention_probs_begin = attn.get_attention_scores(
|
121 |
+
query, key_begin, attention_mask
|
122 |
+
)
|
123 |
+
hidden_states_begin = torch.bmm(attention_probs_begin, value_begin)
|
124 |
+
hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin)
|
125 |
+
|
126 |
+
# Apply outer interpolation on attention
|
127 |
+
coef = self.coef.reshape(-1, 1, 1)
|
128 |
+
coef = coef.to(key.device, key.dtype)
|
129 |
+
hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end
|
130 |
+
|
131 |
+
hidden_states = attn.to_out[0](hidden_states)
|
132 |
+
hidden_states = attn.to_out[1](hidden_states)
|
133 |
+
|
134 |
+
if input_ndim == 4:
|
135 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
136 |
+
batch_size, channel, height, width
|
137 |
+
)
|
138 |
+
|
139 |
+
if attn.residual_connection:
|
140 |
+
hidden_states = hidden_states + residual
|
141 |
+
|
142 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
143 |
+
|
144 |
+
return hidden_states
|
145 |
+
|
146 |
+
|
147 |
+
class InnerInterpolatedAttnProcessor:
|
148 |
+
r"""
|
149 |
+
Personalized processor for performing inner attention interpolation.
|
150 |
+
|
151 |
+
The attention output of interpolated image is obtained by:
|
152 |
+
(1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m;
|
153 |
+
If fused with self-attention:
|
154 |
+
(1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t];
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
t: Optional[float] = None,
|
160 |
+
size: int = 7,
|
161 |
+
is_fused: bool = False,
|
162 |
+
alpha: float = 1,
|
163 |
+
beta: float = 1,
|
164 |
+
):
|
165 |
+
"""
|
166 |
+
t: float, interpolation point between 0 and 1, if specified, size is set to 3
|
167 |
+
"""
|
168 |
+
if t is None:
|
169 |
+
ts = generate_beta_tensor(size, alpha=alpha, beta=beta)
|
170 |
+
ts[0], ts[-1] = 0, 1
|
171 |
+
else:
|
172 |
+
assert t > 0 and t < 1, "t must be between 0 and 1"
|
173 |
+
ts = [0, t, 1]
|
174 |
+
ts = torch.tensor(ts)
|
175 |
+
size = 3
|
176 |
+
|
177 |
+
self.size = size
|
178 |
+
self.coef = ts
|
179 |
+
self.is_fused = is_fused
|
180 |
+
|
181 |
+
def __call__(
|
182 |
+
self,
|
183 |
+
attn,
|
184 |
+
hidden_states: torch.FloatTensor,
|
185 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
186 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
187 |
+
temb: Optional[torch.FloatTensor] = None,
|
188 |
+
) -> torch.Tensor:
|
189 |
+
residual = hidden_states
|
190 |
+
|
191 |
+
if attn.spatial_norm is not None:
|
192 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
193 |
+
|
194 |
+
input_ndim = hidden_states.ndim
|
195 |
+
|
196 |
+
if input_ndim == 4:
|
197 |
+
batch_size, channel, height, width = hidden_states.shape
|
198 |
+
hidden_states = hidden_states.view(
|
199 |
+
batch_size, channel, height * width
|
200 |
+
).transpose(1, 2)
|
201 |
+
|
202 |
+
batch_size, sequence_length, _ = (
|
203 |
+
hidden_states.shape
|
204 |
+
if encoder_hidden_states is None
|
205 |
+
else encoder_hidden_states.shape
|
206 |
+
)
|
207 |
+
attention_mask = attn.prepare_attention_mask(
|
208 |
+
attention_mask, sequence_length, batch_size
|
209 |
+
)
|
210 |
+
|
211 |
+
if attn.group_norm is not None:
|
212 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
213 |
+
1, 2
|
214 |
+
)
|
215 |
+
|
216 |
+
query = attn.to_q(hidden_states)
|
217 |
+
query = attn.head_to_batch_dim(query)
|
218 |
+
|
219 |
+
if encoder_hidden_states is None:
|
220 |
+
encoder_hidden_states = hidden_states
|
221 |
+
elif attn.norm_cross:
|
222 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
223 |
+
encoder_hidden_states
|
224 |
+
)
|
225 |
+
|
226 |
+
key = attn.to_k(encoder_hidden_states)
|
227 |
+
value = attn.to_v(encoder_hidden_states)
|
228 |
+
|
229 |
+
# Specify the first and last key and value
|
230 |
+
key_start = key[0:1]
|
231 |
+
key_end = key[-1:]
|
232 |
+
value_start = value[0:1]
|
233 |
+
value_end = value[-1:]
|
234 |
+
|
235 |
+
key_start = torch.cat([key_start] * (self.size))
|
236 |
+
key_end = torch.cat([key_end] * (self.size))
|
237 |
+
value_start = torch.cat([value_start] * (self.size))
|
238 |
+
value_end = torch.cat([value_end] * (self.size))
|
239 |
+
|
240 |
+
# Apply inner interpolation on attention
|
241 |
+
coef = self.coef.reshape(-1, 1, 1)
|
242 |
+
coef = coef.to(key.device, key.dtype)
|
243 |
+
key_cross = (1 - coef) * key_start + coef * key_end
|
244 |
+
value_cross = (1 - coef) * value_start + coef * value_end
|
245 |
+
|
246 |
+
key_cross = attn.head_to_batch_dim(key_cross)
|
247 |
+
value_cross = attn.head_to_batch_dim(value_cross)
|
248 |
+
|
249 |
+
# Fused with self-attention
|
250 |
+
if self.is_fused:
|
251 |
+
key = attn.head_to_batch_dim(key)
|
252 |
+
value = attn.head_to_batch_dim(value)
|
253 |
+
key_cross = torch.cat([key, key_cross], dim=-2)
|
254 |
+
value_cross = torch.cat([value, value_cross], dim=-2)
|
255 |
+
|
256 |
+
attention_probs = attn.get_attention_scores(query, key_cross, attention_mask)
|
257 |
+
|
258 |
+
hidden_states = torch.bmm(attention_probs, value_cross)
|
259 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
260 |
+
hidden_states = attn.to_out[0](hidden_states)
|
261 |
+
hidden_states = attn.to_out[1](hidden_states)
|
262 |
+
|
263 |
+
if input_ndim == 4:
|
264 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
265 |
+
batch_size, channel, height, width
|
266 |
+
)
|
267 |
+
|
268 |
+
if attn.residual_connection:
|
269 |
+
hidden_states = hidden_states + residual
|
270 |
+
|
271 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
272 |
+
|
273 |
+
return hidden_states
|
274 |
+
|
275 |
+
|
276 |
+
def linear_interpolation(
|
277 |
+
l1: FloatTensor, l2: FloatTensor, ts: Optional[FloatTensor] = None, size: int = 5
|
278 |
+
) -> FloatTensor:
|
279 |
+
"""
|
280 |
+
Linear interpolation
|
281 |
+
|
282 |
+
Args:
|
283 |
+
l1: Starting vector: (1, *)
|
284 |
+
l2: Final vector: (1, *)
|
285 |
+
ts: FloatTensor, interpolation points between 0 and 1
|
286 |
+
size: int, number of interpolation points including l1 and l2
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
Interpolated vectors: (size, *)
|
290 |
+
"""
|
291 |
+
assert l1.shape == l2.shape, "shapes of l1 and l2 must match"
|
292 |
+
|
293 |
+
res = []
|
294 |
+
if ts is not None:
|
295 |
+
for t in ts:
|
296 |
+
li = torch.lerp(l1, l2, t)
|
297 |
+
res.append(li)
|
298 |
+
else:
|
299 |
+
for i in range(size):
|
300 |
+
t = i / (size - 1)
|
301 |
+
li = torch.lerp(l1, l2, t)
|
302 |
+
res.append(li)
|
303 |
+
res = torch.cat(res, dim=0)
|
304 |
+
return res
|
305 |
+
|
306 |
+
|
307 |
+
def spherical_interpolation(l1: FloatTensor, l2: FloatTensor, size=5) -> FloatTensor:
|
308 |
+
"""
|
309 |
+
Spherical interpolation
|
310 |
+
|
311 |
+
Args:
|
312 |
+
l1: Starting vector: (1, *)
|
313 |
+
l2: Final vector: (1, *)
|
314 |
+
size: int, number of interpolation points including l1 and l2
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
Interpolated vectors: (size, *)
|
318 |
+
"""
|
319 |
+
assert l1.shape == l2.shape, "shapes of l1 and l2 must match"
|
320 |
+
|
321 |
+
res = []
|
322 |
+
for i in range(size):
|
323 |
+
t = i / (size - 1)
|
324 |
+
li = slerp(l1, l2, t)
|
325 |
+
res.append(li)
|
326 |
+
res = torch.cat(res, dim=0)
|
327 |
+
return res
|
328 |
+
|
329 |
+
|
330 |
+
def slerp(v0: FloatTensor, v1: FloatTensor, t, threshold=0.9995):
|
331 |
+
"""
|
332 |
+
Spherical linear interpolation
|
333 |
+
Args:
|
334 |
+
v0: Starting vector
|
335 |
+
v1: Final vector
|
336 |
+
t: Float value between 0.0 and 1.0
|
337 |
+
threshold: Threshold for considering the two vectors as
|
338 |
+
colinear. Not recommended to alter this.
|
339 |
+
Returns:
|
340 |
+
Interpolation vector between v0 and v1
|
341 |
+
"""
|
342 |
+
assert v0.shape == v1.shape, "shapes of v0 and v1 must match"
|
343 |
+
|
344 |
+
# Normalize the vectors to get the directions and angles
|
345 |
+
v0_norm: FloatTensor = torch.norm(v0, dim=-1)
|
346 |
+
v1_norm: FloatTensor = torch.norm(v1, dim=-1)
|
347 |
+
|
348 |
+
v0_normed: FloatTensor = v0 / v0_norm.unsqueeze(-1)
|
349 |
+
v1_normed: FloatTensor = v1 / v1_norm.unsqueeze(-1)
|
350 |
+
|
351 |
+
# Dot product with the normalized vectors
|
352 |
+
dot: FloatTensor = (v0_normed * v1_normed).sum(-1)
|
353 |
+
dot_mag: FloatTensor = dot.abs()
|
354 |
+
|
355 |
+
# if dp is NaN, it's because the v0 or v1 row was filled with 0s
|
356 |
+
# If absolute value of dot product is almost 1, vectors are ~colinear, so use torch.lerp
|
357 |
+
gotta_lerp: LongTensor = dot_mag.isnan() | (dot_mag > threshold)
|
358 |
+
can_slerp: LongTensor = ~gotta_lerp
|
359 |
+
|
360 |
+
t_batch_dim_count: int = max(0, t.dim() - v0.dim()) if isinstance(t, Tensor) else 0
|
361 |
+
t_batch_dims: Size = (
|
362 |
+
t.shape[:t_batch_dim_count] if isinstance(t, Tensor) else Size([])
|
363 |
+
)
|
364 |
+
out: FloatTensor = torch.zeros_like(v0.expand(*t_batch_dims, *[-1] * v0.dim()))
|
365 |
+
|
366 |
+
# if no elements are lerpable, our vectors become 0-dimensional, preventing broadcasting
|
367 |
+
if gotta_lerp.any():
|
368 |
+
lerped: FloatTensor = torch.lerp(v0, v1, t)
|
369 |
+
|
370 |
+
out: FloatTensor = lerped.where(gotta_lerp.unsqueeze(-1), out)
|
371 |
+
|
372 |
+
# if no elements are slerpable, our vectors become 0-dimensional, preventing broadcasting
|
373 |
+
if can_slerp.any():
|
374 |
+
|
375 |
+
# Calculate initial angle between v0 and v1
|
376 |
+
theta_0: FloatTensor = dot.arccos().unsqueeze(-1)
|
377 |
+
sin_theta_0: FloatTensor = theta_0.sin()
|
378 |
+
# Angle at timestep t
|
379 |
+
theta_t: FloatTensor = theta_0 * t
|
380 |
+
sin_theta_t: FloatTensor = theta_t.sin()
|
381 |
+
# Finish the slerp algorithm
|
382 |
+
s0: FloatTensor = (theta_0 - theta_t).sin() / sin_theta_0
|
383 |
+
s1: FloatTensor = sin_theta_t / sin_theta_0
|
384 |
+
slerped: FloatTensor = s0 * v0 + s1 * v1
|
385 |
+
|
386 |
+
out: FloatTensor = slerped.where(can_slerp.unsqueeze(-1), out)
|
387 |
+
|
388 |
+
return out
|
pipeline_interpolated_sdxl.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pipeline_interpolated_stable_diffusion.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from diffusers import (
|
6 |
+
AutoencoderKL,
|
7 |
+
DDIMScheduler,
|
8 |
+
SchedulerMixin,
|
9 |
+
UNet2DConditionModel,
|
10 |
+
UniPCMultistepScheduler,
|
11 |
+
)
|
12 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
15 |
+
|
16 |
+
from interpolation import (
|
17 |
+
InnerInterpolatedAttnProcessor,
|
18 |
+
OuterInterpolatedAttnProcessor,
|
19 |
+
generate_beta_tensor,
|
20 |
+
linear_interpolation,
|
21 |
+
slerp,
|
22 |
+
spherical_interpolation,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class InterpolationStableDiffusionPipeline:
|
27 |
+
"""
|
28 |
+
Diffusion Pipeline that generates interpolated images
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
repo_name: str = "CompVis/stable-diffusion-v1-4",
|
34 |
+
scheduler_name: str = "ddim",
|
35 |
+
frozen: bool = True,
|
36 |
+
guidance_scale: float = 7.5,
|
37 |
+
scheduler: Optional[SchedulerMixin] = None,
|
38 |
+
cache_dir: Optional[str] = None,
|
39 |
+
):
|
40 |
+
|
41 |
+
# Initialize the generator
|
42 |
+
self.vae = AutoencoderKL.from_pretrained(
|
43 |
+
repo_name, subfolder="vae", use_safetensors=True, cache_dir=cache_dir
|
44 |
+
)
|
45 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
46 |
+
repo_name, subfolder="tokenizer", cache_dir=cache_dir
|
47 |
+
)
|
48 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
49 |
+
repo_name,
|
50 |
+
subfolder="text_encoder",
|
51 |
+
use_safetensors=True,
|
52 |
+
cache_dir=cache_dir,
|
53 |
+
)
|
54 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
55 |
+
repo_name, subfolder="unet", use_safetensors=True, cache_dir=cache_dir
|
56 |
+
)
|
57 |
+
|
58 |
+
# Initialize the scheduler
|
59 |
+
if scheduler is not None:
|
60 |
+
self.scheduler = scheduler
|
61 |
+
elif scheduler_name == "ddim":
|
62 |
+
self.scheduler = DDIMScheduler.from_pretrained(
|
63 |
+
repo_name, subfolder="scheduler", cache_dir=cache_dir
|
64 |
+
)
|
65 |
+
elif scheduler_name == "unipc":
|
66 |
+
self.scheduler = UniPCMultistepScheduler.from_pretrained(
|
67 |
+
repo_name, subfolder="scheduler", cache_dir=cache_dir
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
raise ValueError(
|
71 |
+
"Invalid scheduler name (ddim, unipc) and not specify scheduler."
|
72 |
+
)
|
73 |
+
|
74 |
+
# Setup device
|
75 |
+
|
76 |
+
self.guidance_scale = guidance_scale # Scale for classifier-free guidance
|
77 |
+
|
78 |
+
if frozen:
|
79 |
+
for param in self.unet.parameters():
|
80 |
+
param.requires_grad = False
|
81 |
+
|
82 |
+
for param in self.text_encoder.parameters():
|
83 |
+
param.requires_grad = False
|
84 |
+
|
85 |
+
for param in self.vae.parameters():
|
86 |
+
param.requires_grad = False
|
87 |
+
|
88 |
+
def to(self, *args, **kwargs):
|
89 |
+
self.vae.to(*args, **kwargs)
|
90 |
+
self.text_encoder.to(*args, **kwargs)
|
91 |
+
self.unet.to(*args, **kwargs)
|
92 |
+
|
93 |
+
def generate_latent(
|
94 |
+
self, generator: Optional[torch.Generator] = None, torch_device: str = "cpu"
|
95 |
+
) -> torch.FloatTensor:
|
96 |
+
"""
|
97 |
+
Generates a random latent tensor.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
generator (Optional[torch.Generator], optional): Generator for random number generation. Defaults to None.
|
101 |
+
torch_device (str, optional): Device to store the tensor. Defaults to "cpu".
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
torch.FloatTensor: Random latent tensor.
|
105 |
+
"""
|
106 |
+
channel = self.unet.config.in_channels
|
107 |
+
height = self.unet.config.sample_size
|
108 |
+
width = self.unet.config.sample_size
|
109 |
+
if generator is None:
|
110 |
+
latent = torch.randn(
|
111 |
+
(1, channel, height, width),
|
112 |
+
device=torch_device,
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
latent = torch.randn(
|
116 |
+
(1, channel, height, width),
|
117 |
+
generator=generator,
|
118 |
+
device=torch_device,
|
119 |
+
)
|
120 |
+
return latent
|
121 |
+
|
122 |
+
@torch.no_grad()
|
123 |
+
def prompt_to_embedding(
|
124 |
+
self, prompt: str, negative_prompt: str = ""
|
125 |
+
) -> torch.FloatTensor:
|
126 |
+
"""
|
127 |
+
Prepare the text prompt for the diffusion process
|
128 |
+
|
129 |
+
Args:
|
130 |
+
prompt: str, text prompt
|
131 |
+
negative_prompt: str, negative text prompt
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
FloatTensor, text embeddings
|
135 |
+
"""
|
136 |
+
|
137 |
+
text_input = self.tokenizer(
|
138 |
+
prompt,
|
139 |
+
padding="max_length",
|
140 |
+
max_length=self.tokenizer.model_max_length,
|
141 |
+
truncation=True,
|
142 |
+
return_tensors="pt",
|
143 |
+
)
|
144 |
+
|
145 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.torch_device))[
|
146 |
+
0
|
147 |
+
]
|
148 |
+
|
149 |
+
uncond_input = self.tokenizer(
|
150 |
+
negative_prompt,
|
151 |
+
padding="max_length",
|
152 |
+
max_length=self.tokenizer.model_max_length,
|
153 |
+
truncation=True,
|
154 |
+
return_tensors="pt",
|
155 |
+
)
|
156 |
+
uncond_embeddings = self.text_encoder(
|
157 |
+
uncond_input.input_ids.to(self.torch_device)
|
158 |
+
)[0]
|
159 |
+
|
160 |
+
text_embeddings = torch.cat([text_embeddings, uncond_embeddings])
|
161 |
+
return text_embeddings
|
162 |
+
|
163 |
+
@torch.no_grad()
|
164 |
+
def interpolate(
|
165 |
+
self,
|
166 |
+
latent_start: torch.FloatTensor,
|
167 |
+
latent_end: torch.FloatTensor,
|
168 |
+
prompt_start: str,
|
169 |
+
prompt_end: str,
|
170 |
+
guide_prompt: Optional[str] = None,
|
171 |
+
negative_prompt: str = "",
|
172 |
+
size: int = 7,
|
173 |
+
num_inference_steps: int = 25,
|
174 |
+
warmup_ratio: float = 0.5,
|
175 |
+
early: str = "fused_outer",
|
176 |
+
late: str = "self",
|
177 |
+
alpha: Optional[float] = None,
|
178 |
+
beta: Optional[float] = None,
|
179 |
+
guidance_scale: Optional[float] = None,
|
180 |
+
) -> np.ndarray:
|
181 |
+
"""
|
182 |
+
Interpolate between two generation
|
183 |
+
|
184 |
+
Args:
|
185 |
+
latent_start: FloatTensor, latent vector of the first image
|
186 |
+
latent_end: FloatTensor, latent vector of the second image
|
187 |
+
prompt_start: str, text prompt of the first image
|
188 |
+
prompt_end: str, text prompt of the second image
|
189 |
+
guide_prompt: str, text prompt for the interpolation
|
190 |
+
negative_prompt: str, negative text prompt
|
191 |
+
size: int, number of interpolations including starting and ending points
|
192 |
+
num_inference_steps: int, number of inference steps in scheduler
|
193 |
+
warmup_ratio: float, ratio of warmup steps
|
194 |
+
early: str, warmup interpolation methods
|
195 |
+
late: str, late interpolation methods
|
196 |
+
alpha: float, alpha parameter for beta distribution
|
197 |
+
beta: float, beta parameter for beta distribution
|
198 |
+
guidance_scale: Optional[float], scale for classifier-free guidance
|
199 |
+
Returns:
|
200 |
+
Numpy array of interpolated images, shape (size, H, W, 3)
|
201 |
+
"""
|
202 |
+
# Specify alpha and beta
|
203 |
+
self.torch_device = self.unet.device
|
204 |
+
if alpha is None:
|
205 |
+
alpha = num_inference_steps
|
206 |
+
if beta is None:
|
207 |
+
beta = num_inference_steps
|
208 |
+
if guidance_scale is None:
|
209 |
+
guidance_scale = self.guidance_scale
|
210 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
211 |
+
|
212 |
+
# Prepare interpolated latents and embeddings
|
213 |
+
latents = spherical_interpolation(latent_start, latent_end, size)
|
214 |
+
embs_start = self.prompt_to_embedding(prompt_start, negative_prompt)
|
215 |
+
emb_start = embs_start[0:1]
|
216 |
+
uncond_emb_start = embs_start[1:2]
|
217 |
+
embs_end = self.prompt_to_embedding(prompt_end, negative_prompt)
|
218 |
+
emb_end = embs_end[0:1]
|
219 |
+
uncond_emb_end = embs_end[1:2]
|
220 |
+
|
221 |
+
# Perform prompt guidance if it is specified
|
222 |
+
if guide_prompt is not None:
|
223 |
+
guide_embs = self.prompt_to_embedding(guide_prompt, negative_prompt)
|
224 |
+
guide_emb = guide_embs[0:1]
|
225 |
+
uncond_guide_emb = guide_embs[1:2]
|
226 |
+
embs = torch.cat([emb_start] + [guide_emb] * (size - 2) + [emb_end], dim=0)
|
227 |
+
uncond_embs = torch.cat(
|
228 |
+
[uncond_emb_start] + [uncond_guide_emb] * (size - 2) + [uncond_emb_end],
|
229 |
+
dim=0,
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
embs = linear_interpolation(emb_start, emb_end, size=size)
|
233 |
+
uncond_embs = linear_interpolation(
|
234 |
+
uncond_emb_start, uncond_emb_end, size=size
|
235 |
+
)
|
236 |
+
|
237 |
+
# Specify the interpolation methods
|
238 |
+
pure_inner_attn_proc = InnerInterpolatedAttnProcessor(
|
239 |
+
size=size,
|
240 |
+
is_fused=False,
|
241 |
+
alpha=alpha,
|
242 |
+
beta=beta,
|
243 |
+
)
|
244 |
+
fused_inner_attn_proc = InnerInterpolatedAttnProcessor(
|
245 |
+
size=size,
|
246 |
+
is_fused=True,
|
247 |
+
alpha=alpha,
|
248 |
+
beta=beta,
|
249 |
+
)
|
250 |
+
pure_outer_attn_proc = OuterInterpolatedAttnProcessor(
|
251 |
+
size=size,
|
252 |
+
is_fused=False,
|
253 |
+
alpha=alpha,
|
254 |
+
beta=beta,
|
255 |
+
)
|
256 |
+
fused_outer_attn_proc = OuterInterpolatedAttnProcessor(
|
257 |
+
size=size,
|
258 |
+
is_fused=True,
|
259 |
+
alpha=alpha,
|
260 |
+
beta=beta,
|
261 |
+
)
|
262 |
+
self_attn_proc = AttnProcessor2_0()
|
263 |
+
procs_dict = {
|
264 |
+
"pure_inner": pure_inner_attn_proc,
|
265 |
+
"fused_inner": fused_inner_attn_proc,
|
266 |
+
"pure_outer": pure_outer_attn_proc,
|
267 |
+
"fused_outer": fused_outer_attn_proc,
|
268 |
+
"self": self_attn_proc,
|
269 |
+
}
|
270 |
+
|
271 |
+
# Denoising process
|
272 |
+
i = 0
|
273 |
+
warmup_step = int(num_inference_steps * warmup_ratio)
|
274 |
+
for t in tqdm(self.scheduler.timesteps):
|
275 |
+
i += 1
|
276 |
+
latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
|
277 |
+
with torch.no_grad():
|
278 |
+
# Change attention module
|
279 |
+
if i < warmup_step:
|
280 |
+
interpolate_attn_proc = procs_dict[early]
|
281 |
+
else:
|
282 |
+
interpolate_attn_proc = procs_dict[late]
|
283 |
+
self.unet.set_attn_processor(processor=interpolate_attn_proc)
|
284 |
+
|
285 |
+
# Predict the noise residual
|
286 |
+
noise_pred = self.unet(
|
287 |
+
latent_model_input, t, encoder_hidden_states=embs
|
288 |
+
).sample
|
289 |
+
attn_proc = AttnProcessor()
|
290 |
+
self.unet.set_attn_processor(processor=attn_proc)
|
291 |
+
noise_uncond = self.unet(
|
292 |
+
latent_model_input, t, encoder_hidden_states=uncond_embs
|
293 |
+
).sample
|
294 |
+
# perform guidance
|
295 |
+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
296 |
+
# compute the previous noisy sample x_t -> x_t-1
|
297 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
298 |
+
|
299 |
+
# Decode the images
|
300 |
+
latents = 1 / 0.18215 * latents
|
301 |
+
with torch.no_grad():
|
302 |
+
image = self.vae.decode(latents).sample
|
303 |
+
images = (image / 2 + 0.5).clamp(0, 1)
|
304 |
+
images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
|
305 |
+
return images
|
306 |
+
|
307 |
+
@torch.no_grad()
|
308 |
+
def interpolate_save_gpu(
|
309 |
+
self,
|
310 |
+
latent_start: torch.FloatTensor,
|
311 |
+
latent_end: torch.FloatTensor,
|
312 |
+
prompt_start: str,
|
313 |
+
prompt_end: str,
|
314 |
+
guide_prompt: Optional[str] = None,
|
315 |
+
negative_prompt: str = "",
|
316 |
+
size: int = 7,
|
317 |
+
num_inference_steps: int = 25,
|
318 |
+
warmup_ratio: float = 0.5,
|
319 |
+
early: str = "fused_outer",
|
320 |
+
late: str = "self",
|
321 |
+
alpha: Optional[float] = None,
|
322 |
+
beta: Optional[float] = None,
|
323 |
+
init: str = "linear",
|
324 |
+
guidance_scale: Optional[float] = None,
|
325 |
+
) -> np.ndarray:
|
326 |
+
"""
|
327 |
+
Interpolate between two generation
|
328 |
+
|
329 |
+
Args:
|
330 |
+
latent_start: FloatTensor, latent vector of the first image
|
331 |
+
latent_end: FloatTensor, latent vector of the second image
|
332 |
+
prompt_start: str, text prompt of the first image
|
333 |
+
prompt_end: str, text prompt of the second image
|
334 |
+
guide_prompt: str, text prompt for the interpolation
|
335 |
+
negative_prompt: str, negative text prompt
|
336 |
+
size: int, number of interpolations including starting and ending points
|
337 |
+
num_inference_steps: int, number of inference steps in scheduler
|
338 |
+
warmup_ratio: float, ratio of warmup steps
|
339 |
+
early: str, warmup interpolation methods
|
340 |
+
late: str, late interpolation methods
|
341 |
+
alpha: float, alpha parameter for beta distribution
|
342 |
+
beta: float, beta parameter for beta distribution
|
343 |
+
init: str, interpolation initialization methods
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
Numpy array of interpolated images, shape (size, H, W, 3)
|
347 |
+
"""
|
348 |
+
self.torch_device = self.unet.device
|
349 |
+
# Specify alpha and beta
|
350 |
+
if alpha is None:
|
351 |
+
alpha = num_inference_steps
|
352 |
+
if beta is None:
|
353 |
+
beta = num_inference_steps
|
354 |
+
betas = generate_beta_tensor(size, alpha=alpha, beta=beta)
|
355 |
+
final_images = None
|
356 |
+
|
357 |
+
# Generate interpolated images one by one
|
358 |
+
for i in range(size - 2):
|
359 |
+
it = betas[i + 1].item()
|
360 |
+
if init == "denoising":
|
361 |
+
images = self.denoising_interpolate(
|
362 |
+
latent_start,
|
363 |
+
prompt_start,
|
364 |
+
prompt_end,
|
365 |
+
negative_prompt,
|
366 |
+
interpolated_ratio=it,
|
367 |
+
timesteps=num_inference_steps,
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
images = self.interpolate_single(
|
371 |
+
it,
|
372 |
+
latent_start,
|
373 |
+
latent_end,
|
374 |
+
prompt_start,
|
375 |
+
prompt_end,
|
376 |
+
guide_prompt=guide_prompt,
|
377 |
+
num_inference_steps=num_inference_steps,
|
378 |
+
warmup_ratio=warmup_ratio,
|
379 |
+
early=early,
|
380 |
+
late=late,
|
381 |
+
negative_prompt=negative_prompt,
|
382 |
+
init=init,
|
383 |
+
guidance_scale=guidance_scale,
|
384 |
+
)
|
385 |
+
if size == 3:
|
386 |
+
return images
|
387 |
+
if i == 0:
|
388 |
+
final_images = images[:2]
|
389 |
+
elif i == size - 3:
|
390 |
+
final_images = np.concatenate([final_images, images[1:]], axis=0)
|
391 |
+
else:
|
392 |
+
final_images = np.concatenate([final_images, images[1:2]], axis=0)
|
393 |
+
return final_images
|
394 |
+
|
395 |
+
def interpolate_single(
|
396 |
+
self,
|
397 |
+
it,
|
398 |
+
latent_start: torch.FloatTensor,
|
399 |
+
latent_end: torch.FloatTensor,
|
400 |
+
prompt_start: str,
|
401 |
+
prompt_end: str,
|
402 |
+
guide_prompt: str = None,
|
403 |
+
negative_prompt: str = "",
|
404 |
+
num_inference_steps: int = 25,
|
405 |
+
warmup_ratio: float = 0.5,
|
406 |
+
early: str = "fused_outer",
|
407 |
+
late: str = "self",
|
408 |
+
init="linear",
|
409 |
+
guidance_scale: Optional[float] = None,
|
410 |
+
) -> np.ndarray:
|
411 |
+
"""
|
412 |
+
Interpolates between two latent vectors and generates a sequence of images.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
it (float): Interpolation factor between latent_start and latent_end.
|
416 |
+
latent_start (torch.FloatTensor): Starting latent vector.
|
417 |
+
latent_end (torch.FloatTensor): Ending latent vector.
|
418 |
+
prompt_start (str): Starting prompt for text conditioning.
|
419 |
+
prompt_end (str): Ending prompt for text conditioning.
|
420 |
+
guide_prompt (str, optional): Guiding prompt for text conditioning. Defaults to None.
|
421 |
+
negative_prompt (str, optional): Negative prompt for text conditioning. Defaults to "".
|
422 |
+
num_inference_steps (int, optional): Number of inference steps. Defaults to 25.
|
423 |
+
warmup_ratio (float, optional): Ratio of warm-up steps. Defaults to 0.5.
|
424 |
+
early (str, optional): Early attention processing method. Defaults to "fused_outer".
|
425 |
+
late (str, optional): Late attention processing method. Defaults to "self".
|
426 |
+
init (str, optional): Initialization method for interpolation. Defaults to "linear".
|
427 |
+
guidance_scale (Optional[float], optional): Scale for classifier-free guidance. Defaults to None.
|
428 |
+
Returns:
|
429 |
+
numpy.ndarray: Sequence of generated images.
|
430 |
+
"""
|
431 |
+
self.torch_device = self.unet.device
|
432 |
+
if guidance_scale is None:
|
433 |
+
guidance_scale = self.guidance_scale
|
434 |
+
|
435 |
+
# Prepare interpolated inputs
|
436 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
437 |
+
|
438 |
+
embs_start = self.prompt_to_embedding(prompt_start, negative_prompt)
|
439 |
+
emb_start = embs_start[0:1]
|
440 |
+
uncond_emb_start = embs_start[1:2]
|
441 |
+
embs_end = self.prompt_to_embedding(prompt_end, negative_prompt)
|
442 |
+
emb_end = embs_end[0:1]
|
443 |
+
uncond_emb_end = embs_end[1:2]
|
444 |
+
|
445 |
+
latent_t = slerp(latent_start, latent_end, it)
|
446 |
+
if guide_prompt is not None:
|
447 |
+
embs_guide = self.prompt_to_embedding(guide_prompt, negative_prompt)
|
448 |
+
emb_t = embs_guide[0:1]
|
449 |
+
else:
|
450 |
+
if init == "linear":
|
451 |
+
emb_t = torch.lerp(emb_start, emb_end, it)
|
452 |
+
else:
|
453 |
+
emb_t = slerp(emb_start, emb_end, it)
|
454 |
+
if init == "linear":
|
455 |
+
uncond_emb_t = torch.lerp(uncond_emb_start, uncond_emb_end, it)
|
456 |
+
else:
|
457 |
+
uncond_emb_t = slerp(uncond_emb_start, uncond_emb_end, it)
|
458 |
+
|
459 |
+
latents = torch.cat([latent_start, latent_t, latent_end], dim=0)
|
460 |
+
embs = torch.cat([emb_start, emb_t, emb_end], dim=0)
|
461 |
+
uncond_embs = torch.cat([uncond_emb_start, uncond_emb_t, uncond_emb_end], dim=0)
|
462 |
+
|
463 |
+
# Specifiy the attention processors
|
464 |
+
pure_inner_attn_proc = InnerInterpolatedAttnProcessor(
|
465 |
+
t=it,
|
466 |
+
is_fused=False,
|
467 |
+
)
|
468 |
+
fused_inner_attn_proc = InnerInterpolatedAttnProcessor(
|
469 |
+
t=it,
|
470 |
+
is_fused=True,
|
471 |
+
)
|
472 |
+
pure_outer_attn_proc = OuterInterpolatedAttnProcessor(
|
473 |
+
t=it,
|
474 |
+
is_fused=False,
|
475 |
+
)
|
476 |
+
fused_outer_attn_proc = OuterInterpolatedAttnProcessor(
|
477 |
+
t=it,
|
478 |
+
is_fused=True,
|
479 |
+
)
|
480 |
+
self_attn_proc = AttnProcessor()
|
481 |
+
procs_dict = {
|
482 |
+
"pure_inner": pure_inner_attn_proc,
|
483 |
+
"fused_inner": fused_inner_attn_proc,
|
484 |
+
"pure_outer": pure_outer_attn_proc,
|
485 |
+
"fused_outer": fused_outer_attn_proc,
|
486 |
+
"self": self_attn_proc,
|
487 |
+
}
|
488 |
+
|
489 |
+
i = 0
|
490 |
+
warmup_step = int(num_inference_steps * warmup_ratio)
|
491 |
+
for t in tqdm(self.scheduler.timesteps):
|
492 |
+
i += 1
|
493 |
+
latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
|
494 |
+
# predict the noise residual
|
495 |
+
with torch.no_grad():
|
496 |
+
# Warmup
|
497 |
+
if i < warmup_step:
|
498 |
+
interpolate_attn_proc = procs_dict[early]
|
499 |
+
else:
|
500 |
+
interpolate_attn_proc = procs_dict[late]
|
501 |
+
self.unet.set_attn_processor(processor=interpolate_attn_proc)
|
502 |
+
# predict the noise residual
|
503 |
+
noise_pred = self.unet(
|
504 |
+
latent_model_input, t, encoder_hidden_states=embs
|
505 |
+
).sample
|
506 |
+
attn_proc = AttnProcessor()
|
507 |
+
self.unet.set_attn_processor(processor=attn_proc)
|
508 |
+
noise_uncond = self.unet(
|
509 |
+
latent_model_input, t, encoder_hidden_states=uncond_embs
|
510 |
+
).sample
|
511 |
+
# perform guidance
|
512 |
+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
513 |
+
# compute the previous noisy sample x_t -> x_t-1
|
514 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
515 |
+
|
516 |
+
# Decode the images
|
517 |
+
latents = 1 / 0.18215 * latents
|
518 |
+
with torch.no_grad():
|
519 |
+
image = self.vae.decode(latents).sample
|
520 |
+
images = (image / 2 + 0.5).clamp(0, 1)
|
521 |
+
images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
|
522 |
+
return images
|
523 |
+
|
524 |
+
def denoising_interpolate(
|
525 |
+
self,
|
526 |
+
latents: torch.FloatTensor,
|
527 |
+
text_1: str,
|
528 |
+
text_2: str,
|
529 |
+
negative_prompt: str = "",
|
530 |
+
interpolated_ratio: float = 1,
|
531 |
+
timesteps: int = 25,
|
532 |
+
) -> np.ndarray:
|
533 |
+
"""
|
534 |
+
Performs denoising interpolation on the given latents.
|
535 |
+
|
536 |
+
Args:
|
537 |
+
latents (torch.Tensor): The input latents.
|
538 |
+
text_1 (str): The first text prompt.
|
539 |
+
text_2 (str): The second text prompt.
|
540 |
+
negative_prompt (str, optional): The negative text prompt. Defaults to "".
|
541 |
+
interpolated_ratio (int, optional): The ratio of interpolation between text_1 and text_2. Defaults to 1.
|
542 |
+
timesteps (int, optional): The number of timesteps for diffusion. Defaults to 25.
|
543 |
+
|
544 |
+
Returns:
|
545 |
+
numpy.ndarray: The interpolated images.
|
546 |
+
"""
|
547 |
+
self.unet.set_attn_processor(processor=AttnProcessor())
|
548 |
+
start_emb = self.prompt_to_embedding(text_1)
|
549 |
+
end_emb = self.prompt_to_embedding(text_2)
|
550 |
+
neg_emb = self.prompt_to_embedding(negative_prompt)
|
551 |
+
uncond_emb = neg_emb[0:1]
|
552 |
+
emb_1 = start_emb[0:1]
|
553 |
+
emb_2 = end_emb[0:1]
|
554 |
+
self.scheduler.set_timesteps(timesteps)
|
555 |
+
i = 0
|
556 |
+
for t in tqdm(self.scheduler.timesteps):
|
557 |
+
i += 1
|
558 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
559 |
+
latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
|
560 |
+
# predict the noise residual
|
561 |
+
with torch.no_grad():
|
562 |
+
if i < timesteps * interpolated_ratio:
|
563 |
+
noise_pred = self.unet(
|
564 |
+
latent_model_input, t, encoder_hidden_states=emb_1
|
565 |
+
).sample
|
566 |
+
else:
|
567 |
+
noise_pred = self.unet(
|
568 |
+
latent_model_input, t, encoder_hidden_states=emb_2
|
569 |
+
).sample
|
570 |
+
noise_uncond = self.unet(
|
571 |
+
latent_model_input, t, encoder_hidden_states=uncond_emb
|
572 |
+
).sample
|
573 |
+
# perform guidance
|
574 |
+
noise_pred = noise_uncond + self.guidance_scale * (
|
575 |
+
noise_pred - noise_uncond
|
576 |
+
)
|
577 |
+
# compute the previous noisy sample x_t -> x_t-1
|
578 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
579 |
+
latents = 1 / 0.18215 * latents
|
580 |
+
with torch.no_grad():
|
581 |
+
image = self.vae.decode(latents).sample
|
582 |
+
images = (image / 2 + 0.5).clamp(0, 1)
|
583 |
+
images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
|
584 |
+
return images
|
prior.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from bayes_opt import BayesianOptimization, SequentialDomainReductionTransformer
|
3 |
+
from lpips import LPIPS
|
4 |
+
from scipy.stats import beta as beta_distribution
|
5 |
+
|
6 |
+
from utils import compute_lpips, compute_smoothness_and_consistency
|
7 |
+
|
8 |
+
|
9 |
+
def bayesian_prior_selection(
|
10 |
+
interpolation_pipe,
|
11 |
+
latent1: torch.FloatTensor,
|
12 |
+
latent2: torch.FloatTensor,
|
13 |
+
prompt1: str,
|
14 |
+
prompt2: str,
|
15 |
+
lpips_model: LPIPS,
|
16 |
+
guide_prompt: str | None = None,
|
17 |
+
negative_prompt: str = "",
|
18 |
+
size: int = 3,
|
19 |
+
num_inference_steps: int = 25,
|
20 |
+
warmup_ratio: float = 1,
|
21 |
+
early: str = "vfused",
|
22 |
+
late: str = "self",
|
23 |
+
target_score: float = 0.9,
|
24 |
+
n_iter: int = 15,
|
25 |
+
p_min: float | None = None,
|
26 |
+
p_max: float | None = None,
|
27 |
+
) -> tuple:
|
28 |
+
"""
|
29 |
+
Select the alpha and beta parameters for the interpolation using Bayesian optimization.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
interpolation_pipe (any): The interpolation pipeline.
|
33 |
+
latent1 (torch.FloatTensor): The first source latent vector.
|
34 |
+
latent2 (torch.FloatTensor): The second source latent vector.
|
35 |
+
prompt1 (str): The first source prompt.
|
36 |
+
prompt2 (str): The second source prompt.
|
37 |
+
lpips_model (any): The LPIPS model used to compute perceptual distances.
|
38 |
+
guide_prompt (str | None, optional): The guide prompt for the interpolation, if any. Defaults to None.
|
39 |
+
negative_prompt (str, optional): The negative prompt for the interpolation, default to empty string. Defaults to "".
|
40 |
+
size (int, optional): The size of the interpolation sequence. Defaults to 3.
|
41 |
+
num_inference_steps (int, optional): The number of inference steps. Defaults to 25.
|
42 |
+
warmup_ratio (float, optional): The warmup ratio. Defaults to 1.
|
43 |
+
early (str, optional): The early fusion method. Defaults to "vfused".
|
44 |
+
late (str, optional): The late fusion method. Defaults to "self".
|
45 |
+
target_score (float, optional): The target score. Defaults to 0.9.
|
46 |
+
n_iter (int, optional): The maximum number of iterations. Defaults to 15.
|
47 |
+
p_min (float, optional): The minimum value of alpha and beta. Defaults to None.
|
48 |
+
p_max (float, optional): The maximum value of alpha and beta. Defaults to None.
|
49 |
+
Returns:
|
50 |
+
tuple: A tuple containing the selected alpha and beta parameters.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def get_smoothness(alpha, beta):
|
54 |
+
"""
|
55 |
+
Black-box objective function of Bayesian Optimization.
|
56 |
+
Get the smoothness of the interpolated sequence with the given alpha and beta.
|
57 |
+
"""
|
58 |
+
if alpha < beta and large_alpha_prior:
|
59 |
+
return 0
|
60 |
+
if alpha > beta and not large_alpha_prior:
|
61 |
+
return 0
|
62 |
+
if alpha == beta:
|
63 |
+
return init_smoothness
|
64 |
+
interpolation_sequence = interpolation_pipe.interpolate_save_gpu(
|
65 |
+
latent1,
|
66 |
+
latent2,
|
67 |
+
prompt1,
|
68 |
+
prompt2,
|
69 |
+
guide_prompt=guide_prompt,
|
70 |
+
negative_prompt=negative_prompt,
|
71 |
+
size=size,
|
72 |
+
num_inference_steps=num_inference_steps,
|
73 |
+
warmup_ratio=warmup_ratio,
|
74 |
+
early=early,
|
75 |
+
late=late,
|
76 |
+
alpha=alpha,
|
77 |
+
beta=beta,
|
78 |
+
)
|
79 |
+
smoothness, _, _ = compute_smoothness_and_consistency(
|
80 |
+
interpolation_sequence, lpips_model
|
81 |
+
)
|
82 |
+
return smoothness
|
83 |
+
|
84 |
+
# Add prior into selection of alpha and beta
|
85 |
+
# We firstly compute the interpolated images with t=0.5
|
86 |
+
images = interpolation_pipe.interpolate_single(
|
87 |
+
0.5,
|
88 |
+
latent1,
|
89 |
+
latent2,
|
90 |
+
prompt1,
|
91 |
+
prompt2,
|
92 |
+
guide_prompt=guide_prompt,
|
93 |
+
negative_prompt=negative_prompt,
|
94 |
+
num_inference_steps=num_inference_steps,
|
95 |
+
warmup_ratio=warmup_ratio,
|
96 |
+
early=early,
|
97 |
+
late=late,
|
98 |
+
)
|
99 |
+
# We compute the perceptual distances of the interpolated images (t=0.5) to the source image
|
100 |
+
distances = compute_lpips(images, lpips_model)
|
101 |
+
# We compute the init_smoothness as the smoothness when alpha=beta to avoid recomputation
|
102 |
+
init_smoothness, _, _ = compute_smoothness_and_consistency(images, lpips_model)
|
103 |
+
# If perceptual distance to the first source image is smaller, alpha should be larger than beta
|
104 |
+
large_alpha_prior = distances[0] < distances[1]
|
105 |
+
|
106 |
+
# Bayesian optimization configuration
|
107 |
+
num_warmup_steps = warmup_ratio * num_inference_steps
|
108 |
+
if p_min is None:
|
109 |
+
p_min = 1
|
110 |
+
if p_max is None:
|
111 |
+
p_max = num_warmup_steps
|
112 |
+
pbounds = {"alpha": (p_min, p_max), "beta": (p_min, p_max)}
|
113 |
+
bounds_transformer = SequentialDomainReductionTransformer(minimum_window=0.1)
|
114 |
+
optimizer = BayesianOptimization(
|
115 |
+
f=get_smoothness,
|
116 |
+
pbounds=pbounds,
|
117 |
+
random_state=1,
|
118 |
+
bounds_transformer=bounds_transformer,
|
119 |
+
allow_duplicate_points=True,
|
120 |
+
)
|
121 |
+
alpha_init = [p_min, (p_min + p_max) / 2, p_max]
|
122 |
+
beta_init = [p_min, (p_min + p_max) / 2, p_max]
|
123 |
+
|
124 |
+
# Initial probing
|
125 |
+
for alpha in alpha_init:
|
126 |
+
for beta in beta_init:
|
127 |
+
optimizer.probe(params={"alpha": alpha, "beta": beta}, lazy=False)
|
128 |
+
latest_result = optimizer.res[-1] # Get the last result
|
129 |
+
latest_score = latest_result["target"]
|
130 |
+
if latest_score >= target_score:
|
131 |
+
return alpha, beta
|
132 |
+
|
133 |
+
# Start optimization
|
134 |
+
for _ in range(n_iter): # Max iterations
|
135 |
+
optimizer.maximize(init_points=0, n_iter=1) # One iteration at a time
|
136 |
+
max_score = optimizer.max["target"] # Get the highest score so far
|
137 |
+
if max_score >= target_score:
|
138 |
+
print(f"Stopping early, target of {target_score} reached.")
|
139 |
+
break # Exit the loop if target is reached or exceeded
|
140 |
+
|
141 |
+
results = optimizer.max
|
142 |
+
alpha = results["params"]["alpha"]
|
143 |
+
beta = results["params"]["beta"]
|
144 |
+
return alpha, beta
|
145 |
+
|
146 |
+
|
147 |
+
def generate_beta_tensor(
|
148 |
+
size: int, alpha: float = 3, beta: float = 3
|
149 |
+
) -> torch.FloatTensor:
|
150 |
+
"""
|
151 |
+
Assume size as n
|
152 |
+
Generates a PyTorch tensor of values [x0, x1, ..., xn-1] for the Beta distribution
|
153 |
+
where each xi satisfies F(xi) = i/(n-1) for the CDF F of the Beta distribution.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
size (int): The number of values to generate.
|
157 |
+
alpha (float): The alpha parameter of the Beta distribution.
|
158 |
+
beta (float): The beta parameter of the Beta distribution.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
torch.Tensor: A tensor of the inverse CDF values of the Beta distribution.
|
162 |
+
"""
|
163 |
+
# Generating the inverse CDF values
|
164 |
+
prob_values = [i / (size - 1) for i in range(size)]
|
165 |
+
inverse_cdf_values = beta_distribution.ppf(prob_values, alpha, beta)
|
166 |
+
|
167 |
+
# Converting to a PyTorch tensor
|
168 |
+
return torch.tensor(inverse_cdf_values, dtype=torch.float32)
|
requirements.txt
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==0.27.2
|
3 |
+
addict==2.4.0
|
4 |
+
antlr4-python3-runtime==4.9.3
|
5 |
+
bayesian-optimization==1.4.3
|
6 |
+
clean-fid==0.1.35
|
7 |
+
clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
|
8 |
+
colorama==0.4.6
|
9 |
+
contourpy==1.2.0
|
10 |
+
cycler==0.12.1
|
11 |
+
diffusers==0.27.1
|
12 |
+
einops==0.7.0
|
13 |
+
facexlib==0.3.0
|
14 |
+
filterpy==1.4.5
|
15 |
+
fonttools==4.49.0
|
16 |
+
fsspec==2024.2.0
|
17 |
+
ftfy==6.1.3
|
18 |
+
future==1.0.0
|
19 |
+
grpcio==1.62.0
|
20 |
+
huggingface-hub==0.20.3
|
21 |
+
imageio==2.34.0
|
22 |
+
imgaug==0.4.0
|
23 |
+
joblib==1.3.2
|
24 |
+
kiwisolver==1.4.5
|
25 |
+
lazy_loader==0.3
|
26 |
+
llvmlite==0.42.0
|
27 |
+
lmdb==1.4.1
|
28 |
+
lpips==0.1.4
|
29 |
+
Markdown==3.5.2
|
30 |
+
matplotlib==3.8.3
|
31 |
+
mkl-service==2.4.0
|
32 |
+
numba==0.59.0
|
33 |
+
numpy==1.24.4
|
34 |
+
omegaconf==2.3.0
|
35 |
+
openai-clip==1.0.1
|
36 |
+
opencv-python==4.9.0.80
|
37 |
+
pandas==2.2.0
|
38 |
+
protobuf==4.25.3
|
39 |
+
pyiqa==0.1.10
|
40 |
+
pyparsing==3.1.1
|
41 |
+
python-dateutil==2.8.2
|
42 |
+
pytorch-fid==0.3.0
|
43 |
+
pytz==2024.1
|
44 |
+
regex==2023.12.25
|
45 |
+
safetensors==0.4.2
|
46 |
+
scikit-image==0.22.0
|
47 |
+
scikit-learn==1.4.1.post1
|
48 |
+
scipy==1.9.1
|
49 |
+
shapely==2.0.3
|
50 |
+
tensorboard==2.16.2
|
51 |
+
tensorboard-data-server==0.7.2
|
52 |
+
threadpoolctl==3.3.0
|
53 |
+
tifffile==2024.2.12
|
54 |
+
timm==0.9.16
|
55 |
+
tokenizers==0.15.2
|
56 |
+
tomli==2.0.1
|
57 |
+
torch==2.1.0
|
58 |
+
torchaudio==2.1.0
|
59 |
+
torchvision==0.16.0
|
60 |
+
tqdm==4.66.2
|
61 |
+
transformers==4.38.2
|
62 |
+
triton==2.1.0
|
63 |
+
tzdata==2024.1
|
64 |
+
Werkzeug==3.0.1
|
65 |
+
yapf==0.40.2
|
style.css
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
justify-content: center;
|
4 |
+
}
|
5 |
+
|
6 |
+
[role="tabpanel"] {
|
7 |
+
border: 0
|
8 |
+
}
|
9 |
+
|
10 |
+
#duplicate-button {
|
11 |
+
margin: auto;
|
12 |
+
color: #fff;
|
13 |
+
background: #1565c0;
|
14 |
+
border-radius: 100vh;
|
15 |
+
}
|
16 |
+
|
17 |
+
.gradio-container {
|
18 |
+
max-width: 690px ! important;
|
19 |
+
}
|
20 |
+
|
21 |
+
#share-btn-container {
|
22 |
+
padding-left: 0.5rem !important;
|
23 |
+
padding-right: 0.5rem !important;
|
24 |
+
background-color: #000000;
|
25 |
+
justify-content: center;
|
26 |
+
align-items: center;
|
27 |
+
border-radius: 9999px !important;
|
28 |
+
max-width: 13rem;
|
29 |
+
margin-left: auto;
|
30 |
+
margin-top: 0.35em;
|
31 |
+
}
|
32 |
+
|
33 |
+
div#share-btn-container>div {
|
34 |
+
flex-direction: row;
|
35 |
+
background: black;
|
36 |
+
align-items: center
|
37 |
+
}
|
38 |
+
|
39 |
+
#share-btn-container:hover {
|
40 |
+
background-color: #060606
|
41 |
+
}
|
42 |
+
|
43 |
+
#share-btn {
|
44 |
+
all: initial;
|
45 |
+
color: #ffffff;
|
46 |
+
font-weight: 600;
|
47 |
+
cursor: pointer;
|
48 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
49 |
+
margin-left: 0.5rem !important;
|
50 |
+
padding-top: 0.5rem !important;
|
51 |
+
padding-bottom: 0.5rem !important;
|
52 |
+
right: 0;
|
53 |
+
font-size: 15px;
|
54 |
+
}
|
55 |
+
|
56 |
+
#share-btn * {
|
57 |
+
all: unset
|
58 |
+
}
|
59 |
+
|
60 |
+
#share-btn-container div:nth-child(-n+2) {
|
61 |
+
width: auto !important;
|
62 |
+
min-height: 0px !important;
|
63 |
+
}
|
64 |
+
|
65 |
+
#share-btn-container .wrap {
|
66 |
+
display: none !important
|
67 |
+
}
|
68 |
+
|
69 |
+
#share-btn-container.hidden {
|
70 |
+
display: none !important
|
71 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from lpips import LPIPS
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
|
11 |
+
|
12 |
+
def show_images_horizontally(
|
13 |
+
list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False
|
14 |
+
) -> None:
|
15 |
+
"""
|
16 |
+
Visualize the list of images horizontally and save the figure as PNG.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
list_of_files: The list of images as numpy array with shape (N, H, W, C).
|
20 |
+
output_file: The output file path to save the figure as PNG.
|
21 |
+
interact: Whether to show the figure interactively in Jupyter Notebook or not in Python.
|
22 |
+
"""
|
23 |
+
number_of_files = len(list_of_files)
|
24 |
+
|
25 |
+
heights = [a[0].shape[0] for a in list_of_files]
|
26 |
+
widths = [a.shape[1] for a in list_of_files[0]]
|
27 |
+
|
28 |
+
fig_width = 8.0 # inches
|
29 |
+
fig_height = fig_width * sum(heights) / sum(widths)
|
30 |
+
|
31 |
+
# Create a figure with subplots
|
32 |
+
_, axs = plt.subplots(
|
33 |
+
1, number_of_files, figsize=(fig_width * number_of_files, fig_height)
|
34 |
+
)
|
35 |
+
plt.tight_layout()
|
36 |
+
for i in range(number_of_files):
|
37 |
+
_image = list_of_files[i]
|
38 |
+
axs[i].imshow(_image)
|
39 |
+
axs[i].axis("off")
|
40 |
+
|
41 |
+
# Save the figure as PNG
|
42 |
+
if interact:
|
43 |
+
plt.show()
|
44 |
+
else:
|
45 |
+
plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25)
|
46 |
+
|
47 |
+
|
48 |
+
def save_image(image: np.array, file_name: str) -> None:
|
49 |
+
"""
|
50 |
+
Save the image as JPG.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
image: The input image as numpy array with shape (H, W, C).
|
54 |
+
file_name: The file name to save the image.
|
55 |
+
"""
|
56 |
+
image = Image.fromarray(image)
|
57 |
+
image.save(file_name)
|
58 |
+
|
59 |
+
|
60 |
+
def load_and_process_images(load_dir: str) -> np.array:
|
61 |
+
"""
|
62 |
+
Load and process the images into numpy array from the directory.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
load_dir: The directory to load the images.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
images: The images as numpy array with shape (N, H, W, C).
|
69 |
+
"""
|
70 |
+
images = []
|
71 |
+
print(load_dir)
|
72 |
+
filenames = sorted(
|
73 |
+
os.listdir(load_dir), key=lambda x: int(x.split(".")[0])
|
74 |
+
) # Ensure the files are sorted numerically
|
75 |
+
for filename in filenames:
|
76 |
+
if filename.endswith(".jpg"):
|
77 |
+
img = Image.open(os.path.join(load_dir, filename))
|
78 |
+
img_array = (
|
79 |
+
np.asarray(img) / 255.0
|
80 |
+
) # Convert to numpy array and scale pixel values to [0, 1]
|
81 |
+
images.append(img_array)
|
82 |
+
return images
|
83 |
+
|
84 |
+
|
85 |
+
def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array:
|
86 |
+
"""
|
87 |
+
Compute the LPIPS of the input images.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
images: The input images as numpy array with shape (N, H, W, C).
|
91 |
+
lpips_model: The LPIPS model used to compute perceptual distances.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
distances: The LPIPS of the input images.
|
95 |
+
"""
|
96 |
+
# Get device of lpips_model
|
97 |
+
device = next(lpips_model.parameters()).device
|
98 |
+
device = str(device)
|
99 |
+
|
100 |
+
# Change the input images into tensor
|
101 |
+
images = torch.tensor(images).to(device).float()
|
102 |
+
images = torch.permute(images, (0, 3, 1, 2))
|
103 |
+
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
104 |
+
images = normalize(images)
|
105 |
+
|
106 |
+
# Compute the LPIPS between each adjacent input images
|
107 |
+
distances = []
|
108 |
+
for i in range(images.shape[0]):
|
109 |
+
if i == images.shape[0] - 1:
|
110 |
+
break
|
111 |
+
img1 = images[i].unsqueeze(0)
|
112 |
+
img2 = images[i + 1].unsqueeze(0)
|
113 |
+
loss = lpips_model(img1, img2)
|
114 |
+
distances.append(loss.item())
|
115 |
+
distances = np.array(distances)
|
116 |
+
return distances
|
117 |
+
|
118 |
+
|
119 |
+
def compute_gini(distances: np.array) -> float:
|
120 |
+
"""
|
121 |
+
Compute the Gini index of the input distances.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
distances: The input distances as numpy array.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
gini: The Gini index of the input distances.
|
128 |
+
"""
|
129 |
+
if len(distances) < 2:
|
130 |
+
return 0.0 # Gini index is 0 for less than two elements
|
131 |
+
|
132 |
+
# Sort the list of distances
|
133 |
+
sorted_distances = sorted(distances)
|
134 |
+
n = len(sorted_distances)
|
135 |
+
mean_distance = sum(sorted_distances) / n
|
136 |
+
|
137 |
+
# Compute the sum of absolute differences
|
138 |
+
sum_of_differences = 0
|
139 |
+
for di in sorted_distances:
|
140 |
+
for dj in sorted_distances:
|
141 |
+
sum_of_differences += abs(di - dj)
|
142 |
+
|
143 |
+
# Normalize the sum of differences by the mean and the number of elements
|
144 |
+
gini = sum_of_differences / (2 * n * n * mean_distance)
|
145 |
+
return gini
|
146 |
+
|
147 |
+
|
148 |
+
def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple:
|
149 |
+
"""
|
150 |
+
Compute the smoothness and efficiency of the input images.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
images: The input images as numpy array with shape (N, H, W, C).
|
154 |
+
lpips_model: The LPIPS model used to compute perceptual distances.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
smoothness: One minus gini index of LPIPS of consecutive images.
|
158 |
+
consistency: The mean LPIPS of consecutive images.
|
159 |
+
max_inception_distance: The maximum LPIPS of consecutive images.
|
160 |
+
"""
|
161 |
+
distances = compute_lpips(images, lpips_model)
|
162 |
+
smoothness = 1 - compute_gini(distances)
|
163 |
+
consistency = np.mean(distances)
|
164 |
+
max_inception_distance = np.max(distances)
|
165 |
+
return smoothness, consistency, max_inception_distance
|
166 |
+
|
167 |
+
|
168 |
+
def separate_source_and_interpolated_images(images: np.array) -> tuple:
|
169 |
+
"""
|
170 |
+
Separate the input images into source and interpolated images.
|
171 |
+
The input source is the start and end of the images, while the interpolated images are the rest.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
images: The input images as numpy array with shape (N, H, W, C).
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
source: The source images as numpy array with shape (2, H, W, C).
|
178 |
+
interpolation: The interpolated images as numpy array with shape (N-2, H, W, C).
|
179 |
+
"""
|
180 |
+
# Check if the array has at least two elements
|
181 |
+
if len(images) < 2:
|
182 |
+
raise ValueError("The input array should have at least two elements.")
|
183 |
+
|
184 |
+
# Separate the array into two parts
|
185 |
+
# First part takes the first and last element
|
186 |
+
source = np.array([images[0], images[-1]])
|
187 |
+
# Second part takes the rest of the elements
|
188 |
+
interpolation = images[1:-1]
|
189 |
+
return source, interpolation
|