teticio commited on
Commit
5bc60f9
1 Parent(s): c51d0e3

add sample_rate and n_fft params

Browse files
README.md CHANGED
@@ -71,7 +71,9 @@ python scripts/audio_to_images.py \
71
  --output_dir data/audio-diffusion-256 \
72
  --push_to_hub teticio/audio-diffusion-256
73
  ```
74
-
 
 
75
  ## Train model
76
  #### Run training on local machine.
77
  ```bash
 
71
  --output_dir data/audio-diffusion-256 \
72
  --push_to_hub teticio/audio-diffusion-256
73
  ```
74
+
75
+ Note that the default `sample_rate` is 22050 and audios will be resampled if they are at a different rate. If you change this value, you may find that the results in the `test_mel.ipynb` notebook are not good (for example, if `sample_rate` is 48000) and that it is necessary to adjust `n_fft` (for example, to 2000 instead of the default value of 2048; alternatively, you can resample to a `sample_rate` of 44100). Make sure you use the same parameters for training and inference. You should also bear in mind that not all resolutions work with the neural network architecture as currently configured - you should be safe if you stick to powers of 2.
76
+
77
  ## Train model
78
  #### Run training on local machine.
79
  ```bash
scripts/audio_to_images.py CHANGED
@@ -19,7 +19,8 @@ def main(args):
19
  mel = Mel(x_res=args.resolution[0],
20
  y_res=args.resolution[1],
21
  hop_length=args.hop_length,
22
- sample_rate=args.sample_rate)
 
23
  os.makedirs(args.output_dir, exist_ok=True)
24
  audio_files = [
25
  os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
@@ -86,6 +87,7 @@ if __name__ == "__main__":
86
  parser.add_argument("--hop_length", type=int, default=512)
87
  parser.add_argument("--push_to_hub", type=str, default=None)
88
  parser.add_argument("--sample_rate", type=int, default=22050)
 
89
  args = parser.parse_args()
90
 
91
  if args.input_dir is None:
 
19
  mel = Mel(x_res=args.resolution[0],
20
  y_res=args.resolution[1],
21
  hop_length=args.hop_length,
22
+ sample_rate=args.sample_rate,
23
+ n_fft=args.n_fft)
24
  os.makedirs(args.output_dir, exist_ok=True)
25
  audio_files = [
26
  os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
 
87
  parser.add_argument("--hop_length", type=int, default=512)
88
  parser.add_argument("--push_to_hub", type=str, default=None)
89
  parser.add_argument("--sample_rate", type=int, default=22050)
90
+ parser.add_argument("--n_fft", type=int, default=2048)
91
  args = parser.parse_args()
92
 
93
  if args.input_dir is None:
scripts/train_unconditional.py CHANGED
@@ -173,7 +173,9 @@ def main(args):
173
 
174
  mel = Mel(x_res=resolution[1],
175
  y_res=resolution[0],
176
- hop_length=args.hop_length)
 
 
177
 
178
  global_step = 0
179
  for epoch in range(args.num_epochs):
@@ -362,6 +364,8 @@ if __name__ == "__main__":
362
  "and an Nvidia Ampere GPU."),
363
  )
364
  parser.add_argument("--hop_length", type=int, default=512)
 
 
365
  parser.add_argument("--from_pretrained", type=str, default=None)
366
  parser.add_argument("--start_epoch", type=int, default=0)
367
  parser.add_argument("--num_train_steps", type=int, default=1000)
 
173
 
174
  mel = Mel(x_res=resolution[1],
175
  y_res=resolution[0],
176
+ hop_length=args.hop_length,
177
+ sample_rate=args.sample_rate,
178
+ n_fft=args.n_fft)
179
 
180
  global_step = 0
181
  for epoch in range(args.num_epochs):
 
364
  "and an Nvidia Ampere GPU."),
365
  )
366
  parser.add_argument("--hop_length", type=int, default=512)
367
+ parser.add_argument("--sample_rate", type=int, default=22050)
368
+ parser.add_argument("--n_fft", type=int, default=2048)
369
  parser.add_argument("--from_pretrained", type=str, default=None)
370
  parser.add_argument("--start_epoch", type=int, default=0)
371
  parser.add_argument("--num_train_steps", type=int, default=1000)
scripts/train_vae.py CHANGED
@@ -60,10 +60,16 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
60
 
61
  class ImageLogger(Callback):
62
 
63
- def __init__(self, every=1000, hop_length=512):
 
 
 
 
64
  super().__init__()
65
  self.every = every
66
  self.hop_length = hop_length
 
 
67
 
68
  @rank_zero_only
69
  def log_images_and_audios(self, pl_module, batch):
@@ -76,7 +82,9 @@ class ImageLogger(Callback):
76
  channels = image_shape[1]
77
  mel = Mel(x_res=image_shape[2],
78
  y_res=image_shape[3],
79
- hop_length=self.hop_length)
 
 
80
 
81
  for k in images:
82
  images[k] = images[k].detach().cpu()
@@ -145,6 +153,8 @@ if __name__ == "__main__":
145
  type=int,
146
  default=1)
147
  parser.add_argument("--hop_length", type=int, default=512)
 
 
148
  parser.add_argument("--save_images_batches", type=int, default=1000)
149
  parser.add_argument("--max_epochs", type=int, default=100)
150
  args = parser.parse_args()
@@ -166,7 +176,9 @@ if __name__ == "__main__":
166
  resume_from_checkpoint=args.resume_from_checkpoint,
167
  callbacks=[
168
  ImageLogger(every=args.save_images_batches,
169
- hop_length=args.hop_length),
 
 
170
  HFModelCheckpoint(ldm_config=config,
171
  hf_checkpoint=args.hf_checkpoint_dir,
172
  dirpath=args.ldm_checkpoint_dir,
 
60
 
61
  class ImageLogger(Callback):
62
 
63
+ def __init__(self,
64
+ every=1000,
65
+ hop_length=512,
66
+ sample_rate=22050,
67
+ n_fft=2048):
68
  super().__init__()
69
  self.every = every
70
  self.hop_length = hop_length
71
+ self.sample_rate = sample_rate
72
+ self.n_fft = n_fft
73
 
74
  @rank_zero_only
75
  def log_images_and_audios(self, pl_module, batch):
 
82
  channels = image_shape[1]
83
  mel = Mel(x_res=image_shape[2],
84
  y_res=image_shape[3],
85
+ hop_length=self.hop_length,
86
+ sample_rate=self.sample_rate,
87
+ n_fft=self.n_fft)
88
 
89
  for k in images:
90
  images[k] = images[k].detach().cpu()
 
153
  type=int,
154
  default=1)
155
  parser.add_argument("--hop_length", type=int, default=512)
156
+ parser.add_argument("--sample_rate", type=int, default=22050)
157
+ parser.add_argument("--n_fft", type=int, default=2048)
158
  parser.add_argument("--save_images_batches", type=int, default=1000)
159
  parser.add_argument("--max_epochs", type=int, default=100)
160
  args = parser.parse_args()
 
176
  resume_from_checkpoint=args.resume_from_checkpoint,
177
  callbacks=[
178
  ImageLogger(every=args.save_images_batches,
179
+ hop_length=args.hop_length,
180
+ sample_rate=args.sample_rate,
181
+ n_fft=args.n_fft),
182
  HFModelCheckpoint(ldm_config=config,
183
  hf_checkpoint=args.hf_checkpoint_dir,
184
  dirpath=args.ldm_checkpoint_dir,