prateekrao commited on
Commit
552aef7
1 Parent(s): 81b02c4

Create scheduling_ddim.py

Browse files
Files changed (1) hide show
  1. scheduling_ddim.py +405 -0
scheduling_ddim.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from ..configuration_utils import ConfigMixin, register_to_config
26
+ from ..utils import BaseOutput, randn_tensor
27
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
+ class DDIMSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's step function output.
35
+
36
+ Args:
37
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
39
+ denoising loop.
40
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+ pred_original_sample: Optional[torch.FloatTensor] = None
47
+
48
+
49
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
50
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
51
+ """
52
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
53
+ (1-beta) over time from t = [0,1].
54
+
55
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
56
+ to that part of the diffusion process.
57
+
58
+
59
+ Args:
60
+ num_diffusion_timesteps (`int`): the number of betas to produce.
61
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
62
+ prevent singularities.
63
+
64
+ Returns:
65
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
66
+ """
67
+
68
+ def alpha_bar(time_step):
69
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
70
+
71
+ betas = []
72
+ for i in range(num_diffusion_timesteps):
73
+ t1 = i / num_diffusion_timesteps
74
+ t2 = (i + 1) / num_diffusion_timesteps
75
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
76
+ return torch.tensor(betas, dtype=torch.float32)
77
+
78
+
79
+ class DDIMScheduler(SchedulerMixin, ConfigMixin):
80
+ """
81
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
82
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
83
+
84
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
85
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
86
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
87
+ [`~SchedulerMixin.from_pretrained`] functions.
88
+
89
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
90
+
91
+ Args:
92
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
93
+ beta_start (`float`): the starting `beta` value of inference.
94
+ beta_end (`float`): the final `beta` value.
95
+ beta_schedule (`str`):
96
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
97
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
98
+ trained_betas (`np.ndarray`, optional):
99
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
100
+ clip_sample (`bool`, default `True`):
101
+ option to clip predicted sample for numerical stability.
102
+ clip_sample_range (`float`, default `1.0`):
103
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
104
+ set_alpha_to_one (`bool`, default `True`):
105
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
106
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
107
+ otherwise it uses the value of alpha at step 0.
108
+ steps_offset (`int`, default `0`):
109
+ an offset added to the inference steps. You can use a combination of `offset=1` and
110
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
111
+ stable diffusion.
112
+ prediction_type (`str`, default `epsilon`, optional):
113
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
114
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
115
+ https://imagen.research.google/video/paper.pdf)
116
+ thresholding (`bool`, default `False`):
117
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
118
+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
119
+ stable-diffusion).
120
+ dynamic_thresholding_ratio (`float`, default `0.995`):
121
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
122
+ (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
123
+ sample_max_value (`float`, default `1.0`):
124
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
125
+ """
126
+
127
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
128
+ order = 1
129
+
130
+ @register_to_config
131
+ def __init__(
132
+ self,
133
+ num_train_timesteps: int = 1000,
134
+ beta_start: float = 0.0001,
135
+ beta_end: float = 0.02,
136
+ beta_schedule: str = "linear",
137
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
138
+ clip_sample: bool = True,
139
+ set_alpha_to_one: bool = True,
140
+ steps_offset: int = 0,
141
+ prediction_type: str = "epsilon",
142
+ thresholding: bool = False,
143
+ dynamic_thresholding_ratio: float = 0.995,
144
+ clip_sample_range: float = 1.0,
145
+ sample_max_value: float = 1.0,
146
+ ):
147
+ if trained_betas is not None:
148
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
149
+ elif beta_schedule == "linear":
150
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
151
+ elif beta_schedule == "scaled_linear":
152
+ # this schedule is very specific to the latent diffusion model.
153
+ self.betas = (
154
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
155
+ )
156
+ elif beta_schedule == "squaredcos_cap_v2":
157
+ # Glide cosine schedule
158
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
159
+ else:
160
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
161
+
162
+ self.alphas = 1.0 - self.betas
163
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
164
+
165
+ # At every step in ddim, we are looking into the previous alphas_cumprod
166
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
167
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
168
+ # whether we use the final alpha of the "non-previous" one.
169
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
170
+
171
+ # standard deviation of the initial noise distribution
172
+ self.init_noise_sigma = 1.0
173
+
174
+ # setable values
175
+ self.num_inference_steps = None
176
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
177
+
178
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
179
+ """
180
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
181
+ current timestep.
182
+
183
+ Args:
184
+ sample (`torch.FloatTensor`): input sample
185
+ timestep (`int`, optional): current timestep
186
+
187
+ Returns:
188
+ `torch.FloatTensor`: scaled input sample
189
+ """
190
+ return sample
191
+
192
+ def _get_variance(self, timestep, prev_timestep):
193
+ alpha_prod_t = self.alphas_cumprod[timestep]
194
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
195
+ beta_prod_t = 1 - alpha_prod_t
196
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
197
+
198
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
199
+
200
+ return variance
201
+
202
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
203
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
204
+ # Dynamic thresholding in https://arxiv.org/abs/2205.11487
205
+ dynamic_max_val = (
206
+ sample.flatten(1)
207
+ .abs()
208
+ .quantile(self.config.dynamic_thresholding_ratio, dim=1)
209
+ .clamp_min(self.config.sample_max_value)
210
+ .view(-1, *([1] * (sample.ndim - 1)))
211
+ )
212
+ return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
213
+
214
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
215
+ """
216
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
217
+
218
+ Args:
219
+ num_inference_steps (`int`):
220
+ the number of diffusion steps used when generating samples with a pre-trained model.
221
+ """
222
+
223
+ if num_inference_steps > self.config.num_train_timesteps:
224
+ raise ValueError(
225
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
226
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
227
+ f" maximal {self.config.num_train_timesteps} timesteps."
228
+ )
229
+
230
+ self.num_inference_steps = num_inference_steps
231
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
232
+ # creates integer timesteps by multiplying by ratio
233
+ # casting to int to avoid issues when num_inference_step is power of 3
234
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
235
+ self.timesteps = torch.from_numpy(timesteps).to(device)
236
+ self.timesteps += self.config.steps_offset
237
+
238
+ def step(
239
+ self,
240
+ model_output: torch.FloatTensor,
241
+ timestep: int,
242
+ sample: torch.FloatTensor,
243
+ eta: float = 0.0,
244
+ use_clipped_model_output: bool = False,
245
+ generator=None,
246
+ variance_noise: Optional[torch.FloatTensor] = None,
247
+ return_dict: bool = True,
248
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
249
+ """
250
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
251
+ process from the learned model outputs (most often the predicted noise).
252
+
253
+ Args:
254
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
255
+ timestep (`int`): current discrete timestep in the diffusion chain.
256
+ sample (`torch.FloatTensor`):
257
+ current instance of sample being created by diffusion process.
258
+ eta (`float`): weight of noise for added noise in diffusion step.
259
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
260
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
261
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
262
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
263
+ generator: random number generator.
264
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
265
+ can directly provide the noise for the variance itself. This is useful for methods such as
266
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
267
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
268
+
269
+ Returns:
270
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
271
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
272
+ returning a tuple, the first element is the sample tensor.
273
+
274
+ """
275
+ if self.num_inference_steps is None:
276
+ raise ValueError(
277
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
278
+ )
279
+
280
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
281
+ # Ideally, read DDIM paper in-detail understanding
282
+
283
+ # Notation (<variable name> -> <name in paper>
284
+ # - pred_noise_t -> e_theta(x_t, t)
285
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
286
+ # - std_dev_t -> sigma_t
287
+ # - eta -> η
288
+ # - pred_sample_direction -> "direction pointing to x_t"
289
+ # - pred_prev_sample -> "x_t-1"
290
+
291
+ # 1. get previous step value (=t-1)
292
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
293
+
294
+ # 2. compute alphas, betas
295
+ alpha_prod_t = self.alphas_cumprod[timestep]
296
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
297
+
298
+ beta_prod_t = 1 - alpha_prod_t
299
+
300
+ # 3. compute predicted original sample from predicted noise also called
301
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
302
+ if self.config.prediction_type == "epsilon":
303
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
304
+ pred_epsilon = model_output
305
+ elif self.config.prediction_type == "sample":
306
+ pred_original_sample = model_output
307
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
308
+ elif self.config.prediction_type == "v_prediction":
309
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
310
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
311
+ else:
312
+ raise ValueError(
313
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
314
+ " `v_prediction`"
315
+ )
316
+
317
+ # 4. Clip or threshold "predicted x_0"
318
+ if self.config.clip_sample:
319
+ pred_original_sample = pred_original_sample.clamp(
320
+ -self.config.clip_sample_range, self.config.clip_sample_range
321
+ )
322
+
323
+ if self.config.thresholding:
324
+ pred_original_sample = self._threshold_sample(pred_original_sample)
325
+
326
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
327
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
328
+ variance = self._get_variance(timestep, prev_timestep)
329
+ std_dev_t = eta * variance ** (0.5)
330
+
331
+ if use_clipped_model_output:
332
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
333
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
334
+
335
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
336
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
337
+
338
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
339
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
340
+
341
+ if eta > 0:
342
+ if variance_noise is not None and generator is not None:
343
+ raise ValueError(
344
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
345
+ " `variance_noise` stays `None`."
346
+ )
347
+
348
+ if variance_noise is None:
349
+ variance_noise = randn_tensor(
350
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
351
+ )
352
+ variance = std_dev_t * variance_noise
353
+
354
+ prev_sample = prev_sample + variance
355
+
356
+ if not return_dict:
357
+ return (prev_sample,)
358
+
359
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
360
+
361
+ def add_noise(
362
+ self,
363
+ original_samples: torch.FloatTensor,
364
+ noise: torch.FloatTensor,
365
+ timesteps: torch.IntTensor,
366
+ ) -> torch.FloatTensor:
367
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
368
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
369
+ timesteps = timesteps.to(original_samples.device)
370
+
371
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
372
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
373
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
374
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
375
+
376
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
377
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
378
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
379
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
380
+
381
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
382
+ return noisy_samples
383
+
384
+ def get_velocity(
385
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
386
+ ) -> torch.FloatTensor:
387
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
388
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
389
+ timesteps = timesteps.to(sample.device)
390
+
391
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
392
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
393
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
394
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
395
+
396
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
397
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
398
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
399
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
400
+
401
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
402
+ return velocity
403
+
404
+ def __len__(self):
405
+ return self.config.num_train_timesteps