irow commited on
Commit
21853a2
1 Parent(s): 8f3ccae

Upload 2 files

Browse files
conditional-diffusion.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, Dataset
3
+ import torchaudio
4
+ import torchvision.transforms as tvt
5
+ from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion
6
+ import glob
7
+ import torch.nn as nn
8
+ import time, math
9
+ from PIL import Image
10
+ from diffusers import Mel
11
+ import sys
12
+ import torchaudio
13
+ import librosa
14
+ import matplotlib.pyplot as plt
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ args = sys.argv[1:]
19
+
20
+ class Audio(Dataset):
21
+ def __init__(self, folder):
22
+ # resample = tat.Resample(48000)
23
+ self.waveforms = []
24
+ self.labels = []
25
+ print("Loading files...")
26
+ for file in glob.iglob(folder + '/**/*.wav', recursive=True): # recurse through files
27
+ self.labels.append(int(file.split('/')[-1][0])) # get label from file name
28
+ waveform, _ = torchaudio.load(file)
29
+ # waveform, _ = librosa.load(file, sr=None) # load text
30
+ self.waveforms.append(waveform)
31
+
32
+ def __len__(self):
33
+ return len(self.waveforms)
34
+
35
+ def __getitem__(self, index):
36
+ return self.waveforms[index], self.labels[index]
37
+
38
+
39
+ image_size = 256
40
+ if len(args) >= 1:
41
+ image_size = int(args[0])
42
+
43
+ MEL = Mel(x_res=image_size, y_res=image_size)
44
+ img_to_tensor = tvt.PILToTensor()
45
+
46
+ def collate(batch):
47
+ spectros = []
48
+ labels = []
49
+ for waveform, label in batch:
50
+ MEL.load_audio(raw_audio=waveform[0])
51
+ for slice in range(MEL.get_number_of_slices()):
52
+ spectro = MEL.audio_slice_to_image(slice)
53
+ spectro = img_to_tensor(spectro) / 255.0
54
+ # print(spectro.shape)
55
+ # plt.imshow(spectro[0])
56
+ # plt.show()
57
+ # input("continue")
58
+ spectros.append(spectro)
59
+ labels.append(label)
60
+
61
+ spectros = torch.stack(spectros)
62
+ labels = torch.tensor(labels)
63
+ # one_hot = nn.functional.one_hot(labels, num_classes=10) # one hot vectors for conditional generation
64
+ return spectros.to(device), labels.to(device)
65
+
66
+
67
+ def initialize(scheduler = None, batch_size=32):
68
+ model = Unet(
69
+ dim = 64,
70
+ num_classes=10,
71
+ dim_mults=(1, 2, 4, 8),
72
+ channels=1
73
+ )
74
+ diffusion = GaussianDiffusion(
75
+ model,
76
+ image_size=image_size,
77
+ timesteps=1000,
78
+ loss_type = 'l2',
79
+ objective='pred_x0',
80
+ # channels=1,
81
+ )
82
+ diffusion.to(device)
83
+
84
+ optim = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8)
85
+ if scheduler:
86
+ scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-5, max_lr=1e-3, mode="exp_range", cycle_momentum=False)
87
+ return diffusion, optim, scheduler
88
+
89
+ def timeSince(since):
90
+ now = time.time()
91
+ s = now - since
92
+ m = math.floor(s / 60)
93
+ s -= m * 60
94
+ return '%dm %ds' % (m, s)
95
+
96
+ start = time.time()
97
+
98
+ def train(model, optim, train_dl, batch_size=32, epochs=5, scheduler = None):
99
+ size = len(train_dl.dataset)
100
+ model.train()
101
+ losses = []
102
+
103
+ for e in range(epochs):
104
+ batch_loss, batch_counts = 0, 0
105
+ for step, batch in enumerate(train_dl):
106
+ model.zero_grad()
107
+ batch_counts += 1
108
+ spectros, labels = batch
109
+ loss = model(spectros, classes=labels)
110
+
111
+ batch_loss += loss.item()
112
+ loss.backward()
113
+ nn.utils.clip_grad_norm_(model.parameters(), 1)
114
+ optim.step()
115
+ if scheduler is not None:
116
+ scheduler.step()
117
+
118
+ if (step % 100 == 0 and step != 0) or (step == len(train_dl) - 1):
119
+ to_print = f"{e + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {timeSince(start)} | {step*batch_size:>5d}/{size:>5d}"
120
+ print(to_print)
121
+ losses.append(batch_loss)
122
+ batch_loss, batch_counts = 0, 0
123
+
124
+ labels = torch.randint(0,9,(8, )).to(device)
125
+ print(labels)
126
+ samples = model.sample(labels)
127
+ for i, sample in enumerate(samples):
128
+ im = Image.fromarray(sample[0].cpu().numpy() * 255).convert('L')
129
+ audio = torch.tensor([MEL.image_to_audio(im)])
130
+ torchaudio.save(f"audio/sample{e}_{i}_{labels[i]}.wav", audio, 48000)
131
+ im.save(f"images/sample{e}_{i}_{labels[i]}.jpg")
132
+ return losses
133
+
134
+ if __name__ == "__main__":
135
+ num_epochs = 10
136
+ if len(args) >= 2:
137
+ num_epochs = int(args[1])
138
+
139
+ batch_size = 32
140
+ if len(args) >= 3:
141
+ batch_size = int(args[2])
142
+
143
+ print(image_size, num_epochs, batch_size)
144
+ model, optim, scheduler = initialize(scheduler=True, batch_size=batch_size)
145
+ train_data = Audio("AudioMNIST/data")
146
+ print("Done Loading")
147
+ train_dl = DataLoader(train_data, batch_size, True, collate_fn=collate)
148
+ train(model, optim, train_dl, batch_size, num_epochs, scheduler)
149
+ torch.save(model.state_dict(), "diffusion_condition_model.pt")
diffusion_condition_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25021d8c6a1813ba51f2f7cb9d015b132f7f21d2deaad397fac0d641cdc671cc
3
+ size 153739669