teticio commited on
Commit
1dea888
1 Parent(s): 2d68d80

initial working version

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +18 -0
  3. src/audio_to_images.py +37 -10
  4. src/train_unconditional.py +68 -19
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  __pycache__
3
  .ipynb_checkpoints
4
  data
 
 
2
  __pycache__
3
  .ipynb_checkpoints
4
  data
5
+ ddpm-ema-audio-*
README.md CHANGED
@@ -1 +1,19 @@
1
  # audio-diffusion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # audio-diffusion
2
+ ```bash
3
+ python src/audio_to_images.py \
4
+ --resolution=256 \
5
+ --input_dir=path-to-audio-files \
6
+ --output_dir=data
7
+ ```
8
+ ```bash
9
+ accelerate launch src/train_unconditional.py \
10
+ --dataset_name="data" \
11
+ --resolution=256 \
12
+ --output_dir="ddpm-ema-audio-256" \
13
+ --train_batch_size=16 \
14
+ --num_epochs=100 \
15
+ --gradient_accumulation_steps=1 \
16
+ --learning_rate=1e-4 \
17
+ --lr_warmup_steps=500 \
18
+ --mixed_precision=no
19
+ ```
src/audio_to_images.py CHANGED
@@ -1,15 +1,17 @@
1
  import os
2
  import re
3
- import json
4
  import argparse
5
 
 
6
  from tqdm.auto import tqdm
 
7
 
8
  from mel import Mel
9
 
10
 
11
  def main(args):
12
- mel = Mel(x_res=args.resolution, y_res=args.resolution)
13
  os.makedirs(args.output_dir, exist_ok=True)
14
  audio_files = [
15
  os.path.join(root, file)
@@ -17,9 +19,9 @@ def main(args):
17
  for file in files
18
  if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
19
  ]
20
- meta_data = {}
21
  try:
22
- for audio, audio_file in enumerate(tqdm(audio_files)):
23
  try:
24
  mel.load_audio(audio_file)
25
  except KeyboardInterrupt:
@@ -28,18 +30,43 @@ def main(args):
28
  continue
29
  for slice in range(mel.get_number_of_slices()):
30
  image = mel.audio_slice_to_image(slice)
31
- image_file = f"{audio}_{slice}.png"
32
- image.save(os.path.join(args.output_dir, image_file))
33
- meta_data[image_file] = audio_file
 
 
 
 
 
 
 
 
 
 
 
 
34
  finally:
35
- with open(os.path.join(args.output_dir, 'meta_data.json'), 'wt') as file:
36
- file.write(json.dumps(meta_data))
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  if __name__ == "__main__":
40
- parser = argparse.ArgumentParser(description="Convert audio into Mel spectrograms.")
 
 
41
  parser.add_argument("--input_dir", type=str)
42
  parser.add_argument("--output_dir", type=str, default="data")
43
  parser.add_argument("--resolution", type=int, default=256)
 
44
  args = parser.parse_args()
45
  main(args)
 
1
  import os
2
  import re
3
+ import io
4
  import argparse
5
 
6
+ import pandas as pd
7
  from tqdm.auto import tqdm
8
+ from datasets import Dataset, DatasetDict, Features, Image, Value
9
 
10
  from mel import Mel
11
 
12
 
13
  def main(args):
14
+ mel = Mel(x_res=args.resolution, y_res=args.resolution, hop_length=args.hop_length)
15
  os.makedirs(args.output_dir, exist_ok=True)
16
  audio_files = [
17
  os.path.join(root, file)
 
19
  for file in files
20
  if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
21
  ]
22
+ examples = []
23
  try:
24
+ for audio_file in tqdm(audio_files):
25
  try:
26
  mel.load_audio(audio_file)
27
  except KeyboardInterrupt:
 
30
  continue
31
  for slice in range(mel.get_number_of_slices()):
32
  image = mel.audio_slice_to_image(slice)
