Spaces:
Runtime error
Runtime error
TristanBehrens
commited on
Commit
•
1b5fc25
1
Parent(s):
4f552a8
Enables non square mel spectrograms
Browse files- audiodiffusion/__init__.py +3 -3
- scripts/train_unconditional.py +19 -4
audiodiffusion/__init__.py
CHANGED
@@ -181,8 +181,8 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|
181 |
self.scheduler.set_timesteps(steps)
|
182 |
mask = None
|
183 |
images = noise = torch.randn(
|
184 |
-
(batch_size, self.unet.in_channels, self.unet.sample_size,
|
185 |
-
self.unet.sample_size),
|
186 |
generator=generator)
|
187 |
|
188 |
if audio_file is not None or raw_audio is not None:
|
@@ -206,7 +206,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|
206 |
noise, torch.tensor(steps - start_step))
|
207 |
|
208 |
pixels_per_second = (mel.get_sample_rate() *
|
209 |
-
|
210 |
mel.x_res)
|
211 |
mask_start = int(mask_start_secs * pixels_per_second)
|
212 |
mask_end = int(mask_end_secs * pixels_per_second)
|
|
|
181 |
self.scheduler.set_timesteps(steps)
|
182 |
mask = None
|
183 |
images = noise = torch.randn(
|
184 |
+
(batch_size, self.unet.in_channels, self.unet.sample_size[0],
|
185 |
+
self.unet.sample_size[1]),
|
186 |
generator=generator)
|
187 |
|
188 |
if audio_file is not None or raw_audio is not None:
|
|
|
206 |
noise, torch.tensor(steps - start_step))
|
207 |
|
208 |
pixels_per_second = (mel.get_sample_rate() *
|
209 |
+
mel.y_res / mel.hop_length /
|
210 |
mel.x_res)
|
211 |
mask_start = int(mask_start_secs * pixels_per_second)
|
212 |
mask_end = int(mask_end_secs * pixels_per_second)
|
scripts/train_unconditional.py
CHANGED
@@ -26,6 +26,9 @@ import numpy as np
|
|
26 |
from tqdm.auto import tqdm
|
27 |
from librosa.util import normalize
|
28 |
|
|
|
|
|
|
|
29 |
from audiodiffusion.mel import Mel
|
30 |
from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
|
31 |
|
@@ -42,6 +45,18 @@ def main(args):
|
|
42 |
logging_dir=logging_dir,
|
43 |
)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
if args.vae is not None:
|
46 |
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
47 |
|
@@ -156,9 +171,9 @@ def main(args):
|
|
156 |
run = os.path.split(__file__)[-1].split(".")[0]
|
157 |
accelerator.init_trackers(run)
|
158 |
|
159 |
-
mel = Mel(x_res=args.resolution,
|
160 |
-
|
161 |
-
|
162 |
|
163 |
global_step = 0
|
164 |
for epoch in range(args.num_epochs):
|
@@ -311,7 +326,7 @@ if __name__ == "__main__":
|
|
311 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
312 |
parser.add_argument("--overwrite_output_dir", type=bool, default=False)
|
313 |
parser.add_argument("--cache_dir", type=str, default=None)
|
314 |
-
parser.add_argument("--resolution", type=
|
315 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
316 |
parser.add_argument("--eval_batch_size", type=int, default=16)
|
317 |
parser.add_argument("--num_epochs", type=int, default=100)
|
|
|
26 |
from tqdm.auto import tqdm
|
27 |
from librosa.util import normalize
|
28 |
|
29 |
+
import sys
|
30 |
+
sys.path.append('.')
|
31 |
+
sys.path.append('..')
|
32 |
from audiodiffusion.mel import Mel
|
33 |
from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
|
34 |
|
|
|
45 |
logging_dir=logging_dir,
|
46 |
)
|
47 |
|
48 |
+
# Handle the resolutions.
|
49 |
+
try:
|
50 |
+
args.resolution = (int(args.resolution), int(args.resolution))
|
51 |
+
except:
|
52 |
+
try :
|
53 |
+
args.resolution = tuple(int(x) for x in args.resolution.split(","))
|
54 |
+
if len(args.resolution) != 2:
|
55 |
+
raise ValueError("Resolution must be a tuple of two integers or a single integer.")
|
56 |
+
except:
|
57 |
+
raise ValueError("Resolution must be a tuple of two integers or a single integer.")
|
58 |
+
assert isinstance(args.resolution, tuple)
|
59 |
+
|
60 |
if args.vae is not None:
|
61 |
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
62 |
|
|
|
171 |
run = os.path.split(__file__)[-1].split(".")[0]
|
172 |
accelerator.init_trackers(run)
|
173 |
|
174 |
+
mel = Mel(x_res=args.resolution[0],
|
175 |
+
y_res=args.resolution[1],
|
176 |
+
hop_length=args.hop_length)
|
177 |
|
178 |
global_step = 0
|
179 |
for epoch in range(args.num_epochs):
|
|
|
326 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
327 |
parser.add_argument("--overwrite_output_dir", type=bool, default=False)
|
328 |
parser.add_argument("--cache_dir", type=str, default=None)
|
329 |
+
parser.add_argument("--resolution", type=str, default="256")
|
330 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
331 |
parser.add_argument("--eval_batch_size", type=int, default=16)
|
332 |
parser.add_argument("--num_epochs", type=int, default=100)
|