kd5678 commited on
Commit
95cf596
1 Parent(s): cfa5d1b

Delete ddim_with_prob.py

Browse files
Files changed (1) hide show
  1. ddim_with_prob.py +0 -397
ddim_with_prob.py DELETED
@@ -1,397 +0,0 @@
1
- # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
- # and https://github.com/hojonathanho/diffusion
17
-
18
- import math
19
- from dataclasses import dataclass
20
- from typing import List, Optional, Tuple, Union
21
- import numpy as np
22
- import torch
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.utils import BaseOutput
25
- from diffusers.utils.torch_utils import randn_tensor
26
- from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27
-
28
-
29
- @dataclass
30
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
31
- class DDIMSchedulerOutput(BaseOutput):
32
- """
33
- Output class for the scheduler's step function output.
34
-
35
- Args:
36
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
- Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
38
- denoising loop.
39
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
- The predicted denoised sample (x_{0}) based on the model output from the current timestep.
41
- `pred_original_sample` can be used to preview progress or for guidance.
42
- """
43
-
44
- prev_sample: torch.FloatTensor
45
- pred_original_sample: Optional[torch.FloatTensor] = None
46
- log_prob: Optional[torch.FloatTensor] = None
47
-
48
-
49
-
50
- def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
51
- """
52
- Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
53
- (1-beta) over time from t = [0,1].
54
-
55
- Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
56
- to that part of the diffusion process.
57
-
58
-
59
- Args:
60
- num_diffusion_timesteps (`int`): the number of betas to produce.
61
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
62
- prevent singularities.
63
-
64
- Returns:
65
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
66
- """
67
-
68
- def alpha_bar(time_step):
69
- return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
70
-
71
- betas = []
72
- for i in range(num_diffusion_timesteps):
73
- t1 = i / num_diffusion_timesteps
74
- t2 = (i + 1) / num_diffusion_timesteps
75
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
76
- return torch.tensor(betas)
77
-
78
-
79
- class DDIMSchedulerCustom(SchedulerMixin, ConfigMixin):
80
- """
81
- Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
82
- diffusion probabilistic models (DDPMs) with non-Markovian guidance.
83
-
84
- [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
85
- function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
86
- [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
87
- [`~SchedulerMixin.from_pretrained`] functions.
88
-
89
- For more details, see the original paper: https://arxiv.org/abs/2010.02502
90
-
91
- Args:
92
- num_train_timesteps (`int`): number of diffusion steps used to train the model.
93
- beta_start (`float`): the starting `beta` value of inference.
94
- beta_end (`float`): the final `beta` value.
95
- beta_schedule (`str`):
96
- the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
97
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
98
- trained_betas (`np.ndarray`, optional):
99
- option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
100
- clip_sample (`bool`, default `True`):
101
- option to clip predicted sample between -1 and 1 for numerical stability.
102
- set_alpha_to_one (`bool`, default `True`):
103
- each diffusion step uses the value of alphas product at that step and at the previous one. For the final
104
- step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
105
- otherwise it uses the value of alpha at step 0.
106
- steps_offset (`int`, default `0`):
107
- an offset added to the inference steps. You can use a combination of `offset=1` and
108
- `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
109
- stable diffusion.
110
- prediction_type (`str`, default `epsilon`, optional):
111
- prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
112
- process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
113
- https://imagen.research.google/video/paper.pdf)
114
- """
115
-
116
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
117
- order = 1
118
-
119
- @register_to_config
120
- def __init__(
121
- self,
122
- num_train_timesteps: int = 1000,
123
- beta_start: float = 0.0001,
124
- beta_end: float = 0.02,
125
- beta_schedule: str = "linear",
126
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
127
- clip_sample: bool = True,
128
- set_alpha_to_one: bool = True,
129
- steps_offset: int = 0,
130
- prediction_type: str = "epsilon",
131
- ):
132
- if trained_betas is not None:
133
- self.betas = torch.tensor(trained_betas, dtype=torch.float32)
134
- elif beta_schedule == "linear":
135
- self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
136
- elif beta_schedule == "scaled_linear":
137
- # this schedule is very specific to the latent diffusion model.
138
- self.betas = (
139
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
140
- )
141
- elif beta_schedule == "squaredcos_cap_v2":
142
- # Glide cosine schedule
143
- self.betas = betas_for_alpha_bar(num_train_timesteps)
144
- else:
145
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
146
-
147
- self.alphas = 1.0 - self.betas
148
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
149
-
150
- # At every step in ddim, we are looking into the previous alphas_cumprod
151
- # For the final step, there is no previous alphas_cumprod because we are already at 0
152
- # `set_alpha_to_one` decides whether we set this parameter simply to one or
153
- # whether we use the final alpha of the "non-previous" one.
154
- self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
155
-
156
- # standard deviation of the initial noise distribution
157
- self.init_noise_sigma = 1.0
158
-
159
- # setable values
160
- self.num_inference_steps = None
161
- self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
162
-
163
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
164
- """
165
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
166
- current timestep.
167
-
168
- Args:
169
- sample (`torch.FloatTensor`): input sample
170
- timestep (`int`, optional): current timestep
171
-
172
- Returns:
173
- `torch.FloatTensor`: scaled input sample
174
- """
175
- return sample
176
-
177
- def _get_variance(self, timestep, prev_timestep):
178
- alpha_prod_t = self.alphas_cumprod[timestep]
179
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
180
- beta_prod_t = 1 - alpha_prod_t
181
- beta_prod_t_prev = 1 - alpha_prod_t_prev
182
-
183
- variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
184
-
185
- return variance
186
-
187
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
188
- """
189
- Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
190
-
191
- Args:
192
- num_inference_steps (`int`):
193
- the number of diffusion steps used when generating samples with a pre-trained model.
194
- """
195
-
196
- if num_inference_steps > self.config.num_train_timesteps:
197
- raise ValueError(
198
- f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
199
- f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
200
- f" maximal {self.config.num_train_timesteps} timesteps."
201
- )
202
-
203
- self.num_inference_steps = num_inference_steps
204
- step_ratio = self.config.num_train_timesteps // self.num_inference_steps
205
- # creates integer timesteps by multiplying by ratio
206
- # casting to int to avoid issues when num_inference_step is power of 3
207
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
208
- self.timesteps = torch.from_numpy(timesteps).to(device)
209
- self.timesteps += self.config.steps_offset
210
-
211
- def step(
212
- self,
213
- model_output: torch.FloatTensor,
214
- timestep: int,
215
- sample: torch.FloatTensor,
216
- eta: float = 0.0,
217
- use_clipped_model_output: bool = False,
218
- generator=None,
219
- variance_noise: Optional[torch.FloatTensor] = None,
220
- return_dict: bool = True,
221
- prev_sample: Optional[torch.FloatTensor] = None,
222
- ) -> Union[DDIMSchedulerOutput, Tuple]:
223
- """
224
-
225
- Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
226
- process from the learned model outputs (most often the predicted noise).
227
-
228
- First, the model_output is used to calculate the prev_sample_mean. If
229
- key is not None, some noise is added to produce prev_sample (with
230
- variance depending on eta). If prev_sample is not None, this function
231
- essentially just calculates the log_prob of prev_sample given
232
- prev_sample_mean, and prev_sample is returned unmodified.
233
-
234
-
235
- Args:
236
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
237
- timestep (`int`): current discrete timestep in the diffusion chain.
238
- sample (`torch.FloatTensor`):
239
- current instance of sample being created by diffusion process.
240
- eta (`float`): weight of noise for added noise in diffusion step.
241
- use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
242
- predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
243
- `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
244
- coincide with the one provided as input and `use_clipped_model_output` will have not effect.
245
- generator: random number generator.
246
- variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
247
- can directly provide the noise for the variance itself. This is useful for methods such as
248
- CycleDiffusion. (https://arxiv.org/abs/2210.05559)
249
- return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
250
-
251
- Returns:
252
- [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
253
- [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
254
- returning a tuple, the first element is the sample tensor.
255
-
256
- """
257
- # eta = 1.0
258
- if self.num_inference_steps is None:
259
- raise ValueError(
260
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
261
- )
262
-
263
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
264
- # Ideally, read DDIM paper in-detail understanding
265
-
266
- # Notation (<variable name> -> <name in paper>
267
- # - pred_noise_t -> e_theta(x_t, t)
268
- # - pred_original_sample -> f_theta(x_t, t) or x_0
269
- # - std_dev_t -> sigma_t
270
- # - eta -> η
271
- # - pred_sample_direction -> "direction pointing to x_t"
272
- # - pred_prev_sample -> "x_t-1"
273
-
274
- # 1. get previous step value (=t-1)
275
- prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
276
-
277
-
278
- # 2. compute alphas, betas
279
- alpha_prod_t = self.alphas_cumprod[timestep]
280
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
281
-
282
- beta_prod_t = 1 - alpha_prod_t
283
-
284
- # 3. compute predicted original sample from predicted noise also called
285
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
286
- if self.config.prediction_type == "epsilon":
287
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
288
- elif self.config.prediction_type == "sample":
289
- pred_original_sample = model_output
290
- elif self.config.prediction_type == "v_prediction":
291
- pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
292
- # predict V
293
- model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
294
- else:
295
- raise ValueError(
296
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
297
- " `v_prediction`"
298
- )
299
-
300
- # 4. Clip "predicted x_0"
301
- if self.config.clip_sample:
302
- pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
303
-
304
-
305
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
306
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
307
- variance = self._get_variance(timestep, prev_timestep)
308
- std_dev_t = eta * variance ** (0.5)
309
-
310
-
311
- if use_clipped_model_output:
312
- # the model_output is always re-derived from the clipped x_0 in Glide
313
- model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
314
-
315
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
316
- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
317
-
318
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
319
- prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
320
-
321
-
322
- if prev_sample is None and eta > 0:
323
- device = model_output.device
324
- if variance_noise is not None and generator is not None:
325
- raise ValueError(
326
- "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
327
- " `variance_noise` stays `None`."
328
- )
329
-
330
- if variance_noise is None:
331
- variance_noise = randn_tensor(
332
- model_output.shape, generator=generator, device=device, dtype=model_output.dtype
333
- )
334
-
335
- prev_sample = prev_sample_mean + std_dev_t * variance_noise
336
-
337
- # std_dev_t = torch.clip(std_dev_t, min=1e-6)
338
- log_prob = (
339
- -((prev_sample - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
340
- - math.log(std_dev_t)
341
- - math.log(math.sqrt(2 * math.pi))
342
- )
343
-
344
- log_prob_mean = torch.mean(log_prob, axis=tuple(range(1, log_prob.ndim)))
345
-
346
-
347
-
348
- if not return_dict:
349
- return (prev_sample, pred_original_sample, log_prob, log_prob_mean)
350
-
351
- return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample, log_prob=log_prob_mean)
352
-
353
- def add_noise(
354
- self,
355
- original_samples: torch.FloatTensor,
356
- noise: torch.FloatTensor,
357
- timesteps: torch.IntTensor,
358
- ) -> torch.FloatTensor:
359
- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
360
- self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
361
- timesteps = timesteps.to(original_samples.device)
362
-
363
- sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
364
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
365
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
366
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
367
-
368
- sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
369
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
370
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
371
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
372
-
373
- noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
374
- return noisy_samples
375
-
376
- def get_velocity(
377
- self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
378
- ) -> torch.FloatTensor:
379
- # Make sure alphas_cumprod and timestep have same device and dtype as sample
380
- self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
381
- timesteps = timesteps.to(sample.device)
382
-
383
- sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
384
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
385
- while len(sqrt_alpha_prod.shape) < len(sample.shape):
386
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
387
-
388
- sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
389
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
390
- while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
391
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
392
-
393
- velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
394
- return velocity
395
-
396
- def __len__(self):
397
- return self.config.num_train_timesteps