33
+ assert (
34
+ image.width == args.resolution and image.height == args.resolution
35
+ )
36
+ with io.BytesIO() as output:
37
+ image.save(output, format="PNG")
38
+ bytes = output.getvalue()
39
+ examples.extend(
40
+ [
41
+ {
42
+ "image": {"bytes": bytes},
43
+ "audio_file": audio_file,
44
+ "slice": slice,
45
+ }
46
+ ]
47
+ )
48
  finally:
49
+ ds = Dataset.from_pandas(
50
+ pd.DataFrame(examples),
51
+ features=Features(
52
+ {
53
+ "image": Image(),
54
+ "audio_file": Value(dtype="string"),
55
+ "slice": Value(dtype="int16"),
56
+ }
57
+ ),
58
+ )
59
+ dsd = DatasetDict({"train": ds})
60
+ dsd.save_to_disk(os.path.join(args.output_dir))
61
 
62
 
63
  if __name__ == "__main__":
64
+ parser = argparse.ArgumentParser(
65
+ description="Create dataset of Mel spectrograms from directory of audio files."
66
+ )
67
  parser.add_argument("--input_dir", type=str)
68
  parser.add_argument("--output_dir", type=str, default="data")
69
  parser.add_argument("--resolution", type=int, default=256)
70
+ parser.add_argument("--hop_length", type=int, default=512)
71
  args = parser.parse_args()
72
  main(args)
src/train_unconditional.py CHANGED
@@ -3,10 +3,12 @@ import os
3
 
4
  import torch
5
  import torch.nn.functional as F
 
 
6
 
7
  from accelerate import Accelerator
8
  from accelerate.logging import get_logger
9
- from datasets import load_dataset
10
  from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
11
  from diffusers.hub_utils import init_git_repo, push_to_hub
12
  from diffusers.optimization import get_scheduler
@@ -22,6 +24,7 @@ from torchvision.transforms import (
22
  )
23
  from tqdm.auto import tqdm
24
 
 
25
 
26
  logger = get_logger(__name__)
27
 
@@ -77,35 +80,42 @@ def main(args):
77
  )
78
 
79
  if args.dataset_name is not None:
 
 
80
  dataset = load_dataset(
81
- args.dataset_name,
82
- args.dataset_config_name,
83
  cache_dir=args.cache_dir,
84
- use_auth_token=True if args.use_auth_token else None,
85
  split="train",
86
  )
87
- else:
88
- dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
89
 
90
  def transforms(examples):
91
  images = [augmentations(image.convert("RGB")) for image in examples["image"]]
92
  return {"input": images}
93
 
94
  dataset.set_transform(transforms)
95
- train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
 
 
96
 
97
  lr_scheduler = get_scheduler(
98
  args.lr_scheduler,
99
  optimizer=optimizer,
100
  num_warmup_steps=args.lr_warmup_steps,
101
- num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
 
102
  )
103
 
104
  model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
105
  model, optimizer, train_dataloader, lr_scheduler
106
  )
107
 
108
- ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
 
 
 
 
 
109
 
110
  if args.push_to_hub:
111
  repo = init_git_repo(args, at_init=True)
@@ -114,10 +124,14 @@ def main(args):
114
  run = os.path.split(__file__)[-1].split(".")[0]
115
  accelerator.init_trackers(run)
116
 
 
 
117
  global_step = 0
118
  for epoch in range(args.num_epochs):
119
  model.train()
120
- progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
 
 
121
  progress_bar.set_description(f"Epoch {epoch}")
122
  for step, batch in enumerate(train_dataloader):
123
  clean_images = batch["input"]
@@ -126,7 +140,10 @@ def main(args):
126
  bsz = clean_images.shape[0]
127
  # Sample a random timestep for each image
128
  timesteps = torch.randint(
129
- 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
 
 
 
130
  ).long()
131
 
132
  # Add noise to the clean images according to the noise magnitude at each timestep
@@ -147,7 +164,11 @@ def main(args):
147
  optimizer.zero_grad()
148
 
149
  progress_bar.update(1)
150
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
 
 
 
 
151
  if args.use_ema:
