wjh commited on
Commit
67e6974
1 Parent(s): c464026
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import user_history
9
+ from PIL import Image
10
+ from scipy.stats import beta as beta_distribution
11
+
12
+ from pipeline_interpolated_sdxl import InterpolationStableDiffusionXLPipeline
13
+ from pipeline_interpolated_stable_diffusion import InterpolationStableDiffusionPipeline
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ title = r"""
18
+ <h1 align="center">PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</h1>
19
+ """
20
+
21
+ description = r"""
22
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</b></a>.<br>
23
+ How to use:<br>
24
+ 1. Input prompt 1 and prompt 2.
25
+ 2. (Optional) Input the guidance prompt and negative prompt.
26
+ 3. (Optional) Change the interpolation parameters and check the Beta distribution.
27
+ 4. Click the <b>Generate</b> button to begin generating images.
28
+ 5. Enjoy! 😊"""
29
+
30
+ article = r"""
31
+ ---
32
+ ✒️ **Citation**
33
+ <br>
34
+ If you found this demo/our paper useful, please consider citing:
35
+ ```bibtex
36
+ @article{he024paid,
37
+ title={PAID:(Prompt-guided) Attention Interpolation of Text-to-Image Diffusion},
38
+ author={He, Qiyuan and Wang, Jinghao and Liu, Ziwei and Angle, Yao},
39
+ journal={},
40
+ year={2024}
41
+ }
42
+ ```
43
+ 📧 **Contact**
44
+ <br>
45
+ If you have any questions, please feel free to open an issue in our <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>Github Repo</b></a> or directly reach us out at <b>qhe@u.nus.edu.sg</b>.
46
+ """
47
+
48
+ MAX_SEED = np.iinfo(np.int32).max
49
+ CACHE_EXAMPLES = False
50
+ USE_TORCH_COMPILE = False
51
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
52
+ PREVIEW_IMAGES = False
53
+
54
+ dtype = torch.float32
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ pipeline = InterpolationStableDiffusionPipeline(
57
+ repo_name="runwayml/stable-diffusion-v1-5",
58
+ guidance_scale=10.0,
59
+ scheduler_name="unipc",
60
+ )
61
+ pipeline.to(device, dtype=dtype)
62
+
63
+
64
+ def change_model_fn(model_name: str) -> None:
65
+ global pipeline
66
+ name_mapping = {
67
+ "SD1.4-521": "CompVis/stable-diffusion-v1-4",
68
+ "SD1.5-512": "runwayml/stable-diffusion-v1-5",
69
+ "SD2.1-768": "stabilityai/stable-diffusion-2-1",
70
+ "SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
71
+ }
72
+ if "XL" not in model_name:
73
+ pipeline = InterpolationStableDiffusionPipeline(
74
+ repo_name=name_mapping[model_name],
75
+ guidance_scale=10.0,
76
+ scheduler_name="unipc",
77
+ )
78
+ pipeline.to(device, dtype=dtype)
79
+ else:
80
+ pipeline = InterpolationStableDiffusionXLPipeline.from_pretrained(
81
+ name_mapping[model_name]
82
+ )
83
+ pipeline.to(device, dtype=dtype)
84
+
85
+
86
+ def save_image(img, index):
87
+ unique_name = f"{index}.png"
88
+ img = Image.fromarray(img)
89
+ img.save(unique_name)
90
+ return unique_name
91
+
92
+
93
+ def generate_beta_tensor(
94
+ size: int, alpha: float = 3.0, beta: float = 3.0
95
+ ) -> torch.FloatTensor:
96
+ prob_values = [i / (size - 1) for i in range(size)]
97
+ inverse_cdf_values = beta_distribution.ppf(prob_values, alpha, beta)
98
+ return inverse_cdf_values
99
+
100
+
101
+ def plot_gemma_fn(alpha: float, beta: float, size: int) -> pd.DataFrame:
102
+ beta_ppf = generate_beta_tensor(size=size, alpha=int(alpha), beta=int(beta))
103
+ return pd.DataFrame(
104
+ {
105
+ "interpolation index": [i for i in range(size)],
106
+ "coefficient": beta_ppf.tolist(),
107
+ }
108
+ )
109
+
110
+
111
+ def get_example() -> list:
112
+ case = [
113
+ [
114
+ "A photo of dog, best quality, extremely detailed",
115
+ "A photo of car, best quality, extremely detailed",
116
+ 3,
117
+ 6,
118
+ 3,
119
+ "A photo of a dog driving a car, logical, best quality, extremely detailed",
120
+ "monochrome, lowres, bad anatomy, worst quality, low quality",
121
+ "SD1.5-512",
122
+ 6.1 / 50,
123
+ 10,
124
+ 50,
125
+ "fused_inner",
126
+ "self",
127
+ 1002,
128
+ True,
129
+ ]
130
+ ]
131
+ return case
132
+
133
+
134
+ def dynamic_gallery_fn(interpolation_size: int):
135
+
136
+ return gr.Gallery(
137
+ label="Result", show_label=False, rows=1, columns=interpolation_size
138
+ )
139
+
140
+
141
+ @torch.no_grad()
142
+ def generate(
143
+ prompt1: str,
144
+ prompt2: str,
145
+ guidance_prompt: Optional[str] = None,
146
+ negative_prompt: str = "",
147
+ warmup_ratio: int = 8,
148
+ guidance_scale: float = 10,
149
+ early: str = "fused_outer",
150
+ late: str = "self",
151
+ alpha: float = 4.0,
152
+ beta: float = 4.0,
153
+ interpolation_size: int = 3,
154
+ seed: int = 0,
155
+ same_latent: bool = True,
156
+ num_inference_steps: int = 50,
157
+ progress=gr.Progress(),
158
+ ) -> np.ndarray:
159
+ global pipeline
160
+ generator = (
161
+ torch.cuda.manual_seed(seed)
162
+ if torch.cuda.is_available()
163
+ else torch.manual_seed(seed)
164
+ )
165
+ latent1 = pipeline.generate_latent(generator=generator)
166
+ latent1 = latent1.to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
167
+ if same_latent:
168
+ latent2 = latent1.clone()
169
+ else:
170
+ latent2 = pipeline.generate_latent(generator=generator)
171
+ latent2 = latent2.to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
172
+ betas = generate_beta_tensor(size=interpolation_size, alpha=alpha, beta=beta)
173
+ for i in progress.tqdm(
174
+ range(interpolation_size - 2),
175
+ desc=(
176
+ f"Generating {interpolation_size-2} images"
177
+ if interpolation_size > 3
178
+ else "Generating 1 image"
179
+ ),
180
+ ):
181
+ it = betas[i + 1].item()
182
+ images = pipeline.interpolate_single(
183
+ it,
184
+ latent_start=latent1,
185
+ latent_end=latent2,
186
+ prompt_start=prompt1,
187
+ prompt_end=prompt2,
188
+ guide_prompt=guidance_prompt,
189
+ num_inference_steps=num_inference_steps,
190
+ warmup_ratio=warmup_ratio,
191
+ early=early,
192
+ late=late,
193
+ negative_prompt=negative_prompt,
194
+ guidance_scale=guidance_scale,
195
+ )
196
+ if interpolation_size == 3:
197
+ final_images = images
198
+ break
199
+ if i == 0:
200
+ final_images = images[:2]
201
+ elif i == interpolation_size - 3:
202
+ final_images = np.concatenate([final_images, images[1:]], axis=0)
203
+ else:
204
+ final_images = np.concatenate([final_images, images[1:2]], axis=0)
205
+ return final_images
206
+
207
+
208
+ interpolation_size = None
209
+
210
+ with gr.Blocks() as demo:
211
+ gr.Markdown(title)
212
+ gr.Markdown(description)
213
+ with gr.Group():
214
+ prompt1 = gr.Text(
215
+ label="Prompt 1",
216
+ max_lines=3,
217
+ placeholder="Enter the First Prompt",
218
+ interactive=True,
219
+ value="A photo of dog, best quality, extremely detailed",
220
+ )
221
+ prompt2 = gr.Text(
222
+ label="Prompt 2",
223
+ max_lines=3,
224
+ placeholder="Enter the Second prompt",
225
+ interactive=True,
226
+ value="A photo of car, best quality, extremely detaile",
227
+ )
228
+ result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3)
229
+ generate_button = gr.Button("Generate", variant="primary")
230
+ with gr.Accordion("Advanced options", open=True):
231
+ with gr.Group():
232
+ with gr.Row():
233
+ with gr.Column():
234
+ interpolation_size = gr.Slider(
235
+ label="Interpolation Size",
236
+ minimum=3,
237
+ maximum=15,
238
+ step=1,
239
+ value=3,
240
+ info="Interpolation size includes the start and end images",
241
+ )
242
+ alpha = gr.Slider(
243
+ label="alpha",
244
+ minimum=1,
245
+ maximum=50,
246
+ step=0.1,
247
+ value=6.0,
248
+ )
249
+ beta = gr.Slider(
250
+ label="beta",
251
+ minimum=1,
252
+ maximum=50,
253
+ step=0.1,
254
+ value=3.0,
255
+ )
256
+ gamma_plot = gr.LinePlot(
257
+ x="interpolation index",
258
+ y="coefficient",
259
+ title="Beta Distribution with Sampled Points",
260
+ height=500,
261
+ width=400,
262
+ overlay_point=True,
263
+ tooltip=["coefficient", "interpolation index"],
264
+ interactive=False,
265
+ show_label=False,
266
+ )
267
+ gamma_plot.change(
268
+ plot_gemma_fn,
269
+ inputs=[
270
+ alpha,
271
+ beta,
272
+ interpolation_size,
273
+ ],
274
+ outputs=gamma_plot,
275
+ )
276
+ with gr.Group():
277
+ guidance_prompt = gr.Text(
278
+ label="Guidance prompt",
279
+ max_lines=3,
280
+ placeholder="Enter a Guidance Prompt",
281
+ interactive=True,
282
+ value="A photo of a dog driving a car, logical, best quality, extremely detailed",
283
+ )
284
+ negative_prompt = gr.Text(
285
+ label="Negative prompt",
286
+ max_lines=3,
287
+ placeholder="Enter a Negative Prompt",
288
+ interactive=True,
289
+ value="monochrome, lowres, bad anatomy, worst quality, low quality",
290
+ )
291
+ with gr.Row():
292
+ with gr.Column():
293
+ warmup_ratio = gr.Slider(
294
+ label="Warmup Ratio",
295
+ minimum=0.02,
296
+ maximum=1,
297
+ step=0.01,
298
+ value=0.122,
299
+ interactive=True,
300
+ )
301
+ guidance_scale = gr.Slider(
302
+ label="Guidance Scale",
303
+ minimum=0,
304
+ maximum=50,
305
+ step=0.1,
306
+ value=10,
307
+ interactive=True,
308
+ )
309
+ with gr.Column():
310
+ early = gr.Dropdown(
311
+ label="Early stage attention type",
312
+ choices=[
313
+ "pure_inner",
314
+ "fused_inner",
315
+ "pure_outer",
316
+ "fused_outer",
317
+ "self",
318
+ ],
319
+ value="fused_outer",
320
+ type="value",
321
+ interactive=True,
322
+ )
323
+ late = gr.Dropdown(
324
+ label="Late stage attention type",
325
+ choices=[
326
+ "pure_inner",
327
+ "fused_inner",
328
+ "pure_outer",
329
+ "fused_outer",
330
+ "self",
331
+ ],
332
+ value="self",
333
+ type="value",
334
+ interactive=True,
335
+ )
336
+ num_inference_steps = gr.Slider(
337
+ label="Inference Steps",
338
+ minimum=25,
339
+ maximum=50,
340
+ step=1,
341
+ value=50,
342
+ interactive=True,
343
+ )
344
+ with gr.Row():
345
+ model_choice = gr.Dropdown(
346
+ ["SD1.4-521", "SD1.5-512", "SD2.1-768", "SDXL-1024"],
347
+ label="Model",
348
+ value="SD1.5-512",
349
+ interactive=True,
350
+ )
351
+ with gr.Column():
352
+ seed = gr.Slider(
353
+ label="Seed",
354
+ minimum=0,
355
+ maximum=MAX_SEED,
356
+ step=1,
357
+ value=1002,
358
+ )
359
+ same_latent = gr.Checkbox(
360
+ label="Same latent",
361
+ value=True,
362
+ info="Use the same latent for start and end images",
363
+ show_label=True,
364
+ )
365
+
366
+ gr.Examples(
367
+ examples=get_example(),
368
+ inputs=[
369
+ prompt1,
370
+ prompt2,
371
+ interpolation_size,
372
+ alpha,
373
+ beta,
374
+ guidance_prompt,
375
+ negative_prompt,
376
+ model_choice,
377
+ warmup_ratio,
378
+ guidance_scale,
379
+ num_inference_steps,
380
+ early,
381
+ late,
382
+ seed,
383
+ same_latent,
384
+ ],
385
+ outputs=result,
386
+ fn=generate,
387
+ cache_examples=CACHE_EXAMPLES,
388
+ )
389
+
390
+ alpha.change(
391
+ fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
392
+ )
393
+ beta.change(
394
+ fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
395
+ )
396
+ interpolation_size.change(
397
+ fn=plot_gemma_fn, inputs=[alpha, beta, interpolation_size], outputs=gamma_plot
398
+ )
399
+ model_choice.change(fn=change_model_fn, inputs=model_choice)
400
+ inputs = [
401
+ prompt1,
402
+ prompt2,
403
+ guidance_prompt,
404
+ negative_prompt,
405
+ warmup_ratio,
406
+ guidance_scale,
407
+ early,
408
+ late,
409
+ alpha,
410
+ beta,
411
+ interpolation_size,
412
+ seed,
413
+ same_latent,
414
+ num_inference_steps,
415
+ ]
416
+ generate_button.click(
417
+ fn=dynamic_gallery_fn,
418
+ inputs=interpolation_size,
419
+ outputs=result,
420
+ ).then(
421
+ fn=generate,
422
+ inputs=inputs,
423
+ outputs=result,
424
+ )
425
+ gr.Markdown(article)
426
+
427
+ with gr.Blocks(css="style.css") as demo_with_history:
428
+ with gr.Tab("App"):
429
+ demo.render()
430
+
431
+ if __name__ == "__main__":
432
+ demo_with_history.queue(max_size=20).launch()
interpolation.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import FloatTensor, LongTensor, Size, Tensor
5
+
6
+ from prior import generate_beta_tensor
7
+
8
+
9
+ class OuterInterpolatedAttnProcessor:
10
+ r"""
11
+ Personalized processor for performing outer attention interpolation.
12
+
13
+ The attention output of interpolated image is obtained by:
14
+ (1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m;
15
+ If fused with self-attention:
16
+ (1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t];
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ t: Optional[float] = None,
22
+ size: int = 7,
23
+ is_fused: bool = False,
24
+ alpha: float = 1,
25
+ beta: float = 1,
26
+ ):
27
+ """
28
+ t: float, interpolation point between 0 and 1, if specified, size is set to 3
29
+ """
30
+ if t is None:
31
+ ts = generate_beta_tensor(size, alpha=alpha, beta=beta)
32
+ ts[0], ts[-1] = 0, 1
33
+ else:
34
+ assert t > 0 and t < 1, "t must be between 0 and 1"
35
+ ts = [0, t, 1]
36
+ ts = torch.tensor(ts)
37
+ size = 3
38
+
39
+ self.size = size
40
+ self.coef = ts
41
+ self.is_fused = is_fused
42
+
43
+ def __call__(
44
+ self,
45
+ attn,
46
+ hidden_states: torch.FloatTensor,
47
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
48
+ attention_mask: Optional[torch.FloatTensor] = None,
49
+ temb: Optional[torch.FloatTensor] = None,
50
+ ) -> torch.Tensor:
51
+ residual = hidden_states
52
+
53
+ if attn.spatial_norm is not None:
54
+ hidden_states = attn.spatial_norm(hidden_states, temb)
55
+
56
+ input_ndim = hidden_states.ndim
57
+
58
+ if input_ndim == 4:
59
+ batch_size, channel, height, width = hidden_states.shape
60
+ hidden_states = hidden_states.view(
61
+ batch_size, channel, height * width
62
+ ).transpose(1, 2)
63
+
64
+ batch_size, sequence_length, _ = (
65
+ hidden_states.shape
66
+ if encoder_hidden_states is None
67
+ else encoder_hidden_states.shape
68
+ )
69
+ attention_mask = attn.prepare_attention_mask(
70
+ attention_mask, sequence_length, batch_size
71
+ )
72
+
73
+ if attn.group_norm is not None:
74
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
75
+ 1, 2
76
+ )
77
+
78
+ query = attn.to_q(hidden_states)
79
+ query = attn.head_to_batch_dim(query)
80
+
81
+ if encoder_hidden_states is None:
82
+ encoder_hidden_states = hidden_states
83
+ elif attn.norm_cross:
84
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
85
+ encoder_hidden_states
86
+ )
87
+
88
+ key = attn.to_k(encoder_hidden_states)
89
+ value = attn.to_v(encoder_hidden_states)
90
+
91
+ # Specify the first and last key and value
92
+ key_begin = key[0:1]
93
+ key_end = key[-1:]
94
+ value_begin = value[0:1]
95
+ value_end = value[-1:]
96
+
97
+ key_begin = torch.cat([key_begin] * (self.size))
98
+ key_end = torch.cat([key_end] * (self.size))
99
+ value_begin = torch.cat([value_begin] * (self.size))
100
+ value_end = torch.cat([value_end] * (self.size))
101
+
102
+ key_begin = attn.head_to_batch_dim(key_begin)
103
+ value_begin = attn.head_to_batch_dim(value_begin)
104
+ key_end = attn.head_to_batch_dim(key_end)
105
+ value_end = attn.head_to_batch_dim(value_end)
106
+
107
+ # Fused with self-attention
108
+ if self.is_fused:
109
+ key = attn.head_to_batch_dim(key)
110
+ value = attn.head_to_batch_dim(value)
111
+ key_end = torch.cat([key, key_end], dim=-2)
112
+ value_end = torch.cat([value, value_end], dim=-2)
113
+ key_begin = torch.cat([key, key_begin], dim=-2)
114
+ value_begin = torch.cat([value, value_begin], dim=-2)
115
+
116
+ attention_probs_end = attn.get_attention_scores(query, key_end, attention_mask)
117
+ hidden_states_end = torch.bmm(attention_probs_end, value_end)
118
+ hidden_states_end = attn.batch_to_head_dim(hidden_states_end)
119
+
120
+ attention_probs_begin = attn.get_attention_scores(
121
+ query, key_begin, attention_mask
122
+ )
123
+ hidden_states_begin = torch.bmm(attention_probs_begin, value_begin)
124
+ hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin)
125
+
126
+ # Apply outer interpolation on attention
127
+ coef = self.coef.reshape(-1, 1, 1)
128
+ coef = coef.to(key.device, key.dtype)
129
+ hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end
130
+
131
+ hidden_states = attn.to_out[0](hidden_states)
132
+ hidden_states = attn.to_out[1](hidden_states)
133
+
134
+ if input_ndim == 4:
135
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
136
+ batch_size, channel, height, width
137
+ )
138
+
139
+ if attn.residual_connection:
140
+ hidden_states = hidden_states + residual
141
+
142
+ hidden_states = hidden_states / attn.rescale_output_factor
143
+
144
+ return hidden_states
145
+
146
+
147
+ class InnerInterpolatedAttnProcessor:
148
+ r"""
149
+ Personalized processor for performing inner attention interpolation.
150
+
151
+ The attention output of interpolated image is obtained by:
152
+ (1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m;
153
+ If fused with self-attention:
154
+ (1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t];
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ t: Optional[float] = None,
160
+ size: int = 7,
161
+ is_fused: bool = False,
162
+ alpha: float = 1,
163
+ beta: float = 1,
164
+ ):
165
+ """
166
+ t: float, interpolation point between 0 and 1, if specified, size is set to 3
167
+ """
168
+ if t is None:
169
+ ts = generate_beta_tensor(size, alpha=alpha, beta=beta)
170
+ ts[0], ts[-1] = 0, 1
171
+ else:
172
+ assert t > 0 and t < 1, "t must be between 0 and 1"
173
+ ts = [0, t, 1]
174
+ ts = torch.tensor(ts)
175
+ size = 3
176
+
177
+ self.size = size
178
+ self.coef = ts
179
+ self.is_fused = is_fused
180
+
181
+ def __call__(
182
+ self,
183
+ attn,
184
+ hidden_states: torch.FloatTensor,
185
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
186
+ attention_mask: Optional[torch.FloatTensor] = None,
187
+ temb: Optional[torch.FloatTensor] = None,
188
+ ) -> torch.Tensor:
189
+ residual = hidden_states
190
+
191
+ if attn.spatial_norm is not None:
192
+ hidden_states = attn.spatial_norm(hidden_states, temb)
193
+
194
+ input_ndim = hidden_states.ndim
195
+
196
+ if input_ndim == 4:
197
+ batch_size, channel, height, width = hidden_states.shape
198
+ hidden_states = hidden_states.view(
199
+ batch_size, channel, height * width
200
+ ).transpose(1, 2)
201
+
202
+ batch_size, sequence_length, _ = (
203
+ hidden_states.shape
204
+ if encoder_hidden_states is None
205
+ else encoder_hidden_states.shape
206
+ )
207
+ attention_mask = attn.prepare_attention_mask(
208
+ attention_mask, sequence_length, batch_size
209
+ )
210
+
211
+ if attn.group_norm is not None:
212
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
213
+ 1, 2
214
+ )
215
+
216
+ query = attn.to_q(hidden_states)
217
+ query = attn.head_to_batch_dim(query)
218
+
219
+ if encoder_hidden_states is None:
220
+ encoder_hidden_states = hidden_states
221
+ elif attn.norm_cross:
222
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
223
+ encoder_hidden_states
224
+ )
225
+
226
+ key = attn.to_k(encoder_hidden_states)
227
+ value = attn.to_v(encoder_hidden_states)
228
+
229
+ # Specify the first and last key and value
230
+ key_start = key[0:1]
231
+ key_end = key[-1:]
232
+ value_start = value[0:1]
233
+ value_end = value[-1:]
234
+
235
+ key_start = torch.cat([key_start] * (self.size))
236
+ key_end = torch.cat([key_end] * (self.size))
237
+ value_start = torch.cat([value_start] * (self.size))
238
+ value_end = torch.cat([value_end] * (self.size))
239
+
240
+ # Apply inner interpolation on attention
241
+ coef = self.coef.reshape(-1, 1, 1)
242
+ coef = coef.to(key.device, key.dtype)
243
+ key_cross = (1 - coef) * key_start + coef * key_end
244
+ value_cross = (1 - coef) * value_start + coef * value_end
245
+
246
+ key_cross = attn.head_to_batch_dim(key_cross)
247
+ value_cross = attn.head_to_batch_dim(value_cross)
248
+
249
+ # Fused with self-attention
250
+ if self.is_fused:
251
+ key = attn.head_to_batch_dim(key)
252
+ value = attn.head_to_batch_dim(value)
253
+ key_cross = torch.cat([key, key_cross], dim=-2)
254
+ value_cross = torch.cat([value, value_cross], dim=-2)
255
+
256
+ attention_probs = attn.get_attention_scores(query, key_cross, attention_mask)
257
+
258
+ hidden_states = torch.bmm(attention_probs, value_cross)
259
+ hidden_states = attn.batch_to_head_dim(hidden_states)
260
+ hidden_states = attn.to_out[0](hidden_states)
261
+ hidden_states = attn.to_out[1](hidden_states)
262
+
263
+ if input_ndim == 4:
264
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
265
+ batch_size, channel, height, width
266
+ )
267
+
268
+ if attn.residual_connection:
269
+ hidden_states = hidden_states + residual
270
+
271
+ hidden_states = hidden_states / attn.rescale_output_factor
272
+
273
+ return hidden_states
274
+
275
+
276
+ def linear_interpolation(
277
+ l1: FloatTensor, l2: FloatTensor, ts: Optional[FloatTensor] = None, size: int = 5
278
+ ) -> FloatTensor:
279
+ """
280
+ Linear interpolation
281
+
282
+ Args:
283
+ l1: Starting vector: (1, *)
284
+ l2: Final vector: (1, *)
285
+ ts: FloatTensor, interpolation points between 0 and 1
286
+ size: int, number of interpolation points including l1 and l2
287
+
288
+ Returns:
289
+ Interpolated vectors: (size, *)
290
+ """
291
+ assert l1.shape == l2.shape, "shapes of l1 and l2 must match"
292
+
293
+ res = []
294
+ if ts is not None:
295
+ for t in ts:
296
+ li = torch.lerp(l1, l2, t)
297
+ res.append(li)
298
+ else:
299
+ for i in range(size):
300
+ t = i / (size - 1)
301
+ li = torch.lerp(l1, l2, t)
302
+ res.append(li)
303
+ res = torch.cat(res, dim=0)
304
+ return res
305
+
306
+
307
+ def spherical_interpolation(l1: FloatTensor, l2: FloatTensor, size=5) -> FloatTensor:
308
+ """
309
+ Spherical interpolation
310
+
311
+ Args:
312
+ l1: Starting vector: (1, *)
313
+ l2: Final vector: (1, *)
314
+ size: int, number of interpolation points including l1 and l2
315
+
316
+ Returns:
317
+ Interpolated vectors: (size, *)
318
+ """
319
+ assert l1.shape == l2.shape, "shapes of l1 and l2 must match"
320
+
321
+ res = []
322
+ for i in range(size):
323
+ t = i / (size - 1)
324
+ li = slerp(l1, l2, t)
325
+ res.append(li)
326
+ res = torch.cat(res, dim=0)
327
+ return res
328
+
329
+
330
+ def slerp(v0: FloatTensor, v1: FloatTensor, t, threshold=0.9995):
331
+ """
332
+ Spherical linear interpolation
333
+ Args:
334
+ v0: Starting vector
335
+ v1: Final vector
336
+ t: Float value between 0.0 and 1.0
337
+ threshold: Threshold for considering the two vectors as
338
+ colinear. Not recommended to alter this.
339
+ Returns:
340
+ Interpolation vector between v0 and v1
341
+ """
342
+ assert v0.shape == v1.shape, "shapes of v0 and v1 must match"
343
+
344
+ # Normalize the vectors to get the directions and angles
345
+ v0_norm: FloatTensor = torch.norm(v0, dim=-1)
346
+ v1_norm: FloatTensor = torch.norm(v1, dim=-1)
347
+
348
+ v0_normed: FloatTensor = v0 / v0_norm.unsqueeze(-1)
349
+ v1_normed: FloatTensor = v1 / v1_norm.unsqueeze(-1)
350
+
351
+ # Dot product with the normalized vectors
352
+ dot: FloatTensor = (v0_normed * v1_normed).sum(-1)
353
+ dot_mag: FloatTensor = dot.abs()
354
+
355
+ # if dp is NaN, it's because the v0 or v1 row was filled with 0s
356
+ # If absolute value of dot product is almost 1, vectors are ~colinear, so use torch.lerp
357
+ gotta_lerp: LongTensor = dot_mag.isnan() | (dot_mag > threshold)
358
+ can_slerp: LongTensor = ~gotta_lerp
359
+
360
+ t_batch_dim_count: int = max(0, t.dim() - v0.dim()) if isinstance(t, Tensor) else 0
361
+ t_batch_dims: Size = (
362
+ t.shape[:t_batch_dim_count] if isinstance(t, Tensor) else Size([])
363
+ )
364
+ out: FloatTensor = torch.zeros_like(v0.expand(*t_batch_dims, *[-1] * v0.dim()))
365
+
366
+ # if no elements are lerpable, our vectors become 0-dimensional, preventing broadcasting
367
+ if gotta_lerp.any():
368
+ lerped: FloatTensor = torch.lerp(v0, v1, t)
369
+
370
+ out: FloatTensor = lerped.where(gotta_lerp.unsqueeze(-1), out)
371
+
372
+ # if no elements are slerpable, our vectors become 0-dimensional, preventing broadcasting
373
+ if can_slerp.any():
374
+
375
+ # Calculate initial angle between v0 and v1
376
+ theta_0: FloatTensor = dot.arccos().unsqueeze(-1)
377
+ sin_theta_0: FloatTensor = theta_0.sin()
378
+ # Angle at timestep t
379
+ theta_t: FloatTensor = theta_0 * t
380
+ sin_theta_t: FloatTensor = theta_t.sin()
381
+ # Finish the slerp algorithm
382
+ s0: FloatTensor = (theta_0 - theta_t).sin() / sin_theta_0
383
+ s1: FloatTensor = sin_theta_t / sin_theta_0
384
+ slerped: FloatTensor = s0 * v0 + s1 * v1
385
+
386
+ out: FloatTensor = slerped.where(can_slerp.unsqueeze(-1), out)
387
+
388
+ return out
pipeline_interpolated_sdxl.py ADDED
The diff for this file is too large to render. See raw diff
 
