Spaces:
Runtime error
Runtime error
File size: 5,976 Bytes
7cc4b41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from dataclasses import dataclass
from abc import ABC
from typing import Optional, Union, List
@dataclass
class SchedulerConversionOutput:
pred_epsilon: torch.Tensor
pred_original_sample: torch.Tensor
pred_velocity: torch.Tensor
@dataclass
class SchedulerStepOutput:
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
class Scheduler(ABC):
prediction_types = ["epsilon", "sample", "v_prediction"]
timesteps_types = ["leading", "linspace", "trailing"]
def __init__(
self,
num_train_timesteps: int,
num_inference_timesteps: int,
betas: torch.Tensor,
inference_timesteps: Union[str, List[int]] = "trailing",
set_alpha_to_one: bool = True,
device: Optional[Union[str, torch.device]] = None,
dtype: torch.dtype = torch.float32
):
assert num_train_timesteps > 0
assert num_train_timesteps >= num_inference_timesteps
assert num_train_timesteps == betas.size(0)
assert betas.ndim == 1
self.device = device or betas.device
self.dtype = dtype
self.num_train_timesteps = num_train_timesteps
self.num_inference_timesteps = num_inference_timesteps
self.betas = betas.to(device=device, dtype=dtype)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.final_alpha_cumprod = torch.tensor(1.0, device=self.device, dtype=self.dtype) if set_alpha_to_one else self.alphas_cumprod[0]
if isinstance(inference_timesteps, list):
# If user defines a custom inference timestep, directly assign it.
assert len(inference_timesteps) == num_inference_timesteps
self.timesteps = torch.tensor(inference_timesteps, device=self.device, dtype=torch.int)
elif inference_timesteps == "trailing":
# Example 20 steps: [999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 499, 449, 399, 349, 299, 249, 199, 149, 99, 49]
self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=self.device).round().int()
elif inference_timesteps == "linspace":
# Example 20 steps: [999, 946, 894, 841, 789, 736, 684, 631, 578, 526, 473, 421, 368, 315, 263, 210, 158, 105, 53, 0]
self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=self.device).round().int().flip(0)
elif inference_timesteps == "leading":
# Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
# The inference timestep does not start from 999.
# Example 20 steps: [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0]
self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
else:
raise NotImplementedError
def reset(self):
pass
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: Union[torch.Tensor, int],
) -> torch.Tensor:
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1)))
return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise
def convert_output(
self,
model_output: torch.Tensor,
model_output_type: str,
sample: torch.Tensor,
timesteps: Union[torch.Tensor, int]
) -> SchedulerConversionOutput:
assert model_output_type in self.prediction_types
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
beta_prod_t = 1 - alpha_prod_t
if model_output_type == "epsilon":
pred_epsilon = model_output
pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5)
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
elif model_output_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
elif model_output_type == "v_prediction":
pred_velocity = model_output
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError("Unknown prediction type")
return SchedulerConversionOutput(
pred_epsilon=pred_epsilon,
pred_original_sample=pred_original_sample,
pred_velocity=pred_velocity)
def get_velocity(
self,
sample: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor
) -> torch.FloatTensor:
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample
|