152
  logs["ema_decay"] = ema_model.decay
153
  progress_bar.set_postfix(**logs)
@@ -161,24 +182,44 @@ def main(args):
161
  if accelerator.is_main_process:
162
  if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
163
  pipeline = DDPMPipeline(
164
- unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
 
 
165
  scheduler=noise_scheduler,
166
  )
167
 
168
  generator = torch.manual_seed(0)
169
  # run pipeline in inference (sample random noise and denoise)
170
- images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
 
 
 
 
171
 
172
  # denormalize the images and save to tensorboard
173
- images_processed = (images * 255).round().astype("uint8")
 
 
174
  accelerator.trackers[0].writer.add_images(
175
- "test_samples", images_processed.transpose(0, 3, 1, 2), epoch
176
  )
 
 
 
 
 
 
177
 
178
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
179
  # save the model
180
  if args.push_to_hub:
181
- push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
 
 
 
 
 
 
182
  else:
183
  pipeline.save_pretrained(args.output_dir)
184
  accelerator.wait_for_everyone()
@@ -191,7 +232,12 @@ if __name__ == "__main__":
191
  parser.add_argument("--local_rank", type=int, default=-1)
192
  parser.add_argument("--dataset_name", type=str, default=None)
193
  parser.add_argument("--dataset_config_name", type=str, default=None)
194
- parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
 
 
 
 
 
195
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
196
  parser.add_argument("--overwrite_output_dir", action="store_true")
197
  parser.add_argument("--cache_dir", type=str, default=None)
@@ -230,6 +276,7 @@ if __name__ == "__main__":
230
  "and an Nvidia Ampere GPU."
231
  ),
232
  )
 
233
 
234
  args = parser.parse_args()
235
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -237,6 +284,8 @@ if __name__ == "__main__":
237
  args.local_rank = env_local_rank
238
 
239
  if args.dataset_name is None and args.train_data_dir is None:
240
- raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
 
 
241
 
242
  main(args)
 
3
 
4
  import torch
5
  import torch.nn.functional as F
6
+ import numpy as np
7
+ from PIL import Image
8
 
9
  from accelerate import Accelerator
10
  from accelerate.logging import get_logger
11
+ from datasets import load_from_disk, load_dataset
12
  from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
13
  from diffusers.hub_utils import init_git_repo, push_to_hub
14
  from diffusers.optimization import get_scheduler
 
24
  )
25
  from tqdm.auto import tqdm
26
 
27
+ from mel import Mel
28
 
29
  logger = get_logger(__name__)
30
 
 
80
  )
81
 
82
  if args.dataset_name is not None:
83
+ dataset = load_from_disk(args.dataset_name, args.dataset_config_name)["train"]
84
+ else:
85
  dataset = load_dataset(
86
+ "imagefolder",
87
+ data_dir=args.train_data_dir,
88
  cache_dir=args.cache_dir,
 
89
  split="train",
90
  )
 
 
91
 
92
  def transforms(examples):
93
  images = [augmentations(image.convert("RGB")) for image in examples["image"]]
94
  return {"input": images}
95
 
96
  dataset.set_transform(transforms)
97
+ train_dataloader = torch.utils.data.DataLoader(
98
+ dataset, batch_size=args.train_batch_size, shuffle=True
99
+ )
100
 
101
  lr_scheduler = get_scheduler(
102
  args.lr_scheduler,
103
  optimizer=optimizer,
104
  num_warmup_steps=args.lr_warmup_steps,
105
+ num_training_steps=(len(train_dataloader) * args.num_epochs)
106
+ // args.gradient_accumulation_steps,
107
  )
108
 
109
  model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
110
  model, optimizer, train_dataloader, lr_scheduler
111
  )
112
 
113
+ ema_model = EMAModel(
114
+ model,
115
+ inv_gamma=args.ema_inv_gamma,
116
+ power=args.ema_power,
117
+ max_value=args.ema_max_decay,
118
+ )
119
 
120
  if args.push_to_hub:
121
  repo = init_git_repo(args, at_init=True)
 
