Lawrence-cj commited on
Commit
19ac58b
1 Parent(s): 0f6fb46

update app.py

Browse files

1. add height-width control;
2. add resolution bin;
3. replace samples;
4. support SA-Solver

Files changed (3) hide show
  1. app.py +84 -35
  2. requirements.txt +1 -1
  3. sa_solver_diffusers.py +856 -0
app.py CHANGED
@@ -8,23 +8,25 @@ import uuid
8
 
9
  import gradio as gr
10
  import numpy as np
11
- import PIL.Image
12
  import torch
 
 
13
 
14
- from diffusers import AutoencoderKL, PixArtAlphaPipeline
15
 
16
  DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-alpha.github.io/master/static/images/logo.png)
17
  # PixArt-Alpha 1024px
18
  #### [PixArt-Alpha 1024px](https://github.com/PixArt-alpha/PixArt-alpha) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint.
19
  #### English prompts ONLY; 提示词仅限英文
20
- Don't want to queue? Try [Google Colab Demo](https://colab.research.google.com/drive/1jZ5UZXk7tcpTfVwnX33dDuefNMcnW9ME?usp=sharing). It's slower but still free.
 
21
  """
22
  if not torch.cuda.is_available():
23
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
24
 
25
  MAX_SEED = np.iinfo(np.int32).max
26
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
27
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
28
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
29
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
30
 
@@ -86,6 +88,9 @@ style_list = [
86
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
87
  STYLE_NAMES = list(styles.keys())
88
  DEFAULT_STYLE_NAME = "(No style)"
 
 
 
89
 
90
 
91
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
@@ -99,7 +104,6 @@ if torch.cuda.is_available():
99
  pipe = PixArtAlphaPipeline.from_pretrained(
100
  "PixArt-alpha/PixArt-XL-2-1024-MS",
101
  torch_dtype=torch.float16,
102
- variant="fp16",
103
  use_safetensors=True,
104
  )
105
 
@@ -113,9 +117,7 @@ if torch.cuda.is_available():
113
  pipe.text_encoder.to_bettertransformer()
114
 
115
  if USE_TORCH_COMPILE:
116
- pipe.transformer = torch.compile(
117
- pipe.transformer, mode="reduce-overhead", fullgraph=True
118
- )
119
  print("Model Compiled!")
120
 
121
 
@@ -139,42 +141,62 @@ def generate(
139
  seed: int = 0,
140
  width: int = 1024,
141
  height: int = 1024,
142
- guidance_scale: float = 4.5,
143
- num_inference_steps: int = 20,
 
 
 
144
  randomize_seed: bool = False,
 
145
  progress=gr.Progress(track_tqdm=True),
146
  ):
147
- seed = randomize_seed_fn(seed, randomize_seed)
148
  generator = torch.Generator().manual_seed(seed)
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  if not use_negative_prompt:
151
  negative_prompt = None # type: ignore
152
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
153
- image = pipe(
 
154
  prompt=prompt,
155
- negative_prompt=negative_prompt,
156
  width=width,
157
  height=height,
158
  guidance_scale=guidance_scale,
159
  num_inference_steps=num_inference_steps,
160
  generator=generator,
 
 
161
  output_type="pil",
162
- ).images[0]
163
 
164
- image_path = save_image(image)
165
- print(image_path)
166
- return [image_path], seed
167
 
168
 
169
  examples = [
170
  "A small cactus with a happy face in the Sahara desert.",
 
171
  "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
172
  "stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.",
173
- "3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background",
174
- "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background",
175
  "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
176
- "an astronaut sitting in a diner, eating fries, cinematic, analog film",
177
- "Albert Einstein in a surrealist Cyberpunk 2077 world, hyperrealistic",
 
 
178
  ]
179
 
180
  with gr.Blocks(css="style.css") as demo:
@@ -194,10 +216,19 @@ with gr.Blocks(css="style.css") as demo:
194
  container=False,
195
  )
196
  run_button = gr.Button("Run", scale=0)
197
- result = gr.Gallery(label="Result", columns=1, show_label=False)
198
  with gr.Accordion("Advanced options", open=False):
199
  with gr.Row():
200
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
 
 
 
 
 
 
 
 
 
201
  style_selection = gr.Radio(
202
  show_label=True,
203
  container=True,
@@ -220,7 +251,7 @@ with gr.Blocks(css="style.css") as demo:
220
  value=0,
221
  )
222
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
223
- with gr.Row(visible=False):
224
  width = gr.Slider(
225
  label="Width",
226
  minimum=256,
@@ -236,19 +267,34 @@ with gr.Blocks(css="style.css") as demo:
236
  value=1024,
237
  )
238
  with gr.Row():
239
- guidance_scale = gr.Slider(
240
- label="Guidance scale",
241
  minimum=1,
242
- maximum=20,
243
  step=0.1,
244
  value=4.5,
245
  )
246
- num_inference_steps = gr.Slider(
247
- label="Number of inference steps",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  minimum=10,
249
- maximum=100,
250
  step=1,
251
- value=20,
252
  )
253
 
254
  gr.Examples(
@@ -281,8 +327,11 @@ with gr.Blocks(css="style.css") as demo:
281
  seed,
282
  width,
283
  height,
284
- guidance_scale,
285
- num_inference_steps,
 
 
 
286
  randomize_seed,
287
  ],
288
  outputs=[result, seed],
@@ -290,5 +339,5 @@ with gr.Blocks(css="style.css") as demo:
290
  )
291
 
292
  if __name__ == "__main__":
293
- # demo.queue(max_size=20).launch()
294
- demo.launch(share=True)
 
8
 
9
  import gradio as gr
10
  import numpy as np
11
+ from PIL import Image
12
  import torch
13
+ from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
14
+ from sa_solver_diffusers import SASolverScheduler
15
 
 
16
 
17
  DESCRIPTION = """![Logo](https://raw.githubusercontent.com/PixArt-alpha/PixArt-alpha.github.io/master/static/images/logo.png)
18
  # PixArt-Alpha 1024px
19
  #### [PixArt-Alpha 1024px](https://github.com/PixArt-alpha/PixArt-alpha) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint.
20
  #### English prompts ONLY; 提示词仅限英文
21
+ ### <span style='color: red;'>You may change the DPM-Solver inference steps from 10 to 20, if you didn't get satisfied results.
22
+ <span style='color: green;'>Don't want to queue? Try [OpenXLab](https://openxlab.org.cn/apps/detail/PixArt-alpha/PixArt-alpha) or [Google Colab Demo](https://colab.research.google.com/drive/1jZ5UZXk7tcpTfVwnX33dDuefNMcnW9ME?usp=sharing).
23
  """
24
  if not torch.cuda.is_available():
25
+ DESCRIPTION += "\n<p>Running on CPU �� This demo does not work on CPU.</p>"
26
 
27
  MAX_SEED = np.iinfo(np.int32).max
28
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
29
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
30
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
31
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
32
 
 
88
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
89
  STYLE_NAMES = list(styles.keys())
90
  DEFAULT_STYLE_NAME = "(No style)"
91
+ SCHEDULE_NAME = ["DPM-Solver", "SA-Solver"]
92
+ DEFAULT_SCHEDULE_NAME = "DPM-Solver"
93
+ NUM_IMAGES_PER_PROMPT = 1
94
 
95
 
96
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
 
104
  pipe = PixArtAlphaPipeline.from_pretrained(
105
  "PixArt-alpha/PixArt-XL-2-1024-MS",
106
  torch_dtype=torch.float16,
 
107
  use_safetensors=True,
108
  )
109
 
 
117
  pipe.text_encoder.to_bettertransformer()
118
 
119
  if USE_TORCH_COMPILE:
120
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
 
 
121
  print("Model Compiled!")
122
 
123
 
 
141
  seed: int = 0,
142
  width: int = 1024,
143
  height: int = 1024,
144
+ schedule: str = 'DPM-Solver',
145
+ dpms_guidance_scale: float = 4.5,
146
+ sas_guidance_scale: float = 3,
147
+ dpms_inference_steps: int = 20,
148
+ sas_inference_steps: int = 25,
149
  randomize_seed: bool = False,
150
+ use_resolution_binning: bool = True,
151
  progress=gr.Progress(track_tqdm=True),
152
  ):
153
+ seed = int(randomize_seed_fn(seed, randomize_seed))
154
  generator = torch.Generator().manual_seed(seed)
155
 
156
+ if schedule == 'DPM-Solver':
157
+ if not isinstance(pipe.scheduler, DPMSolverMultistepScheduler):
158
+ pipe.scheduler = DPMSolverMultistepScheduler()
159
+ num_inference_steps = dpms_inference_steps
160
+ guidance_scale = dpms_guidance_scale
161
+ elif schedule == "SA-Solver":
162
+ if not isinstance(pipe.scheduler, SASolverScheduler):
163
+ pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config, algorithm_type='data_prediction', tau_func=lambda t: 1 if 200 <= t <= 800 else 0, predictor_order=2, corrector_order=2)
164
+ num_inference_steps = sas_inference_steps
165
+ guidance_scale = sas_guidance_scale
166
+ else:
167
+ raise ValueError(f"Unknown schedule: {schedule}")
168
+
169
  if not use_negative_prompt:
170
  negative_prompt = None # type: ignore
171
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
172
+
173
+ images = pipe(
174
  prompt=prompt,
 
175
  width=width,
176
  height=height,
177
  guidance_scale=guidance_scale,
178
  num_inference_steps=num_inference_steps,
179
  generator=generator,
180
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
181
+ use_resolution_binning=use_resolution_binning,
182
  output_type="pil",
183
+ ).images
184
 
185
+ image_paths = [save_image(img) for img in images]
186
+ print(image_paths)
187
+ return image_paths, seed
188
 
189
 
190
  examples = [
191
  "A small cactus with a happy face in the Sahara desert.",
192
+ "an astronaut sitting in a diner, eating fries, cinematic, analog film",
193
  "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
194
  "stars, water, brilliantly, gorgeous large scale scene, a little girl, in the style of dreamy realism, light gold and amber, blue and pink, brilliantly illuminated in the background.",
 
 
195
  "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
196
+ "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background",
197
+ "Spectacular Tiny World in the Transparent Jar On the Table, interior of the Great Hall, Elaborate, Carved Architecture, Anatomy, Symetrical, Geometric and Parameteric Details, Precision Flat line Details, Pattern, Dark fantasy, Dark errie mood and ineffably mysterious mood, Technical design, Intricate Ultra Detail, Ornate Detail, Stylized and Futuristic and Biomorphic Details, Architectural Concept, Low contrast Details, Cinematic Lighting, 8k, by moebius, Fullshot, Epic, Fullshot, Octane render, Unreal ,Photorealistic, Hyperrealism",
198
+ "anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur",
199
+ "The parametric hotel lobby is a sleek and modern space with plenty of natural light. The lobby is spacious and open with a variety of seating options. The front desk is a sleek white counter with a parametric design. The walls are a light blue color with parametric patterns. The floor is a light wood color with a parametric design. There are plenty of plants and flowers throughout the space. The overall effect is a calm and relaxing space. occlusion, moody, sunset, concept art, octane rendering, 8k, highly detailed, concept art, highly detailed, beautiful scenery, cinematic, beautiful light, hyperreal, octane render, hdr, long exposure, 8K, realistic, fog, moody, fire and explosions, smoke, 50mm f2.8",
200
  ]
201
 
202
  with gr.Blocks(css="style.css") as demo:
 
216
  container=False,
217
  )
218
  run_button = gr.Button("Run", scale=0)
219
+ result = gr.Gallery(label="Result", columns=NUM_IMAGES_PER_PROMPT, show_label=False)
220
  with gr.Accordion("Advanced options", open=False):
221
  with gr.Row():
222
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
223
+ schedule = gr.Radio(
224
+ show_label=True,
225
+ container=True,
226
+ interactive=True,
227
+ choices=SCHEDULE_NAME,
228
+ value=DEFAULT_SCHEDULE_NAME,
229
+ label="Sampler Schedule",
230
+ visible=True,
231
+ )
232
  style_selection = gr.Radio(
233
  show_label=True,
234
  container=True,
 
251
  value=0,
252
  )
253
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
254
+ with gr.Row(visible=True):
255
  width = gr.Slider(
256
  label="Width",
257
  minimum=256,
 
267
  value=1024,
268
  )
269
  with gr.Row():
270
+ dpms_guidance_scale = gr.Slider(
271
+ label="DPM-Solver Guidance scale",
272
  minimum=1,
273
+ maximum=10,
274
  step=0.1,
275
  value=4.5,
276
  )
277
+ dpms_inference_steps = gr.Slider(
278
+ label="DPM-Solver inference steps",
279
+ minimum=5,
280
+ maximum=40,
281
+ step=1,
282
+ value=10,
283
+ )
284
+ with gr.Row():
285
+ sas_guidance_scale = gr.Slider(
286
+ label="SA-Solver Guidance scale",
287
+ minimum=1,
288
+ maximum=10,
289
+ step=0.1,
290
+ value=3,
291
+ )
292
+ sas_inference_steps = gr.Slider(
293
+ label="SA-Solver inference steps",
294
  minimum=10,
295
+ maximum=40,
296
  step=1,
297
+ value=25,
298
  )
299
 
300
  gr.Examples(
 
327
  seed,
328
  width,
329
  height,
330
+ schedule,
331
+ dpms_guidance_scale,
332
+ sas_guidance_scale,
333
+ dpms_inference_steps,
334
+ sas_inference_steps,
335
  randomize_seed,
336
  ],
337
  outputs=[result, seed],
 
339
  )
340
 
341
  if __name__ == "__main__":
342
+ demo.queue(max_size=20).launch()
343
+ # demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=11900, debug=True)
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- diffusers>=0.22.0
2
  accelerate
3
  transformers
4
  gradio==4.1.1
 
1
+ git+https://github.com/huggingface/diffusers
2
  accelerate
3
  transformers
4
  gradio==4.1.1
sa_solver_diffusers.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ # DISCLAIMER: check https://arxiv.org/abs/2309.05019
14
+ # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
15
+
16
+ import math
17
+ from typing import List, Optional, Tuple, Union, Callable
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
+
26
+
27
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
28
+ def betas_for_alpha_bar(
29
+ num_diffusion_timesteps,
30
+ max_beta=0.999,
31
+ alpha_transform_type="cosine",
32
+ ):
33
+ """
34
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
35
+ (1-beta) over time from t = [0,1].
36
+
37
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
38
+ to that part of the diffusion process.
39
+
40
+
41
+ Args:
42
+ num_diffusion_timesteps (`int`): the number of betas to produce.
43
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
44
+ prevent singularities.
45
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
46
+ Choose from `cosine` or `exp`
47
+
48
+ Returns:
49
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
50
+ """
51
+ if alpha_transform_type == "cosine":
52
+
53
+ def alpha_bar_fn(t):
54
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
55
+
56
+ elif alpha_transform_type == "exp":
57
+
58
+ def alpha_bar_fn(t):
59
+ return math.exp(t * -12.0)
60
+
61
+ else:
62
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
63
+
64
+ betas = []
65
+ for i in range(num_diffusion_timesteps):
66
+ t1 = i / num_diffusion_timesteps
67
+ t2 = (i + 1) / num_diffusion_timesteps
68
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
69
+ return torch.tensor(betas, dtype=torch.float32)
70
+
71
+
72
+ class SASolverScheduler(SchedulerMixin, ConfigMixin):
73
+ """
74
+ `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs.
75
+
76
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
77
+ methods the library implements for all schedulers such as loading and saving.
78
+
79
+ Args:
80
+ num_train_timesteps (`int`, defaults to 1000):
81
+ The number of diffusion steps to train the model.
82
+ beta_start (`float`, defaults to 0.0001):
83
+ The starting `beta` value of inference.
84
+ beta_end (`float`, defaults to 0.02):
85
+ The final `beta` value.
86
+ beta_schedule (`str`, defaults to `"linear"`):
87
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
88
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
89
+ trained_betas (`np.ndarray`, *optional*):
90
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
91
+ predictor_order (`int`, defaults to 2):
92
+ The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided
93
+ sampling, and `predictor_order=3` for unconditional sampling.
94
+ corrector_order (`int`, defaults to 2):
95
+ The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided
96
+ sampling, and `corrector_order=3` for unconditional sampling.
97
+ predictor_corrector_mode (`str`, defaults to `PEC`):
98
+ The predictor-corrector mode can be `PEC` or 'PECE'. It is recommended to use `PEC` mode for fast
99
+ sampling, and `PECE` for high-quality sampling (PECE needs around twice model evaluations as PEC).
100
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
101
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
102
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
103
+ Video](https://imagen.research.google/video/paper.pdf) paper).
104
+ thresholding (`bool`, defaults to `False`):
105
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
106
+ as Stable Diffusion.
107
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
108
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
109
+ sample_max_value (`float`, defaults to 1.0):
110
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
111
+ `algorithm_type="dpmsolver++"`.
112
+ algorithm_type (`str`, defaults to `data_prediction`):
113
+ Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction`
114
+ with `solver_order=2` for guided sampling like in Stable Diffusion.
115
+ lower_order_final (`bool`, defaults to `True`):
116
+ Whether to use lower-order solvers in the final steps. Default = True.
117
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
118
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
119
+ the sigmas are determined according to a sequence of noise levels {σi}.
120
+ lambda_min_clipped (`float`, defaults to `-inf`):
121
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
122
+ cosine (`squaredcos_cap_v2`) noise schedule.
123
+ variance_type (`str`, *optional*):
124
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
125
+ contains the predicted Gaussian variance.
126
+ timestep_spacing (`str`, defaults to `"linspace"`):
127
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
128
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
129
+ steps_offset (`int`, defaults to 0):
130
+ An offset added to the inference steps. You can use a combination of `offset=1` and
131
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
132
+ Diffusion.
133
+ """
134
+
135
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
136
+ order = 1
137
+
138
+ @register_to_config
139
+ def __init__(
140
+ self,
141
+ num_train_timesteps: int = 1000,
142
+ beta_start: float = 0.0001,
143
+ beta_end: float = 0.02,
144
+ beta_schedule: str = "linear",
145
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
146
+ predictor_order: int = 2,
147
+ corrector_order: int = 2,
148
+ predictor_corrector_mode: str = 'PEC',
149
+ prediction_type: str = "epsilon",
150
+ tau_func: Callable = lambda t: 1 if t >= 200 and t <= 800 else 0,
151
+ thresholding: bool = False,
152
+ dynamic_thresholding_ratio: float = 0.995,
153
+ sample_max_value: float = 1.0,
154
+ algorithm_type: str = "data_prediction",
155
+ lower_order_final: bool = True,
156
+ use_karras_sigmas: Optional[bool] = False,
157
+ lambda_min_clipped: float = -float("inf"),
158
+ variance_type: Optional[str] = None,
159
+ timestep_spacing: str = "linspace",
160
+ steps_offset: int = 0,
161
+ ):
162
+ if trained_betas is not None:
163
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
164
+ elif beta_schedule == "linear":
165
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
166
+ elif beta_schedule == "scaled_linear":
167
+ # this schedule is very specific to the latent diffusion model.
168
+ self.betas = (
169
+ torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
170
+ )
171
+ elif beta_schedule == "squaredcos_cap_v2":
172
+ # Glide cosine schedule
173
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
174
+ else:
175
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
176
+
177
+ self.alphas = 1.0 - self.betas
178
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
179
+ # Currently we only support VP-type noise schedule
180
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
181
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
182
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
183
+
184
+ # standard deviation of the initial noise distribution
185
+ self.init_noise_sigma = 1.0
186
+
187
+ if algorithm_type not in ["data_prediction", "noise_prediction"]:
188
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
189
+
190
+ # setable values
191
+ self.num_inference_steps = None
192
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
193
+ self.timesteps = torch.from_numpy(timesteps)
194
+ self.timestep_list = [None] * max(predictor_order, corrector_order - 1)
195
+ self.model_outputs = [None] * max(predictor_order, corrector_order - 1)
196
+
197
+ self.tau_func = tau_func
198
+ self.predict_x0 = algorithm_type == "data_prediction"
199
+ self.lower_order_nums = 0
200
+ self.last_sample = None
201
+
202
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
203
+ """
204
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
205
+
206
+ Args:
207
+ num_inference_steps (`int`):
208
+ The number of diffusion steps used when generating samples with a pre-trained model.
209
+ device (`str` or `torch.device`, *optional*):
210
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
211
+ """
212
+ # Clipping the minimum of all lambda(t) for numerical stability.
213
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
214
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
215
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
216
+
217
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
218
+ if self.config.timestep_spacing == "linspace":
219
+ timesteps = (
220
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
221
+ )
222
+
223
+ elif self.config.timestep_spacing == "leading":
224
+ step_ratio = last_timestep // (num_inference_steps + 1)
225
+ # creates integer timesteps by multiplying by ratio
226
+ # casting to int to avoid issues when num_inference_step is power of 3
227
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
228
+ timesteps += self.config.steps_offset
229
+ elif self.config.timestep_spacing == "trailing":
230
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
231
+ # creates integer timesteps by multiplying by ratio
232
+ # casting to int to avoid issues when num_inference_step is power of 3
233
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
234
+ timesteps -= 1
235
+ else:
236
+ raise ValueError(
237
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
238
+ )
239
+
240
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
241
+ if self.config.use_karras_sigmas:
242
+ log_sigmas = np.log(sigmas)
243
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
244
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
245
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
246
+
247
+ self.sigmas = torch.from_numpy(sigmas)
248
+
249
+ # when num_inference_steps == num_train_timesteps, we can end up with
250
+ # duplicates in timesteps.
251
+ _, unique_indices = np.unique(timesteps, return_index=True)
252
+ timesteps = timesteps[np.sort(unique_indices)]
253
+
254
+ self.timesteps = torch.from_numpy(timesteps).to(device)
255
+
256
+ self.num_inference_steps = len(timesteps)
257
+
258
+ self.model_outputs = [
259
+ None,
260
+ ] * max(self.config.predictor_order, self.config.corrector_order - 1)
261
+ self.lower_order_nums = 0
262
+ self.last_sample = None
263
+
264
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
265
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
266
+ """
267
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
268
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
269
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
270
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
271
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
272
+
273
+ https://arxiv.org/abs/2205.11487
274
+ """
275
+ dtype = sample.dtype
276
+ batch_size, channels, height, width = sample.shape
277
+
278
+ if dtype not in (torch.float32, torch.float64):
279
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
280
+
281
+ # Flatten sample for doing quantile calculation along each image
282
+ sample = sample.reshape(batch_size, channels * height * width)
283
+
284
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
285
+
286
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
287
+ s = torch.clamp(
288
+ s, min=1, max=self.config.sample_max_value
289
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
290
+
291
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
292
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
293
+
294
+ sample = sample.reshape(batch_size, channels, height, width)
295
+ sample = sample.to(dtype)
296
+
297
+ return sample
298
+
299
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
300
+ def _sigma_to_t(self, sigma, log_sigmas):
301
+ # get log sigma
302
+ log_sigma = np.log(sigma)
303
+
304
+ # get distribution
305
+ dists = log_sigma - log_sigmas[:, np.newaxis]
306
+
307
+ # get sigmas range
308
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
309
+ high_idx = low_idx + 1
310
+
311
+ low = log_sigmas[low_idx]
312
+ high = log_sigmas[high_idx]
313
+
314
+ # interpolate sigmas
315
+ w = (low - log_sigma) / (low - high)
316
+ w = np.clip(w, 0, 1)
317
+
318
+ # transform interpolation to time range
319
+ t = (1 - w) * low_idx + w * high_idx
320
+ t = t.reshape(sigma.shape)
321
+ return t
322
+
323
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
324
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
325
+ """Constructs the noise schedule of Karras et al. (2022)."""
326
+
327
+ sigma_min: float = in_sigmas[-1].item()
328
+ sigma_max: float = in_sigmas[0].item()
329
+
330
+ rho = 7.0 # 7.0 is the value used in the paper
331
+ ramp = np.linspace(0, 1, num_inference_steps)
332
+ min_inv_rho = sigma_min ** (1 / rho)
333
+ max_inv_rho = sigma_max ** (1 / rho)
334
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
335
+ return sigmas
336
+
337
+ def convert_model_output(
338
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
339
+ ) -> torch.FloatTensor:
340
+ """
341
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
342
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
343
+ integral of the data prediction model.
344
+
345
+ <Tip>
346
+
347
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
348
+ prediction and data prediction models.
349
+
350
+ </Tip>
351
+
352
+ Args:
353
+ model_output (`torch.FloatTensor`):
354
+ The direct output from the learned diffusion model.
355
+ timestep (`int`):
356
+ The current discrete timestep in the diffusion chain.
357
+ sample (`torch.FloatTensor`):
358
+ A current instance of a sample created by the diffusion process.
359
+
360
+ Returns:
361
+ `torch.FloatTensor`:
362
+ The converted model output.
363
+ """
364
+
365
+ # SA-Solver_data_prediction needs to solve an integral of the data prediction model.
366
+ if self.config.algorithm_type in ["data_prediction"]:
367
+ if self.config.prediction_type == "epsilon":
368
+ # SA-Solver only needs the "mean" output.
369
+ if self.config.variance_type in ["learned", "learned_range"]:
370
+ model_output = model_output[:, :3]
371
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
372
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
373
+ elif self.config.prediction_type == "sample":
374
+ x0_pred = model_output
375
+ elif self.config.prediction_type == "v_prediction":
376
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
377
+ x0_pred = alpha_t * sample - sigma_t * model_output
378
+ else:
379
+ raise ValueError(
380
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
381
+ " `v_prediction` for the SASolverScheduler."
382
+ )
383
+
384
+ if self.config.thresholding:
385
+ x0_pred = self._threshold_sample(x0_pred)
386
+
387
+ return x0_pred
388
+
389
+ # SA-Solver_noise_prediction needs to solve an integral of the noise prediction model.
390
+ elif self.config.algorithm_type in ["noise_prediction"]:
391
+ if self.config.prediction_type == "epsilon":
392
+ # SA-Solver only needs the "mean" output.
393
+ if self.config.variance_type in ["learned", "learned_range"]:
394
+ epsilon = model_output[:, :3]
395
+ else:
396
+ epsilon = model_output
397
+ elif self.config.prediction_type == "sample":
398
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
399
+ epsilon = (sample - alpha_t * model_output) / sigma_t
400
+ elif self.config.prediction_type == "v_prediction":
401
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
402
+ epsilon = alpha_t * model_output + sigma_t * sample
403
+ else:
404
+ raise ValueError(
405
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
406
+ " `v_prediction` for the SASolverScheduler."
407
+ )
408
+
409
+ if self.config.thresholding:
410
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
411
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
412
+ x0_pred = self._threshold_sample(x0_pred)
413
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
414
+
415
+ return epsilon
416
+
417
+ def get_coefficients_exponential_negative(self, order, interval_start, interval_end):
418
+ """
419
+ Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end
420
+ """
421
+ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
422
+
423
+ if order == 0:
424
+ return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1)
425
+ elif order == 1:
426
+ return torch.exp(-interval_end) * (
427
+ (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1))
428
+ elif order == 2:
429
+ return torch.exp(-interval_end) * (
430
+ (interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - (
431
+ interval_end ** 2 + 2 * interval_end + 2))
432
+ elif order == 3:
433
+ return torch.exp(-interval_end) * (
434
+ (interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp(
435
+ interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6))
436
+
437
+ def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau):
438
+ """
439
+ Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
440
+ """
441
+ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
442
+
443
+ # after change of variable(cov)
444
+ interval_end_cov = (1 + tau ** 2) * interval_end
445
+ interval_start_cov = (1 + tau ** 2) * interval_start
446
+
447
+ if order == 0:
448
+ return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (
449
+ (1 + tau ** 2))
450
+ elif order == 1:
451
+ return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(
452
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2)
453
+ elif order == 2:
454
+ return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - (
455
+ interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp(
456
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3)
457
+ elif order == 3:
458
+ return torch.exp(interval_end_cov) * (
459
+ (interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - (
460
+ interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp(
461
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4)
462
+
463
+ def lagrange_polynomial_coefficient(self, order, lambda_list):
464
+ """
465
+ Calculate the coefficient of lagrange polynomial
466
+ """
467
+
468
+ assert order in [0, 1, 2, 3]
469
+ assert order == len(lambda_list) - 1
470
+ if order == 0:
471
+ return [[1]]
472
+ elif order == 1:
473
+ return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
474
+ [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
475
+ elif order == 2:
476
+ denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
477
+ denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
478
+ denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
479
+ return [[1 / denominator1,
480
+ (-lambda_list[1] - lambda_list[2]) / denominator1,
481
+ lambda_list[1] * lambda_list[2] / denominator1],
482
+
483
+ [1 / denominator2,
484
+ (-lambda_list[0] - lambda_list[2]) / denominator2,
485
+ lambda_list[0] * lambda_list[2] / denominator2],
486
+
487
+ [1 / denominator3,
488
+ (-lambda_list[0] - lambda_list[1]) / denominator3,
489
+ lambda_list[0] * lambda_list[1] / denominator3]
490
+ ]
491
+ elif order == 3:
492
+ denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (
493
+ lambda_list[0] - lambda_list[3])
494
+ denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (
495
+ lambda_list[1] - lambda_list[3])
496
+ denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (
497
+ lambda_list[2] - lambda_list[3])
498
+ denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (
499
+ lambda_list[3] - lambda_list[2])
500
+ return [[1 / denominator1,
501
+ (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
502
+ (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[
503
+ 3]) / denominator1,
504
+ (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
505
+
506
+ [1 / denominator2,
507
+ (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
508
+ (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[
509
+ 3]) / denominator2,
510
+ (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
511
+
512
+ [1 / denominator3,
513
+ (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
514
+ (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[
515
+ 3]) / denominator3,
516
+ (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
517
+
518
+ [1 / denominator4,
519
+ (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
520
+ (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[
521
+ 2]) / denominator4,
522
+ (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
523
+
524
+ ]
525
+
526
+ def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau):
527
+ assert order in [1, 2, 3, 4]
528
+ assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
529
+ coefficients = []
530
+ lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list)
531
+ for i in range(order):
532
+ coefficient = 0
533
+ for j in range(order):
534
+ if self.predict_x0:
535
+
536
+ coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive(
537
+ order - 1 - j, interval_start, interval_end, tau)
538
+ else:
539
+ coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative(
540
+ order - 1 - j, interval_start, interval_end)
541
+ coefficients.append(coefficient)
542
+ assert len(coefficients) == order, 'the length of coefficients does not match the order'
543
+ return coefficients
544
+
545
+ def stochastic_adams_bashforth_update(
546
+ self,
547
+ model_output: torch.FloatTensor,
548
+ prev_timestep: int,
549
+ sample: torch.FloatTensor,
550
+ noise: torch.FloatTensor,
551
+ order: int,
552
+ tau: torch.FloatTensor,
553
+ ) -> torch.FloatTensor:
554
+ """
555
+ One step for the SA-Predictor.
556
+
557
+ Args:
558
+ model_output (`torch.FloatTensor`):
559
+ The direct output from the learned diffusion model at the current timestep.
560
+ prev_timestep (`int`):
561
+ The previous discrete timestep in the diffusion chain.
562
+ sample (`torch.FloatTensor`):
563
+ A current instance of a sample created by the diffusion process.
564
+ order (`int`):
565
+ The order of SA-Predictor at this timestep.
566
+
567
+ Returns:
568
+ `torch.FloatTensor`:
569
+ The sample tensor at the previous timestep.
570
+ """
571
+
572
+ assert noise is not None
573
+ timestep_list = self.timestep_list
574
+ model_output_list = self.model_outputs
575
+ s0, t = self.timestep_list[-1], prev_timestep
576
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
577
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
578
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
579
+ gradient_part = torch.zeros_like(sample)
580
+ h = lambda_t - lambda_s0
581
+ lambda_list = []
582
+
583
+ for i in range(order):
584
+ lambda_list.append(self.lambda_t[timestep_list[-(i + 1)]])
585
+
586
+ gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
587
+
588
+ x = sample
589
+
590
+ if self.predict_x0:
591
+ if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling.
592
+ # The added term is O(h^3). Empirically we find it will slightly improve the image quality.
593
+ # ODE case
594
+ # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
595
+ # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
596
+ gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
597
+ h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
598
+ (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
599
+ timestep_list[-2]])
600
+ gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
601
+ h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
602
+ (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
603
+ timestep_list[-2]])
604
+
605
+ for i in range(order):
606
+ if self.predict_x0:
607
+
608
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
609
+ i] * model_output_list[-(i + 1)]
610
+ else:
611
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)]
612
+
613
+ if self.predict_x0:
614
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
615
+ else:
616
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
617
+
618
+ if self.predict_x0:
619
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
620
+ else:
621
+ x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
622
+
623
+ x_t = x_t.to(x.dtype)
624
+ return x_t
625
+
626
+ def stochastic_adams_moulton_update(
627
+ self,
628
+ this_model_output: torch.FloatTensor,
629
+ this_timestep: int,
630
+ last_sample: torch.FloatTensor,
631
+ last_noise: torch.FloatTensor,
632
+ this_sample: torch.FloatTensor,
633
+ order: int,
634
+ tau: torch.FloatTensor,
635
+ ) -> torch.FloatTensor:
636
+ """
637
+ One step for the SA-Corrector.
638
+
639
+ Args:
640
+ this_model_output (`torch.FloatTensor`):
641
+ The model outputs at `x_t`.
642
+ this_timestep (`int`):
643
+ The current timestep `t`.
644
+ last_sample (`torch.FloatTensor`):
645
+ The generated sample before the last predictor `x_{t-1}`.
646
+ this_sample (`torch.FloatTensor`):
647
+ The generated sample after the last predictor `x_{t}`.
648
+ order (`int`):
649
+ The order of SA-Corrector at this step.
650
+
651
+ Returns:
652
+ `torch.FloatTensor`:
653
+ The corrected sample tensor at the current timestep.
654
+ """
655
+
656
+ assert last_noise is not None
657
+ timestep_list = self.timestep_list
658
+ model_output_list = self.model_outputs
659
+ s0, t = self.timestep_list[-1], this_timestep
660
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
661
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
662
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
663
+ gradient_part = torch.zeros_like(this_sample)
664
+ h = lambda_t - lambda_s0
665
+ t_list = timestep_list + [this_timestep]
666
+ lambda_list = []
667
+ for i in range(order):
668
+ lambda_list.append(self.lambda_t[t_list[-(i + 1)]])
669
+
670
+ model_prev_list = model_output_list + [this_model_output]
671
+
672
+ gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
673
+
674
+ x = last_sample
675
+
676
+ if self.predict_x0:
677
+ if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
678
+ # The added term is O(h^3). Empirically we find it will slightly improve the image quality.
679
+ # ODE case
680
+ # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
681
+ # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
682
+ gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
683
+ h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
684
+ (1 + tau ** 2) ** 2 * h))
685
+ gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
686
+ h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
687
+ (1 + tau ** 2) ** 2 * h))
688
+
689
+ for i in range(order):
690
+ if self.predict_x0:
691
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
692
+ i] * model_prev_list[-(i + 1)]
693
+ else:
694
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
695
+
696
+ if self.predict_x0:
697
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * last_noise
698
+ else:
699
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise
700
+
701
+ if self.predict_x0:
702
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
703
+ else:
704
+ x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
705
+
706
+ x_t = x_t.to(x.dtype)
707
+ return x_t
708
+
709
+ def step(
710
+ self,
711
+ model_output: torch.FloatTensor,
712
+ timestep: int,
713
+ sample: torch.FloatTensor,
714
+ generator=None,
715
+ return_dict: bool = True,
716
+ ) -> Union[SchedulerOutput, Tuple]:
717
+ """
718
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
719
+ the SA-Solver.
720
+
721
+ Args:
722
+ model_output (`torch.FloatTensor`):
723
+ The direct output from learned diffusion model.
724
+ timestep (`int`):
725
+ The current discrete timestep in the diffusion chain.
726
+ sample (`torch.FloatTensor`):
727
+ A current instance of a sample created by the diffusion process.
728
+ generator (`torch.Generator`, *optional*):
729
+ A random number generator.
730
+ return_dict (`bool`):
731
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
732
+
733
+ Returns:
734
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
735
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
736
+ tuple is returned where the first element is the sample tensor.
737
+
738
+ """
739
+ if self.num_inference_steps is None:
740
+ raise ValueError(
741
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
742
+ )
743
+
744
+ if isinstance(timestep, torch.Tensor):
745
+ timestep = timestep.to(self.timesteps.device)
746
+ step_index = (self.timesteps == timestep).nonzero()
747
+ if len(step_index) == 0:
748
+ step_index = len(self.timesteps) - 1
749
+ else:
750
+ step_index = step_index.item()
751
+
752
+ use_corrector = (
753
+ step_index > 0 and self.last_sample is not None
754
+ )
755
+
756
+ model_output_convert = self.convert_model_output(model_output, timestep, sample)
757
+
758
+ if use_corrector:
759
+ current_tau = self.tau_func(self.timestep_list[-1])
760
+ sample = self.stochastic_adams_moulton_update(
761
+ this_model_output=model_output_convert,
762
+ this_timestep=timestep,
763
+ last_sample=self.last_sample,
764
+ last_noise=self.last_noise,
765
+ this_sample=sample,
766
+ order=self.this_corrector_order,
767
+ tau=current_tau,
768
+ )
769
+
770
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
771
+
772
+ for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1):
773
+ self.model_outputs[i] = self.model_outputs[i + 1]
774
+ self.timestep_list[i] = self.timestep_list[i + 1]
775
+
776
+ self.model_outputs[-1] = model_output_convert
777
+ self.timestep_list[-1] = timestep
778
+
779
+ noise = randn_tensor(
780
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
781
+ )
782
+
783
+ if self.config.lower_order_final:
784
+ this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - step_index)
785
+ this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - step_index + 1)
786
+ else:
787
+ this_predictor_order = self.config.predictor_order
788
+ this_corrector_order = self.config.corrector_order
789
+
790
+ self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep
791
+ self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep
792
+ assert self.this_predictor_order > 0
793
+ assert self.this_corrector_order > 0
794
+
795
+ self.last_sample = sample
796
+ self.last_noise = noise
797
+
798
+ current_tau = self.tau_func(self.timestep_list[-1])
799
+ prev_sample = self.stochastic_adams_bashforth_update(
800
+ model_output=model_output_convert,
801
+ prev_timestep=prev_timestep,
802
+ sample=sample,
803
+ noise=noise,
804
+ order=self.this_predictor_order,
805
+ tau=current_tau,
806
+ )
807
+
808
+ if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1):
809
+ self.lower_order_nums += 1
810
+
811
+ if not return_dict:
812
+ return (prev_sample,)
813
+
814
+ return SchedulerOutput(prev_sample=prev_sample)
815
+
816
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
817
+ """
818
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
819
+ current timestep.
820
+
821
+ Args:
822
+ sample (`torch.FloatTensor`):
823
+ The input sample.
824
+
825
+ Returns:
826
+ `torch.FloatTensor`:
827
+ A scaled input sample.
828
+ """
829
+ return sample
830
+
831
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
832
+ def add_noise(
833
+ self,
834
+ original_samples: torch.FloatTensor,
835
+ noise: torch.FloatTensor,
836
+ timesteps: torch.IntTensor,
837
+ ) -> torch.FloatTensor:
838
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
839
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
840
+ timesteps = timesteps.to(original_samples.device)
841
+
842
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
843
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
844
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
845
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
846
+
847
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
848
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
849
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
850
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
851
+
852
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
853
+ return noisy_samples
854
+
855
+ def __len__(self):
856
+ return self.config.num_train_timesteps