Upload scheduler/scheduling_ncsn.py with huggingface_hub
Browse files
scheduler/scheduling_ncsn.py
CHANGED
@@ -15,11 +15,13 @@ from einops import rearrange
|
|
15 |
|
16 |
|
17 |
@dataclass
|
18 |
-
class
|
19 |
-
"""Annealed Langevin
|
20 |
|
21 |
|
22 |
-
class
|
|
|
|
|
23 |
order = 1
|
24 |
|
25 |
@register_to_config
|
@@ -106,13 +108,13 @@ class AnnealedLangevinDynamicScheduler(SchedulerMixin, ConfigMixin): # type: ig
|
|
106 |
samples: torch.Tensor,
|
107 |
return_dict: bool = True,
|
108 |
**kwargs,
|
109 |
-
) -> Union[
|
110 |
z = torch.randn_like(samples)
|
111 |
step_size = self.step_size[timestep]
|
112 |
samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
|
113 |
|
114 |
if return_dict:
|
115 |
-
return
|
116 |
else:
|
117 |
return (samples,)
|
118 |
|
|
|
15 |
|
16 |
|
17 |
@dataclass
|
18 |
+
class AnnealedLangevinDynamicsOutput(SchedulerOutput):
|
19 |
+
"""Annealed Langevin Dynamics output class."""
|
20 |
|
21 |
|
22 |
+
class AnnealedLangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): # type: ignore
|
23 |
+
"""Annealed Langevin Dynamics scheduler for Noise Conditional Score Network (NCSN)."""
|
24 |
+
|
25 |
order = 1
|
26 |
|
27 |
@register_to_config
|
|
|
108 |
samples: torch.Tensor,
|
109 |
return_dict: bool = True,
|
110 |
**kwargs,
|
111 |
+
) -> Union[AnnealedLangevinDynamicsOutput, Tuple]:
|
112 |
z = torch.randn_like(samples)
|
113 |
step_size = self.step_size[timestep]
|
114 |
samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
|
115 |
|
116 |
if return_dict:
|
117 |
+
return AnnealedLangevinDynamicsOutput(prev_sample=samples)
|
118 |
else:
|
119 |
return (samples,)
|
120 |
|