Diffusers
Safetensors
shunk031 commited on
Commit
03482fb
·
verified ·
1 Parent(s): c2e4aaf

Upload scheduler/scheduling_ncsn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scheduler/scheduling_ncsn.py +7 -5
scheduler/scheduling_ncsn.py CHANGED
@@ -15,11 +15,13 @@ from einops import rearrange
15
 
16
 
17
  @dataclass
18
- class AnnealedLangevinDynamicOutput(SchedulerOutput):
19
- """Annealed Langevin Dynamic output class."""
20
 
21
 
22
- class AnnealedLangevinDynamicScheduler(SchedulerMixin, ConfigMixin): # type: ignore
 
 
23
  order = 1
24
 
25
  @register_to_config
@@ -106,13 +108,13 @@ class AnnealedLangevinDynamicScheduler(SchedulerMixin, ConfigMixin): # type: ig
106
  samples: torch.Tensor,
107
  return_dict: bool = True,
108
  **kwargs,
109
- ) -> Union[AnnealedLangevinDynamicOutput, Tuple]:
110
  z = torch.randn_like(samples)
111
  step_size = self.step_size[timestep]
112
  samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
113
 
114
  if return_dict:
115
- return AnnealedLangevinDynamicOutput(prev_sample=samples)
116
  else:
117
  return (samples,)
118
 
 
15
 
16
 
17
  @dataclass
18
+ class AnnealedLangevinDynamicsOutput(SchedulerOutput):
19
+ """Annealed Langevin Dynamics output class."""
20
 
21
 
22
+ class AnnealedLangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): # type: ignore
23
+ """Annealed Langevin Dynamics scheduler for Noise Conditional Score Network (NCSN)."""
24
+
25
  order = 1
26
 
27
  @register_to_config
 
108
  samples: torch.Tensor,
109
  return_dict: bool = True,
110
  **kwargs,
111
+ ) -> Union[AnnealedLangevinDynamicsOutput, Tuple]:
112
  z = torch.randn_like(samples)
113
  step_size = self.step_size[timestep]
114
  samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
115
 
116
  if return_dict:
117
+ return AnnealedLangevinDynamicsOutput(prev_sample=samples)
118
  else:
119
  return (samples,)
120