Diffusers
Safetensors
File size: 3,948 Bytes
fb4d378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03482fb
 
fb4d378
 
03482fb
 
 
fb4d378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03482fb
fb4d378
 
 
 
 
03482fb
fb4d378
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from diffusers.configuration_utils import (
    ConfigMixin,
    register_to_config,
)
from diffusers.schedulers.scheduling_utils import (
    SchedulerMixin,
    SchedulerOutput,
)
from einops import rearrange


@dataclass
class AnnealedLangevinDynamicsOutput(SchedulerOutput):
    """Annealed Langevin Dynamics output class."""


class AnnealedLangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):  # type: ignore
    """Annealed Langevin Dynamics scheduler for Noise Conditional Score Network (NCSN)."""

    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int,
        num_annealed_steps: int,
        sigma_min: float,
        sigma_max: float,
        sampling_eps: float,
    ) -> None:
        self.num_train_timesteps = num_train_timesteps
        self.num_annealed_steps = num_annealed_steps

        self._sigma_min = sigma_min
        self._sigma_max = sigma_max
        self._sampling_eps = sampling_eps

        self._sigmas: Optional[torch.Tensor] = None
        self._step_size: Optional[torch.Tensor] = None
        self._timesteps: Optional[torch.Tensor] = None

        self.set_sigmas(num_inference_steps=num_train_timesteps)

    @property
    def sigmas(self) -> torch.Tensor:
        assert self._sigmas is not None
        return self._sigmas

    @property
    def step_size(self) -> torch.Tensor:
        assert self._step_size is not None
        return self._step_size

    @property
    def timesteps(self) -> torch.Tensor:
        assert self._timesteps is not None
        return self._timesteps

    def scale_model_input(
        self, sample: torch.Tensor, timestep: Optional[int] = None
    ) -> torch.Tensor:
        return sample

    def set_timesteps(
        self,
        num_inference_steps: int,
        sampling_eps: Optional[float] = None,
        device: Optional[Union[str, torch.device]] = None,
    ) -> None:
        sampling_eps = sampling_eps or self._sampling_eps
        self._timesteps = torch.arange(start=0, end=num_inference_steps)

    def set_sigmas(
        self,
        num_inference_steps: int,
        sigma_min: Optional[float] = None,
        sigma_max: Optional[float] = None,
        sampling_eps: Optional[float] = None,
    ) -> None:
        if self._timesteps is None:
            self.set_timesteps(
                num_inference_steps=num_inference_steps,
                sampling_eps=sampling_eps,
            )

        sigma_min = sigma_min or self._sigma_min
        sigma_max = sigma_max or self._sigma_max
        self._sigmas = torch.exp(
            torch.linspace(
                start=math.log(sigma_max),
                end=math.log(sigma_min),
                steps=num_inference_steps,
            )
        )

        sampling_eps = sampling_eps or self._sampling_eps
        self._step_size = sampling_eps * (self.sigmas / self.sigmas[-1]) ** 2

    def step(
        self,
        model_output: torch.Tensor,
        timestep: int,
        samples: torch.Tensor,
        return_dict: bool = True,
        **kwargs,
    ) -> Union[AnnealedLangevinDynamicsOutput, Tuple]:
        z = torch.randn_like(samples)
        step_size = self.step_size[timestep]
        samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z

        if return_dict:
            return AnnealedLangevinDynamicsOutput(prev_sample=samples)
        else:
            return (samples,)

    def add_noise(
        self,
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
        timesteps = timesteps.to(original_samples.device)
        sigmas = self.sigmas.to(original_samples.device)[timesteps]
        sigmas = rearrange(sigmas, "b -> b 1 1 1")
        noisy_samples = original_samples + noise * sigmas
        return noisy_samples