TristanBehrens commited on
Commit
1b5fc25
1 Parent(s): 4f552a8

Enables non square mel spectrograms

Browse files
audiodiffusion/__init__.py CHANGED
@@ -181,8 +181,8 @@ class AudioDiffusionPipeline(DiffusionPipeline):
181
  self.scheduler.set_timesteps(steps)
182
  mask = None
183
  images = noise = torch.randn(
184
- (batch_size, self.unet.in_channels, self.unet.sample_size,
185
- self.unet.sample_size),
186
  generator=generator)
187
 
188
  if audio_file is not None or raw_audio is not None:
@@ -206,7 +206,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
206
  noise, torch.tensor(steps - start_step))
207
 
208
  pixels_per_second = (mel.get_sample_rate() *
209
- self.unet.sample_size / mel.hop_length /
210
  mel.x_res)
211
  mask_start = int(mask_start_secs * pixels_per_second)
212
  mask_end = int(mask_end_secs * pixels_per_second)
 
181
  self.scheduler.set_timesteps(steps)
182
  mask = None
183
  images = noise = torch.randn(
184
+ (batch_size, self.unet.in_channels, self.unet.sample_size[0],
185
+ self.unet.sample_size[1]),
186
  generator=generator)
187
 
188
  if audio_file is not None or raw_audio is not None:
 
206
  noise, torch.tensor(steps - start_step))
207
 
208
  pixels_per_second = (mel.get_sample_rate() *
209
+ mel.y_res / mel.hop_length /
210
  mel.x_res)
211
  mask_start = int(mask_start_secs * pixels_per_second)
212
  mask_end = int(mask_end_secs * pixels_per_second)
scripts/train_unconditional.py CHANGED
@@ -26,6 +26,9 @@ import numpy as np
26
  from tqdm.auto import tqdm
27
  from librosa.util import normalize
28
 
 
 
 
29
  from audiodiffusion.mel import Mel
30
  from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
31
 
@@ -42,6 +45,18 @@ def main(args):
42
  logging_dir=logging_dir,
43
  )
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  if args.vae is not None:
46
  vqvae = AutoencoderKL.from_pretrained(args.vae)
47
 
@@ -156,9 +171,9 @@ def main(args):
156
  run = os.path.split(__file__)[-1].split(".")[0]
157
  accelerator.init_trackers(run)
158
 
159
- mel = Mel(x_res=args.resolution,
160
- y_res=args.resolution,
161
- hop_length=args.hop_length)
162
 
163
  global_step = 0
164
  for epoch in range(args.num_epochs):
@@ -311,7 +326,7 @@ if __name__ == "__main__":
311
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
312
  parser.add_argument("--overwrite_output_dir", type=bool, default=False)
313
  parser.add_argument("--cache_dir", type=str, default=None)
314
- parser.add_argument("--resolution", type=int, default=256)
315
  parser.add_argument("--train_batch_size", type=int, default=16)
316
  parser.add_argument("--eval_batch_size", type=int, default=16)
317
  parser.add_argument("--num_epochs", type=int, default=100)
 
26
  from tqdm.auto import tqdm
27
  from librosa.util import normalize
28
 
29
+ import sys
30
+ sys.path.append('.')
31
+ sys.path.append('..')
32
  from audiodiffusion.mel import Mel
33
  from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
34
 
 
45
  logging_dir=logging_dir,
46
  )
47
 
48
+ # Handle the resolutions.
49
+ try:
50
+ args.resolution = (int(args.resolution), int(args.resolution))
51
+ except:
52
+ try :
53
+ args.resolution = tuple(int(x) for x in args.resolution.split(","))
54
+ if len(args.resolution) != 2:
55
+ raise ValueError("Resolution must be a tuple of two integers or a single integer.")
56
+ except:
57
+ raise ValueError("Resolution must be a tuple of two integers or a single integer.")
58
+ assert isinstance(args.resolution, tuple)
59
+
60
  if args.vae is not None:
61
  vqvae = AutoencoderKL.from_pretrained(args.vae)
62
 
 
171
  run = os.path.split(__file__)[-1].split(".")[0]
172
  accelerator.init_trackers(run)
173
 
174
+ mel = Mel(x_res=args.resolution[0],
175
+ y_res=args.resolution[1],
176
+ hop_length=args.hop_length)
177
 
178
  global_step = 0
179
  for epoch in range(args.num_epochs):
 
326
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
327
  parser.add_argument("--overwrite_output_dir", type=bool, default=False)
328
  parser.add_argument("--cache_dir", type=str, default=None)
329
+ parser.add_argument("--resolution", type=str, default="256")
330
  parser.add_argument("--train_batch_size", type=int, default=16)
331
  parser.add_argument("--eval_batch_size", type=int, default=16)
332
  parser.add_argument("--num_epochs", type=int, default=100)