Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from ..configuration_utils import ConfigMixin, register_to_config | |
from ..utils import BaseOutput, logging | |
from ..utils.torch_utils import randn_tensor | |
from .scheduling_utils import SchedulerMixin | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput): | |
""" | |
Output class for the scheduler's `step` function output. | |
Args: | |
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | |
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the | |
denoising loop. | |
""" | |
prev_sample: torch.FloatTensor | |
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): | |
""" | |
Heun scheduler. | |
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic | |
methods the library implements for all schedulers such as loading and saving. | |
Args: | |
num_train_timesteps (`int`, defaults to 1000): | |
The number of diffusion steps to train the model. | |
timestep_spacing (`str`, defaults to `"linspace"`): | |
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and | |
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. | |
shift (`float`, defaults to 1.0): | |
The shift value for the timestep schedule. | |
""" | |
_compatibles = [] | |
order = 2 | |
def __init__( | |
self, | |
num_train_timesteps: int = 1000, | |
shift: float = 1.0, | |
): | |
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() | |
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) | |
sigmas = timesteps / num_train_timesteps | |
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) | |
self.timesteps = sigmas * num_train_timesteps | |
self._step_index = None | |
self._begin_index = None | |
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication | |
self.sigma_min = self.sigmas[-1].item() | |
self.sigma_max = self.sigmas[0].item() | |
def step_index(self): | |
""" | |
The index counter for current timestep. It will increase 1 after each scheduler step. | |
""" | |
return self._step_index | |
def begin_index(self): | |
""" | |
The index for the first timestep. It should be set from pipeline with `set_begin_index` method. | |
""" | |
return self._begin_index | |
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index | |
def set_begin_index(self, begin_index: int = 0): | |
""" | |
Sets the begin index for the scheduler. This function should be run from pipeline before the inference. | |
Args: | |
begin_index (`int`): | |
The begin index for the scheduler. | |
""" | |
self._begin_index = begin_index | |
def scale_noise( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
noise: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
""" | |
Forward process in flow-matching | |
Args: | |
sample (`torch.FloatTensor`): | |
The input sample. | |
timestep (`int`, *optional*): | |
The current timestep in the diffusion chain. | |
Returns: | |
`torch.FloatTensor`: | |
A scaled input sample. | |
""" | |
if self.step_index is None: | |
self._init_step_index(timestep) | |
sigma = self.sigmas[self.step_index] | |
sample = sigma * noise + (1.0 - sigma) * sample | |
return sample | |
def _sigma_to_t(self, sigma): | |
return sigma * self.config.num_train_timesteps | |
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | |
""" | |
Sets the discrete timesteps used for the diffusion chain (to be run before inference). | |
Args: | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. | |
device (`str` or `torch.device`, *optional*): | |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
""" | |
self.num_inference_steps = num_inference_steps | |
timesteps = np.linspace( | |
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps | |
) | |
sigmas = timesteps / self.config.num_train_timesteps | |
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) | |
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) | |
timesteps = sigmas * self.config.num_train_timesteps | |
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) | |
self.timesteps = timesteps.to(device=device) | |
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) | |
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) | |
# empty dt and derivative | |
self.prev_derivative = None | |
self.dt = None | |
self._step_index = None | |
self._begin_index = None | |
def index_for_timestep(self, timestep, schedule_timesteps=None): | |
if schedule_timesteps is None: | |
schedule_timesteps = self.timesteps | |
indices = (schedule_timesteps == timestep).nonzero() | |
# The sigma index that is taken for the **very** first `step` | |
# is always the second index (or the last index if there is only 1) | |
# This way we can ensure we don't accidentally skip a sigma in | |
# case we start in the middle of the denoising schedule (e.g. for image-to-image) | |
pos = 1 if len(indices) > 1 else 0 | |
return indices[pos].item() | |
def _init_step_index(self, timestep): | |
if self.begin_index is None: | |
if isinstance(timestep, torch.Tensor): | |
timestep = timestep.to(self.timesteps.device) | |
self._step_index = self.index_for_timestep(timestep) | |
else: | |
self._step_index = self._begin_index | |
def state_in_first_order(self): | |
return self.dt is None | |
def step( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
sample: torch.FloatTensor, | |
s_churn: float = 0.0, | |
s_tmin: float = 0.0, | |
s_tmax: float = float("inf"), | |
s_noise: float = 1.0, | |
generator: Optional[torch.Generator] = None, | |
return_dict: bool = True, | |
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
A current instance of a sample created by the diffusion process. | |
s_churn (`float`): | |
s_tmin (`float`): | |
s_tmax (`float`): | |
s_noise (`float`, defaults to 1.0): | |
Scaling factor for noise added to the sample. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
return_dict (`bool`): | |
Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or | |
tuple. | |
Returns: | |
[`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is | |
returned, otherwise a tuple is returned where the first element is the sample tensor. | |
""" | |
if ( | |
isinstance(timestep, int) | |
or isinstance(timestep, torch.IntTensor) | |
or isinstance(timestep, torch.LongTensor) | |
): | |
raise ValueError( | |
( | |
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
" `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" | |
" one of the `scheduler.timesteps` as a timestep." | |
), | |
) | |
if self.step_index is None: | |
self._init_step_index(timestep) | |
# Upcast to avoid precision issues when computing prev_sample | |
sample = sample.to(torch.float32) | |
if self.state_in_first_order: | |
sigma = self.sigmas[self.step_index] | |
sigma_next = self.sigmas[self.step_index + 1] | |
else: | |
# 2nd order / Heun's method | |
sigma = self.sigmas[self.step_index - 1] | |
sigma_next = self.sigmas[self.step_index] | |
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 | |
noise = randn_tensor( | |
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator | |
) | |
eps = noise * s_noise | |
sigma_hat = sigma * (gamma + 1) | |
if gamma > 0: | |
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 | |
if self.state_in_first_order: | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
denoised = sample - model_output * sigma | |
# 2. convert to an ODE derivative for 1st order | |
derivative = (sample - denoised) / sigma_hat | |
# 3. Delta timestep | |
dt = sigma_next - sigma_hat | |
# store for 2nd order step | |
self.prev_derivative = derivative | |
self.dt = dt | |
self.sample = sample | |
else: | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
denoised = sample - model_output * sigma_next | |
# 2. 2nd order / Heun's method | |
derivative = (sample - denoised) / sigma_next | |
derivative = 0.5 * (self.prev_derivative + derivative) | |
# 3. take prev timestep & sample | |
dt = self.dt | |
sample = self.sample | |
# free dt and derivative | |
# Note, this puts the scheduler in "first order mode" | |
self.prev_derivative = None | |
self.dt = None | |
self.sample = None | |
prev_sample = sample + derivative * dt | |
# Cast sample back to model compatible dtype | |
prev_sample = prev_sample.to(model_output.dtype) | |
# upon completion increase step index by one | |
self._step_index += 1 | |
if not return_dict: | |
return (prev_sample,) | |
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample) | |
def __len__(self): | |
return self.config.num_train_timesteps | |