renyuxi commited on
Commit
fcf30a0
β€’
1 Parent(s): 99ea941

Upload scheduling_tcd.py

Browse files
Files changed (1) hide show
  1. scheduling_tcd.py +686 -0
scheduling_tcd.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class TCDSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_noised_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted noised sample `(x_{s})` based on the model output from the current timestep.
45
+ """
46
+
47
+ prev_sample: torch.FloatTensor
48
+ pred_noised_sample: Optional[torch.FloatTensor] = None
49
+
50
+
51
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
52
+ def betas_for_alpha_bar(
53
+ num_diffusion_timesteps,
54
+ max_beta=0.999,
55
+ alpha_transform_type="cosine",
56
+ ):
57
+ """
58
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
59
+ (1-beta) over time from t = [0,1].
60
+
61
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
62
+ to that part of the diffusion process.
63
+
64
+
65
+ Args:
66
+ num_diffusion_timesteps (`int`): the number of betas to produce.
67
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
68
+ prevent singularities.
69
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
70
+ Choose from `cosine` or `exp`
71
+
72
+ Returns:
73
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
74
+ """
75
+ if alpha_transform_type == "cosine":
76
+
77
+ def alpha_bar_fn(t):
78
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
79
+
80
+ elif alpha_transform_type == "exp":
81
+
82
+ def alpha_bar_fn(t):
83
+ return math.exp(t * -12.0)
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
87
+
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
+ return torch.tensor(betas, dtype=torch.float32)
94
+
95
+
96
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
97
+ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
98
+ """
99
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
100
+
101
+
102
+ Args:
103
+ betas (`torch.FloatTensor`):
104
+ the betas that the scheduler is being initialized with.
105
+
106
+ Returns:
107
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
108
+ """
109
+ # Convert betas to alphas_bar_sqrt
110
+ alphas = 1.0 - betas
111
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
112
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
113
+
114
+ # Store old values.
115
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
116
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
117
+
118
+ # Shift so the last timestep is zero.
119
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
120
+
121
+ # Scale so the first timestep is back to the old value.
122
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
123
+
124
+ # Convert alphas_bar_sqrt to betas
125
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
126
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
127
+ alphas = torch.cat([alphas_bar[0:1], alphas])
128
+ betas = 1 - alphas
129
+
130
+ return betas
131
+
132
+
133
+ class TCDScheduler(SchedulerMixin, ConfigMixin):
134
+ """
135
+ `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency Distillation`,
136
+ extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
137
+
138
+ This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD).
139
+
140
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
141
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
142
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
143
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
144
+
145
+ Args:
146
+ num_train_timesteps (`int`, defaults to 1000):
147
+ The number of diffusion steps to train the model.
148
+ beta_start (`float`, defaults to 0.0001):
149
+ The starting `beta` value of inference.
150
+ beta_end (`float`, defaults to 0.02):
151
+ The final `beta` value.
152
+ beta_schedule (`str`, defaults to `"linear"`):
153
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
154
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
155
+ trained_betas (`np.ndarray`, *optional*):
156
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
157
+ original_inference_steps (`int`, *optional*, defaults to 50):
158
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
159
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
160
+ clip_sample (`bool`, defaults to `True`):
161
+ Clip the predicted sample for numerical stability.
162
+ clip_sample_range (`float`, defaults to 1.0):
163
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
164
+ set_alpha_to_one (`bool`, defaults to `True`):
165
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
166
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
167
+ otherwise it uses the alpha value at step 0.
168
+ steps_offset (`int`, defaults to 0):
169
+ An offset added to the inference steps, as required by some model families.
170
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
171
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
172
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
173
+ Video](https://imagen.research.google/video/paper.pdf) paper).
174
+ thresholding (`bool`, defaults to `False`):
175
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
176
+ as Stable Diffusion.
177
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
178
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
179
+ sample_max_value (`float`, defaults to 1.0):
180
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
181
+ timestep_spacing (`str`, defaults to `"leading"`):
182
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
183
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
184
+ timestep_scaling (`float`, defaults to 10.0):
185
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
186
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
187
+ error at the default of `10.0` is already pretty small).
188
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
189
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
190
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
191
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
192
+ """
193
+
194
+ order = 1
195
+
196
+ @register_to_config
197
+ def __init__(
198
+ self,
199
+ num_train_timesteps: int = 1000,
200
+ beta_start: float = 0.00085,
201
+ beta_end: float = 0.012,
202
+ beta_schedule: str = "scaled_linear",
203
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
204
+ original_inference_steps: int = 50,
205
+ clip_sample: bool = False,
206
+ clip_sample_range: float = 1.0,
207
+ set_alpha_to_one: bool = True,
208
+ steps_offset: int = 0,
209
+ prediction_type: str = "epsilon",
210
+ thresholding: bool = False,
211
+ dynamic_thresholding_ratio: float = 0.995,
212
+ sample_max_value: float = 1.0,
213
+ timestep_spacing: str = "leading",
214
+ timestep_scaling: float = 10.0,
215
+ rescale_betas_zero_snr: bool = False,
216
+ ):
217
+ if trained_betas is not None:
218
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
219
+ elif beta_schedule == "linear":
220
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
221
+ elif beta_schedule == "scaled_linear":
222
+ # this schedule is very specific to the latent diffusion model.
223
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
224
+ elif beta_schedule == "squaredcos_cap_v2":
225
+ # Glide cosine schedule
226
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
227
+ else:
228
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
229
+
230
+ # Rescale for zero SNR
231
+ if rescale_betas_zero_snr:
232
+ self.betas = rescale_zero_terminal_snr(self.betas)
233
+
234
+ self.alphas = 1.0 - self.betas
235
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
236
+
237
+ # At every step in ddim, we are looking into the previous alphas_cumprod
238
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
239
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
240
+ # whether we use the final alpha of the "non-previous" one.
241
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
242
+
243
+ # standard deviation of the initial noise distribution
244
+ self.init_noise_sigma = 1.0
245
+
246
+ # setable values
247
+ self.num_inference_steps = None
248
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
249
+ self.custom_timesteps = False
250
+
251
+ self._step_index = None
252
+ self._begin_index = None
253
+
254
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
255
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
256
+ if schedule_timesteps is None:
257
+ schedule_timesteps = self.timesteps
258
+
259
+ indices = (schedule_timesteps == timestep).nonzero()
260
+
261
+ # The sigma index that is taken for the **very** first `step`
262
+ # is always the second index (or the last index if there is only 1)
263
+ # This way we can ensure we don't accidentally skip a sigma in
264
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
265
+ pos = 1 if len(indices) > 1 else 0
266
+
267
+ return indices[pos].item()
268
+
269
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
270
+ def _init_step_index(self, timestep):
271
+ if self.begin_index is None:
272
+ if isinstance(timestep, torch.Tensor):
273
+ timestep = timestep.to(self.timesteps.device)
274
+ self._step_index = self.index_for_timestep(timestep)
275
+ else:
276
+ self._step_index = self._begin_index
277
+
278
+ @property
279
+ def step_index(self):
280
+ return self._step_index
281
+
282
+ @property
283
+ def begin_index(self):
284
+ """
285
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
286
+ """
287
+ return self._begin_index
288
+
289
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
290
+ def set_begin_index(self, begin_index: int = 0):
291
+ """
292
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
293
+
294
+ Args:
295
+ begin_index (`int`):
296
+ The begin index for the scheduler.
297
+ """
298
+ self._begin_index = begin_index
299
+
300
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
301
+ """
302
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
303
+ current timestep.
304
+
305
+ Args:
306
+ sample (`torch.FloatTensor`):
307
+ The input sample.
308
+ timestep (`int`, *optional*):
309
+ The current timestep in the diffusion chain.
310
+ Returns:
311
+ `torch.FloatTensor`:
312
+ A scaled input sample.
313
+ """
314
+ return sample
315
+
316
+ # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
317
+ def _get_variance(self, timestep, prev_timestep):
318
+ alpha_prod_t = self.alphas_cumprod[timestep]
319
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
320
+ beta_prod_t = 1 - alpha_prod_t
321
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
322
+
323
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
324
+
325
+ return variance
326
+
327
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
328
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
329
+ """
330
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
331
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
332
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
333
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
334
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
335
+
336
+ https://arxiv.org/abs/2205.11487
337
+ """
338
+ dtype = sample.dtype
339
+ batch_size, channels, *remaining_dims = sample.shape
340
+
341
+ if dtype not in (torch.float32, torch.float64):
342
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
343
+
344
+ # Flatten sample for doing quantile calculation along each image
345
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
346
+
347
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
348
+
349
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
350
+ s = torch.clamp(
351
+ s, min=1, max=self.config.sample_max_value
352
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
353
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
354
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
355
+
356
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
357
+ sample = sample.to(dtype)
358
+
359
+ return sample
360
+
361
+ def set_timesteps(
362
+ self,
363
+ num_inference_steps: Optional[int] = None,
364
+ device: Union[str, torch.device] = None,
365
+ original_inference_steps: Optional[int] = None,
366
+ timesteps: Optional[List[int]] = None,
367
+ strength: int = 1.0,
368
+ ):
369
+ """
370
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
371
+
372
+ Args:
373
+ num_inference_steps (`int`, *optional*):
374
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
375
+ `timesteps` must be `None`.
376
+ device (`str` or `torch.device`, *optional*):
377
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
378
+ original_inference_steps (`int`, *optional*):
379
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
380
+ schedule (which is different from the standard `diffusers` implementation). We will then take
381
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
382
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
383
+ timesteps (`List[int]`, *optional*):
384
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
385
+ timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
386
+ schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
387
+ """
388
+ # 0. Check inputs
389
+ if num_inference_steps is None and timesteps is None:
390
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
391
+
392
+ if num_inference_steps is not None and timesteps is not None:
393
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
394
+
395
+ # 1. Calculate the TCD original training/distillation timestep schedule.
396
+ original_steps = (
397
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
398
+ )
399
+
400
+ if original_inference_steps is None:
401
+ # default option, timesteps align with discrete inference steps
402
+ if original_steps > self.config.num_train_timesteps:
403
+ raise ValueError(
404
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
405
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
406
+ f" maximal {self.config.num_train_timesteps} timesteps."
407
+ )
408
+ # TCD Timesteps Setting
409
+ # The skipping step parameter k from the paper.
410
+ k = self.config.num_train_timesteps // original_steps
411
+ # TCD Training/Distillation Steps Schedule
412
+ tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
413
+ else:
414
+ # customised option, sampled timesteps can be any arbitrary value
415
+ tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps * strength))))
416
+
417
+ # 2. Calculate the TCD inference timestep schedule.
418
+ if timesteps is not None:
419
+ # 2.1 Handle custom timestep schedules.
420
+ train_timesteps = set(tcd_origin_timesteps)
421
+ non_train_timesteps = []
422
+ for i in range(1, len(timesteps)):
423
+ if timesteps[i] >= timesteps[i - 1]:
424
+ raise ValueError("`custom_timesteps` must be in descending order.")
425
+
426
+ if timesteps[i] not in train_timesteps:
427
+ non_train_timesteps.append(timesteps[i])
428
+
429
+ if timesteps[0] >= self.config.num_train_timesteps:
430
+ raise ValueError(
431
+ f"`timesteps` must start before `self.config.train_timesteps`:"
432
+ f" {self.config.num_train_timesteps}."
433
+ )
434
+
435
+ # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
436
+ if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
437
+ logger.warning(
438
+ f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
439
+ f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
440
+ f" unexpected results when using this timestep schedule."
441
+ )
442
+
443
+ # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
444
+ if non_train_timesteps:
445
+ logger.warning(
446
+ f"The custom timestep schedule contains the following timesteps which are not on the original"
447
+ f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
448
+ f" when using this timestep schedule."
449
+ )
450
+
451
+ # Raise warning if custom timestep schedule is longer than original_steps
452
+ if original_steps is not None:
453
+ if len(timesteps) > original_steps:
454
+ logger.warning(
455
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
456
+ f" the length of the timestep schedule used for training: {original_steps}. You may get some"
457
+ f" unexpected results when using this timestep schedule."
458
+ )
459
+ else:
460
+ if len(timesteps) > self.config.num_train_timesteps:
461
+ logger.warning(
462
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
463
+ f" the length of the timestep schedule used for training: {self.config.num_train_timesteps}. You may get some"
464
+ f" unexpected results when using this timestep schedule."
465
+ )
466
+
467
+ timesteps = np.array(timesteps, dtype=np.int64)
468
+ self.num_inference_steps = len(timesteps)
469
+ self.custom_timesteps = True
470
+
471
+ # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
472
+ init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
473
+ t_start = max(self.num_inference_steps - init_timestep, 0)
474
+ timesteps = timesteps[t_start * self.order :]
475
+ # TODO: also reset self.num_inference_steps?
476
+ else:
477
+ # 2.2 Create the "standard" TCD inference timestep schedule.
478
+ if num_inference_steps > self.config.num_train_timesteps:
479
+ raise ValueError(
480
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
481
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
482
+ f" maximal {self.config.num_train_timesteps} timesteps."
483
+ )
484
+
485
+ if original_steps is not None:
486
+ skipping_step = len(tcd_origin_timesteps) // num_inference_steps
487
+
488
+ if skipping_step < 1:
489
+ raise ValueError(
490
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
491
+ )
492
+
493
+ self.num_inference_steps = num_inference_steps
494
+
495
+ if original_steps is not None:
496
+ if num_inference_steps > original_steps:
497
+ raise ValueError(
498
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
499
+ f" {original_steps} because the final timestep schedule will be a subset of the"
500
+ f" `original_inference_steps`-sized initial timestep schedule."
501
+ )
502
+ else:
503
+ if num_inference_steps > self.config.num_train_timesteps:
504
+ raise ValueError(
505
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:"
506
+ f" {self.config.num_train_timesteps} because the final timestep schedule will be a subset of the"
507
+ f" `num_train_timesteps`-sized initial timestep schedule."
508
+ )
509
+
510
+ # TCD Inference Steps Schedule
511
+ tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
512
+ # Select (approximately) evenly spaced indices from tcd_origin_timesteps.
513
+ inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
514
+ inference_indices = np.floor(inference_indices).astype(np.int64)
515
+ timesteps = tcd_origin_timesteps[inference_indices]
516
+
517
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
518
+
519
+ self._step_index = None
520
+ self._begin_index = None
521
+
522
+ def step(
523
+ self,
524
+ model_output: torch.FloatTensor,
525
+ timestep: int,
526
+ sample: torch.FloatTensor,
527
+ eta: float = 0.3,
528
+ generator: Optional[torch.Generator] = None,
529
+ return_dict: bool = True,
530
+ ) -> Union[TCDSchedulerOutput, Tuple]:
531
+ """
532
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
533
+ process from the learned model outputs (most often the predicted noise).
534
+
535
+ Args:
536
+ model_output (`torch.FloatTensor`):
537
+ The direct output from learned diffusion model.
538
+ timestep (`int`):
539
+ The current discrete timestep in the diffusion chain.
540
+ sample (`torch.FloatTensor`):
541
+ A current instance of a sample created by the diffusion process.
542
+ eta (`float`):
543
+ A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every step.
544
+ When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
545
+ generator (`torch.Generator`, *optional*):
546
+ A random number generator.
547
+ return_dict (`bool`, *optional*, defaults to `True`):
548
+ Whether or not to return a [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] or `tuple`.
549
+ Returns:
550
+ [`~schedulers.scheduling_utils.TCDSchedulerOutput`] or `tuple`:
551
+ If return_dict is `True`, [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] is returned, otherwise a
552
+ tuple is returned where the first element is the sample tensor.
553
+ """
554
+ if self.num_inference_steps is None:
555
+ raise ValueError(
556
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
557
+ )
558
+
559
+ if self.step_index is None:
560
+ self._init_step_index(timestep)
561
+
562
+ assert 0 <= eta <= 1.0, "gamma must be less than or equal to 1.0"
563
+
564
+ # 1. get previous step value
565
+ prev_step_index = self.step_index + 1
566
+ if prev_step_index < len(self.timesteps):
567
+ prev_timestep = self.timesteps[prev_step_index]
568
+ else:
569
+ prev_timestep = torch.tensor(0)
570
+
571
+ timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long)
572
+
573
+ # 2. compute alphas, betas
574
+ alpha_prod_t = self.alphas_cumprod[timestep]
575
+ beta_prod_t = 1 - alpha_prod_t
576
+
577
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
578
+
579
+ alpha_prod_s = self.alphas_cumprod[timestep_s]
580
+ beta_prod_s = 1 - alpha_prod_s
581
+
582
+ # 3. Compute the predicted noised sample x_s based on the model parameterization
583
+ if self.config.prediction_type == "epsilon": # noise-prediction
584
+ pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
585
+ pred_epsilon = model_output
586
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
587
+ elif self.config.prediction_type == "sample": # x-prediction
588
+ pred_original_sample = model_output
589
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
590
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
591
+ elif self.config.prediction_type == "v_prediction": # v-prediction
592
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
593
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
594
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
595
+ else:
596
+ raise ValueError(
597
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
598
+ " `v_prediction` for `TCDScheduler`."
599
+ )
600
+
601
+ # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference
602
+ # Noise is not used on the final timestep of the timestep schedule.
603
+ # This also means that noise is not used for one-step sampling.
604
+ # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step.
605
+ # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
606
+ if eta > 0:
607
+ if self.step_index != self.num_inference_steps - 1:
608
+ noise = randn_tensor(
609
+ model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype
610
+ )
611
+ prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
612
+ 1 - alpha_prod_t_prev / alpha_prod_s
613
+ ).sqrt() * noise
614
+ else:
615
+ prev_sample = pred_noised_sample
616
+ else:
617
+ prev_sample = pred_noised_sample
618
+
619
+ # upon completion increase step index by one
620
+ self._step_index += 1
621
+
622
+ if not return_dict:
623
+ return (prev_sample, pred_noised_sample)
624
+
625
+ return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
626
+
627
+ def add_noise(
628
+ self,
629
+ original_samples: torch.FloatTensor,
630
+ noise: torch.FloatTensor,
631
+ timesteps: torch.IntTensor,
632
+ ) -> torch.FloatTensor:
633
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
634
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
635
+ timesteps = timesteps.to(original_samples.device)
636
+
637
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
638
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
639
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
640
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
641
+
642
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
643
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
644
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
645
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
646
+
647
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
648
+ return noisy_samples
649
+
650
+ def get_velocity(
651
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
652
+ ) -> torch.FloatTensor:
653
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
654
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
655
+ timesteps = timesteps.to(sample.device)
656
+
657
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
658
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
659
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
660
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
661
+
662
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
663
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
664
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
665
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
666
+
667
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
668
+ return velocity
669
+
670
+ def __len__(self):
671
+ return self.config.num_train_timesteps
672
+
673
+ def previous_timestep(self, timestep):
674
+ if self.custom_timesteps:
675
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
676
+ if index == self.timesteps.shape[0] - 1:
677
+ prev_t = torch.tensor(-1)
678
+ else:
679
+ prev_t = self.timesteps[index + 1]
680
+ else:
681
+ num_inference_steps = (
682
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
683
+ )
684
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
685
+
686
+ return prev_t