garibida commited on
Commit
d65c9b3
1 Parent(s): 9d151b6

Upload Files

Browse files
example_images/kitten.jpg ADDED
example_images/lion.jpeg ADDED
example_images/monkey.jpeg ADDED
gradio_app.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torch
6
+
7
+ from src.eunms import Model_Type, Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
8
+ from src.enums_utils import model_type_to_size, get_pipes
9
+ from src.config import RunConfig
10
+ from main import run as run_model
11
+
12
+
13
+ DESCRIPTION = '''# ReNoise: Real Image Inversion Through Iterative Noising
14
+ This is a demo for our ''ReNoise: Real Image Inversion Through Iterative Noising'' [paper](https://garibida.github.io/ReNoise-Inversion/). Code is available [here](https://github.com/garibida/ReNoise-Inversion)
15
+ Our ReNoise inversion technique can be applied to various diffusion models, including recent few-step ones such as SDXL-Turbo.
16
+ This demo preform real image editing using our ReNoise inversion. The input image is resize to size of 512x512, the optimal size of SDXL Turbo.
17
+ '''
18
+
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ model_type = Model_Type.SDXL_Turbo
21
+ scheduler_type = Scheduler_Type.EULER
22
+ image_size = model_type_to_size(Model_Type.SDXL_Turbo)
23
+ pipe_inversion, pipe_inference = get_pipes(model_type, scheduler_type, device=device)
24
+
25
+ cache_size = 10
26
+ prev_configs = [None for i in range(cache_size)]
27
+ prev_inv_latents = [None for i in range(cache_size)]
28
+ prev_images = [None for i in range(cache_size)]
29
+ prev_noises = [None for i in range(cache_size)]
30
+
31
+ def main_pipeline(
32
+ input_image: str,
33
+ src_prompt: str,
34
+ tgt_prompt: str,
35
+ edit_cfg: float,
36
+ number_of_renoising_iterations: int,
37
+ inersion_strength: float,
38
+ avg_gradients: bool,
39
+ first_step_range_start: int,
40
+ first_step_range_end: int,
41
+ rest_step_range_start: int,
42
+ rest_step_range_end: int,
43
+ lambda_ac: float,
44
+ lambda_kl: float,
45
+ noise_correction: bool):
46
+
47
+ global prev_configs, prev_inv_latents, prev_images, prev_noises
48
+
49
+ update_epsilon_type = Epsilon_Update_Type.OPTIMIZE if noise_correction else Epsilon_Update_Type.NONE
50
+ avg_gradients_type = Gradient_Averaging_Type.ON_END if avg_gradients else Gradient_Averaging_Type.NONE
51
+
52
+ first_step_range = (first_step_range_start, first_step_range_end)
53
+ rest_step_range = (rest_step_range_start, rest_step_range_end)
54
+
55
+ config = RunConfig(model_type = model_type,
56
+ num_inference_steps = 4,
57
+ num_inversion_steps = 4,
58
+ guidance_scale = 0.0,
59
+ max_num_aprox_steps_first_step = first_step_range_end+1,
60
+ num_aprox_steps = number_of_renoising_iterations,
61
+ inversion_max_step = inersion_strength,
62
+ gradient_averaging_type = avg_gradients_type,
63
+ gradient_averaging_first_step_range = first_step_range,
64
+ gradient_averaging_step_range = rest_step_range,
65
+ scheduler_type = scheduler_type,
66
+ num_reg_steps = 4,
67
+ num_ac_rolls = 5,
68
+ lambda_ac = lambda_ac,
69
+ lambda_kl = lambda_kl,
70
+ update_epsilon_type = update_epsilon_type,
71
+ do_reconstruction = True)
72
+ config.prompt = src_prompt
73
+
74
+ inv_latent = None
75
+ noise_list = None
76
+ for i in range(cache_size):
77
+ if prev_configs[i] is not None and prev_configs[i] == config and prev_images[i] == input_image:
78
+ print(f"Using cache for config #{i}")
79
+ inv_latent = prev_inv_latents[i]
80
+ noise_list = prev_noises[i]
81
+ prev_configs.pop(i)
82
+ prev_inv_latents.pop(i)
83
+ prev_images.pop(i)
84
+ prev_noises.pop(i)
85
+ break
86
+
87
+ original_image = Image.open(input_image).convert("RGB").resize(image_size)
88
+
89
+ res_image, inv_latent, noise, all_latents = run_model(original_image,
90
+ config,
91
+ latents=inv_latent,
92
+ pipe_inversion=pipe_inversion,
93
+ pipe_inference=pipe_inference,
94
+ edit_prompt=tgt_prompt,
95
+ noise=noise_list,
96
+ edit_cfg=edit_cfg)
97
+
98
+ prev_configs.append(config)
99
+ prev_inv_latents.append(inv_latent)
100
+ prev_images.append(input_image)
101
+ prev_noises.append(noise)
102
+
103
+ if len(prev_configs) > cache_size:
104
+ print("Popping cache")
105
+ prev_configs.pop(0)
106
+ prev_inv_latents.pop(0)
107
+ prev_images.pop(0)
108
+ prev_noises.pop(0)
109
+
110
+ return res_image
111
+
112
+
113
+ with gr.Blocks(css='style.css') as demo:
114
+ gr.Markdown(DESCRIPTION)
115
+
116
+ gr.HTML(
117
+ '''<a href="https://huggingface.co/spaces/orpatashnik/local-prompt-mixing?duplicate=true">
118
+ <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to run privately without waiting in queue''')
119
+
120
+ with gr.Row():
121
+ with gr.Column():
122
+ input_image = gr.Image(
123
+ label="Input image",
124
+ type="filepath",
125
+ height=image_size[0],
126
+ width=image_size[1]
127
+ )
128
+ src_prompt = gr.Text(
129
+ label='Source Prompt',
130
+ max_lines=1,
131
+ placeholder='A kitten is sitting in a basket on a branch',
132
+ )
133
+ tgt_prompt = gr.Text(
134
+ label='Target Prompt',
135
+ max_lines=1,
136
+ placeholder='A plush toy kitten is sitting in a basket on a branch',
137
+ )
138
+ with gr.Accordion("Advanced Options", open=False):
139
+ edit_cfg = gr.Slider(
140
+ label='Denoise Classifier-Free Guidence Scale',
141
+ minimum=1.0,
142
+ maximum=3.5,
143
+ value=1.0,
144
+ step=0.1
145
+ )
146
+ number_of_renoising_iterations = gr.Slider(
147
+ label='Number of ReNoise Iterations',
148
+ minimum=0,
149
+ maximum=20,
150
+ value=9,
151
+ step=1
152
+ )
153
+ inersion_strength = gr.Slider(
154
+ label='Inversion Strength',
155
+ minimum=0.0,
156
+ maximum=1.0,
157
+ value=1.0,
158
+ step=0.25
159
+ )
160
+ avg_gradients = gr.Checkbox(
161
+ label="Preform Estimation Averaging"
162
+ )
163
+ first_step_range_start = gr.Slider(
164
+ label='First Estimation in Average (t < 250)',
165
+ minimum=0,
166
+ maximum=21,
167
+ value=0,
168
+ step=1
169
+ )
170
+ first_step_range_end = gr.Slider(
171
+ label='Last Estimation in Average (t < 250)',
172
+ minimum=0,
173
+ maximum=21,
174
+ value=5,
175
+ step=1
176
+ )
177
+ rest_step_range_start = gr.Slider(
178
+ label='First Estimation in Average (t > 250)',
179
+ minimum=0,
180
+ maximum=21,
181
+ value=8,
182
+ step=1
183
+ )
184
+ rest_step_range_end = gr.Slider(
185
+ label='Last Estimation in Average (t > 250)',
186
+ minimum=0,
187
+ maximum=21,
188
+ value=10,
189
+ step=1
190
+ )
191
+ num_reg_steps = 4
192
+ num_ac_rolls = 5
193
+ lambda_ac = gr.Slider(
194
+ label='Labmda AC',
195
+ minimum=0.0,
196
+ maximum=50.0,
197
+ value=20.0,
198
+ step=1.0
199
+ )
200
+ lambda_kl = gr.Slider(
201
+ label='Labmda Patch KL',
202
+ minimum=0.0,
203
+ maximum=0.4,
204
+ value=0.065,
205
+ step=0.005
206
+ )
207
+ noise_correction = gr.Checkbox(
208
+ label="Preform Noise Correction"
209
+ )
210
+
211
+ run_button = gr.Button('Edit')
212
+ with gr.Column():
213
+ # result = gr.Gallery(label='Result')
214
+ result = gr.Image(
215
+ label="Result",
216
+ type="pil",
217
+ height=image_size[0],
218
+ width=image_size[1]
219
+ )
220
+
221
+ examples = [
222
+ [
223
+ "example_images/kitten.jpg", #input_image
224
+ "A kitten is sitting in a basket on a branch", #src_prompt
225
+ "a lego kitten is sitting in a basket on a branch", #tgt_prompt
226
+ 1.0, #edit_cfg
227
+ 9, #number_of_renoising_iterations
228
+ 1.0, #inersion_strength
229
+ True, #avg_gradients
230
+ 0, #first_step_range_start
231
+ 5, #first_step_range_end
232
+ 8, #rest_step_range_start
233
+ 10, #rest_step_range_end
234
+ 20.0, #lambda_ac
235
+ 0.055, #lambda_kl
236
+ False #noise_correction
237
+ ],
238
+ [
239
+ "example_images/kitten.jpg", #input_image
240
+ "A kitten is sitting in a basket on a branch", #src_prompt
241
+ "a brokkoli is sitting in a basket on a branch", #tgt_prompt
242
+ 1.0, #edit_cfg
243
+ 9, #number_of_renoising_iterations
244
+ 1.0, #inersion_strength
245
+ True, #avg_gradients
246
+ 0, #first_step_range_start
247
+ 5, #first_step_range_end
248
+ 8, #rest_step_range_start
249
+ 10, #rest_step_range_end
250
+ 20.0, #lambda_ac
251
+ 0.055, #lambda_kl
252
+ False #noise_correction
253
+ ],
254
+ [
255
+ "example_images/kitten.jpg", #input_image
256
+ "A kitten is sitting in a basket on a branch", #src_prompt
257
+ "a dog is sitting in a basket on a branch", #tgt_prompt
258
+ 1.0, #edit_cfg
259
+ 9, #number_of_renoising_iterations
260
+ 1.0, #inersion_strength
261
+ True, #avg_gradients
262
+ 0, #first_step_range_start
263
+ 5, #first_step_range_end
264
+ 8, #rest_step_range_start
265
+ 10, #rest_step_range_end
266
+ 20.0, #lambda_ac
267
+ 0.055, #lambda_kl
268
+ False #noise_correction
269
+ ],
270
+ [
271
+ "example_images/monkey.jpeg", #input_image
272
+ "a monkey sitting on a tree branch in the forest", #src_prompt
273
+ "a beaver sitting on a tree branch in the forest", #tgt_prompt
274
+ 1.0, #edit_cfg
275
+ 9, #number_of_renoising_iterations
276
+ 1.0, #inersion_strength
277
+ True, #avg_gradients
278
+ 0, #first_step_range_start
279
+ 5, #first_step_range_end
280
+ 8, #rest_step_range_start
281
+ 10, #rest_step_range_end
282
+ 20.0, #lambda_ac
283
+ 0.055, #lambda_kl
284
+ True #noise_correction
285
+ ],
286
+ [
287
+ "example_images/monkey.jpeg", #input_image
288
+ "a monkey sitting on a tree branch in the forest", #src_prompt
289
+ "a raccoon sitting on a tree branch in the forest", #tgt_prompt
290
+ 1.0, #edit_cfg
291
+ 9, #number_of_renoising_iterations
292
+ 1.0, #inersion_strength
293
+ True, #avg_gradients
294
+ 0, #first_step_range_start
295
+ 5, #first_step_range_end
296
+ 8, #rest_step_range_start
297
+ 10, #rest_step_range_end
298
+ 20.0, #lambda_ac
299
+ 0.055, #lambda_kl
300
+ True #noise_correction
301
+ ],
302
+ [
303
+ "example_images/lion.jpeg", #input_image
304
+ "a lion is sitting in the grass at sunset", #src_prompt
305
+ "a tiger is sitting in the grass at sunset", #tgt_prompt
306
+ 1.0, #edit_cfg
307
+ 9, #number_of_renoising_iterations
308
+ 1.0, #inersion_strength
309
+ True, #avg_gradients
310
+ 0, #first_step_range_start
311
+ 5, #first_step_range_end
312
+ 8, #rest_step_range_start
313
+ 10, #rest_step_range_end
314
+ 20.0, #lambda_ac
315
+ 0.055, #lambda_kl
316
+ True #noise_correction
317
+ ]
318
+ ]
319
+
320
+ gr.Examples(examples=examples,
321
+ inputs=[
322
+ input_image,
323
+ src_prompt,
324
+ tgt_prompt,
325
+ edit_cfg,
326
+ number_of_renoising_iterations,
327
+ inersion_strength,
328
+ avg_gradients,
329
+ first_step_range_start,
330
+ first_step_range_end,
331
+ rest_step_range_start,
332
+ rest_step_range_end,
333
+ lambda_ac,
334
+ lambda_kl,
335
+ noise_correction
336
+ ],
337
+ outputs=[
338
+ result
339
+ ],
340
+ fn=main_pipeline,
341
+ cache_examples=True)
342
+
343
+
344
+ inputs = [
345
+ input_image,
346
+ src_prompt,
347
+ tgt_prompt,
348
+ edit_cfg,
349
+ number_of_renoising_iterations,
350
+ inersion_strength,
351
+ avg_gradients,
352
+ first_step_range_start,
353
+ first_step_range_end,
354
+ rest_step_range_start,
355
+ rest_step_range_end,
356
+ lambda_ac,
357
+ lambda_kl,
358
+ noise_correction
359
+ ]
360
+ outputs = [
361
+ result
362
+ ]
363
+ run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
364
+
365
+ demo.queue(max_size=50).launch(share=True)
main.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrallis
2
+ import torch
3
+ from PIL import Image
4
+ from diffusers.utils.torch_utils import randn_tensor
5
+
6
+ from src.config import RunConfig, Scheduler_Type
7
+ from src.enums_utils import model_type_to_size
8
+
9
+ @pyrallis.wrap()
10
+ def main(cfg: RunConfig):
11
+ run(cfg)
12
+
13
+ def inversion_callback(pipe, step, timestep, callback_kwargs):
14
+ return callback_kwargs
15
+
16
+ def inference_callback(pipe, step, timestep, callback_kwargs):
17
+ return callback_kwargs
18
+
19
+ def run(init_image: Image, cfg: RunConfig, pipe_inversion, pipe_inference, latents = None, edit_prompt = None, edit_cfg = 1.0, noise = None):
20
+ # pyrallis.dump(cfg, open(cfg.output_path / 'config.yaml', 'w'))
21
+
22
+ if latents is None and cfg.scheduler_type == Scheduler_Type.EULER or cfg.scheduler_type == Scheduler_Type.LCM or cfg.scheduler_type == Scheduler_Type.DDPM:
23
+ g_cpu = torch.Generator().manual_seed(7865)
24
+ img_size = model_type_to_size(cfg.model_type)
25
+ VQAE_SCALE = 8
26
+ latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
27
+ noise = [randn_tensor(latents_size, dtype=torch.float16, device=torch.device("cuda:0"), generator=g_cpu) for i in range(cfg.num_inversion_steps)]
28
+ pipe_inversion.scheduler.set_noise_list(noise)
29
+ pipe_inference.scheduler.set_noise_list(noise)
30
+ pipe_inversion.scheduler_inference.set_noise_list(noise)
31
+
32
+ if latents is not None and cfg.scheduler_type == Scheduler_Type.EULER or cfg.scheduler_type == Scheduler_Type.LCM or cfg.scheduler_type == Scheduler_Type.DDPM:
33
+ pipe_inversion.scheduler.set_noise_list(noise)
34
+ pipe_inference.scheduler.set_noise_list(noise)
35
+ pipe_inversion.scheduler_inference.set_noise_list(noise)
36
+
37
+
38
+ pipe_inversion.cfg = cfg
39
+ pipe_inference.cfg = cfg
40
+ all_latents = None
41
+
42
+ if latents is None:
43
+ print("Inverting...")
44
+ if cfg.save_gpu_mem:
45
+ pipe_inference.to("cpu")
46
+ pipe_inversion.to("cuda")
47
+ res = pipe_inversion(prompt = cfg.prompt,
48
+ num_inversion_steps = cfg.num_inversion_steps,
49
+ num_inference_steps = cfg.num_inference_steps,
50
+ image = init_image,
51
+ guidance_scale = cfg.guidance_scale,
52
+ opt_iters = cfg.opt_iters,
53
+ opt_lr = cfg.opt_lr,
54
+ callback_on_step_end = inversion_callback,
55
+ strength = cfg.inversion_max_step,
56
+ denoising_start = 1.0-cfg.inversion_max_step,
57
+ opt_loss_kl_lambda = cfg.loss_kl_lambda,
58
+ num_aprox_steps = cfg.num_aprox_steps)
59
+ latents = res[0][0]
60
+ all_latents = res[1]
61
+
62
+ inv_latent = latents.clone()
63
+
64
+ if cfg.do_reconstruction:
65
+ print("Generating...")
66
+ edit_prompt = cfg.prompt if edit_prompt is None else edit_prompt
67
+ guidance_scale = edit_cfg
68
+ if cfg.save_gpu_mem:
69
+ pipe_inversion.to("cpu")
70
+ pipe_inference.to("cuda")
71
+ img = pipe_inference(prompt = edit_prompt,
72
+ num_inference_steps = cfg.num_inference_steps,
73
+ negative_prompt = cfg.prompt,
74
+ callback_on_step_end = inference_callback,
75
+ image = latents,
76
+ strength = cfg.inversion_max_step,
77
+ denoising_start = 1.0-cfg.inversion_max_step,
78
+ guidance_scale = guidance_scale).images[0]
79
+ else:
80
+ img = None
81
+
82
+ return img, inv_latent, noise, all_latents
83
+
84
+ if __name__ == "__main__":
85
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch==2.2.1
3
+ torchvision==0.17.1
4
+ diffusers==0.24.0
5
+ transformers==4.32.1
6
+ pyrallis==0.3.1
7
+ accelerate==0.25.0
src/config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import NamedTuple
4
+
5
+ from src.eunms import Model_Type, Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
6
+
7
+ @dataclass
8
+ class RunConfig:
9
+ model_type : Model_Type = Model_Type.SDXL_Turbo
10
+
11
+ scheduler_type : Scheduler_Type = Scheduler_Type.EULER
12
+
13
+ prompt: str = ""
14
+
15
+ num_inference_steps: int = 4
16
+
17
+ num_inversion_steps: int = 100
18
+
19
+ opt_lr: float = 0.1
20
+
21
+ opt_iters: int = 0
22
+
23
+ opt_none_inference_steps: bool = False
24
+
25
+ guidance_scale: float = 0.0
26
+
27
+ # pipe_inversion: DiffusionPipeline = None
28
+
29
+ # pipe_inference: DiffusionPipeline = None
30
+
31
+ save_gpu_mem: bool = False
32
+
33
+ do_reconstruction: bool = True
34
+
35
+ loss_kl_lambda: float = 10.0
36
+
37
+ max_num_aprox_steps_first_step: int = 1
38
+
39
+ num_aprox_steps: int = 10
40
+
41
+ inversion_max_step: float = 1.0
42
+
43
+ gradient_averaging_type: Gradient_Averaging_Type = Gradient_Averaging_Type.NONE
44
+
45
+ gradient_averaging_first_step_range: tuple = (0, 10)
46
+
47
+ gradient_averaging_step_range: tuple = (0, 10)
48
+
49
+ noise_friendly_inversion: bool = False
50
+
51
+ update_epsilon_type: Epsilon_Update_Type = Gradient_Averaging_Type.NONE
52
+
53
+ #pip2pip zero
54
+
55
+ lambda_ac: float = 20.0
56
+
57
+ lambda_kl: float = 20.0
58
+
59
+ num_reg_steps: int = 5
60
+
61
+ num_ac_rolls: int = 5
62
+
63
+ def __post_init__(self):
64
+ pass
src/ddpm_scheduler.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMScheduler, LCMScheduler
2
+ from diffusers.utils import BaseOutput
3
+ from diffusers.utils.torch_utils import randn_tensor
4
+ import torch
5
+ from typing import List, Optional, Tuple, Union
6
+ import numpy as np
7
+
8
+ class DDPMSchedulerOutput(BaseOutput):
9
+ """
10
+ Output class for the scheduler's `step` function output.
11
+
12
+ Args:
13
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
14
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
15
+ denoising loop.
16
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
18
+ `pred_original_sample` can be used to preview progress or for guidance.
19
+ """
20
+
21
+ prev_sample: torch.FloatTensor
22
+ pred_original_sample: Optional[torch.FloatTensor] = None
23
+
24
+ class MyDDPMScheduler(DDPMScheduler):
25
+ def set_noise_list(self, noise_list):
26
+ self.noise_list = noise_list
27
+
28
+ def step_and_update(
29
+ self,
30
+ model_output: torch.FloatTensor,
31
+ timestep: int,
32
+ sample: torch.FloatTensor,
33
+ next_sample: torch.FloatTensor = None,
34
+ generator=None,
35
+ return_dict: bool = True,
36
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
37
+ """
38
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
39
+ process from the learned model outputs (most often the predicted noise).
40
+
41
+ Args:
42
+ model_output (`torch.FloatTensor`):
43
+ The direct output from learned diffusion model.
44
+ timestep (`float`):
45
+ The current discrete timestep in the diffusion chain.
46
+ sample (`torch.FloatTensor`):
47
+ A current instance of a sample created by the diffusion process.
48
+ generator (`torch.Generator`, *optional*):
49
+ A random number generator.
50
+ return_dict (`bool`, *optional*, defaults to `True`):
51
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
52
+
53
+ Returns:
54
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
55
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
56
+ tuple is returned where the first element is the sample tensor.
57
+
58
+ """
59
+ t = timestep
60
+
61
+ prev_t = self.previous_timestep(t)
62
+
63
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
64
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
65
+ else:
66
+ predicted_variance = None
67
+
68
+ # 1. compute alphas, betas
69
+ alpha_prod_t = self.alphas_cumprod[t]
70
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
71
+ beta_prod_t = 1 - alpha_prod_t
72
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
73
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
74
+ current_beta_t = 1 - current_alpha_t
75
+
76
+ # 2. compute predicted original sample from predicted noise also called
77
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
78
+ if self.config.prediction_type == "epsilon":
79
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
80
+ elif self.config.prediction_type == "sample":
81
+ pred_original_sample = model_output
82
+ elif self.config.prediction_type == "v_prediction":
83
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
84
+ else:
85
+ raise ValueError(
86
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
87
+ " `v_prediction` for the DDPMScheduler."
88
+ )
89
+
90
+ # 3. Clip or threshold "predicted x_0"
91
+ if self.config.thresholding:
92
+ pred_original_sample = self._threshold_sample(pred_original_sample)
93
+ elif self.config.clip_sample:
94
+ pred_original_sample = pred_original_sample.clamp(
95
+ -self.config.clip_sample_range, self.config.clip_sample_range
96
+ )
97
+
98
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
99
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
100
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
101
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
102
+
103
+ # 5. Compute predicted previous sample µ_t
104
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
105
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
106
+
107
+ # 6. Add noise
108
+ variance = 0
109
+ if t > 0:
110
+ v = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5)
111
+ if v > 1e-9:
112
+ self.noise_list[int(t.item() // (1000 // self.num_inference_steps))] = (next_sample - pred_prev_sample) / v
113
+ variance_noise = self.noise_list[int(t.item() // (1000 // self.num_inference_steps))]
114
+ variance = v * variance_noise
115
+
116
+ pred_prev_sample = pred_prev_sample + variance
117
+
118
+ if not return_dict:
119
+ return (pred_prev_sample,)
120
+
121
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
122
+
123
+ def step(
124
+ self,
125
+ model_output: torch.FloatTensor,
126
+ timestep: int,
127
+ sample: torch.FloatTensor,
128
+ generator=None,
129
+ return_dict: bool = True,
130
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
131
+ """
132
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
133
+ process from the learned model outputs (most often the predicted noise).
134
+
135
+ Args:
136
+ model_output (`torch.FloatTensor`):
137
+ The direct output from learned diffusion model.
138
+ timestep (`float`):
139
+ The current discrete timestep in the diffusion chain.
140
+ sample (`torch.FloatTensor`):
141
+ A current instance of a sample created by the diffusion process.
142
+ generator (`torch.Generator`, *optional*):
143
+ A random number generator.
144
+ return_dict (`bool`, *optional*, defaults to `True`):
145
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
146
+
147
+ Returns:
148
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
149
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
150
+ tuple is returned where the first element is the sample tensor.
151
+
152
+ """
153
+ t = timestep
154
+
155
+ prev_t = self.previous_timestep(t)
156
+
157
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
158
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
159
+ else:
160
+ predicted_variance = None
161
+
162
+ # 1. compute alphas, betas
163
+ alpha_prod_t = self.alphas_cumprod[t]
164
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
165
+ beta_prod_t = 1 - alpha_prod_t
166
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
167
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
168
+ current_beta_t = 1 - current_alpha_t
169
+
170
+ # 2. compute predicted original sample from predicted noise also called
171
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
172
+ if self.config.prediction_type == "epsilon":
173
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
174
+ elif self.config.prediction_type == "sample":
175
+ pred_original_sample = model_output
176
+ elif self.config.prediction_type == "v_prediction":
177
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
178
+ else:
179
+ raise ValueError(
180
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
181
+ " `v_prediction` for the DDPMScheduler."
182
+ )
183
+
184
+ # 3. Clip or threshold "predicted x_0"
185
+ if self.config.thresholding:
186
+ pred_original_sample = self._threshold_sample(pred_original_sample)
187
+ elif self.config.clip_sample:
188
+ pred_original_sample = pred_original_sample.clamp(
189
+ -self.config.clip_sample_range, self.config.clip_sample_range
190
+ )
191
+
192
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
193
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
194
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
195
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
196
+
197
+ # 5. Compute predicted previous sample µ_t
198
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
199
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
200
+
201
+ # 6. Add noise
202
+ variance = 0
203
+ if t > 0:
204
+ device = model_output.device
205
+ variance_noise = self.noise_list[int(t.item() // (1000 // self.num_inference_steps))]
206
+ if self.variance_type == "fixed_small_log":
207
+ variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
208
+ elif self.variance_type == "learned_range":
209
+ variance = self._get_variance(t, predicted_variance=predicted_variance)
210
+ variance = torch.exp(0.5 * variance) * variance_noise
211
+ else:
212
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
213
+
214
+ pred_prev_sample = pred_prev_sample + variance
215
+
216
+ if not return_dict:
217
+ return (pred_prev_sample,)
218
+
219
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
src/enums_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image
4
+
5
+ from src.eunms import Model_Type, Scheduler_Type
6
+ from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
7
+ from src.lcm_scheduler import MyLCMScheduler
8
+ from src.ddpm_scheduler import MyDDPMScheduler
9
+ from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
10
+ from src.sd_inversion_pipeline import SDDDIMPipeline
11
+
12
+ def scheduler_type_to_class(scheduler_type):
13
+ if scheduler_type == Scheduler_Type.DDIM:
14
+ return DDIMScheduler
15
+ elif scheduler_type == Scheduler_Type.EULER:
16
+ return MyEulerAncestralDiscreteScheduler
17
+ elif scheduler_type == Scheduler_Type.LCM:
18
+ return MyLCMScheduler
19
+ elif scheduler_type == Scheduler_Type.DDPM:
20
+ return MyDDPMScheduler
21
+ else:
22
+ raise ValueError("Unknown scheduler type")
23
+
24
+ def model_type_to_class(model_type):
25
+ if model_type == Model_Type.SDXL:
26
+ return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline
27
+ elif model_type == Model_Type.SDXL_Turbo:
28
+ return AutoPipelineForImage2Image, SDXLDDIMPipeline
29
+ elif model_type == Model_Type.LCM_SDXL:
30
+ return AutoPipelineForImage2Image, SDXLDDIMPipeline
31
+ elif model_type == Model_Type.SD15:
32
+ return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
33
+ elif model_type == Model_Type.SD14:
34
+ return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
35
+ elif model_type == Model_Type.SD21:
36
+ return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
37
+ elif model_type == Model_Type.SD21_Turbo:
38
+ return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
39
+ else:
40
+ raise ValueError("Unknown model type")
41
+
42
+ def model_type_to_model_name(model_type):
43
+ if model_type == Model_Type.SDXL:
44
+ return "stabilityai/stable-diffusion-xl-base-1.0"
45
+ elif model_type == Model_Type.SDXL_Turbo:
46
+ return "stabilityai/sdxl-turbo"
47
+ elif model_type == Model_Type.LCM_SDXL:
48
+ return "stabilityai/stable-diffusion-xl-base-1.0"
49
+ elif model_type == Model_Type.SD15:
50
+ return "runwayml/stable-diffusion-v1-5"
51
+ elif model_type == Model_Type.SD14:
52
+ return "CompVis/stable-diffusion-v1-4"
53
+ elif model_type == Model_Type.SD21:
54
+ return "stabilityai/stable-diffusion-2-1"
55
+ elif model_type == Model_Type.SD21_Turbo:
56
+ return "stabilityai/sd-turbo"
57
+ else:
58
+ raise ValueError("Unknown model type")
59
+
60
+
61
+ def model_type_to_size(model_type):
62
+ if model_type == Model_Type.SDXL:
63
+ return (1024, 1024)
64
+ elif model_type == Model_Type.SDXL_Turbo:
65
+ return (512, 512)
66
+ elif model_type == Model_Type.LCM_SDXL:
67
+ return (768, 768) #TODO: check
68
+ elif model_type == Model_Type.SD15:
69
+ return (512, 512)
70
+ elif model_type == Model_Type.SD14:
71
+ return (512, 512)
72
+ elif model_type == Model_Type.SD21:
73
+ return (512, 512)
74
+ elif model_type == Model_Type.SD21_Turbo:
75
+ return (512, 512)
76
+ else:
77
+ raise ValueError("Unknown model type")
78
+
79
+ def is_float16(model_type):
80
+ if model_type == Model_Type.SDXL:
81
+ return True
82
+ elif model_type == Model_Type.SDXL_Turbo:
83
+ return True
84
+ elif model_type == Model_Type.LCM_SDXL:
85
+ return True
86
+ elif model_type == Model_Type.SD15:
87
+ return False
88
+ elif model_type == Model_Type.SD14:
89
+ return False
90
+ elif model_type == Model_Type.SD21:
91
+ return False
92
+ elif model_type == Model_Type.SD21_Turbo:
93
+ return False
94
+ else:
95
+ raise ValueError("Unknown model type")
96
+
97
+ def is_sd(model_type):
98
+ if model_type == Model_Type.SDXL:
99
+ return False
100
+ elif model_type == Model_Type.SDXL_Turbo:
101
+ return False
102
+ elif model_type == Model_Type.LCM_SDXL:
103
+ return False
104
+ elif model_type == Model_Type.SD15:
105
+ return True
106
+ elif model_type == Model_Type.SD14:
107
+ return True
108
+ elif model_type == Model_Type.SD21:
109
+ return True
110
+ elif model_type == Model_Type.SD21_Turbo:
111
+ return True
112
+ else:
113
+ raise ValueError("Unknown model type")
114
+
115
+ def _get_pipes(model_type, device):
116
+ model_name = model_type_to_model_name(model_type)
117
+ pipeline_inf, pipeline_inv = model_type_to_class(model_type)
118
+
119
+ if is_float16(model_type):
120
+ pipe_inversion = pipeline_inv.from_pretrained(
121
+ model_name,
122
+ torch_dtype=torch.float16,
123
+ use_safetensors=True,
124
+ variant="fp16",
125
+ safety_checker = None
126
+ ).to(device)
127
+
128
+ pipe_inference = pipeline_inf.from_pretrained(
129
+ model_name,
130
+ torch_dtype=torch.float16,
131
+ use_safetensors=True,
132
+ variant="fp16",
133
+ safety_checker = None
134
+ ).to(device)
135
+ else:
136
+ pipe_inversion = pipeline_inv.from_pretrained(
137
+ model_name,
138
+ use_safetensors=True,
139
+ safety_checker = None
140
+ ).to(device)
141
+
142
+ pipe_inference = pipeline_inf.from_pretrained(
143
+ model_name,
144
+ use_safetensors=True,
145
+ safety_checker = None
146
+ ).to(device)
147
+
148
+ return pipe_inversion, pipe_inference
149
+
150
+ def get_pipes(model_type, scheduler_type, device="cuda"):
151
+ # model_name = model_type_to_model_name(model_type)
152
+ # pipeline_inf, pipeline_inv = model_type_to_class(model_type)
153
+ scheduler_class = scheduler_type_to_class(scheduler_type)
154
+
155
+ pipe_inversion, pipe_inference = _get_pipes(model_type, device)
156
+
157
+ # pipe_inversion = pipeline_inv.from_pretrained(
158
+ # model_name,
159
+ # # torch_dtype=torch.float16,
160
+ # use_safetensors=True,
161
+ # # variant="fp16",
162
+ # safety_checker = None
163
+ # ).to("cuda")
164
+
165
+ # pipe_inference = pipeline_inf.from_pretrained(
166
+ # model_name,
167
+ # # torch_dtype=torch.float16,
168
+ # use_safetensors=True,
169
+ # # variant="fp16",
170
+ # safety_checker = None
171
+ # ).to("cuda")
172
+
173
+ pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
174
+ pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
175
+ pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
176
+
177
+ if is_sd(model_type):
178
+ pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
179
+ pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
180
+ pipe_inversion.scheduler_inference.add_noise = lambda init_latents, noise, timestep: init_latents
181
+
182
+ if model_type == Model_Type.LCM_SDXL:
183
+ adapter_id = "latent-consistency/lcm-lora-sdxl"
184
+ # load and fuse lcm lora
185
+ pipe_inversion.load_lora_weights(adapter_id)
186
+ # pipe_inversion.fuse_lora()
187
+ pipe_inference.load_lora_weights(adapter_id)
188
+ # pipe_inference.fuse_lora()
189
+
190
+ return pipe_inversion, pipe_inference
src/euler_scheduler.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import EulerAncestralDiscreteScheduler, LCMScheduler
2
+ from diffusers.utils import BaseOutput
3
+ from diffusers.utils.torch_utils import randn_tensor
4
+ import torch
5
+ from typing import List, Optional, Tuple, Union
6
+ import numpy as np
7
+
8
+ from src.eunms import Epsilon_Update_Type
9
+
10
+ # g_cpu = torch.Generator().manual_seed(7865)
11
+ # noise = [randn_tensor((1, 4, 64, 64), dtype=torch.float16, device=torch.device("cuda:0"), generator=g_cpu) for i in range(4)]
12
+ # for i, n in enumerate(noise):
13
+ # torch.save(n, f"noise_{i}.pt")
14
+
15
+ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
16
+ """
17
+ Output class for the scheduler's `step` function output.
18
+
19
+ Args:
20
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
21
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
22
+ denoising loop.
23
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
24
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
25
+ `pred_original_sample` can be used to preview progress or for guidance.
26
+ """
27
+
28
+ prev_sample: torch.FloatTensor
29
+ pred_original_sample: Optional[torch.FloatTensor] = None
30
+
31
+ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
32
+ def set_noise_list(self, noise_list):
33
+ self.noise_list = noise_list
34
+
35
+ def get_noise_to_remove(self):
36
+ sigma_from = self.sigmas[self.step_index]
37
+ sigma_to = self.sigmas[self.step_index + 1]
38
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
39
+
40
+ return self.noise_list[self.step_index] * sigma_up\
41
+
42
+ def scale_model_input(
43
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
44
+ ) -> torch.FloatTensor:
45
+ """
46
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
47
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
48
+
49
+ Args:
50
+ sample (`torch.FloatTensor`):
51
+ The input sample.
52
+ timestep (`int`, *optional*):
53
+ The current timestep in the diffusion chain.
54
+
55
+ Returns:
56
+ `torch.FloatTensor`:
57
+ A scaled input sample.
58
+ """
59
+
60
+ self._init_step_index(timestep.view((1)))
61
+ return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep)
62
+
63
+
64
+ def step(
65
+ self,
66
+ model_output: torch.FloatTensor,
67
+ timestep: Union[float, torch.FloatTensor],
68
+ sample: torch.FloatTensor,
69
+ generator: Optional[torch.Generator] = None,
70
+ return_dict: bool = True,
71
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
72
+ """
73
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
74
+ process from the learned model outputs (most often the predicted noise).
75
+
76
+ Args:
77
+ model_output (`torch.FloatTensor`):
78
+ The direct output from learned diffusion model.
79
+ timestep (`float`):
80
+ The current discrete timestep in the diffusion chain.
81
+ sample (`torch.FloatTensor`):
82
+ A current instance of a sample created by the diffusion process.
83
+ generator (`torch.Generator`, *optional*):
84
+ A random number generator.
85
+ return_dict (`bool`):
86
+ Whether or not to return a
87
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
88
+
89
+ Returns:
90
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
91
+ If return_dict is `True`,
92
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
93
+ otherwise a tuple is returned where the first element is the sample tensor.
94
+
95
+ """
96
+
97
+ if (
98
+ isinstance(timestep, int)
99
+ or isinstance(timestep, torch.IntTensor)
100
+ or isinstance(timestep, torch.LongTensor)
101
+ ):
102
+ raise ValueError(
103
+ (
104
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
105
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
106
+ " one of the `scheduler.timesteps` as a timestep."
107
+ ),
108
+ )
109
+
110
+ if not self.is_scale_input_called:
111
+ logger.warning(
112
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
113
+ "See `StableDiffusionPipeline` for a usage example."
114
+ )
115
+
116
+ self._init_step_index(timestep.view((1)))
117
+
118
+ sigma = self.sigmas[self.step_index]
119
+
120
+ # Upcast to avoid precision issues when computing prev_sample
121
+ sample = sample.to(torch.float32)
122
+
123
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
124
+ if self.config.prediction_type == "epsilon":
125
+ pred_original_sample = sample - sigma * model_output
126
+ elif self.config.prediction_type == "v_prediction":
127
+ # * c_out + input * c_skip
128
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
129
+ elif self.config.prediction_type == "sample":
130
+ raise NotImplementedError("prediction_type not implemented yet: sample")
131
+ else:
132
+ raise ValueError(
133
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
134
+ )
135
+
136
+ sigma_from = self.sigmas[self.step_index]
137
+ sigma_to = self.sigmas[self.step_index + 1]
138
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
139
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
140
+
141
+ # 2. Convert to an ODE derivative
142
+ # derivative = (sample - pred_original_sample) / sigma
143
+ derivative = model_output
144
+
145
+ dt = sigma_down - sigma
146
+
147
+ prev_sample = sample + derivative * dt
148
+
149
+ device = model_output.device
150
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
151
+ # prev_sample = prev_sample + noise * sigma_up
152
+
153
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
154
+
155
+ # Cast sample back to model compatible dtype
156
+ prev_sample = prev_sample.to(model_output.dtype)
157
+
158
+ # upon completion increase step index by one
159
+ self._step_index += 1
160
+
161
+ if not return_dict:
162
+ return (prev_sample,)
163
+
164
+ return EulerAncestralDiscreteSchedulerOutput(
165
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
166
+ )
167
+
168
+ def step_and_update_noise(
169
+ self,
170
+ model_output: torch.FloatTensor,
171
+ timestep: Union[float, torch.FloatTensor],
172
+ sample: torch.FloatTensor,
173
+ expected_prev_sample: torch.FloatTensor,
174
+ update_epsilon_type=Epsilon_Update_Type.OVERRIDE,
175
+ generator: Optional[torch.Generator] = None,
176
+ return_dict: bool = True,
177
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
178
+ """
179
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
180
+ process from the learned model outputs (most often the predicted noise).
181
+
182
+ Args:
183
+ model_output (`torch.FloatTensor`):
184
+ The direct output from learned diffusion model.
185
+ timestep (`float`):
186
+ The current discrete timestep in the diffusion chain.
187
+ sample (`torch.FloatTensor`):
188
+ A current instance of a sample created by the diffusion process.
189
+ generator (`torch.Generator`, *optional*):
190
+ A random number generator.
191
+ return_dict (`bool`):
192
+ Whether or not to return a
193
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
194
+
195
+ Returns:
196
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
197
+ If return_dict is `True`,
198
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
199
+ otherwise a tuple is returned where the first element is the sample tensor.
200
+
201
+ """
202
+
203
+ if (
204
+ isinstance(timestep, int)
205
+ or isinstance(timestep, torch.IntTensor)
206
+ or isinstance(timestep, torch.LongTensor)
207
+ ):
208
+ raise ValueError(
209
+ (
210
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
211
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
212
+ " one of the `scheduler.timesteps` as a timestep."
213
+ ),
214
+ )
215
+
216
+ if not self.is_scale_input_called:
217
+ logger.warning(
218
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
219
+ "See `StableDiffusionPipeline` for a usage example."
220
+ )
221
+
222
+ self._init_step_index(timestep.view((1)))
223
+
224
+ sigma = self.sigmas[self.step_index]
225
+
226
+ # Upcast to avoid precision issues when computing prev_sample
227
+ sample = sample.to(torch.float32)
228
+
229
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
230
+ if self.config.prediction_type == "epsilon":
231
+ pred_original_sample = sample - sigma * model_output
232
+ elif self.config.prediction_type == "v_prediction":
233
+ # * c_out + input * c_skip
234
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
235
+ elif self.config.prediction_type == "sample":
236
+ raise NotImplementedError("prediction_type not implemented yet: sample")
237
+ else:
238
+ raise ValueError(
239
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
240
+ )
241
+
242
+ sigma_from = self.sigmas[self.step_index]
243
+ sigma_to = self.sigmas[self.step_index + 1]
244
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
245
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
246
+
247
+ # 2. Convert to an ODE derivative
248
+ # derivative = (sample - pred_original_sample) / sigma
249
+ derivative = model_output
250
+
251
+ dt = sigma_down - sigma
252
+
253
+ prev_sample = sample + derivative * dt
254
+
255
+ device = model_output.device
256
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
257
+ # prev_sample = prev_sample + noise * sigma_up
258
+
259
+ if sigma_up > 0:
260
+ req_noise = (expected_prev_sample - prev_sample) / sigma_up
261
+ if update_epsilon_type == Epsilon_Update_Type.OVERRIDE:
262
+ self.noise_list[self.step_index] = req_noise
263
+ else:
264
+ for i in range(10):
265
+ n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True)
266
+ loss = torch.norm(n - req_noise.detach())
267
+ loss.backward()
268
+ self.noise_list[self.step_index] -= n.grad.detach() * 1.8
269
+
270
+
271
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
272
+
273
+ # Cast sample back to model compatible dtype
274
+ prev_sample = prev_sample.to(model_output.dtype)
275
+
276
+ # upon completion increase step index by one
277
+ self._step_index += 1
278
+
279
+ if not return_dict:
280
+ return (prev_sample,)
281
+
282
+ return EulerAncestralDiscreteSchedulerOutput(
283
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
284
+ )
285
+
286
+ def inv_step(
287
+ self,
288
+ model_output: torch.FloatTensor,
289
+ timestep: Union[float, torch.FloatTensor],
290
+ sample: torch.FloatTensor,
291
+ generator: Optional[torch.Generator] = None,
292
+ return_dict: bool = True,
293
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
294
+ """
295
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
296
+ process from the learned model outputs (most often the predicted noise).
297
+
298
+ Args:
299
+ model_output (`torch.FloatTensor`):
300
+ The direct output from learned diffusion model.
301
+ timestep (`float`):
302
+ The current discrete timestep in the diffusion chain.
303
+ sample (`torch.FloatTensor`):
304
+ A current instance of a sample created by the diffusion process.
305
+ generator (`torch.Generator`, *optional*):
306
+ A random number generator.
307
+ return_dict (`bool`):
308
+ Whether or not to return a
309
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
310
+
311
+ Returns:
312
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
313
+ If return_dict is `True`,
314
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
315
+ otherwise a tuple is returned where the first element is the sample tensor.
316
+
317
+ """
318
+
319
+ if (
320
+ isinstance(timestep, int)
321
+ or isinstance(timestep, torch.IntTensor)
322
+ or isinstance(timestep, torch.LongTensor)
323
+ ):
324
+ raise ValueError(
325
+ (
326
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
327
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
328
+ " one of the `scheduler.timesteps` as a timestep."
329
+ ),
330
+ )
331
+
332
+ if not self.is_scale_input_called:
333
+ logger.warning(
334
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
335
+ "See `StableDiffusionPipeline` for a usage example."
336
+ )
337
+
338
+ self._init_step_index(timestep.view((1)))
339
+
340
+ sigma = self.sigmas[self.step_index]
341
+
342
+ # Upcast to avoid precision issues when computing prev_sample
343
+ sample = sample.to(torch.float32)
344
+
345
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
346
+ if self.config.prediction_type == "epsilon":
347
+ pred_original_sample = sample - sigma * model_output
348
+ elif self.config.prediction_type == "v_prediction":
349
+ # * c_out + input * c_skip
350
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
351
+ elif self.config.prediction_type == "sample":
352
+ raise NotImplementedError("prediction_type not implemented yet: sample")
353
+ else:
354
+ raise ValueError(
355
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
356
+ )
357
+
358
+ sigma_from = self.sigmas[self.step_index]
359
+ sigma_to = self.sigmas[self.step_index+1]
360
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
361
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
362
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
363
+ sigma_down = sigma_to**2 / sigma_from
364
+
365
+ # 2. Convert to an ODE derivative
366
+ # derivative = (sample - pred_original_sample) / sigma
367
+ derivative = model_output
368
+
369
+ dt = sigma_down - sigma
370
+ # dt = sigma_down - sigma_from
371
+
372
+ prev_sample = sample - derivative * dt
373
+
374
+ device = model_output.device
375
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
376
+ # prev_sample = prev_sample + noise * sigma_up
377
+
378
+ prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
379
+
380
+ # Cast sample back to model compatible dtype
381
+ prev_sample = prev_sample.to(model_output.dtype)
382
+
383
+ # upon completion increase step index by one
384
+ self._step_index += 1
385
+
386
+ if not return_dict:
387
+ return (prev_sample,)
388
+
389
+ return EulerAncestralDiscreteSchedulerOutput(
390
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
391
+ )
392
+
393
+ def get_all_sigmas(self) -> torch.FloatTensor:
394
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
395
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
396
+ return torch.from_numpy(sigmas)
397
+
398
+ def add_noise_off_schedule(
399
+ self,
400
+ original_samples: torch.FloatTensor,
401
+ noise: torch.FloatTensor,
402
+ timesteps: torch.FloatTensor,
403
+ ) -> torch.FloatTensor:
404
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
405
+ sigmas = self.get_all_sigmas()
406
+ sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
407
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
408
+ # mps does not support float64
409
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
410
+ else:
411
+ timesteps = timesteps.to(original_samples.device)
412
+
413
+ step_indices = 1000 - int(timesteps.item())
414
+
415
+ sigma = sigmas[step_indices].flatten()
416
+ while len(sigma.shape) < len(original_samples.shape):
417
+ sigma = sigma.unsqueeze(-1)
418
+
419
+ noisy_samples = original_samples + noise * sigma
420
+ return noisy_samples
421
+
422
+ # def update_noise_for_friendly_inversion(
423
+ # self,
424
+ # model_output: torch.FloatTensor,
425
+ # timestep: Union[float, torch.FloatTensor],
426
+ # z_t: torch.FloatTensor,
427
+ # z_tp1: torch.FloatTensor,
428
+ # return_dict: bool = True,
429
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
430
+ # if (
431
+ # isinstance(timestep, int)
432
+ # or isinstance(timestep, torch.IntTensor)
433
+ # or isinstance(timestep, torch.LongTensor)
434
+ # ):
435
+ # raise ValueError(
436
+ # (
437
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
438
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
439
+ # " one of the `scheduler.timesteps` as a timestep."
440
+ # ),
441
+ # )
442
+
443
+ # if not self.is_scale_input_called:
444
+ # logger.warning(
445
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
446
+ # "See `StableDiffusionPipeline` for a usage example."
447
+ # )
448
+
449
+ # self._init_step_index(timestep.view((1)))
450
+
451
+ # sigma = self.sigmas[self.step_index]
452
+
453
+ # sigma_from = self.sigmas[self.step_index]
454
+ # sigma_to = self.sigmas[self.step_index+1]
455
+ # # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
456
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
457
+ # # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
458
+ # sigma_down = sigma_to**2 / sigma_from
459
+
460
+ # # 2. Conv = (sample - pred_original_sample) / sigma
461
+ # derivative = model_output
462
+
463
+ # dt = sigma_down - sigma
464
+ # # dt = sigma_down - sigma_from
465
+
466
+ # prev_sample = z_t - derivative * dt
467
+
468
+ # if sigma_up > 0:
469
+ # self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up
470
+
471
+ # prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
472
+
473
+
474
+ # if not return_dict:
475
+ # return (prev_sample,)
476
+
477
+ # return EulerAncestralDiscreteSchedulerOutput(
478
+ # prev_sample=prev_sample, pred_original_sample=None
479
+ # )
480
+
481
+
482
+ # def step_friendly_inversion(
483
+ # self,
484
+ # model_output: torch.FloatTensor,
485
+ # timestep: Union[float, torch.FloatTensor],
486
+ # sample: torch.FloatTensor,
487
+ # generator: Optional[torch.Generator] = None,
488
+ # return_dict: bool = True,
489
+ # expected_next_sample: torch.FloatTensor = None,
490
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
491
+ # """
492
+ # Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
493
+ # process from the learned model outputs (most often the predicted noise).
494
+
495
+ # Args:
496
+ # model_output (`torch.FloatTensor`):
497
+ # The direct output from learned diffusion model.
498
+ # timestep (`float`):
499
+ # The current discrete timestep in the diffusion chain.
500
+ # sample (`torch.FloatTensor`):
501
+ # A current instance of a sample created by the diffusion process.
502
+ # generator (`torch.Generator`, *optional*):
503
+ # A random number generator.
504
+ # return_dict (`bool`):
505
+ # Whether or not to return a
506
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
507
+
508
+ # Returns:
509
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
510
+ # If return_dict is `True`,
511
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
512
+ # otherwise a tuple is returned where the first element is the sample tensor.
513
+
514
+ # """
515
+
516
+ # if (
517
+ # isinstance(timestep, int)
518
+ # or isinstance(timestep, torch.IntTensor)
519
+ # or isinstance(timestep, torch.LongTensor)
520
+ # ):
521
+ # raise ValueError(
522
+ # (
523
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
524
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
525
+ # " one of the `scheduler.timesteps` as a timestep."
526
+ # ),
527
+ # )
528
+
529
+ # if not self.is_scale_input_called:
530
+ # logger.warning(
531
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
532
+ # "See `StableDiffusionPipeline` for a usage example."
533
+ # )
534
+
535
+ # self._init_step_index(timestep.view((1)))
536
+
537
+ # sigma = self.sigmas[self.step_index]
538
+
539
+ # # Upcast to avoid precision issues when computing prev_sample
540
+ # sample = sample.to(torch.float32)
541
+
542
+ # # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
543
+ # if self.config.prediction_type == "epsilon":
544
+ # pred_original_sample = sample - sigma * model_output
545
+ # elif self.config.prediction_type == "v_prediction":
546
+ # # * c_out + input * c_skip
547
+ # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
548
+ # elif self.config.prediction_type == "sample":
549
+ # raise NotImplementedError("prediction_type not implemented yet: sample")
550
+ # else:
551
+ # raise ValueError(
552
+ # f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
553
+ # )
554
+
555
+ # sigma_from = self.sigmas[self.step_index]
556
+ # sigma_to = self.sigmas[self.step_index + 1]
557
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
558
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
559
+
560
+ # # 2. Convert to an ODE derivative
561
+ # # derivative = (sample - pred_original_sample) / sigma
562
+ # derivative = model_output
563
+
564
+ # dt = sigma_down - sigma
565
+
566
+ # prev_sample = sample + derivative * dt
567
+
568
+ # device = model_output.device
569
+ # # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
570
+ # # prev_sample = prev_sample + noise * sigma_up
571
+
572
+ # if sigma_up > 0:
573
+ # self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up
574
+
575
+ # prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
576
+
577
+ # # Cast sample back to model compatible dtype
578
+ # prev_sample = prev_sample.to(model_output.dtype)
579
+
580
+ # # upon completion increase step index by one
581
+ # self._step_index += 1
582
+
583
+ # if not return_dict:
584
+ # return (prev_sample,)
585
+
586
+ # return EulerAncestralDiscreteSchedulerOutput(
587
+ # prev_sample=prev_sample, pred_original_sample=pred_original_sample
588
+ # )
src/eunms.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class Scheduler_Type(Enum):
4
+ DDIM = 1
5
+ EULER = 2
6
+ LCM = 3
7
+ DDPM = 4
8
+
9
+ class Model_Type(Enum):
10
+ SDXL = 1
11
+ SDXL_Turbo = 2
12
+ LCM_SDXL = 3
13
+ SD15 = 4
14
+ SD21 = 5
15
+ SD21_Turbo = 6
16
+ SD14 = 7
17
+
18
+ class Gradient_Averaging_Type(Enum):
19
+ NONE = 1
20
+ EACH_ITER = 2
21
+ ON_END = 3
22
+
23
+ class Epsilon_Update_Type(Enum):
24
+ NONE = 1
25
+ OVERRIDE = 2
26
+ OPTIMIZE = 3
src/images_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import torch
4
+
5
+ def read_images_in_path(path, size = (512,512)):
6
+ image_paths = []
7
+ for filename in os.listdir(path):
8
+ if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"):
9
+ image_path = os.path.join(path, filename)
10
+ image_paths.append(image_path)
11
+ image_paths = sorted(image_paths)
12
+ return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths]
13
+
14
+ def concatenate_images(image_lists, return_list = False):
15
+ num_rows = len(image_lists[0])
16
+ num_columns = len(image_lists)
17
+ image_width = image_lists[0][0].width
18
+ image_height = image_lists[0][0].height
19
+
20
+ grid_width = num_columns * image_width
21
+ grid_height = num_rows * image_height if not return_list else image_height
22
+ if not return_list:
23
+ grid_image = [Image.new('RGB', (grid_width, grid_height))]
24
+ else:
25
+ grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)]
26
+
27
+ for i in range(num_rows):
28
+ row_index = i if return_list else 0
29
+ for j in range(num_columns):
30
+ image = image_lists[j][i]
31
+ x_offset = j * image_width
32
+ y_offset = i * image_height if not return_list else 0
33
+ grid_image[row_index].paste(image, (x_offset, y_offset))
34
+
35
+ return grid_image if return_list else grid_image[0]
36
+
37
+ def concatenate_images_single(image_lists):
38
+ num_columns = len(image_lists)
39
+ image_width = image_lists[0].width
40
+ image_height = image_lists[0].height
41
+
42
+ grid_width = num_columns * image_width
43
+ grid_height = image_height
44
+ grid_image = Image.new('RGB', (grid_width, grid_height))
45
+
46
+ for j in range(num_columns):
47
+ image = image_lists[j]
48
+ x_offset = j * image_width
49
+ y_offset = 0
50
+ grid_image.paste(image, (x_offset, y_offset))
51
+
52
+ return grid_image
53
+
54
+ def get_captions_for_images(images, device):
55
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
56
+
57
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
58
+ model = Blip2ForConditionalGeneration.from_pretrained(
59
+ "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
60
+ ) # doctest: +IGNORE_RESULT
61
+
62
+ res = []
63
+
64
+ for image in images:
65
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
66
+
67
+ generated_ids = model.generate(**inputs)
68
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
69
+ res.append(generated_text)
70
+
71
+ del processor
72
+ del model
73
+
74
+ return res
src/inversion_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from random import randrange
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def noise_regularization(
7
+ e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls
8
+ ):
9
+ for _outer in range(num_reg_steps):
10
+ if lambda_kl > 0:
11
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
12
+ l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
13
+ l_kld.backward()
14
+ _grad = _var.grad.detach()
15
+ _grad = torch.clip(_grad, -100, 100)
16
+ e_t = e_t - lambda_kl * _grad
17
+ if lambda_ac > 0:
18
+ for _inner in range(num_ac_rolls):
19
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
20
+ l_ac = auto_corr_loss(_var)
21
+ l_ac.backward()
22
+ _grad = _var.grad.detach() / num_ac_rolls
23
+ e_t = e_t - lambda_ac * _grad
24
+ e_t = e_t.detach()
25
+
26
+ return e_t
27
+
28
+
29
+ def auto_corr_loss(x, random_shift=True):
30
+ B, C, H, W = x.shape
31
+ assert B == 1
32
+ x = x.squeeze(0)
33
+ # x must be shape [C,H,W] now
34
+ reg_loss = 0.0
35
+ for ch_idx in range(x.shape[0]):
36
+ noise = x[ch_idx][None, None, :, :]
37
+ while True:
38
+ if random_shift:
39
+ roll_amount = randrange(noise.shape[2] // 2)
40
+ else:
41
+ roll_amount = 1
42
+ reg_loss += (
43
+ noise * torch.roll(noise, shifts=roll_amount, dims=2)
44
+ ).mean() ** 2
45
+ reg_loss += (
46
+ noise * torch.roll(noise, shifts=roll_amount, dims=3)
47
+ ).mean() ** 2
48
+ if noise.shape[2] <= 8:
49
+ break
50
+ noise = F.avg_pool2d(noise, kernel_size=2)
51
+ return reg_loss
52
+
53
+
54
+ def patchify_latents_kl_divergence(x0, x1, patch_size=4, num_channels=4):
55
+
56
+ def patchify_tensor(input_tensor):
57
+ patches = (
58
+ input_tensor.unfold(1, patch_size, patch_size)
59
+ .unfold(2, patch_size, patch_size)
60
+ .unfold(3, patch_size, patch_size)
61
+ )
62
+ patches = patches.contiguous().view(-1, num_channels, patch_size, patch_size)
63
+ return patches
64
+
65
+ x0 = patchify_tensor(x0)
66
+ x1 = patchify_tensor(x1)
67
+
68
+ kl = latents_kl_divergence(x0, x1).sum()
69
+ return kl
70
+
71
+
72
+ def latents_kl_divergence(x0, x1):
73
+ EPSILON = 1e-6
74
+ x0 = x0.view(x0.shape[0], x0.shape[1], -1)
75
+ x1 = x1.view(x1.shape[0], x1.shape[1], -1)
76
+ mu0 = x0.mean(dim=-1)
77
+ mu1 = x1.mean(dim=-1)
78
+ var0 = x0.var(dim=-1)
79
+ var1 = x1.var(dim=-1)
80
+ kl = (
81
+ torch.log((var1 + EPSILON) / (var0 + EPSILON))
82
+ + (var0 + (mu0 - mu1) ** 2) / (var1 + EPSILON)
83
+ - 1
84
+ )
85
+ kl = torch.abs(kl).sum(dim=-1)
86
+ return kl
src/lcm_scheduler.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import LCMScheduler
2
+ from diffusers.utils import BaseOutput
3
+ from diffusers.utils.torch_utils import randn_tensor
4
+ import torch
5
+ from typing import List, Optional, Tuple, Union
6
+ import numpy as np
7
+
8
+ class LCMSchedulerOutput(BaseOutput):
9
+ """
10
+ Output class for the scheduler's `step` function output.
11
+
12
+ Args:
13
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
14
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
15
+ denoising loop.
16
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
18
+ `pred_original_sample` can be used to preview progress or for guidance.
19
+ """
20
+
21
+ prev_sample: torch.FloatTensor
22
+ denoised: Optional[torch.FloatTensor] = None
23
+
24
+ class MyLCMScheduler(LCMScheduler):
25
+
26
+ def set_noise_list(self, noise_list):
27
+ self.noise_list = noise_list
28
+
29
+ def step(
30
+ self,
31
+ model_output: torch.FloatTensor,
32
+ timestep: int,
33
+ sample: torch.FloatTensor,
34
+ generator: Optional[torch.Generator] = None,
35
+ return_dict: bool = True,
36
+ ) -> Union[LCMSchedulerOutput, Tuple]:
37
+ """
38
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
39
+ process from the learned model outputs (most often the predicted noise).
40
+
41
+ Args:
42
+ model_output (`torch.FloatTensor`):
43
+ The direct output from learned diffusion model.
44
+ timestep (`float`):
45
+ The current discrete timestep in the diffusion chain.
46
+ sample (`torch.FloatTensor`):
47
+ A current instance of a sample created by the diffusion process.
48
+ generator (`torch.Generator`, *optional*):
49
+ A random number generator.
50
+ return_dict (`bool`, *optional*, defaults to `True`):
51
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
52
+ Returns:
53
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
54
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
55
+ tuple is returned where the first element is the sample tensor.
56
+ """
57
+ if self.num_inference_steps is None:
58
+ raise ValueError(
59
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
60
+ )
61
+
62
+ self._init_step_index(timestep)
63
+
64
+ # 1. get previous step value
65
+ prev_step_index = self.step_index + 1
66
+ if prev_step_index < len(self.timesteps):
67
+ prev_timestep = self.timesteps[prev_step_index]
68
+ else:
69
+ prev_timestep = timestep
70
+
71
+ # 2. compute alphas, betas
72
+ alpha_prod_t = self.alphas_cumprod[timestep]
73
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
74
+
75
+ beta_prod_t = 1 - alpha_prod_t
76
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
77
+
78
+ # 3. Get scalings for boundary conditions
79
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
80
+
81
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
82
+ if self.config.prediction_type == "epsilon": # noise-prediction
83
+ predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
84
+ elif self.config.prediction_type == "sample": # x-prediction
85
+ predicted_original_sample = model_output
86
+ elif self.config.prediction_type == "v_prediction": # v-prediction
87
+ predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
88
+ else:
89
+ raise ValueError(
90
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
91
+ " `v_prediction` for `LCMScheduler`."
92
+ )
93
+
94
+ # 5. Clip or threshold "predicted x_0"
95
+ if self.config.thresholding:
96
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
97
+ elif self.config.clip_sample:
98
+ predicted_original_sample = predicted_original_sample.clamp(
99
+ -self.config.clip_sample_range, self.config.clip_sample_range
100
+ )
101
+
102
+ # 6. Denoise model output using boundary conditions
103
+ denoised = c_out * predicted_original_sample + c_skip * sample
104
+
105
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
106
+ # Noise is not used on the final timestep of the timestep schedule.
107
+ # This also means that noise is not used for one-step sampling.
108
+ if self.step_index != self.num_inference_steps - 1:
109
+ noise = self.noise_list[self.step_index]
110
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
111
+ else:
112
+ prev_sample = denoised
113
+
114
+ # upon completion increase step index by one
115
+ self._step_index += 1
116
+
117
+ if not return_dict:
118
+ return (prev_sample, denoised)
119
+
120
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
121
+
122
+
123
+ def inv_step(
124
+ self,
125
+ model_output: torch.FloatTensor,
126
+ timestep: int,
127
+ sample: torch.FloatTensor,
128
+ generator: Optional[torch.Generator] = None,
129
+ return_dict: bool = True,
130
+ ) -> Union[LCMSchedulerOutput, Tuple]:
131
+ """
132
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
133
+ process from the learned model outputs (most often the predicted noise).
134
+
135
+ Args:
136
+ model_output (`torch.FloatTensor`):
137
+ The direct output from learned diffusion model.
138
+ timestep (`float`):
139
+ The current discrete timestep in the diffusion chain.
140
+ sample (`torch.FloatTensor`):
141
+ A current instance of a sample created by the diffusion process.
142
+ generator (`torch.Generator`, *optional*):
143
+ A random number generator.
144
+ return_dict (`bool`, *optional*, defaults to `True`):
145
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
146
+ Returns:
147
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
148
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
149
+ tuple is returned where the first element is the sample tensor.
150
+ """
151
+ if self.num_inference_steps is None:
152
+ raise ValueError(
153
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
154
+ )
155
+
156
+ self._init_step_index(timestep)
157
+
158
+ # 1. get previous step value
159
+ prev_step_index = self.step_index + 1
160
+ if prev_step_index < len(self.timesteps):
161
+ prev_timestep = self.timesteps[prev_step_index]
162
+ else:
163
+ prev_timestep = timestep
164
+
165
+ # 2. compute alphas, betas
166
+ alpha_prod_t = self.alphas_cumprod[timestep]
167
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
168
+
169
+ beta_prod_t = 1 - alpha_prod_t
170
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
171
+
172
+ # 3. Get scalings for boundary conditions
173
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
174
+
175
+ if self.step_index != self.num_inference_steps - 1:
176
+ c_skip_actual = c_skip * alpha_prod_t_prev.sqrt()
177
+ c_out_actual = c_out * alpha_prod_t_prev.sqrt()
178
+ noise = self.noise_list[self.step_index] * beta_prod_t_prev.sqrt()
179
+ else:
180
+ c_skip_actual = c_skip
181
+ c_out_actual = c_out
182
+ noise = 0
183
+
184
+
185
+ dem = c_out_actual / (alpha_prod_t.sqrt()) + c_skip
186
+ eps_mul = beta_prod_t.sqrt() * c_out_actual / (alpha_prod_t.sqrt())
187
+
188
+ prev_sample = (sample + eps_mul * model_output - noise) / dem
189
+
190
+ # upon completion increase step index by one
191
+ self._step_index += 1
192
+
193
+ if not return_dict:
194
+ return (prev_sample, prev_sample)
195
+
196
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=prev_sample)
src/lpips.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ from itertools import chain
5
+ from torchvision import models
6
+ from typing import Sequence
7
+ from collections import OrderedDict
8
+
9
+ def get_network(net_type: str = 'vgg'):
10
+ if net_type == 'alex':
11
+ return AlexNet()
12
+ elif net_type == 'squeeze':
13
+ return SqueezeNet()
14
+ elif net_type == 'vgg':
15
+ return VGG16()
16
+ else:
17
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
18
+
19
+ def normalize_activation(x, eps=1e-10):
20
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
21
+ return x / (norm_factor + eps)
22
+
23
+ class BaseNet(nn.Module):
24
+ def __init__(self):
25
+ super(BaseNet, self).__init__()
26
+
27
+ # register buffer
28
+ self.register_buffer(
29
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
30
+ self.register_buffer(
31
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
32
+
33
+ def set_requires_grad(self, state: bool):
34
+ for param in chain(self.parameters(), self.buffers()):
35
+ param.requires_grad = state
36
+
37
+ def z_score(self, x: torch.Tensor):
38
+ return (x - self.mean) / self.std
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ x = self.z_score(x)
42
+
43
+ output = []
44
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
45
+ x = layer(x)
46
+ if i in self.target_layers:
47
+ output.append(normalize_activation(x))
48
+ if len(output) == len(self.target_layers):
49
+ break
50
+ return output
51
+
52
+
53
+ class SqueezeNet(BaseNet):
54
+ def __init__(self):
55
+ super(SqueezeNet, self).__init__()
56
+
57
+ self.layers = models.squeezenet1_1(True).features
58
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
59
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
60
+
61
+ self.set_requires_grad(False)
62
+
63
+
64
+ class AlexNet(BaseNet):
65
+ def __init__(self):
66
+ super(AlexNet, self).__init__()
67
+
68
+ self.layers = models.alexnet(True).features
69
+ self.target_layers = [2, 5, 8, 10, 12]
70
+ self.n_channels_list = [64, 192, 384, 256, 256]
71
+
72
+ self.set_requires_grad(False)
73
+
74
+
75
+ class VGG16(BaseNet):
76
+ def __init__(self):
77
+ super(VGG16, self).__init__()
78
+
79
+ self.layers = models.vgg16(True).features
80
+ self.target_layers = [4, 9, 16, 23, 30]
81
+ self.n_channels_list = [64, 128, 256, 512, 512]
82
+
83
+ self.set_requires_grad(False)
84
+
85
+ class LinLayers(nn.ModuleList):
86
+ def __init__(self, n_channels_list: Sequence[int]):
87
+ super(LinLayers, self).__init__([
88
+ nn.Sequential(
89
+ nn.Identity(),
90
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
91
+ ) for nc in n_channels_list
92
+ ])
93
+
94
+ for param in self.parameters():
95
+ param.requires_grad = False
96
+
97
+
98
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
99
+ # build url
100
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
101
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
102
+
103
+ # download
104
+ old_state_dict = torch.hub.load_state_dict_from_url(
105
+ url, progress=True,
106
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
107
+ )
108
+
109
+ # rename keys
110
+ new_state_dict = OrderedDict()
111
+ for key, val in old_state_dict.items():
112
+ new_key = key
113
+ new_key = new_key.replace('lin', '')
114
+ new_key = new_key.replace('model.', '')
115
+ new_state_dict[new_key] = val
116
+
117
+ return new_state_dict
118
+
119
+ class LPIPS(nn.Module):
120
+ r"""Creates a criterion that measures
121
+ Learned Perceptual Image Patch Similarity (LPIPS).
122
+ Arguments:
123
+ net_type (str): the network type to compare the features:
124
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
125
+ version (str): the version of LPIPS. Default: 0.1.
126
+ """
127
+ def __init__(self, net_type: str = 'vgg', version: str = '0.1'):
128
+
129
+ assert version in ['0.1'], 'v0.1 is only supported now'
130
+
131
+ super(LPIPS, self).__init__()
132
+
133
+ # pretrained network
134
+ self.net = get_network(net_type).to("cuda")
135
+
136
+ # linear layers
137
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
138
+ self.lin.load_state_dict(get_state_dict(net_type, version))
139
+
140
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
141
+ feat_x, feat_y = self.net(x), self.net(y)
142
+
143
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
144
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
145
+
146
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
147
+
src/metric_util.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+
5
+ from src.lpips import LPIPS
6
+ import torch.nn as nn
7
+
8
+ dev = 'cuda'
9
+ to_tensor_transform = transforms.Compose([transforms.ToTensor()])
10
+ mse_loss = nn.MSELoss()
11
+
12
+ def calculate_l2_difference(image1, image2, device = 'cuda'):
13
+ if isinstance(image1, Image.Image):
14
+ image1 = to_tensor_transform(image1).to(device)
15
+ if isinstance(image2, Image.Image):
16
+ image2 = to_tensor_transform(image2).to(device)
17
+
18
+ mse = mse_loss(image1, image2).item()
19
+ return mse
20
+
21
+ def calculate_psnr(image1, image2, device = 'cuda'):
22
+ max_value = 1.0
23
+ if isinstance(image1, Image.Image):
24
+ image1 = to_tensor_transform(image1).to(device)
25
+ if isinstance(image2, Image.Image):
26
+ image2 = to_tensor_transform(image2).to(device)
27
+
28
+ mse = mse_loss(image1, image2)
29
+ psnr = 10 * torch.log10(max_value**2 / mse).item()
30
+ return psnr
31
+
32
+
33
+ loss_fn = LPIPS(net_type='vgg').to(dev).eval()
34
+
35
+ def calculate_lpips(image1, image2, device = 'cuda'):
36
+ if isinstance(image1, Image.Image):
37
+ image1 = to_tensor_transform(image1).to(device)
38
+ if isinstance(image2, Image.Image):
39
+ image2 = to_tensor_transform(image2).to(device)
40
+
41
+ loss = loss_fn(image1, image2).item()
42
+ return loss
43
+
44
+ def calculate_metrics(image1, image2, device = 'cuda', size=(512, 512)):
45
+ if isinstance(image1, Image.Image):
46
+ image1 = image1.resize(size)
47
+ image1 = to_tensor_transform(image1).to(device)
48
+ if isinstance(image2, Image.Image):
49
+ image2 = image2.resize(size)
50
+ image2 = to_tensor_transform(image2).to(device)
51
+
52
+ l2 = calculate_l2_difference(image1, image2, device)
53
+ psnr = calculate_psnr(image1, image2, device)
54
+ lpips = calculate_lpips(image1, image2, device)
55
+ return {"l2": l2, "psnr": psnr, "lpips": lpips}
56
+
57
+ def get_empty_metrics():
58
+ return {"l2": 0, "psnr": 0, "lpips": 0}
59
+
60
+ def print_results(results):
61
+ print(f"Reconstruction Metrics: L2: {results['l2']},\t PSNR: {results['psnr']},\t LPIPS: {results['lpips']}")
src/sd_inversion_pipeline.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Plug&Play Feature Injection
2
+
3
+ import torch
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+ from random import randrange
6
+ import PIL
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from torch.cuda.amp import custom_bwd, custom_fwd
10
+ import torch.nn.functional as F
11
+
12
+
13
+ from diffusers import (
14
+ StableDiffusionPipeline,
15
+ StableDiffusionImg2ImgPipeline,
16
+ DDIMScheduler,
17
+ )
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+
20
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
21
+ StableDiffusionPipelineOutput,
22
+ retrieve_timesteps,
23
+ PipelineImageInput
24
+ )
25
+
26
+ from src.eunms import Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
27
+
28
+ def _backward_ddim(x_tm1, alpha_t, alpha_tm1, eps_xt):
29
+ """
30
+ let a = alpha_t, b = alpha_{t - 1}
31
+ We have a > b,
32
+ x_{t} - x_{t - 1} = sqrt(a) ((sqrt(1/b) - sqrt(1/a)) * x_{t-1} + (sqrt(1/a - 1) - sqrt(1/b - 1)) * eps_{t-1})
33
+ From https://arxiv.org/pdf/2105.05233.pdf, section F.
34
+ """
35
+
36
+ a, b = alpha_t, alpha_tm1
37
+ sa = a**0.5
38
+ sb = b**0.5
39
+
40
+ return sa * ((1 / sb) * x_tm1 + ((1 / a - 1) ** 0.5 - (1 / b - 1) ** 0.5) * eps_xt)
41
+
42
+
43
+ class SDDDIMPipeline(StableDiffusionImg2ImgPipeline):
44
+ # @torch.no_grad()
45
+ def __call__(
46
+ self,
47
+ prompt: Union[str, List[str]] = None,
48
+ image: PipelineImageInput = None,
49
+ strength: float = 1.0,
50
+ num_inversion_steps: Optional[int] = 50,
51
+ timesteps: List[int] = None,
52
+ guidance_scale: Optional[float] = 7.5,
53
+ negative_prompt: Optional[Union[str, List[str]]] = None,
54
+ num_images_per_prompt: Optional[int] = 1,
55
+ eta: Optional[float] = 0.0,
56
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
57
+ prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
59
+ ip_adapter_image: Optional[PipelineImageInput] = None,
60
+ output_type: Optional[str] = "pil",
61
+ return_dict: bool = True,
62
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
63
+ clip_skip: int = None,
64
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
65
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
66
+ opt_lr: float = 0.001,
67
+ opt_iters: int = 1,
68
+ opt_none_inference_steps: bool = False,
69
+ opt_loss_kl_lambda: float = 10.0,
70
+ num_inference_steps: int = 50,
71
+ num_aprox_steps: int = 100,
72
+ **kwargs,
73
+ ):
74
+ callback = kwargs.pop("callback", None)
75
+ callback_steps = kwargs.pop("callback_steps", None)
76
+
77
+ if callback is not None:
78
+ deprecate(
79
+ "callback",
80
+ "1.0.0",
81
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
82
+ )
83
+ if callback_steps is not None:
84
+ deprecate(
85
+ "callback_steps",
86
+ "1.0.0",
87
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
88
+ )
89
+
90
+ # 1. Check inputs. Raise error if not correct
91
+ self.check_inputs(
92
+ prompt,
93
+ strength,
94
+ callback_steps,
95
+ negative_prompt,
96
+ prompt_embeds,
97
+ negative_prompt_embeds,
98
+ callback_on_step_end_tensor_inputs,
99
+ )
100
+
101
+ self._guidance_scale = guidance_scale
102
+ self._clip_skip = clip_skip
103
+ self._cross_attention_kwargs = cross_attention_kwargs
104
+
105
+ # 2. Define call parameters
106
+ if prompt is not None and isinstance(prompt, str):
107
+ batch_size = 1
108
+ elif prompt is not None and isinstance(prompt, list):
109
+ batch_size = len(prompt)
110
+ else:
111
+ batch_size = prompt_embeds.shape[0]
112
+
113
+ device = self._execution_device
114
+
115
+ # 3. Encode input prompt
116
+ text_encoder_lora_scale = (
117
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
118
+ )
119
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
120
+ prompt,
121
+ device,
122
+ num_images_per_prompt,
123
+ self.do_classifier_free_guidance,
124
+ negative_prompt,
125
+ prompt_embeds=prompt_embeds,
126
+ negative_prompt_embeds=negative_prompt_embeds,
127
+ lora_scale=text_encoder_lora_scale,
128
+ clip_skip=self.clip_skip,
129
+ )
130
+ # For classifier free guidance, we need to do two forward passes.
131
+ # Here we concatenate the unconditional and text embeddings into a single batch
132
+ # to avoid doing two forward passes
133
+ if self.do_classifier_free_guidance:
134
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
135
+
136
+ if ip_adapter_image is not None:
137
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
138
+ if self.do_classifier_free_guidance:
139
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
140
+
141
+ # 4. Preprocess image
142
+ image = self.image_processor.preprocess(image)
143
+
144
+ # 5. set timesteps
145
+ timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
146
+ timesteps, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength, device)
147
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
148
+ _, num_inference_steps = retrieve_timesteps(self.scheduler_inference, num_inference_steps, device, None)
149
+
150
+ # 6. Prepare latent variables
151
+ with torch.no_grad():
152
+ latents = self.prepare_latents(
153
+ image,
154
+ latent_timestep,
155
+ batch_size,
156
+ num_images_per_prompt,
157
+ prompt_embeds.dtype,
158
+ device,
159
+ generator,
160
+ )
161
+
162
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
163
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
164
+
165
+ # 7.1 Add image embeds for IP-Adapter
166
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
167
+
168
+ # 7.2 Optionally get Guidance Scale Embedding
169
+ timestep_cond = None
170
+ if self.unet.config.time_cond_proj_dim is not None:
171
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
172
+ timestep_cond = self.get_guidance_scale_embedding(
173
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
174
+ ).to(device=device, dtype=latents.dtype)
175
+
176
+ # 8. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ self._num_timesteps = len(timesteps)
179
+ prev_timestep = None
180
+ self.prev_z = torch.clone(latents)
181
+ self.prev_z4 = torch.clone(latents)
182
+ self.z_0 = torch.clone(latents)
183
+ g_cpu = torch.Generator().manual_seed(7865)
184
+ self.noise = randn_tensor(self.z_0.shape, generator=g_cpu, device=self.z_0.device, dtype=self.z_0.dtype)
185
+
186
+
187
+ all_latents = [latents.clone()]
188
+ with self.progress_bar(total=num_inversion_steps) as progress_bar:
189
+ for i, t in enumerate(reversed(timesteps)):
190
+
191
+ z_tp1 = self.inversion_step(latents,
192
+ t,
193
+ prompt_embeds,
194
+ added_cond_kwargs,
195
+ prev_timestep=prev_timestep,
196
+ num_aprox_steps=num_aprox_steps)
197
+
198
+ if t in self.scheduler_inference.timesteps:
199
+ z_tp1 = self.optimize_z_tp1(z_tp1,
200
+ latents,
201
+ t,
202
+ prompt_embeds,
203
+ added_cond_kwargs,
204
+ nom_opt_iters=opt_iters,
205
+ lr=opt_lr,
206
+ opt_loss_kl_lambda=opt_loss_kl_lambda)
207
+
208
+ prev_timestep = t
209
+ latents = z_tp1
210
+
211
+ all_latents.append(latents.clone())
212
+
213
+ if callback_on_step_end is not None:
214
+ callback_kwargs = {}
215
+ for k in callback_on_step_end_tensor_inputs:
216
+ callback_kwargs[k] = locals()[k]
217
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
218
+
219
+ latents = callback_outputs.pop("latents", latents)
220
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
221
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
222
+
223
+ # call the callback, if provided
224
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
225
+ progress_bar.update()
226
+ if callback is not None and i % callback_steps == 0:
227
+ step_idx = i // getattr(self.scheduler, "order", 1)
228
+ callback(step_idx, t, latents)
229
+
230
+ image = latents
231
+
232
+ # Offload all models
233
+ self.maybe_free_model_hooks()
234
+
235
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), all_latents
236
+
237
+ def noise_regularization(self, e_t, noise_pred_optimal):
238
+ for _outer in range(self.cfg.num_reg_steps):
239
+ if self.cfg.lambda_kl>0:
240
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
241
+ # l_kld = self.kl_divergence(_var)
242
+ l_kld = self.patchify_latents_kl_divergence(_var, noise_pred_optimal)
243
+ l_kld.backward()
244
+ _grad = _var.grad.detach()
245
+ _grad = torch.clip(_grad, -100, 100)
246
+ e_t = e_t - self.cfg.lambda_kl*_grad
247
+ if self.cfg.lambda_ac>0:
248
+ for _inner in range(self.cfg.num_ac_rolls):
249
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
250
+ l_ac = self.auto_corr_loss(_var)
251
+ l_ac.backward()
252
+ _grad = _var.grad.detach()/self.cfg.num_ac_rolls
253
+ e_t = e_t - self.cfg.lambda_ac*_grad
254
+ e_t = e_t.detach()
255
+
256
+ return e_t
257
+
258
+ def auto_corr_loss(self, x, random_shift=True):
259
+ B,C,H,W = x.shape
260
+ assert B==1
261
+ x = x.squeeze(0)
262
+ # x must be shape [C,H,W] now
263
+ reg_loss = 0.0
264
+ for ch_idx in range(x.shape[0]):
265
+ noise = x[ch_idx][None, None,:,:]
266
+ while True:
267
+ if random_shift: roll_amount = randrange(noise.shape[2]//2)
268
+ else: roll_amount = 1
269
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
270
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
271
+ if noise.shape[2] <= 8:
272
+ break
273
+ noise = F.avg_pool2d(noise, kernel_size=2)
274
+ return reg_loss
275
+
276
+ def kl_divergence(self, x):
277
+ _mu = x.mean()
278
+ _var = x.var()
279
+ return _var + _mu**2 - 1 - torch.log(_var+1e-7)
280
+
281
+ # @torch.no_grad()
282
+ def inversion_step(
283
+ self,
284
+ z_t: torch.tensor,
285
+ t: torch.tensor,
286
+ prompt_embeds,
287
+ added_cond_kwargs,
288
+ prev_timestep: Optional[torch.tensor] = None,
289
+ num_aprox_steps: int = 100
290
+ ) -> torch.tensor:
291
+ extra_step_kwargs = {}
292
+
293
+ avg_range = self.cfg.gradient_averaging_first_step_range if t.item() < 250 else self.cfg.gradient_averaging_step_range
294
+
295
+ # When doing more then one approximation step in the first step it adds artifacts
296
+ if t.item() < 250:
297
+ num_aprox_steps = min(self.cfg.max_num_aprox_steps_first_step, num_aprox_steps)
298
+
299
+ approximated_z_tp1 = z_t.clone()
300
+ nosie_pred_avg = None
301
+
302
+ if self.cfg.num_reg_steps > 0:
303
+ z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
304
+ latent_model_input = torch.cat([z_tp1_forward] * 2) if self.do_classifier_free_guidance else z_tp1_forward
305
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
306
+
307
+ with torch.no_grad():
308
+ # predict the noise residual
309
+ noise_pred_optimal = self.unet(
310
+ latent_model_input,
311
+ t,
312
+ encoder_hidden_states=prompt_embeds,
313
+ timestep_cond=None,
314
+ cross_attention_kwargs=self.cross_attention_kwargs,
315
+ added_cond_kwargs=added_cond_kwargs,
316
+ return_dict=False,
317
+ )[0].detach()
318
+ else:
319
+ noise_pred_optimal = None
320
+
321
+ for i in range(num_aprox_steps + 1):
322
+ latent_model_input = torch.cat([approximated_z_tp1] * 2) if self.do_classifier_free_guidance else approximated_z_tp1
323
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
324
+
325
+ with torch.no_grad():
326
+ # predict the noise residual
327
+ noise_pred = self.unet(
328
+ latent_model_input,
329
+ t,
330
+ encoder_hidden_states=prompt_embeds,
331
+ timestep_cond=None,
332
+ cross_attention_kwargs=self.cross_attention_kwargs,
333
+ added_cond_kwargs=added_cond_kwargs,
334
+ return_dict=False,
335
+ )[0]
336
+
337
+ if i >= avg_range[0] and i < avg_range[1]:
338
+ j = i - avg_range[0]
339
+ if nosie_pred_avg is None:
340
+ nosie_pred_avg = noise_pred.clone()
341
+ else:
342
+ nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)
343
+ if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.EACH_ITER:
344
+ noise_pred = nosie_pred_avg.clone()
345
+
346
+ # perform guidance
347
+ if self.do_classifier_free_guidance:
348
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
349
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
350
+
351
+ if i >= avg_range[0] or (self.cfg.gradient_averaging_type == Gradient_Averaging_Type.NONE and i > 0):
352
+ noise_pred = self.noise_regularization(noise_pred, noise_pred_optimal)
353
+
354
+ if self.cfg.scheduler_type == Scheduler_Type.EULER:
355
+ approximated_z_tp1 = self.scheduler.inv_step(noise_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
356
+ else:
357
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
358
+ alpha_prod_t_prev = (
359
+ self.scheduler.alphas_cumprod[prev_timestep]
360
+ if prev_timestep is not None
361
+ else self.scheduler.final_alpha_cumprod
362
+ )
363
+ approximated_z_tp1 = _backward_ddim(
364
+ x_tm1=z_t,
365
+ alpha_t=alpha_prod_t,
366
+ alpha_tm1=alpha_prod_t_prev,
367
+ eps_xt=noise_pred,
368
+ )
369
+
370
+ if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.ON_END and nosie_pred_avg is not None:
371
+
372
+ nosie_pred_avg = self.noise_regularization(nosie_pred_avg, noise_pred_optimal)
373
+ if self.cfg.scheduler_type == Scheduler_Type.EULER:
374
+ approximated_z_tp1 = self.scheduler.inv_step(nosie_pred_avg, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
375
+ else:
376
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
377
+ alpha_prod_t_prev = (
378
+ self.scheduler.alphas_cumprod[prev_timestep]
379
+ if prev_timestep is not None
380
+ else self.scheduler.final_alpha_cumprod
381
+ )
382
+ approximated_z_tp1 = _backward_ddim(
383
+ x_tm1=z_t,
384
+ alpha_t=alpha_prod_t,
385
+ alpha_tm1=alpha_prod_t_prev,
386
+ eps_xt=nosie_pred_avg,
387
+ )
388
+
389
+ if self.cfg.update_epsilon_type != Epsilon_Update_Type.NONE:
390
+ latent_model_input = torch.cat([approximated_z_tp1] * 2) if self.do_classifier_free_guidance else approximated_z_tp1
391
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
392
+
393
+ with torch.no_grad():
394
+ # predict the noise residual
395
+ noise_pred = self.unet(
396
+ latent_model_input,
397
+ t,
398
+ encoder_hidden_states=prompt_embeds,
399
+ timestep_cond=None,
400
+ cross_attention_kwargs=self.cross_attention_kwargs,
401
+ added_cond_kwargs=added_cond_kwargs,
402
+ return_dict=False,
403
+ )[0]
404
+
405
+ # perform guidance
406
+ if self.do_classifier_free_guidance:
407
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
408
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
409
+
410
+ self.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, update_epsilon_type=self.cfg.update_epsilon_type)
411
+
412
+ return approximated_z_tp1
413
+
414
+ def detach_before_opt(self, z_tp1, t, prompt_embeds, added_cond_kwargs):
415
+ z_tp1 = z_tp1.detach()
416
+ t = t.detach()
417
+ prompt_embeds = prompt_embeds.detach()
418
+ return z_tp1, t, prompt_embeds, added_cond_kwargs
419
+
420
+ def opt_z_tp1_single_step(
421
+ self,
422
+ z_tp1,
423
+ z_t,
424
+ t,
425
+ prompt_embeds,
426
+ added_cond_kwargs,
427
+ lr=0.001,
428
+ opt_loss_kl_lambda=10.0,
429
+ ):
430
+ l1_loss = torch.nn.L1Loss(reduction='sum')
431
+ mse = torch.nn.MSELoss(reduction='sum')
432
+ extra_step_kwargs = {}
433
+
434
+ self.unet.requires_grad_(False)
435
+ z_tp1, t, prompt_embeds, added_cond_kwargs = self.detach_before_opt(z_tp1, t, prompt_embeds, added_cond_kwargs)
436
+
437
+ z_tp1 = torch.nn.Parameter(z_tp1, requires_grad=True)
438
+ optimizer = torch.optim.SGD([z_tp1], lr=lr, momentum=0.9)
439
+
440
+ optimizer.zero_grad()
441
+ self.unet.zero_grad()
442
+ latent_model_input = torch.cat([z_tp1] * 2) if self.do_classifier_free_guidance else z_tp1
443
+ latent_model_input = self.scheduler_inference.scale_model_input(latent_model_input, t)
444
+
445
+ noise_pred = self.unet(
446
+ latent_model_input,
447
+ t,
448
+ encoder_hidden_states=prompt_embeds,
449
+ timestep_cond=None,
450
+ cross_attention_kwargs=self.cross_attention_kwargs,
451
+ added_cond_kwargs=added_cond_kwargs,
452
+ return_dict=False,
453
+ )[0]
454
+
455
+ # perform guidance
456
+ if self.do_classifier_free_guidance:
457
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
458
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
459
+
460
+ # # compute the previous noisy sample x_t -> x_t-1
461
+ z_t_hat = self.scheduler_inference.step(noise_pred, t, z_tp1, **extra_step_kwargs, return_dict=False)[0]
462
+
463
+ direct_loss = 0.5 * mse(z_t_hat, z_t.detach()) + 0.5 * l1_loss(z_t_hat, z_t.detach())
464
+ kl_loss = torch.tensor([0]).to(z_t.device)
465
+ loss = 1.0 * direct_loss + opt_loss_kl_lambda * kl_loss
466
+
467
+ loss.backward()
468
+ optimizer.step()
469
+ print(f't: {t}\t total_loss: {format(loss.item(), ".3f")}\t\t direct_loss: {format(direct_loss.item(), ".3f")}\t\t kl_loss: {format(kl_loss.item(), ".3f")}')
470
+
471
+ return z_tp1.detach()
472
+
473
+ def optimize_z_tp1(
474
+ self,
475
+ z_tp1,
476
+ z_t,
477
+ t,
478
+ prompt_embeds,
479
+ added_cond_kwargs,
480
+ nom_opt_iters=1,
481
+ lr=0.001,
482
+ opt_loss_kl_lambda=10.0,
483
+ ):
484
+ l1_loss = torch.nn.L1Loss(reduction='sum')
485
+ mse = torch.nn.MSELoss(reduction='sum')
486
+ extra_step_kwargs = {}
487
+
488
+ self.unet.requires_grad_(False)
489
+ z_tp1, t, prompt_embeds, added_cond_kwargs = self.detach_before_opt(z_tp1, t, prompt_embeds, added_cond_kwargs)
490
+
491
+ z_tp1 = torch.nn.Parameter(z_tp1, requires_grad=True)
492
+ optimizer = torch.optim.SGD([z_tp1], lr=lr, momentum=0.9)
493
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, verbose=True, patience=5, cooldown=3)
494
+ max_loss = 99999999999999
495
+
496
+ z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
497
+ z_tp1_best = None
498
+ for i in range(nom_opt_iters):
499
+ optimizer.zero_grad()
500
+ self.unet.zero_grad()
501
+ latent_model_input = torch.cat([z_tp1] * 2) if self.do_classifier_free_guidance else z_tp1
502
+ latent_model_input = self.scheduler_inference.scale_model_input(latent_model_input, t)
503
+
504
+ noise_pred = self.unet(
505
+ latent_model_input,
506
+ t,
507
+ encoder_hidden_states=prompt_embeds,
508
+ timestep_cond=None,
509
+ cross_attention_kwargs=self.cross_attention_kwargs,
510
+ added_cond_kwargs=added_cond_kwargs,
511
+ return_dict=False,
512
+ )[0]
513
+
514
+ # perform guidance
515
+ if self.do_classifier_free_guidance:
516
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
517
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
518
+
519
+ # # compute the previous noisy sample x_t -> x_t-1
520
+ z_t_hat = self.scheduler_inference.step(noise_pred, t, z_tp1, **extra_step_kwargs, return_dict=False)[0]
521
+
522
+ direct_loss = 0.5 * mse(z_t_hat, z_t.detach()) + 0.5 * l1_loss(z_t_hat, z_t.detach())
523
+ kl_loss = self.patchify_latents_kl_divergence(z_tp1, z_tp1_forward)
524
+ loss = 1.0 * direct_loss + opt_loss_kl_lambda * kl_loss
525
+
526
+ loss.backward()
527
+ best = False
528
+ if loss < max_loss:
529
+ max_loss = loss
530
+ z_tp1_best = torch.clone(z_tp1)
531
+ best = True
532
+ lr_scheduler.step(loss)
533
+ if optimizer.param_groups[0]['lr'] < 9e-06:
534
+ break
535
+ optimizer.step()
536
+ print(f't: {t}\t\t iter: {i}\t total_loss: {format(loss.item(), ".3f")}\t\t direct_loss: {format(direct_loss.item(), ".3f")}\t\t kl_loss: {format(kl_loss.item(), ".3f")}\t\t best: {best}')
537
+
538
+ if z_tp1_best is not None:
539
+ z_tp1 = z_tp1_best
540
+
541
+ self.prev_z4 = torch.clone(z_tp1)
542
+
543
+ return z_tp1.detach()
544
+
545
+ def opt_inv(self,
546
+ z_t,
547
+ t,
548
+ prompt_embeds,
549
+ added_cond_kwargs,
550
+ prev_timestep,
551
+ nom_opt_iters=1,
552
+ lr=0.001,
553
+ opt_none_inference_steps=False,
554
+ opt_loss_kl_lambda=10.0,
555
+ num_aprox_steps=100):
556
+
557
+ z_tp1 = self.inversion_step(z_t, t, prompt_embeds, added_cond_kwargs, num_aprox_steps=num_aprox_steps)
558
+
559
+ if t in self.scheduler_inference.timesteps:
560
+ z_tp1 = self.optimize_z_tp1(z_tp1, z_t, t, prompt_embeds, added_cond_kwargs, nom_opt_iters=nom_opt_iters, lr=lr, opt_loss_kl_lambda=opt_loss_kl_lambda)
561
+
562
+ return z_tp1
563
+
564
+ def latent2image(self, latents):
565
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
566
+
567
+ if needs_upcasting:
568
+ self.upcast_vae()
569
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
570
+
571
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
572
+
573
+ # cast back to fp16 if needed
574
+ # if needs_upcasting:
575
+ # self.vae.to(dtype=torch.float16)
576
+
577
+ return image
578
+
579
+ def patchify_latents_kl_divergence(self, x0, x1):
580
+ # devide x0 and x1 into patches (4x64x64) -> (4x4x4)
581
+ PATCH_SIZE = 4
582
+ NUM_CHANNELS = 4
583
+
584
+ def patchify_tensor(input_tensor):
585
+ patches = input_tensor.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE)
586
+ patches = patches.contiguous().view(-1, NUM_CHANNELS, PATCH_SIZE, PATCH_SIZE)
587
+ return patches
588
+
589
+ x0 = patchify_tensor(x0)
590
+ x1 = patchify_tensor(x1)
591
+
592
+ kl = self.latents_kl_divergence(x0, x1).sum()
593
+ # for i in range(x0.shape[0]):
594
+ # kl += self.latents_kl_divergence(x0[i], x1[i])
595
+ return kl
596
+
597
+
598
+ def latents_kl_divergence(self, x0, x1):
599
+ EPSILON = 1e-6
600
+
601
+ #{\displaystyle D_{\text{KL}}\left({\mathcal {N}}_{0}\parallel {\mathcal {N}}_{1}\right)={\frac {1}{2}}\left(\operatorname {tr} \left(\Sigma _{1}^{-1}\Sigma _{0}\right)-k+\left(\mu _{1}-\mu _{0}\right)^{\mathsf {T}}\Sigma _{1}^{-1}\left(\mu _{1}-\mu _{0}\right)+\ln \left({\frac {\det \Sigma _{1}}{\det \Sigma _{0}}}\right)\right).}
602
+ x0 = x0.view(x0.shape[0], x0.shape[1], -1)
603
+ x1 = x1.view(x1.shape[0], x1.shape[1], -1)
604
+ mu0 = x0.mean(dim=-1)
605
+ mu1 = x1.mean(dim=-1)
606
+ var0 = x0.var(dim=-1)
607
+ var1 = x1.var(dim=-1)
608
+ kl = torch.log((var1 + EPSILON) / (var0 + EPSILON)) + (var0 + (mu0 - mu1)**2) / (var1 + EPSILON) - 1
609
+ kl = torch.abs(kl).sum(dim=-1)
610
+ # kl = torch.linalg.norm(mu0 - mu1) + torch.linalg.norm(var0 - var1)
611
+ # kl *= 1000
612
+ # sigma0 = torch.cov(x0)
613
+ # sigma1 = torch.cov(x1)
614
+ # inv_sigma1 = torch.inverse(sigma1.to(dtype=torch.float64)).to(dtype=x0.dtype)
615
+ # k = x0.shape[1]
616
+ # kl = 0.5 * (torch.trace(inv_sigma1 @ sigma0) - k + (mu1 - mu0).T @ inv_sigma1 @ (mu1 - mu0) + torch.log(torch.det(sigma1) / torch.det(sigma0)))
617
+ return kl
618
+
619
+
620
+ class SpecifyGradient(torch.autograd.Function):
621
+ @staticmethod
622
+ @custom_fwd
623
+ def forward(ctx, input_tensor, gt_grad):
624
+ ctx.save_for_backward(gt_grad)
625
+
626
+ # dummy loss value
627
+ return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype)
628
+
629
+ @staticmethod
630
+ @custom_bwd
631
+ def backward(ctx, grad):
632
+ gt_grad, = ctx.saved_tensors
633
+ batch_size = len(gt_grad)
634
+ return gt_grad / batch_size, None
src/sdxl_inversion_pipeline.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Plug&Play Feature Injection
2
+
3
+ import torch
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+ from random import randrange
6
+ import PIL
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from torch.cuda.amp import custom_bwd, custom_fwd
10
+ import torch.nn.functional as F
11
+
12
+
13
+ from diffusers import (
14
+ StableDiffusionXLPipeline,
15
+ StableDiffusionXLImg2ImgPipeline,
16
+ DDIMScheduler,
17
+ )
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+
20
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
21
+ rescale_noise_cfg,
22
+ StableDiffusionXLPipelineOutput,
23
+ retrieve_timesteps,
24
+ PipelineImageInput
25
+ )
26
+
27
+ from src.eunms import Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
28
+ from src.inversion_utils import noise_regularization
29
+
30
+ def _backward_ddim(x_tm1, alpha_t, alpha_tm1, eps_xt):
31
+ """
32
+ let a = alpha_t, b = alpha_{t - 1}
33
+ We have a > b,
34
+ x_{t} - x_{t - 1} = sqrt(a) ((sqrt(1/b) - sqrt(1/a)) * x_{t-1} + (sqrt(1/a - 1) - sqrt(1/b - 1)) * eps_{t-1})
35
+ From https://arxiv.org/pdf/2105.05233.pdf, section F.
36
+ """
37
+
38
+ a, b = alpha_t, alpha_tm1
39
+ sa = a**0.5
40
+ sb = b**0.5
41
+
42
+ return sa * ((1 / sb) * x_tm1 + ((1 / a - 1) ** 0.5 - (1 / b - 1) ** 0.5) * eps_xt)
43
+
44
+
45
+ class SDXLDDIMPipeline(StableDiffusionXLImg2ImgPipeline):
46
+ # @torch.no_grad()
47
+ def __call__(
48
+ self,
49
+ prompt: Union[str, List[str]] = None,
50
+ prompt_2: Optional[Union[str, List[str]]] = None,
51
+ image: PipelineImageInput = None,
52
+ strength: float = 0.3,
53
+ num_inversion_steps: int = 50,
54
+ timesteps: List[int] = None,
55
+ denoising_start: Optional[float] = None,
56
+ denoising_end: Optional[float] = None,
57
+ guidance_scale: float = 1.0,
58
+ negative_prompt: Optional[Union[str, List[str]]] = None,
59
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
60
+ num_images_per_prompt: Optional[int] = 1,
61
+ eta: float = 0.0,
62
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
63
+ latents: Optional[torch.FloatTensor] = None,
64
+ prompt_embeds: Optional[torch.FloatTensor] = None,
65
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
66
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
67
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
68
+ ip_adapter_image: Optional[PipelineImageInput] = None,
69
+ output_type: Optional[str] = "pil",
70
+ return_dict: bool = True,
71
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
72
+ guidance_rescale: float = 0.0,
73
+ original_size: Tuple[int, int] = None,
74
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
75
+ target_size: Tuple[int, int] = None,
76
+ negative_original_size: Optional[Tuple[int, int]] = None,
77
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
78
+ negative_target_size: Optional[Tuple[int, int]] = None,
79
+ aesthetic_score: float = 6.0,
80
+ negative_aesthetic_score: float = 2.5,
81
+ clip_skip: Optional[int] = None,
82
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
83
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
84
+ opt_lr: float = 0.001,
85
+ opt_iters: int = 1,
86
+ opt_none_inference_steps: bool = False,
87
+ opt_loss_kl_lambda: float = 10.0,
88
+ num_inference_steps: int = 50,
89
+ num_aprox_steps: int = 100,
90
+ **kwargs,
91
+ ):
92
+ callback = kwargs.pop("callback", None)
93
+ callback_steps = kwargs.pop("callback_steps", None)
94
+
95
+ if callback is not None:
96
+ deprecate(
97
+ "callback",
98
+ "1.0.0",
99
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
100
+ )
101
+ if callback_steps is not None:
102
+ deprecate(
103
+ "callback_steps",
104
+ "1.0.0",
105
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
106
+ )
107
+
108
+ # 1. Check inputs. Raise error if not correct
109
+ self.check_inputs(
110
+ prompt,
111
+ prompt_2,
112
+ strength,
113
+ num_inversion_steps,
114
+ callback_steps,
115
+ negative_prompt,
116
+ negative_prompt_2,
117
+ prompt_embeds,
118
+ negative_prompt_embeds,
119
+ callback_on_step_end_tensor_inputs,
120
+ )
121
+
122
+ denoising_start_fr = 1.0 - denoising_start
123
+ denoising_start = 0.0 if self.cfg.noise_friendly_inversion else denoising_start
124
+
125
+ self._guidance_scale = guidance_scale
126
+ self._guidance_rescale = guidance_rescale
127
+ self._clip_skip = clip_skip
128
+ self._cross_attention_kwargs = cross_attention_kwargs
129
+ self._denoising_end = denoising_end
130
+ self._denoising_start = denoising_start
131
+
132
+ # 2. Define call parameters
133
+ if prompt is not None and isinstance(prompt, str):
134
+ batch_size = 1
135
+ elif prompt is not None and isinstance(prompt, list):
136
+ batch_size = len(prompt)
137
+ else:
138
+ batch_size = prompt_embeds.shape[0]
139
+
140
+ device = self._execution_device
141
+
142
+ # 3. Encode input prompt
143
+ text_encoder_lora_scale = (
144
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
145
+ )
146
+ (
147
+ prompt_embeds,
148
+ negative_prompt_embeds,
149
+ pooled_prompt_embeds,
150
+ negative_pooled_prompt_embeds,
151
+ ) = self.encode_prompt(
152
+ prompt=prompt,
153
+ prompt_2=prompt_2,
154
+ device=device,
155
+ num_images_per_prompt=num_images_per_prompt,
156
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
157
+ negative_prompt=negative_prompt,
158
+ negative_prompt_2=negative_prompt_2,
159
+ prompt_embeds=prompt_embeds,
160
+ negative_prompt_embeds=negative_prompt_embeds,
161
+ pooled_prompt_embeds=pooled_prompt_embeds,
162
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
163
+ lora_scale=text_encoder_lora_scale,
164
+ clip_skip=self.clip_skip,
165
+ )
166
+
167
+ # 4. Preprocess image
168
+ image = self.image_processor.preprocess(image)
169
+
170
+ # 5. Prepare timesteps
171
+ def denoising_value_valid(dnv):
172
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
173
+
174
+ timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
175
+ timesteps_num_inference_steps, num_inference_steps = retrieve_timesteps(self.scheduler_inference, num_inference_steps, device, None)
176
+
177
+ timesteps, num_inversion_steps = self.get_timesteps(
178
+ num_inversion_steps,
179
+ strength,
180
+ device,
181
+ denoising_start=self.denoising_start if denoising_value_valid else None,
182
+ )
183
+ # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
184
+
185
+ # add_noise = True if self.denoising_start is None else False
186
+ # 6. Prepare latent variables
187
+ with torch.no_grad():
188
+ latents = self.prepare_latents(
189
+ image,
190
+ None,
191
+ batch_size,
192
+ num_images_per_prompt,
193
+ prompt_embeds.dtype,
194
+ device,
195
+ generator,
196
+ False,
197
+ )
198
+ # 7. Prepare extra step kwargs.
199
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
200
+
201
+ height, width = latents.shape[-2:]
202
+ height = height * self.vae_scale_factor
203
+ width = width * self.vae_scale_factor
204
+
205
+ original_size = original_size or (height, width)
206
+ target_size = target_size or (height, width)
207
+
208
+ # 8. Prepare added time ids & embeddings
209
+ if negative_original_size is None:
210
+ negative_original_size = original_size
211
+ if negative_target_size is None:
212
+ negative_target_size = target_size
213
+
214
+ add_text_embeds = pooled_prompt_embeds
215
+ if self.text_encoder_2 is None:
216
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
217
+ else:
218
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
219
+
220
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
221
+ original_size,
222
+ crops_coords_top_left,
223
+ target_size,
224
+ aesthetic_score,
225
+ negative_aesthetic_score,
226
+ negative_original_size,
227
+ negative_crops_coords_top_left,
228
+ negative_target_size,
229
+ dtype=prompt_embeds.dtype,
230
+ text_encoder_projection_dim=text_encoder_projection_dim,
231
+ )
232
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
233
+
234
+ if self.do_classifier_free_guidance:
235
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
236
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
237
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
238
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
239
+
240
+ prompt_embeds = prompt_embeds.to(device)
241
+ add_text_embeds = add_text_embeds.to(device)
242
+ add_time_ids = add_time_ids.to(device)
243
+
244
+ if ip_adapter_image is not None:
245
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
246
+ if self.do_classifier_free_guidance:
247
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
248
+ image_embeds = image_embeds.to(device)
249
+
250
+ # 9. Denoising loop
251
+ num_warmup_steps = max(len(timesteps) - num_inversion_steps * self.scheduler.order, 0)
252
+ prev_timestep = None
253
+
254
+ self._num_timesteps = len(timesteps)
255
+ self.prev_z = torch.clone(latents)
256
+ self.prev_z4 = torch.clone(latents)
257
+ self.z_0 = torch.clone(latents)
258
+ g_cpu = torch.Generator().manual_seed(7865)
259
+ self.noise = randn_tensor(self.z_0.shape, generator=g_cpu, device=self.z_0.device, dtype=self.z_0.dtype)
260
+
261
+ # Friendly inversion params
262
+ timesteps_for = timesteps if self.cfg.noise_friendly_inversion else reversed(timesteps)
263
+ noise = randn_tensor(latents.shape, generator=g_cpu, device=latents.device, dtype=latents.dtype)
264
+ latents = self.scheduler.add_noise(self.z_0, noise, timesteps_for[0].view((1))).detach() if self.cfg.noise_friendly_inversion else latents
265
+ z_T = latents.clone()
266
+
267
+ all_latents = [latents.clone()]
268
+ with self.progress_bar(total=num_inversion_steps) as progress_bar:
269
+ for i, t in enumerate(timesteps_for):
270
+
271
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
272
+ if ip_adapter_image is not None:
273
+ added_cond_kwargs["image_embeds"] = image_embeds
274
+
275
+ z_tp1 = self.inversion_step(latents,
276
+ t,
277
+ prompt_embeds,
278
+ added_cond_kwargs,
279
+ prev_timestep=prev_timestep,
280
+ num_aprox_steps=num_aprox_steps)
281
+
282
+ prev_timestep = t
283
+ latents = z_tp1
284
+
285
+ all_latents.append(latents.clone())
286
+
287
+ if self.cfg.noise_friendly_inversion and t.item() > 1000 * denoising_start_fr:
288
+ z_T = latents.clone()
289
+
290
+ if callback_on_step_end is not None:
291
+ callback_kwargs = {}
292
+ for k in callback_on_step_end_tensor_inputs:
293
+ callback_kwargs[k] = locals()[k]
294
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
295
+
296
+ latents = callback_outputs.pop("latents", latents)
297
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
298
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
299
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
300
+ negative_pooled_prompt_embeds = callback_outputs.pop(
301
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
302
+ )
303
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
304
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
305
+
306
+ # call the callback, if provided
307
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
308
+ progress_bar.update()
309
+ if callback is not None and i % callback_steps == 0:
310
+ step_idx = i // getattr(self.scheduler, "order", 1)
311
+ callback(step_idx, t, latents)
312
+
313
+ if self.cfg.noise_friendly_inversion:
314
+ latents = z_T
315
+
316
+ image = latents
317
+
318
+ # Offload all models
319
+ self.maybe_free_model_hooks()
320
+
321
+ return StableDiffusionXLPipelineOutput(images=image), all_latents
322
+
323
+ # @torch.no_grad()
324
+ def inversion_step(
325
+ self,
326
+ z_t: torch.tensor,
327
+ t: torch.tensor,
328
+ prompt_embeds,
329
+ added_cond_kwargs,
330
+ prev_timestep: Optional[torch.tensor] = None,
331
+ num_aprox_steps: int = 100
332
+ ) -> torch.tensor:
333
+ extra_step_kwargs = {}
334
+
335
+ avg_range = self.cfg.gradient_averaging_first_step_range if t.item() < 250 else self.cfg.gradient_averaging_step_range
336
+ num_aprox_steps = min(self.cfg.max_num_aprox_steps_first_step, num_aprox_steps) if t.item() < 250 else num_aprox_steps
337
+
338
+ nosie_pred_avg = None
339
+ z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
340
+ noise_pred_optimal = None
341
+
342
+ approximated_z_tp1 = z_t.clone()
343
+ for i in range(num_aprox_steps + 1):
344
+
345
+ with torch.no_grad():
346
+ if self.cfg.num_reg_steps > 0 and i == 0:
347
+ approximated_z_tp1 = torch.cat([z_tp1_forward, approximated_z_tp1])
348
+ prompt_embeds_in = torch.cat([prompt_embeds, prompt_embeds])
349
+ added_cond_kwargs_in = {}
350
+ added_cond_kwargs_in['text_embeds'] = torch.cat([added_cond_kwargs['text_embeds'], added_cond_kwargs['text_embeds']])
351
+ added_cond_kwargs_in['time_ids'] = torch.cat([added_cond_kwargs['time_ids'], added_cond_kwargs['time_ids']])
352
+ else:
353
+ prompt_embeds_in = prompt_embeds
354
+ added_cond_kwargs_in = added_cond_kwargs
355
+
356
+ noise_pred = self.unet_pass(approximated_z_tp1, t, prompt_embeds_in, added_cond_kwargs_in)
357
+
358
+ if self.cfg.num_reg_steps > 0 and i == 0:
359
+ noise_pred_optimal, noise_pred = noise_pred.chunk(2)
360
+ noise_pred_optimal = noise_pred_optimal.detach()
361
+
362
+ # perform guidance
363
+ if self.do_classifier_free_guidance:
364
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
365
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
366
+
367
+ # Calculate average noise
368
+ if i >= avg_range[0] and i < avg_range[1]:
369
+ j = i - avg_range[0]
370
+ if nosie_pred_avg is None:
371
+ nosie_pred_avg = noise_pred.clone()
372
+ else:
373
+ nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)
374
+
375
+ if i >= avg_range[0] or (self.cfg.gradient_averaging_type == Gradient_Averaging_Type.NONE and i > 0):
376
+ noise_pred = noise_regularization(noise_pred, noise_pred_optimal, lambda_kl=self.cfg.lambda_kl, lambda_ac=self.cfg.lambda_ac, num_reg_steps=self.cfg.num_reg_steps, num_ac_rolls=self.cfg.num_ac_rolls)
377
+
378
+ approximated_z_tp1 = self.backward_step(noise_pred, t, z_t, prev_timestep)
379
+
380
+ if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.ON_END and nosie_pred_avg is not None:
381
+
382
+ nosie_pred_avg = noise_regularization(nosie_pred_avg, noise_pred_optimal, lambda_kl=self.cfg.lambda_kl, lambda_ac=self.cfg.lambda_ac, num_reg_steps=self.cfg.num_reg_steps, num_ac_rolls=self.cfg.num_ac_rolls)
383
+ approximated_z_tp1 = self.backward_step(nosie_pred_avg, t, z_t, prev_timestep)
384
+
385
+ if self.cfg.update_epsilon_type != Epsilon_Update_Type.NONE:
386
+ noise_pred = self.unet_pass(approximated_z_tp1, t, prompt_embeds, added_cond_kwargs)
387
+
388
+ # perform guidance
389
+ if self.do_classifier_free_guidance:
390
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
391
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
392
+
393
+ self.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, update_epsilon_type=self.cfg.update_epsilon_type)
394
+
395
+ return approximated_z_tp1
396
+
397
+ @torch.no_grad()
398
+ def unet_pass(self, z_t, t, prompt_embeds, added_cond_kwargs):
399
+ latent_model_input = torch.cat([z_t] * 2) if self.do_classifier_free_guidance else z_t
400
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
401
+ return self.unet(
402
+ latent_model_input,
403
+ t,
404
+ encoder_hidden_states=prompt_embeds,
405
+ timestep_cond=None,
406
+ cross_attention_kwargs=self.cross_attention_kwargs,
407
+ added_cond_kwargs=added_cond_kwargs,
408
+ return_dict=False,
409
+ )[0]
410
+
411
+ @torch.no_grad()
412
+ def backward_step(self, nosie_pred, t, z_t, prev_timestep):
413
+ extra_step_kwargs = {}
414
+ if self.cfg.scheduler_type == Scheduler_Type.EULER or self.cfg.scheduler_type == Scheduler_Type.LCM:
415
+ return self.scheduler.inv_step(nosie_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
416
+ else:
417
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
418
+ alpha_prod_t_prev = (
419
+ self.scheduler.alphas_cumprod[prev_timestep]
420
+ if prev_timestep is not None
421
+ else self.scheduler.final_alpha_cumprod
422
+ )
423
+ return _backward_ddim(
424
+ x_tm1=z_t,
425
+ alpha_t=alpha_prod_t,
426
+ alpha_tm1=alpha_prod_t_prev,
427
+ eps_xt=nosie_pred,
428
+ )
429
+
430
+
style.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+