ntt123 commited on
Commit
a6e518a
·
1 Parent(s): b2bfb1d

add diffusion utils

Browse files
Files changed (8) hide show
  1. README.md +15 -1
  2. app.py +1 -1
  3. diffusion.py +46 -0
  4. diffusion_utils.py +90 -0
  5. gaussian_diffusion.py +917 -0
  6. respace.py +127 -0
  7. timestep_sampler.py +150 -0
  8. uv.lock +53 -37
README.md CHANGED
@@ -5,10 +5,24 @@ colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.9.1
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
 
 
11
  short_description: A simple diffusion-based text to speech model
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
+ python_version: 3.11
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
+ models:
13
+ - ntt123/diffusion-speech-360h
14
+ preload_from_hub: true
15
+ - ntt123/diffusion-speech-360h acoustic_model_0140000.pt,duration_model_0120000.pt
16
  short_description: A simple diffusion-based text to speech model
17
  ---
18
 
19
+ ```
20
+ uv run synthesize.py \
21
+ --duration-model-config ./train_duration_dit_s.yaml \
22
+ --acoustic-model-config ./train_acoustic_dit_b.yaml \
23
+ --duration-model-checkpoint ./duration_model_0120000.pt \
24
+ --acoustic-model-checkpoint ./acoustic_model_0140000.pt \
25
+ --speaker-id 1914 \
26
+ --output-file ./audio.wav \
27
+ --text "Ilya has made several major contributions to the field of deep learning."
28
+ ```
app.py CHANGED
@@ -24,7 +24,7 @@ sampling_steps = [100, 250, 500, 1000]
24
  demo = gr.Interface(
25
  fn=text_to_speech,
26
  inputs=[
27
- gr.Textbox(label="Text"),
28
  gr.Dropdown(choices=speaker_ids, label="Speaker ID", value="0"),
29
  gr.Slider(minimum=0, maximum=10, value=4.0, label="CFG Scale"),
30
  gr.Dropdown(choices=sampling_steps, label="Sampling Steps", value=100),
 
24
  demo = gr.Interface(
25
  fn=text_to_speech,
26
  inputs=[
27
+ gr.Textbox(label="Text", value="Text to Speech with Diffusion Transformer"),
28
  gr.Dropdown(choices=speaker_ids, label="Speaker ID", value="0"),
29
  gr.Slider(minimum=0, maximum=10, value=4.0, label="CFG Scale"),
30
  gr.Dropdown(choices=sampling_steps, label="Sampling Steps", value=100),
diffusion.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import gaussian_diffusion as gd
7
+ from respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000,
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type,
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
diffusion_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
59
+ normalized_x
60
+ )
61
+ return log_probs
62
+
63
+
64
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
65
+ """
66
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
67
+ given image.
68
+ :param x: the target images. It is assumed that this was uint8 values,
69
+ rescaled to the range [-1, 1].
70
+ :param means: the Gaussian mean Tensor.
71
+ :param log_scales: the Gaussian log stddev Tensor.
72
+ :return: a tensor like x of log probabilities (in nats).
73
+ """
74
+ assert x.shape == means.shape == log_scales.shape
75
+ centered_x = x - means
76
+ inv_stdv = th.exp(-log_scales)
77
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
78
+ cdf_plus = approx_standard_normal_cdf(plus_in)
79
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
80
+ cdf_min = approx_standard_normal_cdf(min_in)
81
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
82
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
83
+ cdf_delta = cdf_plus - cdf_min
84
+ log_probs = th.where(
85
+ x < -0.999,
86
+ log_cdf_plus,
87
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
88
+ )
89
+ assert log_probs.shape == x.shape
90
+ return log_probs
gaussian_diffusion.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import enum
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch as th
12
+
13
+ from diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(
62
+ beta_start, beta_end, warmup_time, dtype=np.float64
63
+ )
64
+ return betas
65
+
66
+
67
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
68
+ """
69
+ This is the deprecated API for creating beta schedules.
70
+ See get_named_beta_schedule() for the new library of schedules.
71
+ """
72
+ if beta_schedule == "quad":
73
+ betas = (
74
+ np.linspace(
75
+ beta_start**0.5,
76
+ beta_end**0.5,
77
+ num_diffusion_timesteps,
78
+ dtype=np.float64,
79
+ )
80
+ ** 2
81
+ )
82
+ elif beta_schedule == "linear":
83
+ betas = np.linspace(
84
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
85
+ )
86
+ elif beta_schedule == "warmup10":
87
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
88
+ elif beta_schedule == "warmup50":
89
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
90
+ elif beta_schedule == "const":
91
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
92
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
93
+ betas = 1.0 / np.linspace(
94
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
95
+ )
96
+ else:
97
+ raise NotImplementedError(beta_schedule)
98
+ assert betas.shape == (num_diffusion_timesteps,)
99
+ return betas
100
+
101
+
102
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
103
+ """
104
+ Get a pre-defined beta schedule for the given name.
105
+ The beta schedule library consists of beta schedules which remain similar
106
+ in the limit of num_diffusion_timesteps.
107
+ Beta schedules may be added, but should not be removed or changed once
108
+ they are committed to maintain backwards compatibility.
109
+ """
110
+ if schedule_name == "linear":
111
+ # Linear schedule from Ho et al, extended to work for any number of
112
+ # diffusion steps.
113
+ scale = 1000 / num_diffusion_timesteps
114
+ return get_beta_schedule(
115
+ "linear",
116
+ beta_start=scale * 0.0001,
117
+ beta_end=scale * 0.02,
118
+ num_diffusion_timesteps=num_diffusion_timesteps,
119
+ )
120
+ elif schedule_name == "squaredcos_cap_v2":
121
+ return betas_for_alpha_bar(
122
+ num_diffusion_timesteps,
123
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
124
+ )
125
+ else:
126
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
127
+
128
+
129
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
130
+ """
131
+ Create a beta schedule that discretizes the given alpha_t_bar function,
132
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
133
+ :param num_diffusion_timesteps: the number of betas to produce.
134
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
135
+ produces the cumulative product of (1-beta) up to that
136
+ part of the diffusion process.
137
+ :param max_beta: the maximum beta to use; use values lower than 1 to
138
+ prevent singularities.
139
+ """
140
+ betas = []
141
+ for i in range(num_diffusion_timesteps):
142
+ t1 = i / num_diffusion_timesteps
143
+ t2 = (i + 1) / num_diffusion_timesteps
144
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
145
+ return np.array(betas)
146
+
147
+
148
+ class GaussianDiffusion:
149
+ """
150
+ Utilities for training and sampling diffusion models.
151
+ Original ported from this codebase:
152
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
153
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
154
+ starting at T and going to 1.
155
+ """
156
+
157
+ def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
158
+
159
+ self.model_mean_type = model_mean_type
160
+ self.model_var_type = model_var_type
161
+ self.loss_type = loss_type
162
+
163
+ # Use float64 for accuracy.
164
+ betas = np.array(betas, dtype=np.float64)
165
+ self.betas = betas
166
+ assert len(betas.shape) == 1, "betas must be 1-D"
167
+ assert (betas > 0).all() and (betas <= 1).all()
168
+
169
+ self.num_timesteps = int(betas.shape[0])
170
+
171
+ alphas = 1.0 - betas
172
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
173
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
174
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
175
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
176
+
177
+ # calculations for diffusion q(x_t | x_{t-1}) and others
178
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
179
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
180
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
181
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
182
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
183
+
184
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
185
+ self.posterior_variance = (
186
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
187
+ )
188
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
189
+ self.posterior_log_variance_clipped = (
190
+ np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
191
+ if len(self.posterior_variance) > 1
192
+ else np.array([])
193
+ )
194
+
195
+ self.posterior_mean_coef1 = (
196
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
197
+ )
198
+ self.posterior_mean_coef2 = (
199
+ (1.0 - self.alphas_cumprod_prev)
200
+ * np.sqrt(alphas)
201
+ / (1.0 - self.alphas_cumprod)
202
+ )
203
+
204
+ # convert all numpy arrays to torch tensors
205
+ DEVICE = th.device("cpu")
206
+ self.betas = th.from_numpy(self.betas).to(DEVICE)
207
+ self.alphas_cumprod = th.from_numpy(self.alphas_cumprod).to(DEVICE)
208
+ self.alphas_cumprod_prev = th.from_numpy(self.alphas_cumprod_prev).to(DEVICE)
209
+ self.alphas_cumprod_next = th.from_numpy(self.alphas_cumprod_next).to(DEVICE)
210
+ self.sqrt_alphas_cumprod = th.from_numpy(self.sqrt_alphas_cumprod).to(DEVICE)
211
+ self.sqrt_one_minus_alphas_cumprod = th.from_numpy(self.sqrt_one_minus_alphas_cumprod).to(DEVICE)
212
+ self.log_one_minus_alphas_cumprod = th.from_numpy(self.log_one_minus_alphas_cumprod).to(DEVICE)
213
+ self.sqrt_recip_alphas_cumprod = th.from_numpy(self.sqrt_recip_alphas_cumprod).to(DEVICE)
214
+ self.sqrt_recipm1_alphas_cumprod = th.from_numpy(self.sqrt_recipm1_alphas_cumprod).to(DEVICE)
215
+ self.posterior_variance = th.from_numpy(self.posterior_variance).to(DEVICE)
216
+ self.posterior_log_variance_clipped = th.from_numpy(self.posterior_log_variance_clipped).to(DEVICE)
217
+ self.posterior_mean_coef1 = th.from_numpy(self.posterior_mean_coef1).to(DEVICE)
218
+ self.posterior_mean_coef2 = th.from_numpy(self.posterior_mean_coef2).to(DEVICE)
219
+
220
+
221
+
222
+ def q_mean_variance(self, x_start, t):
223
+ """
224
+ Get the distribution q(x_t | x_0).
225
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
226
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
227
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
228
+ """
229
+ mean = (
230
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
231
+ )
232
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
233
+ log_variance = _extract_into_tensor(
234
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
235
+ )
236
+ return mean, variance, log_variance
237
+
238
+ def q_sample(self, x_start, t, noise=None):
239
+ """
240
+ Diffuse the data for a given number of diffusion steps.
241
+ In other words, sample from q(x_t | x_0).
242
+ :param x_start: the initial data batch.
243
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
244
+ :param noise: if specified, the split-out normal noise.
245
+ :return: A noisy version of x_start.
246
+ """
247
+ if noise is None:
248
+ noise = th.randn_like(x_start)
249
+ assert noise.shape == x_start.shape
250
+ return (
251
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
252
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
253
+ * noise
254
+ )
255
+
256
+ def q_posterior_mean_variance(self, x_start, x_t, t):
257
+ """
258
+ Compute the mean and variance of the diffusion posterior:
259
+ q(x_{t-1} | x_t, x_0)
260
+ """
261
+ assert x_start.shape == x_t.shape
262
+ posterior_mean = (
263
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
264
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
265
+ )
266
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
267
+ posterior_log_variance_clipped = _extract_into_tensor(
268
+ self.posterior_log_variance_clipped, t, x_t.shape
269
+ )
270
+ assert (
271
+ posterior_mean.shape[0]
272
+ == posterior_variance.shape[0]
273
+ == posterior_log_variance_clipped.shape[0]
274
+ == x_start.shape[0]
275
+ )
276
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
277
+
278
+ def p_mean_variance(
279
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
280
+ ):
281
+ """
282
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
283
+ the initial x, x_0.
284
+ :param model: the model, which takes a signal and a batch of timesteps
285
+ as input.
286
+ :param x: the [N x C x ...] tensor at time t.
287
+ :param t: a 1-D Tensor of timesteps.
288
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
289
+ :param denoised_fn: if not None, a function which applies to the
290
+ x_start prediction before it is used to sample. Applies before
291
+ clip_denoised.
292
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
293
+ pass to the model. This can be used for conditioning.
294
+ :return: a dict with the following keys:
295
+ - 'mean': the model mean output.
296
+ - 'variance': the model variance output.
297
+ - 'log_variance': the log of 'variance'.
298
+ - 'pred_xstart': the prediction for x_0.
299
+ """
300
+ if model_kwargs is None:
301
+ model_kwargs = {}
302
+
303
+ B, C = x.shape[:2]
304
+ assert t.shape == (B,)
305
+ model_output = model(x, t, **model_kwargs)
306
+ if isinstance(model_output, tuple):
307
+ model_output, extra = model_output
308
+ else:
309
+ extra = None
310
+
311
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
312
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
313
+ model_output, model_var_values = th.split(model_output, C, dim=1)
314
+ min_log = _extract_into_tensor(
315
+ self.posterior_log_variance_clipped, t, x.shape
316
+ )
317
+ max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
318
+ # The model_var_values is [-1, 1] for [min_var, max_var].
319
+ frac = (model_var_values + 1) / 2
320
+ model_log_variance = frac * max_log + (1 - frac) * min_log
321
+ model_variance = th.exp(model_log_variance)
322
+ else:
323
+ model_variance, model_log_variance = {
324
+ # for fixedlarge, we set the initial (log-)variance like so
325
+ # to get a better decoder log likelihood.
326
+ ModelVarType.FIXED_LARGE: (
327
+ th.cat([self.posterior_variance[1], self.betas[1:]]),
328
+ th.log(th.cat([self.posterior_variance[1], self.betas[1:]])),
329
+ ),
330
+ ModelVarType.FIXED_SMALL: (
331
+ self.posterior_variance,
332
+ self.posterior_log_variance_clipped,
333
+ ),
334
+ }[self.model_var_type]
335
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
336
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
337
+
338
+ def process_xstart(x):
339
+ if denoised_fn is not None:
340
+ x = denoised_fn(x)
341
+ if clip_denoised:
342
+ return x.clamp(-1, 1)
343
+ return x
344
+
345
+ if self.model_mean_type == ModelMeanType.START_X:
346
+ pred_xstart = process_xstart(model_output)
347
+ else:
348
+ pred_xstart = process_xstart(
349
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
350
+ )
351
+ model_mean, _, _ = self.q_posterior_mean_variance(
352
+ x_start=pred_xstart, x_t=x, t=t
353
+ )
354
+
355
+ assert (
356
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
357
+ )
358
+ return {
359
+ "mean": model_mean,
360
+ "variance": model_variance,
361
+ "log_variance": model_log_variance,
362
+ "pred_xstart": pred_xstart,
363
+ "extra": extra,
364
+ }
365
+
366
+ def _predict_xstart_from_eps(self, x_t, t, eps):
367
+ assert x_t.shape == eps.shape
368
+ return (
369
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
370
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
371
+ )
372
+
373
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
374
+ return (
375
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
376
+ - pred_xstart
377
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
378
+
379
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
380
+ """
381
+ Compute the mean for the previous step, given a function cond_fn that
382
+ computes the gradient of a conditional log probability with respect to
383
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
384
+ condition on y.
385
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
386
+ """
387
+ gradient = cond_fn(x, t, **model_kwargs)
388
+ new_mean = (
389
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
390
+ )
391
+ return new_mean
392
+
393
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
394
+ """
395
+ Compute what the p_mean_variance output would have been, should the
396
+ model's score function be conditioned by cond_fn.
397
+ See condition_mean() for details on cond_fn.
398
+ Unlike condition_mean(), this instead uses the conditioning strategy
399
+ from Song et al (2020).
400
+ """
401
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
402
+
403
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
404
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
405
+
406
+ out = p_mean_var.copy()
407
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
408
+ out["mean"], _, _ = self.q_posterior_mean_variance(
409
+ x_start=out["pred_xstart"], x_t=x, t=t
410
+ )
411
+ return out
412
+
413
+ def p_sample(
414
+ self,
415
+ model,
416
+ x,
417
+ t,
418
+ clip_denoised=True,
419
+ denoised_fn=None,
420
+ cond_fn=None,
421
+ model_kwargs=None,
422
+ ):
423
+ """
424
+ Sample x_{t-1} from the model at the given timestep.
425
+ :param model: the model to sample from.
426
+ :param x: the current tensor at x_{t-1}.
427
+ :param t: the value of t, starting at 0 for the first diffusion step.
428
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
429
+ :param denoised_fn: if not None, a function which applies to the
430
+ x_start prediction before it is used to sample.
431
+ :param cond_fn: if not None, this is a gradient function that acts
432
+ similarly to the model.
433
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
434
+ pass to the model. This can be used for conditioning.
435
+ :return: a dict containing the following keys:
436
+ - 'sample': a random sample from the model.
437
+ - 'pred_xstart': a prediction of x_0.
438
+ """
439
+ out = self.p_mean_variance(
440
+ model,
441
+ x,
442
+ t,
443
+ clip_denoised=clip_denoised,
444
+ denoised_fn=denoised_fn,
445
+ model_kwargs=model_kwargs,
446
+ )
447
+ noise = th.randn_like(x)
448
+ nonzero_mask = (
449
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
450
+ ) # no noise when t == 0
451
+ if cond_fn is not None:
452
+ out["mean"] = self.condition_mean(
453
+ cond_fn, out, x, t, model_kwargs=model_kwargs
454
+ )
455
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
456
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
457
+
458
+ def p_sample_loop(
459
+ self,
460
+ model,
461
+ shape,
462
+ noise=None,
463
+ clip_denoised=True,
464
+ denoised_fn=None,
465
+ cond_fn=None,
466
+ model_kwargs=None,
467
+ device=None,
468
+ progress=False,
469
+ ):
470
+ """
471
+ Generate samples from the model.
472
+ :param model: the model module.
473
+ :param shape: the shape of the samples, (N, C, H, W).
474
+ :param noise: if specified, the noise from the encoder to sample.
475
+ Should be of the same shape as `shape`.
476
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
477
+ :param denoised_fn: if not None, a function which applies to the
478
+ x_start prediction before it is used to sample.
479
+ :param cond_fn: if not None, this is a gradient function that acts
480
+ similarly to the model.
481
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
482
+ pass to the model. This can be used for conditioning.
483
+ :param device: if specified, the device to create the samples on.
484
+ If not specified, use a model parameter's device.
485
+ :param progress: if True, show a tqdm progress bar.
486
+ :return: a non-differentiable batch of samples.
487
+ """
488
+ final = None
489
+ samples = []
490
+ for sample in self.p_sample_loop_progressive(
491
+ model,
492
+ shape,
493
+ noise=noise,
494
+ clip_denoised=clip_denoised,
495
+ denoised_fn=denoised_fn,
496
+ cond_fn=cond_fn,
497
+ model_kwargs=model_kwargs,
498
+ device=device,
499
+ progress=progress,
500
+ ):
501
+ final = sample
502
+ samples.append(final["sample"])
503
+ return samples
504
+
505
+ def p_sample_loop_progressive(
506
+ self,
507
+ model,
508
+ shape,
509
+ noise=None,
510
+ clip_denoised=True,
511
+ denoised_fn=None,
512
+ cond_fn=None,
513
+ model_kwargs=None,
514
+ device=None,
515
+ progress=False,
516
+ ):
517
+ """
518
+ Generate samples from the model and yield intermediate samples from
519
+ each timestep of diffusion.
520
+ Arguments are the same as p_sample_loop().
521
+ Returns a generator over dicts, where each dict is the return value of
522
+ p_sample().
523
+ """
524
+ if device is None:
525
+ device = next(model.parameters()).device
526
+ assert isinstance(shape, (tuple, list))
527
+ if noise is not None:
528
+ img = noise
529
+ else:
530
+ img = th.randn(*shape, device=device)
531
+ indices = list(range(self.num_timesteps))[::-1]
532
+
533
+ if progress:
534
+ # Lazy import so that we don't depend on tqdm.
535
+ from tqdm.auto import tqdm
536
+
537
+ indices = tqdm(indices)
538
+
539
+ for i in indices:
540
+ t = th.tensor([i] * shape[0], device=device)
541
+ with th.no_grad():
542
+ out = self.p_sample(
543
+ model,
544
+ img,
545
+ t,
546
+ clip_denoised=clip_denoised,
547
+ denoised_fn=denoised_fn,
548
+ cond_fn=cond_fn,
549
+ model_kwargs=model_kwargs,
550
+ )
551
+ yield out
552
+ img = out["sample"]
553
+
554
+ def ddim_sample(
555
+ self,
556
+ model,
557
+ x,
558
+ t,
559
+ clip_denoised=True,
560
+ denoised_fn=None,
561
+ cond_fn=None,
562
+ model_kwargs=None,
563
+ eta=0.0,
564
+ ):
565
+ """
566
+ Sample x_{t-1} from the model using DDIM.
567
+ Same usage as p_sample().
568
+ """
569
+ out = self.p_mean_variance(
570
+ model,
571
+ x,
572
+ t,
573
+ clip_denoised=clip_denoised,
574
+ denoised_fn=denoised_fn,
575
+ model_kwargs=model_kwargs,
576
+ )
577
+ if cond_fn is not None:
578
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
579
+
580
+ # Usually our model outputs epsilon, but we re-derive it
581
+ # in case we used x_start or x_prev prediction.
582
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
583
+
584
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
585
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
586
+ sigma = (
587
+ eta
588
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
589
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
590
+ )
591
+ # Equation 12.
592
+ noise = th.randn_like(x)
593
+ mean_pred = (
594
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
595
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
596
+ )
597
+ nonzero_mask = (
598
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
599
+ ) # no noise when t == 0
600
+ sample = mean_pred + nonzero_mask * sigma * noise
601
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
602
+
603
+ def ddim_reverse_sample(
604
+ self,
605
+ model,
606
+ x,
607
+ t,
608
+ clip_denoised=True,
609
+ denoised_fn=None,
610
+ cond_fn=None,
611
+ model_kwargs=None,
612
+ eta=0.0,
613
+ ):
614
+ """
615
+ Sample x_{t+1} from the model using DDIM reverse ODE.
616
+ """
617
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
618
+ out = self.p_mean_variance(
619
+ model,
620
+ x,
621
+ t,
622
+ clip_denoised=clip_denoised,
623
+ denoised_fn=denoised_fn,
624
+ model_kwargs=model_kwargs,
625
+ )
626
+ if cond_fn is not None:
627
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
628
+ # Usually our model outputs epsilon, but we re-derive it
629
+ # in case we used x_start or x_prev prediction.
630
+ eps = (
631
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
632
+ - out["pred_xstart"]
633
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
634
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
635
+
636
+ # Equation 12. reversed
637
+ mean_pred = (
638
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
639
+ + th.sqrt(1 - alpha_bar_next) * eps
640
+ )
641
+
642
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
643
+
644
+ def ddim_sample_loop(
645
+ self,
646
+ model,
647
+ shape,
648
+ noise=None,
649
+ clip_denoised=True,
650
+ denoised_fn=None,
651
+ cond_fn=None,
652
+ model_kwargs=None,
653
+ device=None,
654
+ progress=False,
655
+ eta=0.0,
656
+ ):
657
+ """
658
+ Generate samples from the model using DDIM.
659
+ Same usage as p_sample_loop().
660
+ """
661
+ final = None
662
+ for sample in self.ddim_sample_loop_progressive(
663
+ model,
664
+ shape,
665
+ noise=noise,
666
+ clip_denoised=clip_denoised,
667
+ denoised_fn=denoised_fn,
668
+ cond_fn=cond_fn,
669
+ model_kwargs=model_kwargs,
670
+ device=device,
671
+ progress=progress,
672
+ eta=eta,
673
+ ):
674
+ final = sample
675
+ return final["sample"]
676
+
677
+ def ddim_sample_loop_progressive(
678
+ self,
679
+ model,
680
+ shape,
681
+ noise=None,
682
+ clip_denoised=True,
683
+ denoised_fn=None,
684
+ cond_fn=None,
685
+ model_kwargs=None,
686
+ device=None,
687
+ progress=False,
688
+ eta=0.0,
689
+ ):
690
+ """
691
+ Use DDIM to sample from the model and yield intermediate samples from
692
+ each timestep of DDIM.
693
+ Same usage as p_sample_loop_progressive().
694
+ """
695
+ if device is None:
696
+ device = next(model.parameters()).device
697
+ assert isinstance(shape, (tuple, list))
698
+ if noise is not None:
699
+ img = noise
700
+ else:
701
+ img = th.randn(*shape, device=device)
702
+ indices = list(range(self.num_timesteps))[::-1]
703
+
704
+ if progress:
705
+ # Lazy import so that we don't depend on tqdm.
706
+ from tqdm.auto import tqdm
707
+
708
+ indices = tqdm(indices)
709
+
710
+ for i in indices:
711
+ t = th.tensor([i] * shape[0], device=device)
712
+ with th.no_grad():
713
+ out = self.ddim_sample(
714
+ model,
715
+ img,
716
+ t,
717
+ clip_denoised=clip_denoised,
718
+ denoised_fn=denoised_fn,
719
+ cond_fn=cond_fn,
720
+ model_kwargs=model_kwargs,
721
+ eta=eta,
722
+ )
723
+ yield out
724
+ img = out["sample"]
725
+
726
+ def _vb_terms_bpd(
727
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
728
+ ):
729
+ """
730
+ Get a term for the variational lower-bound.
731
+ The resulting units are bits (rather than nats, as one might expect).
732
+ This allows for comparison to other papers.
733
+ :return: a dict with the following keys:
734
+ - 'output': a shape [N] tensor of NLLs or KLs.
735
+ - 'pred_xstart': the x_0 predictions.
736
+ """
737
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
738
+ x_start=x_start, x_t=x_t, t=t
739
+ )
740
+ out = self.p_mean_variance(
741
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
742
+ )
743
+ kl = normal_kl(
744
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
745
+ )
746
+ kl = kl / math.log(2.0)
747
+
748
+ decoder_nll = -discretized_gaussian_log_likelihood(
749
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
750
+ )
751
+ assert decoder_nll.shape == x_start.shape
752
+ decoder_nll = decoder_nll / math.log(2.0)
753
+
754
+ # At the first timestep return the decoder NLL,
755
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
756
+ output = th.where((t[:, None, None] == 0), decoder_nll, kl)
757
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
758
+
759
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
760
+ """
761
+ Compute training losses for a single timestep.
762
+ :param model: the model to evaluate loss on.
763
+ :param x_start: the [N x C x ...] tensor of inputs.
764
+ :param t: a batch of timestep indices.
765
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
766
+ pass to the model. This can be used for conditioning.
767
+ :param noise: if specified, the specific Gaussian noise to try to remove.
768
+ :return: a dict with the key "loss" containing a tensor of shape [N].
769
+ Some mean or variance settings may also have other keys.
770
+ """
771
+ if model_kwargs is None:
772
+ model_kwargs = {}
773
+ if noise is None:
774
+ noise = th.randn_like(x_start)
775
+ x_t = self.q_sample(x_start, t, noise=noise)
776
+
777
+ terms = {}
778
+
779
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
780
+ terms["loss"] = self._vb_terms_bpd(
781
+ model=model,
782
+ x_start=x_start,
783
+ x_t=x_t,
784
+ t=t,
785
+ clip_denoised=False,
786
+ model_kwargs=model_kwargs,
787
+ )["output"]
788
+ if self.loss_type == LossType.RESCALED_KL:
789
+ terms["loss"] *= self.num_timesteps
790
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
791
+ model_output = model(x_t, t, **model_kwargs)
792
+
793
+ if self.model_var_type in [
794
+ ModelVarType.LEARNED,
795
+ ModelVarType.LEARNED_RANGE,
796
+ ]:
797
+ B, C = x_t.shape[:2]
798
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
799
+ model_output, model_var_values = th.split(model_output, C, dim=1)
800
+ # Learn the variance using the variational bound, but don't let
801
+ # it affect our mean prediction.
802
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
803
+ terms["vb"] = self._vb_terms_bpd(
804
+ model=lambda *args, r=frozen_out: r,
805
+ x_start=x_start,
806
+ x_t=x_t,
807
+ t=t,
808
+ clip_denoised=False,
809
+ )["output"]
810
+ if self.loss_type == LossType.RESCALED_MSE:
811
+ # Divide by 1000 for equivalence with initial implementation.
812
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
813
+ terms["vb"] *= self.num_timesteps / 1000.0
814
+
815
+ target = {
816
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
817
+ x_start=x_start, x_t=x_t, t=t
818
+ )[0],
819
+ ModelMeanType.START_X: x_start,
820
+ ModelMeanType.EPSILON: noise,
821
+ }[self.model_mean_type]
822
+ assert model_output.shape == target.shape == x_start.shape
823
+ terms["mse"] = (target - model_output) ** 2
824
+ if "vb" in terms:
825
+ terms["loss"] = terms["mse"] + terms["vb"]
826
+ else:
827
+ terms["loss"] = terms["mse"]
828
+ else:
829
+ raise NotImplementedError(self.loss_type)
830
+
831
+ return terms
832
+
833
+ def _prior_bpd(self, x_start):
834
+ """
835
+ Get the prior KL term for the variational lower-bound, measured in
836
+ bits-per-dim.
837
+ This term can't be optimized, as it only depends on the encoder.
838
+ :param x_start: the [N x C x ...] tensor of inputs.
839
+ :return: a batch of [N] KL values (in bits), one per batch element.
840
+ """
841
+ batch_size = x_start.shape[0]
842
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
843
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
844
+ kl_prior = normal_kl(
845
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
846
+ )
847
+ return mean_flat(kl_prior) / math.log(2.0)
848
+
849
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
850
+ """
851
+ Compute the entire variational lower-bound, measured in bits-per-dim,
852
+ as well as other related quantities.
853
+ :param model: the model to evaluate loss on.
854
+ :param x_start: the [N x C x ...] tensor of inputs.
855
+ :param clip_denoised: if True, clip denoised samples.
856
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
857
+ pass to the model. This can be used for conditioning.
858
+ :return: a dict containing the following keys:
859
+ - total_bpd: the total variational lower-bound, per batch element.
860
+ - prior_bpd: the prior term in the lower-bound.
861
+ - vb: an [N x T] tensor of terms in the lower-bound.
862
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
863
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
864
+ """
865
+ device = x_start.device
866
+ batch_size = x_start.shape[0]
867
+
868
+ vb = []
869
+ xstart_mse = []
870
+ mse = []
871
+ for t in list(range(self.num_timesteps))[::-1]:
872
+ t_batch = th.tensor([t] * batch_size, device=device)
873
+ noise = th.randn_like(x_start)
874
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
875
+ # Calculate VLB term at the current timestep
876
+ with th.no_grad():
877
+ out = self._vb_terms_bpd(
878
+ model,
879
+ x_start=x_start,
880
+ x_t=x_t,
881
+ t=t_batch,
882
+ clip_denoised=clip_denoised,
883
+ model_kwargs=model_kwargs,
884
+ )
885
+ vb.append(out["output"])
886
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
887
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
888
+ mse.append(mean_flat((eps - noise) ** 2))
889
+
890
+ vb = th.stack(vb, dim=1)
891
+ xstart_mse = th.stack(xstart_mse, dim=1)
892
+ mse = th.stack(mse, dim=1)
893
+
894
+ prior_bpd = self._prior_bpd(x_start)
895
+ total_bpd = vb.sum(dim=1) + prior_bpd
896
+ return {
897
+ "total_bpd": total_bpd,
898
+ "prior_bpd": prior_bpd,
899
+ "vb": vb,
900
+ "xstart_mse": xstart_mse,
901
+ "mse": mse,
902
+ }
903
+
904
+
905
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
906
+ """
907
+ Extract values from a 1-D numpy array for a batch of indices.
908
+ :param arr: the 1-D numpy array.
909
+ :param timesteps: a tensor of indices into the array to extract.
910
+ :param broadcast_shape: a larger shape of K dimensions with the batch
911
+ dimension equal to the length of timesteps.
912
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
913
+ """
914
+ res = arr.to(device=timesteps.device)[timesteps].float()
915
+ while len(res.shape) < len(broadcast_shape):
916
+ res = res[..., None]
917
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
respace.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = th.tensor(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
109
+
110
+ def _scale_timesteps(self, t):
111
+ # Scaling is done by the wrapped model.
112
+ return t
113
+
114
+
115
+ class _WrappedModel:
116
+ def __init__(self, model, timestep_map, original_num_steps):
117
+ self.model = model
118
+ self.timestep_map = timestep_map
119
+ # self.rescale_timesteps = rescale_timesteps
120
+ self.original_num_steps = original_num_steps
121
+
122
+ def __call__(self, x, ts, **kwargs):
123
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
124
+ new_ts = map_tensor[ts]
125
+ # if self.rescale_timesteps:
126
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
127
+ return self.model(x, new_ts, **kwargs)
timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
uv.lock CHANGED
@@ -1,8 +1,10 @@
1
  version = 1
2
  requires-python = ">=3.12"
3
  resolution-markers = [
4
- "python_full_version < '3.13'",
5
- "python_full_version >= '3.13'",
 
 
6
  ]
7
 
8
  [[package]]
@@ -188,7 +190,8 @@ dependencies = [
188
  { name = "gradio" },
189
  { name = "nltk" },
190
  { name = "soundfile" },
191
- { name = "torch" },
 
192
  { name = "vocos" },
193
  ]
194
 
@@ -198,7 +201,8 @@ requires-dist = [
198
  { name = "gradio", specifier = ">=5.9.1" },
199
  { name = "nltk", specifier = ">=3.9.1" },
200
  { name = "soundfile", specifier = ">=0.12.1" },
201
- { name = "torch", specifier = ">=2.5.1" },
 
202
  { name = "vocos", specifier = ">=0.1.0" },
203
  ]
204
 
@@ -224,7 +228,8 @@ source = { registry = "https://pypi.org/simple" }
224
  dependencies = [
225
  { name = "einops" },
226
  { name = "numpy" },
227
- { name = "torch" },
 
228
  { name = "torchaudio" },
229
  ]
230
  sdist = { url = "https://files.pythonhosted.org/packages/62/59/e47bbd0542d0e6f4ce9983d5eb458a01d4b42c81e5c410cb9e159b1061ae/encodec-0.1.1.tar.gz", hash = "sha256:36dde98ccfe6c51a15576476cadfcb3b35a63507b8b8555abd69889a6fba6772", size = 3736037 }
@@ -598,7 +603,7 @@ name = "nvidia-cudnn-cu12"
598
  version = "9.1.0.70"
599
  source = { registry = "https://pypi.org/simple" }
600
  dependencies = [
601
- { name = "nvidia-cublas-cu12" },
602
  ]
603
  wheels = [
604
  { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -609,7 +614,7 @@ name = "nvidia-cufft-cu12"
609
  version = "11.2.1.3"
610
  source = { registry = "https://pypi.org/simple" }
611
  dependencies = [
612
- { name = "nvidia-nvjitlink-cu12" },
613
  ]
614
  wheels = [
615
  { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 },
@@ -630,9 +635,9 @@ name = "nvidia-cusolver-cu12"
630
  version = "11.6.1.9"
631
  source = { registry = "https://pypi.org/simple" }
632
  dependencies = [
633
- { name = "nvidia-cublas-cu12" },
634
- { name = "nvidia-cusparse-cu12" },
635
- { name = "nvidia-nvjitlink-cu12" },
636
  ]
637
  wheels = [
638
  { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 },
@@ -644,7 +649,7 @@ name = "nvidia-cusparse-cu12"
644
  version = "12.3.1.170"
645
  source = { registry = "https://pypi.org/simple" }
646
  dependencies = [
647
- { name = "nvidia-nvjitlink-cu12" },
648
  ]
649
  wheels = [
650
  { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 },
@@ -1153,12 +1158,16 @@ wheels = [
1153
  [[package]]
1154
  name = "torch"
1155
  version = "2.5.1"
1156
- source = { registry = "https://pypi.org/simple" }
 
 
 
 
1157
  dependencies = [
1158
- { name = "filelock" },
1159
- { name = "fsspec" },
1160
- { name = "jinja2" },
1161
- { name = "networkx" },
1162
  { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1163
  { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1164
  { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
@@ -1171,17 +1180,33 @@ dependencies = [
1171
  { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1172
  { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1173
  { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1174
- { name = "setuptools" },
1175
- { name = "sympy" },
1176
- { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" },
1177
- { name = "typing-extensions" },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1178
  ]
1179
  wheels = [
1180
- { url = "https://files.pythonhosted.org/packages/8b/5c/36c114d120bfe10f9323ed35061bc5878cc74f3f594003854b0ea298942f/torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03", size = 906389343 },
1181
- { url = "https://files.pythonhosted.org/packages/6d/69/d8ada8b6e0a4257556d5b4ddeb4345ea8eeaaef3c98b60d1cca197c7ad8e/torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697", size = 91811673 },
1182
- { url = "https://files.pythonhosted.org/packages/5f/ba/607d013b55b9fd805db2a5c2662ec7551f1910b4eef39653eeaba182c5b2/torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c", size = 203046841 },
1183
  { url = "https://files.pythonhosted.org/packages/57/6c/bf52ff061da33deb9f94f4121fde7ff3058812cb7d2036c97bc167793bd1/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1", size = 63858109 },
1184
- { url = "https://files.pythonhosted.org/packages/69/72/20cb30f3b39a9face296491a86adb6ff8f1a47a897e4d14667e6cf89d5c3/torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7", size = 906393265 },
1185
  ]
1186
 
1187
  [[package]]
@@ -1189,7 +1214,8 @@ name = "torchaudio"
1189
  version = "2.5.1"
1190
  source = { registry = "https://pypi.org/simple" }
1191
  dependencies = [
1192
- { name = "torch" },
 
1193
  ]
1194
  wheels = [
1195
  { url = "https://files.pythonhosted.org/packages/03/ab/151037a41e2cf4a5d489dfe5e7196b755e0fd83958d5ca7ad8ed85afcb1c/torchaudio-2.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1cbfdfd1bbdfbe7289d47a74f36ff6c5d87c3205606202fef5a7fb693f61cf0", size = 1798042 },
@@ -1210,17 +1236,6 @@ wheels = [
1210
  { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
1211
  ]
1212
 
1213
- [[package]]
1214
- name = "triton"
1215
- version = "3.1.0"
1216
- source = { registry = "https://pypi.org/simple" }
1217
- dependencies = [
1218
- { name = "filelock", marker = "python_full_version < '3.13'" },
1219
- ]
1220
- wheels = [
1221
- { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 },
1222
- ]
1223
-
1224
  [[package]]
1225
  name = "typeguard"
1226
  version = "4.4.1"
@@ -1299,7 +1314,8 @@ dependencies = [
1299
  { name = "numpy" },
1300
  { name = "pyyaml" },
1301
  { name = "scipy" },
1302
- { name = "torch" },
 
1303
  { name = "torchaudio" },
1304
  ]
1305
  sdist = { url = "https://files.pythonhosted.org/packages/db/48/1e4d3a4a97292e47ebaa18e3eae6ecb2f57bde47693ccab0e7acb23f9ffe/vocos-0.1.0.tar.gz", hash = "sha256:b488224dbe398ff7d4790a027ad659478b4bc02e465db992c62c12b32ca043d8", size = 21107 }
 
1
  version = 1
2
  requires-python = ">=3.12"
3
  resolution-markers = [
4
+ "python_full_version < '3.13' and platform_system != 'Darwin'",
5
+ "python_full_version < '3.13' and platform_system == 'Darwin'",
6
+ "python_full_version >= '3.13' and platform_system != 'Darwin'",
7
+ "python_full_version >= '3.13' and platform_system == 'Darwin'",
8
  ]
9
 
10
  [[package]]
 
190
  { name = "gradio" },
191
  { name = "nltk" },
192
  { name = "soundfile" },
193
+ { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "platform_system != 'Darwin'" },
194
+ { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_system == 'Darwin'" },
195
  { name = "vocos" },
196
  ]
197
 
 
201
  { name = "gradio", specifier = ">=5.9.1" },
202
  { name = "nltk", specifier = ">=3.9.1" },
203
  { name = "soundfile", specifier = ">=0.12.1" },
204
+ { name = "torch", marker = "platform_system != 'Darwin'", specifier = ">=2.5.1", index = "https://download.pytorch.org/whl/cpu" },
205
+ { name = "torch", marker = "platform_system == 'Darwin'", specifier = ">=2.5.1" },
206
  { name = "vocos", specifier = ">=0.1.0" },
207
  ]
208
 
 
228
  dependencies = [
229
  { name = "einops" },
230
  { name = "numpy" },
231
+ { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "platform_system != 'Darwin'" },
232
+ { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_system == 'Darwin'" },
233
  { name = "torchaudio" },
234
  ]
235
  sdist = { url = "https://files.pythonhosted.org/packages/62/59/e47bbd0542d0e6f4ce9983d5eb458a01d4b42c81e5c410cb9e159b1061ae/encodec-0.1.1.tar.gz", hash = "sha256:36dde98ccfe6c51a15576476cadfcb3b35a63507b8b8555abd69889a6fba6772", size = 3736037 }
 
603
  version = "9.1.0.70"
604
  source = { registry = "https://pypi.org/simple" }
605
  dependencies = [
606
+ { name = "nvidia-cublas-cu12", marker = "platform_system != 'Darwin'" },
607
  ]
608
  wheels = [
609
  { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
 
614
  version = "11.2.1.3"
615
  source = { registry = "https://pypi.org/simple" }
616
  dependencies = [
617
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_system != 'Darwin'" },
618
  ]
619
  wheels = [
620
  { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 },
 
635
  version = "11.6.1.9"
636
  source = { registry = "https://pypi.org/simple" }
637
  dependencies = [
638
+ { name = "nvidia-cublas-cu12", marker = "platform_system != 'Darwin'" },
639
+ { name = "nvidia-cusparse-cu12", marker = "platform_system != 'Darwin'" },
640
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_system != 'Darwin'" },
641
  ]
642
  wheels = [
643
  { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 },
 
649
  version = "12.3.1.170"
650
  source = { registry = "https://pypi.org/simple" }
651
  dependencies = [
652
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_system != 'Darwin'" },
653
  ]
654
  wheels = [
655
  { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 },
 
1158
  [[package]]
1159
  name = "torch"
1160
  version = "2.5.1"
1161
+ source = { registry = "https://download.pytorch.org/whl/cpu" }
1162
+ resolution-markers = [
1163
+ "python_full_version < '3.13' and platform_system != 'Darwin'",
1164
+ "python_full_version >= '3.13' and platform_system != 'Darwin'",
1165
+ ]
1166
  dependencies = [
1167
+ { name = "filelock", marker = "platform_system != 'Darwin'" },
1168
+ { name = "fsspec", marker = "platform_system != 'Darwin'" },
1169
+ { name = "jinja2", marker = "platform_system != 'Darwin'" },
1170
+ { name = "networkx", marker = "platform_system != 'Darwin'" },
1171
  { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1172
  { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1173
  { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
 
1180
  { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1181
  { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1182
  { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
1183
+ { name = "setuptools", marker = "platform_system != 'Darwin'" },
1184
+ { name = "sympy", marker = "platform_system != 'Darwin'" },
1185
+ { name = "typing-extensions", marker = "platform_system != 'Darwin'" },
1186
+ ]
1187
+ wheels = [
1188
+ { url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d1be99281b6f602d9639bd0af3ee0006e7aab16f6718d86f709d395b6f262c" },
1189
+ ]
1190
+
1191
+ [[package]]
1192
+ name = "torch"
1193
+ version = "2.5.1"
1194
+ source = { registry = "https://pypi.org/simple" }
1195
+ resolution-markers = [
1196
+ "python_full_version < '3.13' and platform_system == 'Darwin'",
1197
+ "python_full_version >= '3.13' and platform_system == 'Darwin'",
1198
+ ]
1199
+ dependencies = [
1200
+ { name = "filelock", marker = "platform_system == 'Darwin'" },
1201
+ { name = "fsspec", marker = "platform_system == 'Darwin'" },
1202
+ { name = "jinja2", marker = "platform_system == 'Darwin'" },
1203
+ { name = "networkx", marker = "platform_system == 'Darwin'" },
1204
+ { name = "setuptools", marker = "platform_system == 'Darwin'" },
1205
+ { name = "sympy", marker = "platform_system == 'Darwin'" },
1206
+ { name = "typing-extensions", marker = "platform_system == 'Darwin'" },
1207
  ]
1208
  wheels = [
 
 
 
1209
  { url = "https://files.pythonhosted.org/packages/57/6c/bf52ff061da33deb9f94f4121fde7ff3058812cb7d2036c97bc167793bd1/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1", size = 63858109 },
 
1210
  ]
1211
 
1212
  [[package]]
 
1214
  version = "2.5.1"
1215
  source = { registry = "https://pypi.org/simple" }
1216
  dependencies = [
1217
+ { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "platform_system != 'Darwin'" },
1218
+ { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_system == 'Darwin'" },
1219
  ]
1220
  wheels = [
1221
  { url = "https://files.pythonhosted.org/packages/03/ab/151037a41e2cf4a5d489dfe5e7196b755e0fd83958d5ca7ad8ed85afcb1c/torchaudio-2.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1cbfdfd1bbdfbe7289d47a74f36ff6c5d87c3205606202fef5a7fb693f61cf0", size = 1798042 },
 
1236
  { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
1237
  ]
1238
 
 
 
 
 
 
 
 
 
 
 
 
1239
  [[package]]
1240
  name = "typeguard"
1241
  version = "4.4.1"
 
1314
  { name = "numpy" },
1315
  { name = "pyyaml" },
1316
  { name = "scipy" },
1317
+ { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "platform_system != 'Darwin'" },
1318
+ { name = "torch", version = "2.5.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_system == 'Darwin'" },
1319
  { name = "torchaudio" },
1320
  ]
1321
  sdist = { url = "https://files.pythonhosted.org/packages/db/48/1e4d3a4a97292e47ebaa18e3eae6ecb2f57bde47693ccab0e7acb23f9ffe/vocos-0.1.0.tar.gz", hash = "sha256:b488224dbe398ff7d4790a027ad659478b4bc02e465db992c62c12b32ca043d8", size = 21107 }