124
  run = os.path.split(__file__)[-1].split(".")[0]
125
  accelerator.init_trackers(run)
126
 
127
+ mel = Mel(x_res=args.resolution, y_res=args.resolution, hop_length=args.hop_length)
128
+
129
  global_step = 0
130
  for epoch in range(args.num_epochs):
131
  model.train()
132
+ progress_bar = tqdm(
133
+ total=len(train_dataloader), disable=not accelerator.is_local_main_process
134
+ )
135
  progress_bar.set_description(f"Epoch {epoch}")
136
  for step, batch in enumerate(train_dataloader):
137
  clean_images = batch["input"]
 
140
  bsz = clean_images.shape[0]
141
  # Sample a random timestep for each image
142
  timesteps = torch.randint(
143
+ 0,
144
+ noise_scheduler.num_train_timesteps,
145
+ (bsz,),
146
+ device=clean_images.device,
147
  ).long()
148
 
149
  # Add noise to the clean images according to the noise magnitude at each timestep
 
164
  optimizer.zero_grad()
165
 
166
  progress_bar.update(1)
167
+ logs = {
168
+ "loss": loss.detach().item(),
169
+ "lr": lr_scheduler.get_last_lr()[0],
170
+ "step": global_step,
171
+ }
172
  if args.use_ema:
173
  logs["ema_decay"] = ema_model.decay
174
  progress_bar.set_postfix(**logs)
 
182
  if accelerator.is_main_process:
183
  if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
184
  pipeline = DDPMPipeline(
185
+ unet=accelerator.unwrap_model(
186
+ ema_model.averaged_model if args.use_ema else model
187
+ ),
188
  scheduler=noise_scheduler,
189
  )
190
 
191
  generator = torch.manual_seed(0)
192
  # run pipeline in inference (sample random noise and denoise)
193
+ images = pipeline(
194
+ generator=generator,
195
+ batch_size=args.eval_batch_size,
196
+ output_type="numpy",
197
+ )["sample"]
198
 
199
  # denormalize the images and save to tensorboard
200
+ images_processed = (
201
+ (images * 255).round().astype("uint8").transpose(0, 3, 1, 2)
202
+ )
203
  accelerator.trackers[0].writer.add_images(
204
+ "test_samples", images_processed, epoch
205
  )
206
+ for image in images_processed:
207
+ image = Image.fromarray(np.mean(image, axis=0).astype("uint8"))
208
+ audio = mel.image_to_audio(image)
209
+ accelerator.trackers[0].writer.add_audio(
210
+ "test_samples", audio, epoch, sample_rate=mel.get_sample_rate()
211
+ )
212
 
213
  if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
214
  # save the model
215
  if args.push_to_hub:
216
+ push_to_hub(
217
+ args,
218
+ pipeline,
219
+ repo,
220
+ commit_message=f"Epoch {epoch}",
221
+ blocking=False,
222
+ )
223
  else:
224
  pipeline.save_pretrained(args.output_dir)
225
  accelerator.wait_for_everyone()
 
232
  parser.add_argument("--local_rank", type=int, default=-1)
233
  parser.add_argument("--dataset_name", type=str, default=None)
234
  parser.add_argument("--dataset_config_name", type=str, default=None)
235
+ parser.add_argument(
236
+ "--train_data_dir",
237
+ type=str,
238
+ default=None,
239
+ help="A folder containing the training data.",
240
+ )
241
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
242
  parser.add_argument("--overwrite_output_dir", action="store_true")
243
  parser.add_argument("--cache_dir", type=str, default=None)
 
276
  "and an Nvidia Ampere GPU."
277
  ),
278
  )
279
+ parser.add_argument("--hop_length", type=int, default=512)
280
 
281
  args = parser.parse_args()
282
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 
284
  args.local_rank = env_local_rank
285
 
286
  if args.dataset_name is None and args.train_data_dir is None:
287
+ raise ValueError(
288
+ "You must specify either a dataset name from the hub or a train data directory."
289
+ )
290
 
291
  main(args)