Georgiy Grigorev commited on
Commit
7cb6d07
·
1 Parent(s): 763fd8a

Create scheduling_ddim.py

Browse files
Files changed (1) hide show
  1. scheduling_ddim.py +404 -0
scheduling_ddim.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted and updated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py
2
+
3
+ # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
18
+ # and https://github.com/hojonathanho/diffusion
19
+
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+ import matplotlib.pyplot as plt
27
+ from tqdm.auto import tqdm
28
+ from PIL import Image
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.utils import BaseOutput, deprecate
32
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
33
+
34
+
35
+ @dataclass
36
+ class DDIMSchedulerOutput(BaseOutput):
37
+ """
38
+ Output class for the scheduler's step function output.
39
+
40
+ Args:
41
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
43
+ denoising loop.
44
+ next_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
45
+ Computed sample (x_{t+1}) of previous timestep. `next_sample` should be used as next model input in the
46
+ reverse denoising loop.
47
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
48
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
49
+ `pred_original_sample` can be used to preview progress or for guidance.
50
+ """
51
+
52
+ prev_sample: Optional[torch.FloatTensor] = None
53
+ next_sample: Optional[torch.FloatTensor] = None
54
+ pred_original_sample: Optional[torch.FloatTensor] = None
55
+
56
+
57
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
58
+ """
59
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
60
+ (1-beta) over time from t = [0,1].
61
+
62
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
63
+ to that part of the diffusion process.
64
+
65
+
66
+ Args:
67
+ num_diffusion_timesteps (`int`): the number of betas to produce.
68
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
69
+ prevent singularities.
70
+
71
+ Returns:
72
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
73
+ """
74
+
75
+ def alpha_bar(time_step):
76
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
77
+
78
+ betas = []
79
+ for i in range(num_diffusion_timesteps):
80
+ t1 = i / num_diffusion_timesteps
81
+ t2 = (i + 1) / num_diffusion_timesteps
82
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
83
+ return torch.tensor(betas)
84
+
85
+
86
+ class DDIMScheduler(SchedulerMixin, ConfigMixin):
87
+ """
88
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
89
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
90
+
91
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
92
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
93
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
94
+ [`~ConfigMixin.from_config`] functions.
95
+
96
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
97
+
98
+ Args:
99
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
100
+ beta_start (`float`): the starting `beta` value of inference.
101
+ beta_end (`float`): the final `beta` value.
102
+ beta_schedule (`str`):
103
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
104
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
105
+ trained_betas (`np.ndarray`, optional):
106
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
107
+ clip_sample (`bool`, default `True`):
108
+ option to clip predicted sample between -1 and 1 for numerical stability.
109
+ set_alpha_to_one (`bool`, default `True`):
110
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
111
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
112
+ otherwise it uses the value of alpha at step 0.
113
+ steps_offset (`int`, default `0`):
114
+ an offset added to the inference steps. You can use a combination of `offset=1` and
115
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
116
+ stable diffusion.
117
+
118
+ """
119
+
120
+ @register_to_config
121
+ def __init__(
122
+ self,
123
+ num_train_timesteps: int = 1000,
124
+ beta_start: float = 0.0001,
125
+ beta_end: float = 0.02,
126
+ beta_schedule: str = "linear",
127
+ trained_betas: Optional[np.ndarray] = None,
128
+ clip_sample: bool = True,
129
+ set_alpha_to_one: bool = True,
130
+ steps_offset: int = 0,
131
+ ):
132
+ if trained_betas is not None:
133
+ self.betas = torch.from_numpy(trained_betas)
134
+ elif beta_schedule == "linear":
135
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
136
+ elif beta_schedule == "scaled_linear":
137
+ # this schedule is very specific to the latent diffusion model.
138
+ self.betas = (
139
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
140
+ )
141
+ elif beta_schedule == "squaredcos_cap_v2":
142
+ # Glide cosine schedule
143
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
144
+ else:
145
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
146
+
147
+ self.alphas = 1.0 - self.betas
148
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
149
+
150
+ # At every step in ddim, we are looking into the previous alphas_cumprod
151
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
152
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
153
+ # whether we use the final alpha of the "non-previous" one.
154
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
155
+
156
+ # standard deviation of the initial noise distribution
157
+ self.init_noise_sigma = 1.0
158
+
159
+ # setable values
160
+ self.num_inference_steps = None
161
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
162
+
163
+ def _get_variance(self, timestep, prev_timestep):
164
+ alpha_prod_t = self.alphas_cumprod[timestep]
165
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
166
+ beta_prod_t = 1 - alpha_prod_t
167
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
168
+
169
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
170
+
171
+ return variance
172
+
173
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
174
+ """
175
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
176
+ Args:
177
+ num_inference_steps (`int`):
178
+ the number of diffusion steps used when generating samples with a pre-trained model.
179
+ """
180
+ self.num_inference_steps = num_inference_steps
181
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
182
+ # creates integer timesteps by multiplying by ratio
183
+ # casting to int to avoid issues when num_inference_step is power of 3
184
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
185
+ self.timesteps = torch.from_numpy(timesteps).to(device)
186
+ self.timesteps += self.config.steps_offset
187
+
188
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
189
+ """
190
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
191
+ current timestep.
192
+ Args:
193
+ sample (`torch.FloatTensor`): input sample
194
+ timestep (`int`, optional): current timestep
195
+ Returns:
196
+ `torch.FloatTensor`: scaled input sample
197
+ """
198
+ return sample
199
+
200
+ def step(
201
+ self,
202
+ model_output: torch.FloatTensor,
203
+ timestep: int,
204
+ sample: torch.FloatTensor,
205
+ eta: float = 0.0,
206
+ use_clipped_model_output: bool = False,
207
+ generator=None,
208
+ return_dict: bool = True,
209
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
210
+ """
211
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
212
+ process from the learned model outputs (most often the predicted noise).
213
+
214
+ Args:
215
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
216
+ timestep (`int`): current discrete timestep in the diffusion chain.
217
+ sample (`torch.FloatTensor`):
218
+ current instance of sample being created by diffusion process.
219
+ eta (`float`): weight of noise for added noise in diffusion step.
220
+ use_clipped_model_output (`bool`): TODO
221
+ generator: random number generator.
222
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
223
+
224
+ Returns:
225
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
226
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
227
+ returning a tuple, the first element is the sample tensor.
228
+
229
+ """
230
+ if self.num_inference_steps is None:
231
+ raise ValueError(
232
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
233
+ )
234
+
235
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
236
+ # Ideally, read DDIM paper in-detail understanding
237
+
238
+ # Notation ( ->
239
+ # - pred_noise_t -> e_theta(x_t, t)
240
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
241
+ # - std_dev_t -> sigma_t
242
+ # - eta -> η
243
+ # - pred_sample_direction -> "direction pointing to x_t"
244
+ # - pred_prev_sample -> "x_t-1"
245
+
246
+ # 1. get previous step value (=t-1)
247
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
248
+
249
+ # 2. compute alphas, betas
250
+ alpha_prod_t = self.alphas_cumprod[timestep]
251
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
252
+
253
+ beta_prod_t = 1 - alpha_prod_t
254
+
255
+ # 3. compute predicted original sample from predicted noise also called
256
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
257
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
258
+
259
+ # 4. Clip "predicted x_0"
260
+ if self.config.clip_sample:
261
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
262
+
263
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
264
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
265
+ variance = self._get_variance(timestep, prev_timestep)
266
+ std_dev_t = eta * variance ** (0.5)
267
+
268
+ if use_clipped_model_output:
269
+ # the model_output is always re-derived from the clipped x_0 in Glide
270
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
271
+
272
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
273
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
274
+
275
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
276
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
277
+
278
+ if eta > 0:
279
+ device = model_output.device if torch.is_tensor(model_output) else "cpu"
280
+ noise = torch.randn(model_output.shape, generator=generator).to(device)
281
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
282
+
283
+ prev_sample = prev_sample + variance
284
+
285
+ if not return_dict:
286
+ return (prev_sample,)
287
+
288
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
289
+
290
+ def reverse_step(
291
+ self,
292
+ model_output: torch.FloatTensor,
293
+ timestep: int,
294
+ sample: torch.FloatTensor,
295
+ eta: float = 0.0,
296
+ use_clipped_model_output: bool = False,
297
+ generator=None,
298
+ return_dict: bool = True,
299
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
300
+ """
301
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
302
+ process from the learned model outputs (most often the predicted noise).
303
+
304
+ Args:
305
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
306
+ timestep (`int`): current discrete timestep in the diffusion chain.
307
+ sample (`torch.FloatTensor`):
308
+ current instance of sample being created by diffusion process.
309
+ eta (`float`): weight of noise for added noise in diffusion step.
310
+ use_clipped_model_output (`bool`): TODO
311
+ generator: random number generator.
312
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
313
+
314
+ Returns:
315
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
316
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
317
+ returning a tuple, the first element is the sample tensor.
318
+
319
+ """
320
+ if self.num_inference_steps is None:
321
+ raise ValueError(
322
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
323
+ )
324
+
325
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
326
+ # Ideally, read DDIM paper in-detail understanding
327
+
328
+ # Notation ( ->
329
+ # - pred_noise_t -> e_theta(x_t, t)
330
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
331
+ # - std_dev_t -> sigma_t
332
+ # - eta -> η
333
+ # - pred_sample_direction -> "direction pointing to x_t"
334
+ # - pred_prev_sample -> "x_t-1"
335
+
336
+ # 1. get previous step value (=t-1)
337
+ next_timestep = min(self.config.num_train_timesteps - 2,
338
+ timestep + self.config.num_train_timesteps // self.num_inference_steps)
339
+
340
+ # 2. compute alphas, betas
341
+ alpha_prod_t = self.alphas_cumprod[timestep]
342
+ alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep >= 0 else self.final_alpha_cumprod
343
+
344
+ beta_prod_t = 1 - alpha_prod_t
345
+
346
+ # 3. compute predicted original sample from predicted noise also called
347
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
348
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
349
+
350
+ # 4. Clip "predicted x_0"
351
+ if self.config.clip_sample:
352
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
353
+
354
+ # 5. TODO: simple noising implementatiom
355
+ next_sample = self.add_noise(pred_original_sample,
356
+ model_output,
357
+ torch.LongTensor([next_timestep]))
358
+
359
+ # # 5. compute variance: "sigma_t(η)" -> see formula (16)
360
+ # # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
361
+ # variance = self._get_variance(next_timestep, timestep)
362
+ # std_dev_t = eta * variance ** (0.5)
363
+
364
+ # if use_clipped_model_output:
365
+ # # the model_output is always re-derived from the clipped x_0 in Glide
366
+ # model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
367
+
368
+ # # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
369
+ # pred_sample_direction = (1 - alpha_prod_t_next - std_dev_t**2) ** (0.5) * model_output
370
+
371
+ # # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
372
+ # next_sample = alpha_prod_t_next ** (0.5) * pred_original_sample + pred_sample_direction
373
+
374
+ if not return_dict:
375
+ return (next_sample,)
376
+
377
+ return DDIMSchedulerOutput(next_sample=next_sample, pred_original_sample=pred_original_sample)
378
+
379
+ def add_noise(
380
+ self,
381
+ original_samples: torch.FloatTensor,
382
+ noise: torch.FloatTensor,
383
+ timesteps: torch.IntTensor,
384
+ ) -> torch.FloatTensor:
385
+ if self.alphas_cumprod.device != original_samples.device:
386
+ self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
387
+ if timesteps.device != original_samples.device:
388
+ timesteps = timesteps.to(original_samples.device)
389
+
390
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
391
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
392
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
393
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
394
+
395
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
396
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
397
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
398
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
399
+
400
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
401
+ return noisy_samples
402
+
403
+ def __len__(self):
404
+ return self.config.num_train_timesteps