Spaces:
Runtime error
Runtime error
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim | |
import math | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import flax | |
import jax.numpy as jnp | |
from jax import random | |
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config | |
from ..utils import deprecate | |
from .scheduling_utils_flax import ( | |
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, | |
FlaxSchedulerMixin, | |
FlaxSchedulerOutput, | |
broadcast_to_shape_from_left, | |
) | |
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: | |
""" | |
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of | |
(1-beta) over time from t = [0,1]. | |
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up | |
to that part of the diffusion process. | |
Args: | |
num_diffusion_timesteps (`int`): the number of betas to produce. | |
max_beta (`float`): the maximum beta to use; use values lower than 1 to | |
prevent singularities. | |
Returns: | |
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs | |
""" | |
def alpha_bar(time_step): | |
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 | |
betas = [] | |
for i in range(num_diffusion_timesteps): | |
t1 = i / num_diffusion_timesteps | |
t2 = (i + 1) / num_diffusion_timesteps | |
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | |
return jnp.array(betas, dtype=jnp.float32) | |
class DDPMSchedulerState: | |
# setable values | |
timesteps: jnp.ndarray | |
num_inference_steps: Optional[int] = None | |
def create(cls, num_train_timesteps: int): | |
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) | |
class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): | |
state: DDPMSchedulerState | |
class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): | |
""" | |
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and | |
Langevin dynamics sampling. | |
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | |
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | |
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and | |
[`~SchedulerMixin.from_pretrained`] functions. | |
For more details, see the original paper: https://arxiv.org/abs/2006.11239 | |
Args: | |
num_train_timesteps (`int`): number of diffusion steps used to train the model. | |
beta_start (`float`): the starting `beta` value of inference. | |
beta_end (`float`): the final `beta` value. | |
beta_schedule (`str`): | |
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | |
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. | |
trained_betas (`np.ndarray`, optional): | |
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | |
variance_type (`str`): | |
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, | |
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. | |
clip_sample (`bool`, default `True`): | |
option to clip predicted sample between -1 and 1 for numerical stability. | |
prediction_type (`str`, default `epsilon`): | |
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. | |
`v-prediction` is not supported for this scheduler. | |
""" | |
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() | |
_deprecated_kwargs = ["predict_epsilon"] | |
def has_state(self): | |
return True | |
def __init__( | |
self, | |
num_train_timesteps: int = 1000, | |
beta_start: float = 0.0001, | |
beta_end: float = 0.02, | |
beta_schedule: str = "linear", | |
trained_betas: Optional[jnp.ndarray] = None, | |
variance_type: str = "fixed_small", | |
clip_sample: bool = True, | |
prediction_type: str = "epsilon", | |
**kwargs, | |
): | |
message = ( | |
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" | |
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." | |
) | |
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs) | |
if predict_epsilon is not None: | |
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") | |
if trained_betas is not None: | |
self.betas = jnp.asarray(trained_betas) | |
elif beta_schedule == "linear": | |
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) | |
elif beta_schedule == "scaled_linear": | |
# this schedule is very specific to the latent diffusion model. | |
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 | |
elif beta_schedule == "squaredcos_cap_v2": | |
# Glide cosine schedule | |
self.betas = betas_for_alpha_bar(num_train_timesteps) | |
else: | |
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) | |
self.one = jnp.array(1.0) | |
def create_state(self): | |
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) | |
def set_timesteps( | |
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () | |
) -> DDPMSchedulerState: | |
""" | |
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | |
Args: | |
state (`DDIMSchedulerState`): | |
the `FlaxDDPMScheduler` state data class instance. | |
num_inference_steps (`int`): | |
the number of diffusion steps used when generating samples with a pre-trained model. | |
""" | |
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) | |
timesteps = jnp.arange( | |
0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps | |
)[::-1] | |
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) | |
def _get_variance(self, t, predicted_variance=None, variance_type=None): | |
alpha_prod_t = self.alphas_cumprod[t] | |
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one | |
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) | |
# and sample from it to get previous sample | |
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample | |
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] | |
if variance_type is None: | |
variance_type = self.config.variance_type | |
# hacks - were probably added for training stability | |
if variance_type == "fixed_small": | |
variance = jnp.clip(variance, a_min=1e-20) | |
# for rl-diffuser https://arxiv.org/abs/2205.09991 | |
elif variance_type == "fixed_small_log": | |
variance = jnp.log(jnp.clip(variance, a_min=1e-20)) | |
elif variance_type == "fixed_large": | |
variance = self.betas[t] | |
elif variance_type == "fixed_large_log": | |
# Glide max_log | |
variance = jnp.log(self.betas[t]) | |
elif variance_type == "learned": | |
return predicted_variance | |
elif variance_type == "learned_range": | |
min_log = variance | |
max_log = self.betas[t] | |
frac = (predicted_variance + 1) / 2 | |
variance = frac * max_log + (1 - frac) * min_log | |
return variance | |
def step( | |
self, | |
state: DDPMSchedulerState, | |
model_output: jnp.ndarray, | |
timestep: int, | |
sample: jnp.ndarray, | |
key: random.KeyArray, | |
return_dict: bool = True, | |
**kwargs, | |
) -> Union[FlaxDDPMSchedulerOutput, Tuple]: | |
""" | |
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. | |
model_output (`jnp.ndarray`): direct output from learned diffusion model. | |
timestep (`int`): current discrete timestep in the diffusion chain. | |
sample (`jnp.ndarray`): | |
current instance of sample being created by diffusion process. | |
key (`random.KeyArray`): a PRNG key. | |
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class | |
Returns: | |
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a | |
`tuple`. When returning a tuple, the first element is the sample tensor. | |
""" | |
message = ( | |
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" | |
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." | |
) | |
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs) | |
if predict_epsilon is not None: | |
new_config = dict(self.config) | |
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" | |
self._internal_dict = FrozenDict(new_config) | |
t = timestep | |
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: | |
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) | |
else: | |
predicted_variance = None | |
# 1. compute alphas, betas | |
alpha_prod_t = self.alphas_cumprod[t] | |
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
# 2. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf | |
if self.config.prediction_type == "epsilon": | |
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
elif self.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " | |
" for the FlaxDDPMScheduler." | |
) | |
# 3. Clip "predicted x_0" | |
if self.config.clip_sample: | |
pred_original_sample = jnp.clip(pred_original_sample, -1, 1) | |
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t | |
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t | |
# 5. Compute predicted previous sample µ_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample | |
# 6. Add noise | |
variance = 0 | |
if t > 0: | |
key = random.split(key, num=1) | |
noise = random.normal(key=key, shape=model_output.shape) | |
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise | |
pred_prev_sample = pred_prev_sample + variance | |
if not return_dict: | |
return (pred_prev_sample, state) | |
return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) | |
def add_noise( | |
self, | |
original_samples: jnp.ndarray, | |
noise: jnp.ndarray, | |
timesteps: jnp.ndarray, | |
) -> jnp.ndarray: | |
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) | |
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) | |
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
return noisy_samples | |
def __len__(self): | |
return self.config.num_train_timesteps | |