Upload 2 files
Browse files- conditional-diffusion.py +149 -0
- diffusion_condition_model.pt +3 -0
conditional-diffusion.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, Dataset
|
3 |
+
import torchaudio
|
4 |
+
import torchvision.transforms as tvt
|
5 |
+
from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion
|
6 |
+
import glob
|
7 |
+
import torch.nn as nn
|
8 |
+
import time, math
|
9 |
+
from PIL import Image
|
10 |
+
from diffusers import Mel
|
11 |
+
import sys
|
12 |
+
import torchaudio
|
13 |
+
import librosa
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
args = sys.argv[1:]
|
19 |
+
|
20 |
+
class Audio(Dataset):
|
21 |
+
def __init__(self, folder):
|
22 |
+
# resample = tat.Resample(48000)
|
23 |
+
self.waveforms = []
|
24 |
+
self.labels = []
|
25 |
+
print("Loading files...")
|
26 |
+
for file in glob.iglob(folder + '/**/*.wav', recursive=True): # recurse through files
|
27 |
+
self.labels.append(int(file.split('/')[-1][0])) # get label from file name
|
28 |
+
waveform, _ = torchaudio.load(file)
|
29 |
+
# waveform, _ = librosa.load(file, sr=None) # load text
|
30 |
+
self.waveforms.append(waveform)
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.waveforms)
|
34 |
+
|
35 |
+
def __getitem__(self, index):
|
36 |
+
return self.waveforms[index], self.labels[index]
|
37 |
+
|
38 |
+
|
39 |
+
image_size = 256
|
40 |
+
if len(args) >= 1:
|
41 |
+
image_size = int(args[0])
|
42 |
+
|
43 |
+
MEL = Mel(x_res=image_size, y_res=image_size)
|
44 |
+
img_to_tensor = tvt.PILToTensor()
|
45 |
+
|
46 |
+
def collate(batch):
|
47 |
+
spectros = []
|
48 |
+
labels = []
|
49 |
+
for waveform, label in batch:
|
50 |
+
MEL.load_audio(raw_audio=waveform[0])
|
51 |
+
for slice in range(MEL.get_number_of_slices()):
|
52 |
+
spectro = MEL.audio_slice_to_image(slice)
|
53 |
+
spectro = img_to_tensor(spectro) / 255.0
|
54 |
+
# print(spectro.shape)
|
55 |
+
# plt.imshow(spectro[0])
|
56 |
+
# plt.show()
|
57 |
+
# input("continue")
|
58 |
+
spectros.append(spectro)
|
59 |
+
labels.append(label)
|
60 |
+
|
61 |
+
spectros = torch.stack(spectros)
|
62 |
+
labels = torch.tensor(labels)
|
63 |
+
# one_hot = nn.functional.one_hot(labels, num_classes=10) # one hot vectors for conditional generation
|
64 |
+
return spectros.to(device), labels.to(device)
|
65 |
+
|
66 |
+
|
67 |
+
def initialize(scheduler = None, batch_size=32):
|
68 |
+
model = Unet(
|
69 |
+
dim = 64,
|
70 |
+
num_classes=10,
|
71 |
+
dim_mults=(1, 2, 4, 8),
|
72 |
+
channels=1
|
73 |
+
)
|
74 |
+
diffusion = GaussianDiffusion(
|
75 |
+
model,
|
76 |
+
image_size=image_size,
|
77 |
+
timesteps=1000,
|
78 |
+
loss_type = 'l2',
|
79 |
+
objective='pred_x0',
|
80 |
+
# channels=1,
|
81 |
+
)
|
82 |
+
diffusion.to(device)
|
83 |
+
|
84 |
+
optim = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8)
|
85 |
+
if scheduler:
|
86 |
+
scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-5, max_lr=1e-3, mode="exp_range", cycle_momentum=False)
|
87 |
+
return diffusion, optim, scheduler
|
88 |
+
|
89 |
+
def timeSince(since):
|
90 |
+
now = time.time()
|
91 |
+
s = now - since
|
92 |
+
m = math.floor(s / 60)
|
93 |
+
s -= m * 60
|
94 |
+
return '%dm %ds' % (m, s)
|
95 |
+
|
96 |
+
start = time.time()
|
97 |
+
|
98 |
+
def train(model, optim, train_dl, batch_size=32, epochs=5, scheduler = None):
|
99 |
+
size = len(train_dl.dataset)
|
100 |
+
model.train()
|
101 |
+
losses = []
|
102 |
+
|
103 |
+
for e in range(epochs):
|
104 |
+
batch_loss, batch_counts = 0, 0
|
105 |
+
for step, batch in enumerate(train_dl):
|
106 |
+
model.zero_grad()
|
107 |
+
batch_counts += 1
|
108 |
+
spectros, labels = batch
|
109 |
+
loss = model(spectros, classes=labels)
|
110 |
+
|
111 |
+
batch_loss += loss.item()
|
112 |
+
loss.backward()
|
113 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1)
|
114 |
+
optim.step()
|
115 |
+
if scheduler is not None:
|
116 |
+
scheduler.step()
|
117 |
+
|
118 |
+
if (step % 100 == 0 and step != 0) or (step == len(train_dl) - 1):
|
119 |
+
to_print = f"{e + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {timeSince(start)} | {step*batch_size:>5d}/{size:>5d}"
|
120 |
+
print(to_print)
|
121 |
+
losses.append(batch_loss)
|
122 |
+
batch_loss, batch_counts = 0, 0
|
123 |
+
|
124 |
+
labels = torch.randint(0,9,(8, )).to(device)
|
125 |
+
print(labels)
|
126 |
+
samples = model.sample(labels)
|
127 |
+
for i, sample in enumerate(samples):
|
128 |
+
im = Image.fromarray(sample[0].cpu().numpy() * 255).convert('L')
|
129 |
+
audio = torch.tensor([MEL.image_to_audio(im)])
|
130 |
+
torchaudio.save(f"audio/sample{e}_{i}_{labels[i]}.wav", audio, 48000)
|
131 |
+
im.save(f"images/sample{e}_{i}_{labels[i]}.jpg")
|
132 |
+
return losses
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
num_epochs = 10
|
136 |
+
if len(args) >= 2:
|
137 |
+
num_epochs = int(args[1])
|
138 |
+
|
139 |
+
batch_size = 32
|
140 |
+
if len(args) >= 3:
|
141 |
+
batch_size = int(args[2])
|
142 |
+
|
143 |
+
print(image_size, num_epochs, batch_size)
|
144 |
+
model, optim, scheduler = initialize(scheduler=True, batch_size=batch_size)
|
145 |
+
train_data = Audio("AudioMNIST/data")
|
146 |
+
print("Done Loading")
|
147 |
+
train_dl = DataLoader(train_data, batch_size, True, collate_fn=collate)
|
148 |
+
train(model, optim, train_dl, batch_size, num_epochs, scheduler)
|
149 |
+
torch.save(model.state_dict(), "diffusion_condition_model.pt")
|
diffusion_condition_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25021d8c6a1813ba51f2f7cb9d015b132f7f21d2deaad397fac0d641cdc671cc
|
3 |
+
size 153739669
|