Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| """ | |
| Schedule base class. | |
| """ | |
| from abc import ABC, abstractmethod, abstractproperty | |
| from typing import Tuple, Union | |
| import torch | |
| from ..types import PredictionType | |
| from ..utils import expand_dims | |
| class Schedule(ABC): | |
| """ | |
| Diffusion schedules are uniquely defined by T, A, B: | |
| x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T] | |
| Schedules can be continuous or discrete. | |
| """ | |
| def T(self) -> Union[int, float]: | |
| """ | |
| Maximum timestep inclusive. | |
| Schedule is continuous if float, discrete if int. | |
| """ | |
| def A(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Interpolation coefficient A. | |
| Returns tensor with the same shape as t. | |
| """ | |
| def B(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Interpolation coefficient B. | |
| Returns tensor with the same shape as t. | |
| """ | |
| # ---------------------------------------------------- | |
| def snr(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Signal to noise ratio. | |
| Returns tensor with the same shape as t. | |
| """ | |
| return (self.A(t) ** 2) / (self.B(t) ** 2) | |
| def isnr(self, snr: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Inverse signal to noise ratio. | |
| Returns tensor with the same shape as snr. | |
| Subclass may implement. | |
| """ | |
| raise NotImplementedError | |
| # ---------------------------------------------------- | |
| def is_continuous(self) -> bool: | |
| """ | |
| Whether the schedule is continuous. | |
| """ | |
| return isinstance(self.T, float) | |
| def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Diffusion forward function. | |
| """ | |
| t = expand_dims(t, x_0.ndim) | |
| return self.A(t) * x_0 + self.B(t) * x_T | |
| def convert_from_pred( | |
| self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Convert from prediction. Return predicted x_0 and x_T. | |
| """ | |
| t = expand_dims(t, x_t.ndim) | |
| A_t = self.A(t) | |
| B_t = self.B(t) | |
| if pred_type == PredictionType.x_T: | |
| pred_x_T = pred | |
| pred_x_0 = (x_t - B_t * pred_x_T) / A_t | |
| elif pred_type == PredictionType.x_0: | |
| pred_x_0 = pred | |
| pred_x_T = (x_t - A_t * pred_x_0) / B_t | |
| elif pred_type == PredictionType.v_cos: | |
| pred_x_0 = A_t * x_t - B_t * pred | |
| pred_x_T = A_t * pred + B_t * x_t | |
| elif pred_type == PredictionType.v_lerp: | |
| pred_x_0 = (x_t - B_t * pred) / (A_t + B_t) | |
| pred_x_T = (x_t + A_t * pred) / (A_t + B_t) | |
| else: | |
| raise NotImplementedError | |
| return pred_x_0, pred_x_T | |
| def convert_to_pred( | |
| self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType | |
| ) -> torch.FloatTensor: | |
| """ | |
| Convert to prediction target given x_0 and x_T. | |
| """ | |
| if pred_type == PredictionType.x_T: | |
| return x_T | |
| if pred_type == PredictionType.x_0: | |
| return x_0 | |
| if pred_type == PredictionType.v_cos: | |
| t = expand_dims(t, x_0.ndim) | |
| return self.A(t) * x_T - self.B(t) * x_0 | |
| if pred_type == PredictionType.v_lerp: | |
| return x_T - x_0 | |
| raise NotImplementedError | |