Spaces:
Paused
Paused
| from typing import Union | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| import torch | |
| from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| from diffusers.utils import BaseOutput | |
| class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): | |
| """ | |
| Output class for the scheduler's `step` function output. | |
| Args: | |
| prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | |
| Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the | |
| denoising loop. | |
| """ | |
| prev_sample: torch.FloatTensor | |
| class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.init_noise_sigma = 1.0 | |
| self.timestep_type = "linear" | |
| with torch.no_grad(): | |
| # create weights for timesteps | |
| num_timesteps = 1000 | |
| # Create linear timesteps from 1000 to 1 | |
| timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu") | |
| self.linear_timesteps = timesteps | |
| pass | |
| def get_weights_for_timesteps( | |
| self, timesteps: torch.Tensor, v2=False, timestep_type="linear" | |
| ) -> torch.Tensor: | |
| # Get the indices of the timesteps | |
| step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] | |
| weights = 1.0 | |
| # Get the weights for the timesteps | |
| if timestep_type == "weighted": | |
| weights = torch.tensor( | |
| [default_weighing_scheme[i] for i in step_indices], | |
| device=timesteps.device, | |
| dtype=timesteps.dtype, | |
| ) | |
| return weights | |
| def add_noise( | |
| self, | |
| original_samples: torch.Tensor, | |
| noise: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| ) -> torch.Tensor: | |
| t_01 = (timesteps / 1000).to(original_samples.device) | |
| noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise | |
| return noisy_model_input | |
| def scale_model_input( | |
| self, sample: torch.Tensor, timestep: Union[float, torch.Tensor] | |
| ) -> torch.Tensor: | |
| return sample | |
| def set_train_timesteps(self, num_timesteps, device, **kwargs): | |
| timesteps = torch.linspace(1000, 1, num_timesteps, device=device) | |
| self.timesteps = timesteps | |
| return timesteps | |
| def step( | |
| self, | |
| model_output: torch.FloatTensor, | |
| timestep: Union[float, torch.FloatTensor], | |
| sample: torch.FloatTensor, | |
| return_dict: bool = True, | |
| **kwargs: Optional[dict], | |
| ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: | |
| # single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ) | |
| output = sample - model_output | |
| if not return_dict: | |
| return (output,) | |
| return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output) | |