teticio commited on
Commit
9c0c5c8
·
1 Parent(s): 4b16ff2

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .vscode
2
+ __pycache__
3
+ .ipynb_checkpoints
4
+ data
notebooks/test-mel.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/audio_to_images.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import argparse
5
+
6
+ from tqdm.auto import tqdm
7
+
8
+ from mel import Mel
9
+
10
+
11
+ def main(args):
12
+ mel = Mel(x_res=args.resolution, y_res=args.resolution)
13
+ os.makedirs(args.output_dir, exist_ok=True)
14
+ audio_files = [
15
+ os.path.join(root, file)
16
+ for root, _, files in os.walk(args.input_dir)
17
+ for file in files
18
+ if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
19
+ ]
20
+ meta_data = {}
21
+ try:
22
+ for audio, audio_file in enumerate(tqdm(audio_files)):
23
+ try:
24
+ mel.load_audio(audio_file)
25
+ except KeyboardInterrupt:
26
+ raise
27
+ except:
28
+ continue
29
+ for slice in range(mel.get_number_of_slices()):
30
+ image = mel.audio_slice_to_image(slice)
31
+ image_file = f"{audio}_{slice}.png"
32
+ image.save(os.path.join(args.output_dir, image_file))
33
+ meta_data[image_file] = audio_file
34
+ finally:
35
+ with open(os.path.join(args.output_dir, 'meta_data.json'), 'wt') as file:
36
+ file.write(json.dumps(meta_data))
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser(description="Convert audio into Mel spectrograms.")
41
+ parser.add_argument("--input_dir", type=str)
42
+ parser.add_argument("--output_dir", type=str, default="data")
43
+ parser.add_argument("--resolution", type=int, default=256)
44
+ args = parser.parse_args()
45
+ main(args)
src/mel.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings('ignore')
3
+
4
+ import librosa
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+
9
+ class Mel:
10
+ def __init__(
11
+ self,
12
+ x_res=256,
13
+ y_res=256,
14
+ sample_rate=22050,
15
+ n_fft=2048,
16
+ hop_length=512,
17
+ top_db=80,
18
+ ):
19
+ self.x_res = x_res
20
+ self.y_res = y_res
21
+ self.sr = sample_rate
22
+ self.n_fft = n_fft
23
+ self.hop_length = hop_length
24
+ self.n_mels = self.y_res
25
+ self.slice_size = self.x_res * self.hop_length - 1
26
+ self.fmax = self.sr / 2
27
+ self.top_db = top_db
28
+ self.y = None
29
+
30
+ def load_audio(self, audio_file):
31
+ self.y, _ = librosa.load(audio_file, mono=True)
32
+
33
+ def get_number_of_slices(self):
34
+ return len(self.y) // self.slice_size
35
+
36
+ def get_sample_rate(self):
37
+ return self.sr
38
+
39
+ def audio_slice_to_image(self, slice):
40
+ S = librosa.feature.melspectrogram(
41
+ y=self.y[self.slice_size * slice : self.slice_size * (slice + 1)],
42
+ sr=self.sr,
43
+ n_fft=self.n_fft,
44
+ hop_length=self.hop_length,
45
+ n_mels=self.n_mels,
46
+ fmax=self.fmax,
47
+ )
48
+ log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
49
+ bytedata = (
50
+ ((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5
51
+ ).astype(np.uint8)
52
+ image = Image.frombytes("L", log_S.shape, bytedata.tobytes())
53
+ return image
54
+
55
+ def image_to_audio(self, image):
56
+ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
57
+ (image.width, image.height)
58
+ )
59
+ log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
60
+ S = librosa.db_to_power(log_S)
61
+ audio = librosa.feature.inverse.mel_to_audio(
62
+ S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length
63
+ )
64
+ return audio
src/train_unconditional.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from accelerate import Accelerator
8
+ from accelerate.logging import get_logger
9
+ from datasets import load_dataset
10
+ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
11
+ from diffusers.hub_utils import init_git_repo, push_to_hub
12
+ from diffusers.optimization import get_scheduler
13
+ from diffusers.training_utils import EMAModel
14
+ from torchvision.transforms import (
15
+ CenterCrop,
16
+ Compose,
17
+ InterpolationMode,
18
+ Normalize,
19
+ RandomHorizontalFlip,
20
+ Resize,
21
+ ToTensor,
22
+ )
23
+ from tqdm.auto import tqdm
24
+
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ def main(args):
30
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
31
+ accelerator = Accelerator(
32
+ mixed_precision=args.mixed_precision,
33
+ log_with="tensorboard",
34
+ logging_dir=logging_dir,
35
+ )
36
+
37
+ model = UNet2DModel(
38
+ sample_size=args.resolution,
39
+ in_channels=3,
40
+ out_channels=3,
41
+ layers_per_block=2,
42
+ block_out_channels=(128, 128, 256, 256, 512, 512),
43
+ down_block_types=(
44
+ "DownBlock2D",
45
+ "DownBlock2D",
46
+ "DownBlock2D",
47
+ "DownBlock2D",
48
+ "AttnDownBlock2D",
49
+ "DownBlock2D",
50
+ ),
51
+ up_block_types=(
52
+ "UpBlock2D",
53
+ "AttnUpBlock2D",
54
+ "UpBlock2D",
55
+ "UpBlock2D",
56
+ "UpBlock2D",
57
+ "UpBlock2D",
58
+ ),
59
+ )
60
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
61
+ optimizer = torch.optim.AdamW(
62
+ model.parameters(),
63
+ lr=args.learning_rate,
64
+ betas=(args.adam_beta1, args.adam_beta2),
65
+ weight_decay=args.adam_weight_decay,
66
+ eps=args.adam_epsilon,
67
+ )
68
+
69
+ augmentations = Compose(
70
+ [
71
+ Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
72
+ CenterCrop(args.resolution),
73
+ RandomHorizontalFlip(),
74
+ ToTensor(),
75
+ Normalize([0.5], [0.5]),
76
+ ]
77
+ )
78
+
79
+ if args.dataset_name is not None:
80
+ dataset = load_dataset(
81
+ args.dataset_name,
82
+ args.dataset_config_name,
83
+ cache_dir=args.cache_dir,
84
+ use_auth_token=True if args.use_auth_token else None,
85
+ split="train",
86
+ )
87
+ else:
88
+ dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
89
+
90
+ def transforms(examples):
91
+ images = [augmentations(image.convert("RGB")) for image in examples["image"]]
92
+ return {"input": images}
93
+
94
+ dataset.set_transform(transforms)
95
+ train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
96
+
97
+ lr_scheduler = get_scheduler(
98
+ args.lr_scheduler,
99
+ optimizer=optimizer,
100
+ num_warmup_steps=args.lr_warmup_steps,
101
+ num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
102
+ )
103
+
104
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
105
+ model, optimizer, train_dataloader, lr_scheduler
106
+ )
107
+
108
+ ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
109
+
110
+ if args.push_to_hub:
111
+ repo = init_git_repo(args, at_init=True)
112
+
113
+ if accelerator.is_main_process:
114
+ run = os.path.split(__file__)[-1].split(".")[0]
115
+ accelerator.init_trackers(run)
116
+
117
+ global_step = 0
118
+ for epoch in range(args.num_epochs):
119
+ model.train()
120
+ progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
121
+ progress_bar.set_description(f"Epoch {epoch}")
122
+ for step, batch in enumerate(train_dataloader):
123
+ clean_images = batch["input"]
124
+ # Sample noise that we'll add to the images
125
+ noise = torch.randn(clean_images.shape).to(clean_images.device)
126
+ bsz = clean_images.shape[0]
127
+ # Sample a random timestep for each image
128
+ timesteps = torch.randint(
129
+ 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
130
+ ).long()
131
+
132
+ # Add noise to the clean images according to the noise magnitude at each timestep
133
+ # (this is the forward diffusion process)
134
+ noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
135
+
136
+ with accelerator.accumulate(model):
137
+ # Predict the noise residual
138
+ noise_pred = model(noisy_images, timesteps)["sample"]
139
+ loss = F.mse_loss(noise_pred, noise)
140
+ accelerator.backward(loss)
141
+
142
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
143
+ optimizer.step()
144
+ lr_scheduler.step()
145
+ if args.use_ema:
146
+ ema_model.step(model)
147
+ optimizer.zero_grad()
148
+
149
+ progress_bar.update(1)
150
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
151
+ if args.use_ema:
152
+ logs["ema_decay"] = ema_model.decay
153
+ progress_bar.set_postfix(**logs)
154
+ accelerator.log(logs, step=global_step)
155
+ global_step += 1
156
+ progress_bar.close()
157
+
158
+ accelerator.wait_for_everyone()
159
+
160
+ # Generate sample images for visual inspection
161
+ if accelerator.is_main_process:
162
+ if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
163
+ pipeline = DDPMPipeline(
164
+ unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
165
+ scheduler=noise_scheduler,
166
+ )
167
+
168
+ generator = torch.manual_seed(0)
169
+ # run pipeline in inference (sample random noise and denoise)
170
+ images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
171
+
172
+ # denormalize the images and save to tensorboard
173
+ images_processed = (images * 255).round().astype("uint8")
174
+ accelerator.trackers[0].writer.add_images(
175
+ "test_samples", images_processed.transpose(0, 3, 1, 2), epoch
176
+ )
177
+
178
+ if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
179
+ # save the model
180
+ if args.push_to_hub:
181
+ push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
182
+ else:
183
+ pipeline.save_pretrained(args.output_dir)
184
+ accelerator.wait_for_everyone()
185
+
186
+ accelerator.end_training()
187
+
188
+
189
+ if __name__ == "__main__":
190
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
191
+ parser.add_argument("--local_rank", type=int, default=-1)
192
+ parser.add_argument("--dataset_name", type=str, default=None)
193
+ parser.add_argument("--dataset_config_name", type=str, default=None)
194
+ parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
195
+ parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
196
+ parser.add_argument("--overwrite_output_dir", action="store_true")
197
+ parser.add_argument("--cache_dir", type=str, default=None)
198
+ parser.add_argument("--resolution", type=int, default=64)
199
+ parser.add_argument("--train_batch_size", type=int, default=16)
200
+ parser.add_argument("--eval_batch_size", type=int, default=16)
201
+ parser.add_argument("--num_epochs", type=int, default=100)
202
+ parser.add_argument("--save_images_epochs", type=int, default=10)
203
+ parser.add_argument("--save_model_epochs", type=int, default=10)
204
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
205
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
206
+ parser.add_argument("--lr_scheduler", type=str, default="cosine")
207
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
208
+ parser.add_argument("--adam_beta1", type=float, default=0.95)
209
+ parser.add_argument("--adam_beta2", type=float, default=0.999)
210
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
211
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08)
212
+ parser.add_argument("--use_ema", action="store_true", default=True)
213
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
214
+ parser.add_argument("--ema_power", type=float, default=3 / 4)
215
+ parser.add_argument("--ema_max_decay", type=float, default=0.9999)
216
+ parser.add_argument("--push_to_hub", action="store_true")
217
+ parser.add_argument("--use_auth_token", action="store_true")
218
+ parser.add_argument("--hub_token", type=str, default=None)
219
+ parser.add_argument("--hub_model_id", type=str, default=None)
220
+ parser.add_argument("--hub_private_repo", action="store_true")
221
+ parser.add_argument("--logging_dir", type=str, default="logs")
222
+ parser.add_argument(
223
+ "--mixed_precision",
224
+ type=str,
225
+ default="no",
226
+ choices=["no", "fp16", "bf16"],
227
+ help=(
228
+ "Whether to use mixed precision. Choose"
229
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
230
+ "and an Nvidia Ampere GPU."
231
+ ),
232
+ )
233
+
234
+ args = parser.parse_args()
235
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
236
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
237
+ args.local_rank = env_local_rank
238
+
239
+ if args.dataset_name is None and args.train_data_dir is None:
240
+ raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
241
+
242
+ main(args)