import torch from torch.utils.data import DataLoader, Dataset import torchaudio import torchvision.transforms as tvt from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion import glob import torch.nn as nn import time, math from PIL import Image from diffusers import Mel import sys import torchaudio import librosa import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args = sys.argv[1:] class Audio(Dataset): def __init__(self, folder): # resample = tat.Resample(48000) self.waveforms = [] self.labels = [] print("Loading files...") for file in glob.iglob(folder + '/**/*.wav', recursive=True): # recurse through files self.labels.append(int(file.split('/')[-1][0])) # get label from file name waveform, _ = torchaudio.load(file) # waveform, _ = librosa.load(file, sr=None) # load text self.waveforms.append(waveform) def __len__(self): return len(self.waveforms) def __getitem__(self, index): return self.waveforms[index], self.labels[index] image_size = 256 if len(args) >= 1: image_size = int(args[0]) MEL = Mel(x_res=image_size, y_res=image_size) img_to_tensor = tvt.PILToTensor() def collate(batch): spectros = [] labels = [] for waveform, label in batch: MEL.load_audio(raw_audio=waveform[0]) for slice in range(MEL.get_number_of_slices()): spectro = MEL.audio_slice_to_image(slice) spectro = img_to_tensor(spectro) / 255.0 # print(spectro.shape) # plt.imshow(spectro[0]) # plt.show() # input("continue") spectros.append(spectro) labels.append(label) spectros = torch.stack(spectros) labels = torch.tensor(labels) # one_hot = nn.functional.one_hot(labels, num_classes=10) # one hot vectors for conditional generation return spectros.to(device), labels.to(device) def initialize(scheduler = None, batch_size=32): model = Unet( dim = 64, num_classes=10, dim_mults=(1, 2, 4, 8), channels=1 ) diffusion = GaussianDiffusion( model, image_size=image_size, timesteps=1000, loss_type = 'l2', objective='pred_x0', # channels=1, ) diffusion.to(device) optim = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8) if scheduler: scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-5, max_lr=1e-3, mode="exp_range", cycle_momentum=False) return diffusion, optim, scheduler def timeSince(since): now = time.time() s = now - since m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) start = time.time() def train(model, optim, train_dl, batch_size=32, epochs=5, scheduler = None): size = len(train_dl.dataset) model.train() losses = [] for e in range(epochs): batch_loss, batch_counts = 0, 0 for step, batch in enumerate(train_dl): model.zero_grad() batch_counts += 1 spectros, labels = batch loss = model(spectros, classes=labels) batch_loss += loss.item() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1) optim.step() if scheduler is not None: scheduler.step() if (step % 100 == 0 and step != 0) or (step == len(train_dl) - 1): to_print = f"{e + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {timeSince(start)} | {step*batch_size:>5d}/{size:>5d}" print(to_print) losses.append(batch_loss) batch_loss, batch_counts = 0, 0 labels = torch.randint(0,9,(8, )).to(device) print(labels) samples = model.sample(labels) for i, sample in enumerate(samples): im = Image.fromarray(sample[0].cpu().numpy() * 255).convert('L') audio = torch.tensor([MEL.image_to_audio(im)]) torchaudio.save(f"audio/sample{e}_{i}_{labels[i]}.wav", audio, 48000) im.save(f"images/sample{e}_{i}_{labels[i]}.jpg") return losses if __name__ == "__main__": num_epochs = 10 if len(args) >= 2: num_epochs = int(args[1]) batch_size = 32 if len(args) >= 3: batch_size = int(args[2]) print(image_size, num_epochs, batch_size) model, optim, scheduler = initialize(scheduler=True, batch_size=batch_size) train_data = Audio("AudioMNIST/data") print("Done Loading") train_dl = DataLoader(train_data, batch_size, True, collate_fn=collate) train(model, optim, train_dl, batch_size, num_epochs, scheduler) torch.save(model.state_dict(), "diffusion_condition_model.pt")