pipeline_interpolated_stable_diffusion.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import (
6
+ AutoencoderKL,
7
+ DDIMScheduler,
8
+ SchedulerMixin,
9
+ UNet2DConditionModel,
10
+ UniPCMultistepScheduler,
11
+ )
12
+ from diffusers.models.attention_processor import AttnProcessor2_0
13
+ from tqdm.auto import tqdm
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+
16
+ from interpolation import (
17
+ InnerInterpolatedAttnProcessor,
18
+ OuterInterpolatedAttnProcessor,
19
+ generate_beta_tensor,
20
+ linear_interpolation,
21
+ slerp,
22
+ spherical_interpolation,
23
+ )
24
+
25
+
26
+ class InterpolationStableDiffusionPipeline:
27
+ """
28
+ Diffusion Pipeline that generates interpolated images
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ repo_name: str = "CompVis/stable-diffusion-v1-4",
34
+ scheduler_name: str = "ddim",
35
+ frozen: bool = True,
36
+ guidance_scale: float = 7.5,
37
+ scheduler: Optional[SchedulerMixin] = None,
38
+ cache_dir: Optional[str] = None,
39
+ ):
40
+
41
+ # Initialize the generator
42
+ self.vae = AutoencoderKL.from_pretrained(
43
+ repo_name, subfolder="vae", use_safetensors=True, cache_dir=cache_dir
44
+ )
45
+ self.tokenizer = CLIPTokenizer.from_pretrained(
46
+ repo_name, subfolder="tokenizer", cache_dir=cache_dir
47
+ )
48
+ self.text_encoder = CLIPTextModel.from_pretrained(
49
+ repo_name,
50
+ subfolder="text_encoder",
51
+ use_safetensors=True,
52
+ cache_dir=cache_dir,
53
+ )
54
+ self.unet = UNet2DConditionModel.from_pretrained(
55
+ repo_name, subfolder="unet", use_safetensors=True, cache_dir=cache_dir
56
+ )
57
+
58
+ # Initialize the scheduler
59
+ if scheduler is not None:
60
+ self.scheduler = scheduler
61
+ elif scheduler_name == "ddim":
62
+ self.scheduler = DDIMScheduler.from_pretrained(
63
+ repo_name, subfolder="scheduler", cache_dir=cache_dir
64
+ )
65
+ elif scheduler_name == "unipc":
66
+ self.scheduler = UniPCMultistepScheduler.from_pretrained(
67
+ repo_name, subfolder="scheduler", cache_dir=cache_dir
68
+ )
69
+ else:
70
+ raise ValueError(
71
+ "Invalid scheduler name (ddim, unipc) and not specify scheduler."
72
+ )
73
+
74
+ # Setup device
75
+
76
+ self.guidance_scale = guidance_scale # Scale for classifier-free guidance
77
+
78
+ if frozen:
79
+ for param in self.unet.parameters():
80
+ param.requires_grad = False
81
+
82
+ for param in self.text_encoder.parameters():
83
+ param.requires_grad = False
84
+
85
+ for param in self.vae.parameters():
86
+ param.requires_grad = False
87
+
88
+ def to(self, *args, **kwargs):
89
+ self.vae.to(*args, **kwargs)
90
+ self.text_encoder.to(*args, **kwargs)
91
+ self.unet.to(*args, **kwargs)
92
+
93
+ def generate_latent(
94
+ self, generator: Optional[torch.Generator] = None, torch_device: str = "cpu"
95
+ ) -> torch.FloatTensor:
96
+ """
97
+ Generates a random latent tensor.
98
+
99
+ Args:
100
+ generator (Optional[torch.Generator], optional): Generator for random number generation. Defaults to None.
101
+ torch_device (str, optional): Device to store the tensor. Defaults to "cpu".
102
+
103
+ Returns:
104
+ torch.FloatTensor: Random latent tensor.
105
+ """
106
+ channel = self.unet.config.in_channels
107
+ height = self.unet.config.sample_size
108
+ width = self.unet.config.sample_size
109
+ if generator is None:
110
+ latent = torch.randn(
111
+ (1, channel, height, width),
112
+ device=torch_device,
113
+ )
114
+ else:
115
+ latent = torch.randn(
116
+ (1, channel, height, width),
117
+ generator=generator,
118
+ device=torch_device,
119
+ )
120
+ return latent
121
+
122
+ @torch.no_grad()
123
+ def prompt_to_embedding(
124
+ self, prompt: str, negative_prompt: str = ""
125
+ ) -> torch.FloatTensor:
126
+ """
127
+ Prepare the text prompt for the diffusion process
128
+
129
+ Args:
130
+ prompt: str, text prompt
131
+ negative_prompt: str, negative text prompt
132
+
133
+ Returns:
134
+ FloatTensor, text embeddings
135
+ """
136
+
137
+ text_input = self.tokenizer(
138
+ prompt,
139
+ padding="max_length",
140
+ max_length=self.tokenizer.model_max_length,
141
+ truncation=True,
142
+ return_tensors="pt",
143
+ )
144
+
145
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.torch_device))[
146
+ 0
147
+ ]
148
+
149
+ uncond_input = self.tokenizer(
150
+ negative_prompt,
151
+ padding="max_length",
152
+ max_length=self.tokenizer.model_max_length,
153
+ truncation=True,
154
+ return_tensors="pt",
155
+ )
156
+ uncond_embeddings = self.text_encoder(
157
+ uncond_input.input_ids.to(self.torch_device)
158
+ )[0]
159
+
160
+ text_embeddings = torch.cat([text_embeddings, uncond_embeddings])
161
+ return text_embeddings
162
+
163
+ @torch.no_grad()
164
+ def interpolate(
165
+ self,
166
+ latent_start: torch.FloatTensor,
167
+ latent_end: torch.FloatTensor,
168
+ prompt_start: str,
169
+ prompt_end: str,
170
+ guide_prompt: Optional[str] = None,
171
+ negative_prompt: str = "",
172
+ size: int = 7,
173
+ num_inference_steps: int = 25,
174
+ warmup_ratio: float = 0.5,
175
+ early: str = "fused_outer",
176
+ late: str = "self",
177
+ alpha: Optional[float] = None,
178
+ beta: Optional[float] = None,
179
+ guidance_scale: Optional[float] = None,
180
+ ) -> np.ndarray:
181
+ """
182
+ Interpolate between two generation
183
+
184
+ Args:
185
+ latent_start: FloatTensor, latent vector of the first image
186
+ latent_end: FloatTensor, latent vector of the second image
187
+ prompt_start: str, text prompt of the first image
188
+ prompt_end: str, text prompt of the second image
189
+ guide_prompt: str, text prompt for the interpolation
190
+ negative_prompt: str, negative text prompt
191
+ size: int, number of interpolations including starting and ending points
192
+ num_inference_steps: int, number of inference steps in scheduler
193
+ warmup_ratio: float, ratio of warmup steps
194
+ early: str, warmup interpolation methods
195
+ late: str, late interpolation methods
196
+ alpha: float, alpha parameter for beta distribution
197
+ beta: float, beta parameter for beta distribution
198
+ guidance_scale: Optional[float], scale for classifier-free guidance
199
+ Returns:
200
+ Numpy array of interpolated images, shape (size, H, W, 3)
201
+ """
202
+ # Specify alpha and beta
203
+ self.torch_device = self.unet.device
204
+ if alpha is None:
205
+ alpha = num_inference_steps
206
+ if beta is None:
207
+ beta = num_inference_steps
208
+ if guidance_scale is None:
209
+ guidance_scale = self.guidance_scale
210
+ self.scheduler.set_timesteps(num_inference_steps)
211
+
212
+ # Prepare interpolated latents and embeddings
213
+ latents = spherical_interpolation(latent_start, latent_end, size)
214
+ embs_start = self.prompt_to_embedding(prompt_start, negative_prompt)
215
+ emb_start = embs_start[0:1]
216
+ uncond_emb_start = embs_start[1:2]
217
+ embs_end = self.prompt_to_embedding(prompt_end, negative_prompt)
218
+ emb_end = embs_end[0:1]
219
+ uncond_emb_end = embs_end[1:2]
220
+
221
+ # Perform prompt guidance if it is specified
222
+ if guide_prompt is not None:
223
+ guide_embs = self.prompt_to_embedding(guide_prompt, negative_prompt)
224
+ guide_emb = guide_embs[0:1]
225
+ uncond_guide_emb = guide_embs[1:2]
226
+ embs = torch.cat([emb_start] + [guide_emb] * (size - 2) + [emb_end], dim=0)
227
+ uncond_embs = torch.cat(
228
+ [uncond_emb_start] + [uncond_guide_emb] * (size - 2) + [uncond_emb_end],
229
+ dim=0,
230
+ )
231
+ else:
232
+ embs = linear_interpolation(emb_start, emb_end, size=size)
233
+ uncond_embs = linear_interpolation(
234
+ uncond_emb_start, uncond_emb_end, size=size
235
+ )
236
+
237
+ # Specify the interpolation methods
238
+ pure_inner_attn_proc = InnerInterpolatedAttnProcessor(
239
+ size=size,
240
+ is_fused=False,
241
+ alpha=alpha,
242
+ beta=beta,
243
+ )
244
+ fused_inner_attn_proc = InnerInterpolatedAttnProcessor(
245
+ size=size,
246
+ is_fused=True,
247
+ alpha=alpha,
248
+ beta=beta,
249
+ )
250
+ pure_outer_attn_proc = OuterInterpolatedAttnProcessor(
251
+ size=size,
252
+ is_fused=False,
253
+ alpha=alpha,
254
+ beta=beta,
255
+ )
256
+ fused_outer_attn_proc = OuterInterpolatedAttnProcessor(
257
+ size=size,
258
+ is_fused=True,
259
+ alpha=alpha,
260
+ beta=beta,
261
+ )
262
+ self_attn_proc = AttnProcessor2_0()
263
+ procs_dict = {
264
+ "pure_inner": pure_inner_attn_proc,
265
+ "fused_inner": fused_inner_attn_proc,
266
+ "pure_outer": pure_outer_attn_proc,
267
+ "fused_outer": fused_outer_attn_proc,
268
+ "self": self_attn_proc,
269
+ }
270
+
271
+ # Denoising process
272
+ i = 0
273
+ warmup_step = int(num_inference_steps * warmup_ratio)
274
+ for t in tqdm(self.scheduler.timesteps):
275
+ i += 1
276
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
277
+ with torch.no_grad():
278
+ # Change attention module
279
+ if i < warmup_step:
280
+ interpolate_attn_proc = procs_dict[early]
281
+ else:
282
+ interpolate_attn_proc = procs_dict[late]
283
+ self.unet.set_attn_processor(processor=interpolate_attn_proc)
284
+
285
+ # Predict the noise residual
286
+ noise_pred = self.unet(
287
+ latent_model_input, t, encoder_hidden_states=embs
288
+ ).sample
289
+ attn_proc = AttnProcessor()
290
+ self.unet.set_attn_processor(processor=attn_proc)
291
+ noise_uncond = self.unet(
292
+ latent_model_input, t, encoder_hidden_states=uncond_embs
293
+ ).sample
294
+ # perform guidance
295
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
296
+ # compute the previous noisy sample x_t -> x_t-1
297
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
298
+
299
+ # Decode the images
300
+ latents = 1 / 0.18215 * latents
301
+ with torch.no_grad():
302
+ image = self.vae.decode(latents).sample
303
+ images = (image / 2 + 0.5).clamp(0, 1)
304
+ images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
305
+ return images
306
+
307
+ @torch.no_grad()
308
+ def interpolate_save_gpu(
309
+ self,
310
+ latent_start: torch.FloatTensor,
311
+ latent_end: torch.FloatTensor,
312
+ prompt_start: str,
313
+ prompt_end: str,
314
+ guide_prompt: Optional[str] = None,
315
+ negative_prompt: str = "",
316
+ size: int = 7,
317
+ num_inference_steps: int = 25,
318
+ warmup_ratio: float = 0.5,
319
+ early: str = "fused_outer",
320
+ late: str = "self",
321
+ alpha: Optional[float] = None,
322
+ beta: Optional[float] = None,
323
+ init: str = "linear",
324
+ guidance_scale: Optional[float] = None,
325
+ ) -> np.ndarray:
326
+ """
327
+ Interpolate between two generation
328
+
329
+ Args:
330
+ latent_start: FloatTensor, latent vector of the first image
331
+ latent_end: FloatTensor, latent vector of the second image
332
+ prompt_start: str, text prompt of the first image
333
+ prompt_end: str, text prompt of the second image
334
+ guide_prompt: str, text prompt for the interpolation
335
+ negative_prompt: str, negative text prompt
336
+ size: int, number of interpolations including starting and ending points
337
+ num_inference_steps: int, number of inference steps in scheduler
338
+ warmup_ratio: float, ratio of warmup steps
339
+ early: str, warmup interpolation methods
340
+ late: str, late interpolation methods
341
+ alpha: float, alpha parameter for beta distribution
342
+ beta: float, beta parameter for beta distribution
343
+ init: str, interpolation initialization methods
344
+
345
+ Returns:
346
+ Numpy array of interpolated images, shape (size, H, W, 3)
347
+ """
348
+ self.torch_device = self.unet.device
349
+ # Specify alpha and beta
350
+ if alpha is None:
351
+ alpha = num_inference_steps
352
+ if beta is None:
353
+ beta = num_inference_steps
354
+ betas = generate_beta_tensor(size, alpha=alpha, beta=beta)
355
+ final_images = None
356
+
357
+ # Generate interpolated images one by one
358
+ for i in range(size - 2):
359
+ it = betas[i + 1].item()
360
+ if init == "denoising":
361
+ images = self.denoising_interpolate(
362
+ latent_start,
363
+ prompt_start,
364
+ prompt_end,
365
+ negative_prompt,
366
+ interpolated_ratio=it,
367
+ timesteps=num_inference_steps,
368
+ )
369
+ else:
370
+ images = self.interpolate_single(
371
+ it,
372
+ latent_start,
373
+ latent_end,
374
+ prompt_start,
375
+ prompt_end,
376
+ guide_prompt=guide_prompt,
377
+ num_inference_steps=num_inference_steps,
378
+ warmup_ratio=warmup_ratio,
379
+ early=early,
380
+ late=late,
381
+ negative_prompt=negative_prompt,
382
+ init=init,
383
+ guidance_scale=guidance_scale,
384
+ )
385
+ if size == 3:
386
+ return images
387
+ if i == 0:
388
+ final_images = images[:2]
389
+ elif i == size - 3:
390
+ final_images = np.concatenate([final_images, images[1:]], axis=0)
391
+ else:
392
+ final_images = np.concatenate([final_images, images[1:2]], axis=0)
393
+ return final_images
394
+
395
+ def interpolate_single(
396
+ self,
397
+ it,
398
+ latent_start: torch.FloatTensor,
399
+ latent_end: torch.FloatTensor,
400
+ prompt_start: str,
401
+ prompt_end: str,
402
+ guide_prompt: str = None,
403
+ negative_prompt: str = "",
404
+ num_inference_steps: int = 25,
405
+ warmup_ratio: float = 0.5,
406
+ early: str = "fused_outer",
407
+ late: str = "self",
408
+ init="linear",
409
+ guidance_scale: Optional[float] = None,
410
+ ) -> np.ndarray:
411
+ """
412
+ Interpolates between two latent vectors and generates a sequence of images.
413
+
414
+ Args:
415
+ it (float): Interpolation factor between latent_start and latent_end.
416
+ latent_start (torch.FloatTensor): Starting latent vector.
417
+ latent_end (torch.FloatTensor): Ending latent vector.
418
+ prompt_start (str): Starting prompt for text conditioning.
419
+ prompt_end (str): Ending prompt for text conditioning.
420
+ guide_prompt (str, optional): Guiding prompt for text conditioning. Defaults to None.
421
+ negative_prompt (str, optional): Negative prompt for text conditioning. Defaults to "".
422
+ num_inference_steps (int, optional): Number of inference steps. Defaults to 25.
423
+ warmup_ratio (float, optional): Ratio of warm-up steps. Defaults to 0.5.
424
+ early (str, optional): Early attention processing method. Defaults to "fused_outer".
425
+ late (str, optional): Late attention processing method. Defaults to "self".
426
+ init (str, optional): Initialization method for interpolation. Defaults to "linear".
427
+ guidance_scale (Optional[float], optional): Scale for classifier-free guidance. Defaults to None.
428
+ Returns:
429
+ numpy.ndarray: Sequence of generated images.
430
+ """
431
+ self.torch_device = self.unet.device
432
+ if guidance_scale is None:
433
+ guidance_scale = self.guidance_scale
434
+
435
+ # Prepare interpolated inputs
436
+ self.scheduler.set_timesteps(num_inference_steps)
437
+
438
+ embs_start = self.prompt_to_embedding(prompt_start, negative_prompt)
439
+ emb_start = embs_start[0:1]
440
+ uncond_emb_start = embs_start[1:2]
441
+ embs_end = self.prompt_to_embedding(prompt_end, negative_prompt)
442
+ emb_end = embs_end[0:1]
443
+ uncond_emb_end = embs_end[1:2]
444
+
445
+ latent_t = slerp(latent_start, latent_end, it)
446
+ if guide_prompt is not None:
447
+ embs_guide = self.prompt_to_embedding(guide_prompt, negative_prompt)
448
+ emb_t = embs_guide[0:1]
449
+ else:
450
+ if init == "linear":
451
+ emb_t = torch.lerp(emb_start, emb_end, it)
452
+ else:
453
+ emb_t = slerp(emb_start, emb_end, it)
454
+ if init == "linear":
455
+ uncond_emb_t = torch.lerp(uncond_emb_start, uncond_emb_end, it)
456
+ else:
457
+ uncond_emb_t = slerp(uncond_emb_start, uncond_emb_end, it)
458
+
459
+ latents = torch.cat([latent_start, latent_t, latent_end], dim=0)
460
+ embs = torch.cat([emb_start, emb_t, emb_end], dim=0)
461
+ uncond_embs = torch.cat([uncond_emb_start, uncond_emb_t, uncond_emb_end], dim=0)
462
+
463
+ # Specifiy the attention processors
464
+ pure_inner_attn_proc = InnerInterpolatedAttnProcessor(
465
+ t=it,
466
+ is_fused=False,
467
+ )
468
+ fused_inner_attn_proc = InnerInterpolatedAttnProcessor(
469
+ t=it,
470
+ is_fused=True,
471
+ )
472
+ pure_outer_attn_proc = OuterInterpolatedAttnProcessor(
473
+ t=it,
474
+ is_fused=False,
475
+ )
476
+ fused_outer_attn_proc = OuterInterpolatedAttnProcessor(
477
+ t=it,
478
+ is_fused=True,
479
+ )
480
+ self_attn_proc = AttnProcessor()
481
+ procs_dict = {
482
+ "pure_inner": pure_inner_attn_proc,
483
+ "fused_inner": fused_inner_attn_proc,
484
+ "pure_outer": pure_outer_attn_proc,
485
+ "fused_outer": fused_outer_attn_proc,
486
+ "self": self_attn_proc,
487
+ }
488
+
489
+ i = 0
490
+ warmup_step = int(num_inference_steps * warmup_ratio)
491
+ for t in tqdm(self.scheduler.timesteps):
492
+ i += 1
493
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
494
+ # predict the noise residual
495
+ with torch.no_grad():
496
+ # Warmup
497
+ if i < warmup_step:
498
+ interpolate_attn_proc = procs_dict[early]
499
+ else:
500
+ interpolate_attn_proc = procs_dict[late]
501
+ self.unet.set_attn_processor(processor=interpolate_attn_proc)
502
+ # predict the noise residual
503
+ noise_pred = self.unet(
504
+ latent_model_input, t, encoder_hidden_states=embs
505
+ ).sample
506
+ attn_proc = AttnProcessor()
507
+ self.unet.set_attn_processor(processor=attn_proc)
508
+ noise_uncond = self.unet(
509
+ latent_model_input, t, encoder_hidden_states=uncond_embs
510
+ ).sample
511
+ # perform guidance
512
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
513
+ # compute the previous noisy sample x_t -> x_t-1
514
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
515
+
516
+ # Decode the images
517
+ latents = 1 / 0.18215 * latents
518
+ with torch.no_grad():
519
+ image = self.vae.decode(latents).sample
520
+ images = (image / 2 + 0.5).clamp(0, 1)
521
+ images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
522
+ return images
523
+
524
+ def denoising_interpolate(
525
+ self,
526
+ latents: torch.FloatTensor,
527
+ text_1: str,
528
+ text_2: str,
529
+ negative_prompt: str = "",
530
+ interpolated_ratio: float = 1,
531
+ timesteps: int = 25,
532
+ ) -> np.ndarray:
533
+ """
534
+ Performs denoising interpolation on the given latents.
535
+
536
+ Args:
537
+ latents (torch.Tensor): The input latents.
538
+ text_1 (str): The first text prompt.
539
+ text_2 (str): The second text prompt.
540
+ negative_prompt (str, optional): The negative text prompt. Defaults to "".
541
+ interpolated_ratio (int, optional): The ratio of interpolation between text_1 and text_2. Defaults to 1.
542
+ timesteps (int, optional): The number of timesteps for diffusion. Defaults to 25.
543
+
544
+ Returns:
545
+ numpy.ndarray: The interpolated images.
546
+ """
547
+ self.unet.set_attn_processor(processor=AttnProcessor())
548
+ start_emb = self.prompt_to_embedding(text_1)
549
+ end_emb = self.prompt_to_embedding(text_2)
550
+ neg_emb = self.prompt_to_embedding(negative_prompt)
551
+ uncond_emb = neg_emb[0:1]
552
+ emb_1 = start_emb[0:1]
553
+ emb_2 = end_emb[0:1]
554
+ self.scheduler.set_timesteps(timesteps)
555
+ i = 0
556
+ for t in tqdm(self.scheduler.timesteps):
557
+ i += 1
558
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
559
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
560
+ # predict the noise residual
561
+ with torch.no_grad():
562
+ if i < timesteps * interpolated_ratio:
563
+ noise_pred = self.unet(
564
+ latent_model_input, t, encoder_hidden_states=emb_1
565
+ ).sample
566
+ else:
567
+ noise_pred = self.unet(
568
+ latent_model_input, t, encoder_hidden_states=emb_2
569
+ ).sample
570
+ noise_uncond = self.unet(
571
+ latent_model_input, t, encoder_hidden_states=uncond_emb
572
+ ).sample
573
+ # perform guidance
574
+ noise_pred = noise_uncond + self.guidance_scale * (
575
+ noise_pred - noise_uncond
576
+ )
577
+ # compute the previous noisy sample x_t -> x_t-1
578
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
579
+ latents = 1 / 0.18215 * latents
580
+ with torch.no_grad():
581
+ image = self.vae.decode(latents).sample
582
+ images = (image / 2 + 0.5).clamp(0, 1)
583
+ images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
584
+ return images
prior.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from bayes_opt import BayesianOptimization, SequentialDomainReductionTransformer
3
+ from lpips import LPIPS
4
+ from scipy.stats import beta as beta_distribution
5
+
6
+ from utils import compute_lpips, compute_smoothness_and_consistency
7
+
8
+
9
+ def bayesian_prior_selection(
10
+ interpolation_pipe,
11
+ latent1: torch.FloatTensor,
12
+ latent2: torch.FloatTensor,
13
+ prompt1: str,
14
+ prompt2: str,
15
+ lpips_model: LPIPS,
16
+ guide_prompt: str | None = None,
17
+ negative_prompt: str = "",
18
+ size: int = 3,
19
+ num_inference_steps: int = 25,
20
+ warmup_ratio: float = 1,
21
+ early: str = "vfused",
22
+ late: str = "self",
23
+ target_score: float = 0.9,
24
+ n_iter: int = 15,
25
+ p_min: float | None = None,
26
+ p_max: float | None = None,
27
+ ) -> tuple:
28
+ """
29
+ Select the alpha and beta parameters for the interpolation using Bayesian optimization.
30
+
31
+ Args:
32
+ interpolation_pipe (any): The interpolation pipeline.
33
+ latent1 (torch.FloatTensor): The first source latent vector.
34
+ latent2 (torch.FloatTensor): The second source latent vector.
35
+ prompt1 (str): The first source prompt.
36
+ prompt2 (str): The second source prompt.
37
+ lpips_model (any): The LPIPS model used to compute perceptual distances.
38
+ guide_prompt (str | None, optional): The guide prompt for the interpolation, if any. Defaults to None.
39
+ negative_prompt (str, optional): The negative prompt for the interpolation, default to empty string. Defaults to "".
40
+ size (int, optional): The size of the interpolation sequence. Defaults to 3.
41
+ num_inference_steps (int, optional): The number of inference steps. Defaults to 25.
42
+ warmup_ratio (float, optional): The warmup ratio. Defaults to 1.
43
+ early (str, optional): The early fusion method. Defaults to "vfused".
44
+ late (str, optional): The late fusion method. Defaults to "self".
45
+ target_score (float, optional): The target score. Defaults to 0.9.
46
+ n_iter (int, optional): The maximum number of iterations. Defaults to 15.
47
+ p_min (float, optional): The minimum value of alpha and beta. Defaults to None.
48
+ p_max (float, optional): The maximum value of alpha and beta. Defaults to None.
49
+ Returns:
50
+ tuple: A tuple containing the selected alpha and beta parameters.
51
+ """
52
+
53
+ def get_smoothness(alpha, beta):
54
+ """
55
+ Black-box objective function of Bayesian Optimization.
56
+ Get the smoothness of the interpolated sequence with the given alpha and beta.
57
+ """
58
+ if alpha < beta and large_alpha_prior:
59
+ return 0
60
+ if alpha > beta and not large_alpha_prior:
61
+ return 0
62
+ if alpha == beta:
63
+ return init_smoothness
64
+ interpolation_sequence = interpolation_pipe.interpolate_save_gpu(
65
+ latent1,
66
+ latent2,
67
+ prompt1,
68
+ prompt2,
69
+ guide_prompt=guide_prompt,
70
+ negative_prompt=negative_prompt,
71
+ size=size,
72
+ num_inference_steps=num_inference_steps,
73
+ warmup_ratio=warmup_ratio,
74
+ early=early,
75
+ late=late,
76
+ alpha=alpha,
77
+ beta=beta,
78
+ )
79
+ smoothness, _, _ = compute_smoothness_and_consistency(
80
+ interpolation_sequence, lpips_model
81
+ )
82
+ return smoothness
83
+
84
+ # Add prior into selection of alpha and beta
85
+ # We firstly compute the interpolated images with t=0.5
86
+ images = interpolation_pipe.interpolate_single(
87
+ 0.5,
88
+ latent1,
89
+ latent2,
90
+ prompt1,
91
+ prompt2,
92
+ guide_prompt=guide_prompt,
93
+ negative_prompt=negative_prompt,
94
+ num_inference_steps=num_inference_steps,
95
+ warmup_ratio=warmup_ratio,
96
+ early=early,
97
+ late=late,
98
+ )
99
+ # We compute the perceptual distances of the interpolated images (t=0.5) to the source image
100
+ distances = compute_lpips(images, lpips_model)
101
+ # We compute the init_smoothness as the smoothness when alpha=beta to avoid recomputation
102
+ init_smoothness, _, _ = compute_smoothness_and_consistency(images, lpips_model)
103
+ # If perceptual distance to the first source image is smaller, alpha should be larger than beta
104
+ large_alpha_prior = distances[0] < distances[1]
105
+
106
+ # Bayesian optimization configuration
107
+ num_warmup_steps = warmup_ratio * num_inference_steps
108
+ if p_min is None:
109
+ p_min = 1
110
+ if p_max is None:
111
+ p_max = num_warmup_steps
112
+ pbounds = {"alpha": (p_min, p_max), "beta": (p_min, p_max)}
113
+ bounds_transformer = SequentialDomainReductionTransformer(minimum_window=0.1)
114
+ optimizer = BayesianOptimization(
115
+ f=get_smoothness,
116
+ pbounds=pbounds,
117
+ random_state=1,
118
+ bounds_transformer=bounds_transformer,
119
+ allow_duplicate_points=True,
120
+ )
121
+ alpha_init = [p_min, (p_min + p_max) / 2, p_max]
122
+ beta_init = [p_min, (p_min + p_max) / 2, p_max]
123
+
124
+ # Initial probing
125
+ for alpha in alpha_init:
126
+ for beta in beta_init:
127
+ optimizer.probe(params={"alpha": alpha, "beta": beta}, lazy=False)
128
+ latest_result = optimizer.res[-1] # Get the last result
129
+ latest_score = latest_result["target"]
130
+ if latest_score >= target_score:
131
+ return alpha, beta
132
+
133
+ # Start optimization
134
+ for _ in range(n_iter): # Max iterations
135
+ optimizer.maximize(init_points=0, n_iter=1) # One iteration at a time
136
+ max_score = optimizer.max["target"] # Get the highest score so far
137
+ if max_score >= target_score:
138
+ print(f"Stopping early, target of {target_score} reached.")
139
+ break # Exit the loop if target is reached or exceeded
140
+
141
+ results = optimizer.max
142
+ alpha = results["params"]["alpha"]
143
+ beta = results["params"]["beta"]
144
+ return alpha, beta
145
+
146
+
147
+ def generate_beta_tensor(
148
+ size: int, alpha: float = 3, beta: float = 3
149
+ ) -> torch.FloatTensor:
150
+ """
151
+ Assume size as n
152
+ Generates a PyTorch tensor of values [x0, x1, ..., xn-1] for the Beta distribution
153
+ where each xi satisfies F(xi) = i/(n-1) for the CDF F of the Beta distribution.
154
+
155
+ Args:
156
+ size (int): The number of values to generate.
157
+ alpha (float): The alpha parameter of the Beta distribution.
158
+ beta (float): The beta parameter of the Beta distribution.
159
+
160
+ Returns:
161
+ torch.Tensor: A tensor of the inverse CDF values of the Beta distribution.
162
+ """
163
+ # Generating the inverse CDF values
164
+ prob_values = [i / (size - 1) for i in range(size)]
165
+ inverse_cdf_values = beta_distribution.ppf(prob_values, alpha, beta)
166
+
167
+ # Converting to a PyTorch tensor
168
+ return torch.tensor(inverse_cdf_values, dtype=torch.float32)
requirements.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.27.2
3
+ addict==2.4.0
4
+ antlr4-python3-runtime==4.9.3
5
+ bayesian-optimization==1.4.3
6
+ clean-fid==0.1.35
7
+ clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
8
+ colorama==0.4.6
9
+ contourpy==1.2.0
10
+ cycler==0.12.1
11
+ diffusers==0.27.1
12
+ einops==0.7.0
13
+ facexlib==0.3.0
14
+ filterpy==1.4.5
15
+ fonttools==4.49.0
16
+ fsspec==2024.2.0
17
+ ftfy==6.1.3
18
+ future==1.0.0
19
+ grpcio==1.62.0
20
+ huggingface-hub==0.20.3
21
+ imageio==2.34.0
22
+ imgaug==0.4.0
23
+ joblib==1.3.2
24
+ kiwisolver==1.4.5
25
+ lazy_loader==0.3
26
+ llvmlite==0.42.0
27
+ lmdb==1.4.1
28
+ lpips==0.1.4
29
+ Markdown==3.5.2
30
+ matplotlib==3.8.3
31
+ mkl-service==2.4.0
32
+ numba==0.59.0
33
+ numpy==1.24.4
34
+ omegaconf==2.3.0
35
+ openai-clip==1.0.1
36
+ opencv-python==4.9.0.80
37
+ pandas==2.2.0
38
+ protobuf==4.25.3
39
+ pyiqa==0.1.10
40
+ pyparsing==3.1.1
41
+ python-dateutil==2.8.2
42
+ pytorch-fid==0.3.0
43
+ pytz==2024.1
44
+ regex==2023.12.25
45
+ safetensors==0.4.2
46
+ scikit-image==0.22.0
47
+ scikit-learn==1.4.1.post1
48
+ scipy==1.9.1
49
+ shapely==2.0.3
50
+ tensorboard==2.16.2
51
+ tensorboard-data-server==0.7.2
52
+ threadpoolctl==3.3.0
53
+ tifffile==2024.2.12
54
+ timm==0.9.16
55
+ tokenizers==0.15.2
56
+ tomli==2.0.1
57
+ torch==2.1.0
58
+ torchaudio==2.1.0
59
+ torchvision==0.16.0
60
+ tqdm==4.66.2
61
+ transformers==4.38.2
62
+ triton==2.1.0
63
+ tzdata==2024.1
64
+ Werkzeug==3.0.1
65
+ yapf==0.40.2
style.css ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ justify-content: center;
4
+ }
5
+
6
+ [role="tabpanel"] {
7
+ border: 0
8
+ }
9
+
10
+ #duplicate-button {
11
+ margin: auto;
12
+ color: #fff;
13
+ background: #1565c0;
14
+ border-radius: 100vh;
15
+ }
16
+
17
+ .gradio-container {
18
+ max-width: 690px ! important;
19
+ }
20
+
21
+ #share-btn-container {
22
+ padding-left: 0.5rem !important;
23
+ padding-right: 0.5rem !important;
24
+ background-color: #000000;
25
+ justify-content: center;
26
+ align-items: center;
27
+ border-radius: 9999px !important;
28
+ max-width: 13rem;
29
+ margin-left: auto;
30
+ margin-top: 0.35em;
31
+ }
32
+
33
+ div#share-btn-container>div {
34
+ flex-direction: row;
35
+ background: black;
36
+ align-items: center
37
+ }
38
+
39
+ #share-btn-container:hover {
40
+ background-color: #060606
41
+ }
42
+
43
+ #share-btn {
44
+ all: initial;
45
+ color: #ffffff;
46
+ font-weight: 600;
47
+ cursor: pointer;
48
+ font-family: 'IBM Plex Sans', sans-serif;
49
+ margin-left: 0.5rem !important;
50
+ padding-top: 0.5rem !important;
51
+ padding-bottom: 0.5rem !important;
52
+ right: 0;
53
+ font-size: 15px;
54
+ }
55
+
56
+ #share-btn * {
57
+ all: unset
58
+ }
59
+
60
+ #share-btn-container div:nth-child(-n+2) {
61
+ width: auto !important;
62
+ min-height: 0px !important;
63
+ }
64
+
65
+ #share-btn-container .wrap {
66
+ display: none !important
67
+ }
68
+
69
+ #share-btn-container.hidden {
70
+ display: none !important
71
+ }
utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ from lpips import LPIPS
8
+ from PIL import Image
9
+ from torchvision.transforms import Normalize
10
+
11
+
12
+ def show_images_horizontally(
13
+ list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False
14
+ ) -> None:
15
+ """
16
+ Visualize the list of images horizontally and save the figure as PNG.
17
+
18
+ Args:
19
+ list_of_files: The list of images as numpy array with shape (N, H, W, C).
20
+ output_file: The output file path to save the figure as PNG.
21
+ interact: Whether to show the figure interactively in Jupyter Notebook or not in Python.
22
+ """
23
+ number_of_files = len(list_of_files)
24
+
25
+ heights = [a[0].shape[0] for a in list_of_files]
26
+ widths = [a.shape[1] for a in list_of_files[0]]
27
+
28
+ fig_width = 8.0 # inches
29
+ fig_height = fig_width * sum(heights) / sum(widths)
30
+
31
+ # Create a figure with subplots
32
+ _, axs = plt.subplots(
33
+ 1, number_of_files, figsize=(fig_width * number_of_files, fig_height)
34
+ )
35
+ plt.tight_layout()
36
+ for i in range(number_of_files):
37
+ _image = list_of_files[i]
38
+ axs[i].imshow(_image)
39
+ axs[i].axis("off")
40
+
41
+ # Save the figure as PNG
42
+ if interact:
43
+ plt.show()
44
+ else:
45
+ plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25)
46
+
47
+
48
+ def save_image(image: np.array, file_name: str) -> None:
49
+ """
50
+ Save the image as JPG.
51
+
52
+ Args:
53
+ image: The input image as numpy array with shape (H, W, C).
54
+ file_name: The file name to save the image.
55
+ """
56
+ image = Image.fromarray(image)
57
+ image.save(file_name)
58
+
59
+
60
+ def load_and_process_images(load_dir: str) -> np.array:
61
+ """
62
+ Load and process the images into numpy array from the directory.
63
+
64
+ Args:
65
+ load_dir: The directory to load the images.
66
+
67
+ Returns:
68
+ images: The images as numpy array with shape (N, H, W, C).
69
+ """
70
+ images = []
71
+ print(load_dir)
72
+ filenames = sorted(
73
+ os.listdir(load_dir), key=lambda x: int(x.split(".")[0])
74
+ ) # Ensure the files are sorted numerically
75
+ for filename in filenames:
76
+ if filename.endswith(".jpg"):
77
+ img = Image.open(os.path.join(load_dir, filename))
78
+ img_array = (
79
+ np.asarray(img) / 255.0
80
+ ) # Convert to numpy array and scale pixel values to [0, 1]
81
+ images.append(img_array)
82
+ return images
83
+
84
+
85
+ def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array:
86
+ """
87
+ Compute the LPIPS of the input images.
88
+
89
+ Args:
90
+ images: The input images as numpy array with shape (N, H, W, C).
91
+ lpips_model: The LPIPS model used to compute perceptual distances.
92
+
93
+ Returns:
94
+ distances: The LPIPS of the input images.
95
+ """
96
+ # Get device of lpips_model
97
+ device = next(lpips_model.parameters()).device
98
+ device = str(device)
99
+
100
+ # Change the input images into tensor
101
+ images = torch.tensor(images).to(device).float()
102
+ images = torch.permute(images, (0, 3, 1, 2))
103
+ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
104
+ images = normalize(images)
105
+
106
+ # Compute the LPIPS between each adjacent input images
107
+ distances = []
108
+ for i in range(images.shape[0]):
109
+ if i == images.shape[0] - 1:
110
+ break
111
+ img1 = images[i].unsqueeze(0)
112
+ img2 = images[i + 1].unsqueeze(0)
113
+ loss = lpips_model(img1, img2)
114
+ distances.append(loss.item())
115
+ distances = np.array(distances)
116
+ return distances
117
+
118
+
119
+ def compute_gini(distances: np.array) -> float:
120
+ """
121
+ Compute the Gini index of the input distances.
122
+
123
+ Args:
124
+ distances: The input distances as numpy array.
125
+
126
+ Returns:
127
+ gini: The Gini index of the input distances.
128
+ """
129
+ if len(distances) < 2:
130
+ return 0.0 # Gini index is 0 for less than two elements
131
+
132
+ # Sort the list of distances
133
+ sorted_distances = sorted(distances)
134
+ n = len(sorted_distances)
135
+ mean_distance = sum(sorted_distances) / n
136
+
137
+ # Compute the sum of absolute differences
138
+ sum_of_differences = 0
139
+ for di in sorted_distances:
140
+ for dj in sorted_distances:
141
+ sum_of_differences += abs(di - dj)
142
+
143
+ # Normalize the sum of differences by the mean and the number of elements
144
+ gini = sum_of_differences / (2 * n * n * mean_distance)
145
+ return gini
146
+
147
+
148
+ def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple:
149
+ """
150
+ Compute the smoothness and efficiency of the input images.
151
+
152
+ Args:
153
+ images: The input images as numpy array with shape (N, H, W, C).
154
+ lpips_model: The LPIPS model used to compute perceptual distances.
155
+
156
+ Returns:
157
+ smoothness: One minus gini index of LPIPS of consecutive images.
158
+ consistency: The mean LPIPS of consecutive images.
159
+ max_inception_distance: The maximum LPIPS of consecutive images.
160
+ """
161
+ distances = compute_lpips(images, lpips_model)
162
+ smoothness = 1 - compute_gini(distances)
163
+ consistency = np.mean(distances)
164
+ max_inception_distance = np.max(distances)
165
+ return smoothness, consistency, max_inception_distance
166
+
167
+
168
+ def separate_source_and_interpolated_images(images: np.array) -> tuple:
169
+ """
170
+ Separate the input images into source and interpolated images.
171
+ The input source is the start and end of the images, while the interpolated images are the rest.
172
+
173
+ Args:
174
+ images: The input images as numpy array with shape (N, H, W, C).
175
+
176
+ Returns:
177
+ source: The source images as numpy array with shape (2, H, W, C).
178
+ interpolation: The interpolated images as numpy array with shape (N-2, H, W, C).
179
+ """
180
+ # Check if the array has at least two elements
181
+ if len(images) < 2:
182
+ raise ValueError("The input array should have at least two elements.")
183
+
184
+ # Separate the array into two parts
185
+ # First part takes the first and last element
186
+ source = np.array([images[0], images[-1]])
187
+ # Second part takes the rest of the elements
188
+ interpolation = images[1:-1]
189
+ return source, interpolation