mattricesound commited on
Commit
a22b103
1 Parent(s): 3356688

Change to train model using pytorch-lightning

Browse files
Files changed (3) hide show
  1. Experiments.ipynb +0 -0
  2. diffusion_test.ipynb +747 -12
  3. setup.py +4 -0
Experiments.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
diffusion_test.ipynb CHANGED
@@ -2,24 +2,91 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
- "from audio_diffusion_pytorch import AudioDiffusionModel\n",
11
  "import torch\n",
12
- "from IPython.display import Audio"
 
 
 
 
 
 
 
 
 
 
 
 
13
  ]
14
  },
15
  {
16
  "cell_type": "code",
17
- "execution_count": 3,
18
  "id": "a005011f-3019-4d34-bdf2-9a00e5480282",
19
  "metadata": {},
20
  "outputs": [],
21
  "source": [
22
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ]
24
  },
25
  {
@@ -29,6 +96,8 @@
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
 
 
32
  "model = AudioDiffusionModel(in_channels=1, \n",
33
  " patch_size=1,\n",
34
  " multipliers=[1, 2, 4, 4, 4, 4, 4],\n",
@@ -36,22 +105,666 @@
36
  " num_blocks=[2, 2, 2, 2, 2, 2],\n",
37
  " attentions=[0, 0, 0, 0, 0, 0]\n",
38
  " )\n",
39
- "model = model.to(device)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ]
41
  },
42
  {
43
  "cell_type": "code",
44
  "execution_count": 7,
45
- "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  "metadata": {},
47
  "outputs": [
48
  {
49
- "name": "stdout",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  "output_type": "stream",
51
  "text": [
52
- "Step 300\n",
53
- "Step 310\n",
54
- "Step 320\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ]
56
  }
57
  ],
@@ -122,9 +835,31 @@
122
  },
123
  {
124
  "cell_type": "code",
125
- "execution_count": null,
126
  "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8",
127
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  "outputs": [],
129
  "source": []
130
  }
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
+ "from audio_diffusion_pytorch import AudioDiffusionModel, Sampler, Schedule, VSampler, LinearSchedule, AudioDiffusionAE\n",
11
  "import torch\n",
12
+ "from torch import Tensor, nn, optim\n",
13
+ "from IPython.display import Audio\n",
14
+ "import pytorch_lightning as pl\n",
15
+ "from torch.utils.data import random_split, DataLoader, Dataset\n",
16
+ "\n",
17
+ "from einops import rearrange\n",
18
+ "from ema_pytorch import EMA\n",
19
+ "from pytorch_lightning import Callback, Trainer\n",
20
+ "from typing import Any, Callable, Dict, List, Optional, Sequence, Union\n",
21
+ "from pytorch_lightning.loggers import WandbLogger\n",
22
+ "import wandb\n",
23
+ "import torchaudio\n",
24
+ "import librosa\n"
25
  ]
26
  },
27
  {
28
  "cell_type": "code",
29
+ "execution_count": 2,
30
  "id": "a005011f-3019-4d34-bdf2-9a00e5480282",
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
+ "# device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 3,
40
+ "id": "6349ed8e-f418-436f-860e-62a51e48f79a",
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "name": "stderr",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
48
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
49
+ ]
50
+ },
51
+ {
52
+ "data": {
53
+ "text/html": [
54
+ "Tracking run with wandb version 0.13.7"
55
+ ],
56
+ "text/plain": [
57
+ "<IPython.core.display.HTML object>"
58
+ ]
59
+ },
60
+ "metadata": {},
61
+ "output_type": "display_data"
62
+ },
63
+ {
64
+ "data": {
65
+ "text/html": [
66
+ "Run data is saved locally in <code>./wandb/run-20230107_213018-192gzo2n</code>"
67
+ ],
68
+ "text/plain": [
69
+ "<IPython.core.display.HTML object>"
70
+ ]
71
+ },
72
+ "metadata": {},
73
+ "output_type": "display_data"
74
+ },
75
+ {
76
+ "data": {
77
+ "text/html": [
78
+ "Syncing run <strong><a href=\"https://wandb.ai/mattricesound/RemFX/runs/192gzo2n\" target=\"_blank\">laced-bush-17</a></strong> to <a href=\"https://wandb.ai/mattricesound/RemFX\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://wandb.me/run\" target=\"_blank\">docs</a>)<br/>"
79
+ ],
80
+ "text/plain": [
81
+ "<IPython.core.display.HTML object>"
82
+ ]
83
+ },
84
+ "metadata": {},
85
+ "output_type": "display_data"
86
+ }
87
+ ],
88
+ "source": [
89
+ "wandb_logger = WandbLogger(project=\"RemFX\", save_dir=\"./\")"
90
  ]
91
  },
92
  {
 
96
  "metadata": {},
97
  "outputs": [],
98
  "source": [
99
+ "#AudioDiffusionModel\n",
100
+ "#AudioDiffusionAE\n",
101
  "model = AudioDiffusionModel(in_channels=1, \n",
102
  " patch_size=1,\n",
103
  " multipliers=[1, 2, 4, 4, 4, 4, 4],\n",
 
105
  " num_blocks=[2, 2, 2, 2, 2, 2],\n",
106
  " attentions=[0, 0, 0, 0, 0, 0]\n",
107
  " )\n",
108
+ "\n",
109
+ "\n",
110
+ "# model = model.to(device)"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 5,
116
+ "id": "950711d4-9e8a-4af1-8d56-204e4ce0a19b",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "class Model(pl.LightningModule):\n",
121
+ " def __init__(\n",
122
+ " self,\n",
123
+ " lr: float,\n",
124
+ " lr_eps: float,\n",
125
+ " lr_beta1: float,\n",
126
+ " lr_beta2: float,\n",
127
+ " lr_weight_decay: float,\n",
128
+ " ema_beta: float,\n",
129
+ " ema_power: float,\n",
130
+ " model: nn.Module,\n",
131
+ " ):\n",
132
+ " super().__init__()\n",
133
+ " self.lr = lr\n",
134
+ " self.lr_eps = lr_eps\n",
135
+ " self.lr_beta1 = lr_beta1\n",
136
+ " self.lr_beta2 = lr_beta2\n",
137
+ " self.lr_weight_decay = lr_weight_decay\n",
138
+ " self.model = model\n",
139
+ " self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)\n",
140
+ "\n",
141
+ " @property\n",
142
+ " def device(self):\n",
143
+ " return next(self.model.parameters()).device\n",
144
+ "\n",
145
+ " def configure_optimizers(self):\n",
146
+ " optimizer = torch.optim.AdamW(\n",
147
+ " list(self.parameters()),\n",
148
+ " lr=self.lr,\n",
149
+ " betas=(self.lr_beta1, self.lr_beta2),\n",
150
+ " eps=self.lr_eps,\n",
151
+ " weight_decay=self.lr_weight_decay,\n",
152
+ " )\n",
153
+ " return optimizer\n",
154
+ "\n",
155
+ " def training_step(self, batch, batch_idx):\n",
156
+ " waveforms = batch\n",
157
+ " loss = self.model(waveforms)\n",
158
+ " self.log(\"train_loss\", loss)\n",
159
+ " self.model_ema.update()\n",
160
+ " self.log(\"ema_decay\", self.model_ema.get_current_decay())\n",
161
+ " return loss\n",
162
+ "\n",
163
+ " def validation_step(self, batch, batch_idx):\n",
164
+ " waveforms = batch\n",
165
+ " loss = self.model_ema(waveforms)\n",
166
+ " self.log(\"valid_loss\", loss)\n",
167
+ " return loss"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "7ce9b20b-d163-425a-a92d-8ddb1a92b905",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": []
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 6,
181
+ "id": "cfa42700-f190-485d-84b9-d9203f8275d7",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "params = {\n",
186
+ " \"lr\": 1e-4,\n",
187
+ " \"lr_beta1\": 0.95,\n",
188
+ " \"lr_beta2\": 0.999,\n",
189
+ " \"lr_eps\": 1e-6,\n",
190
+ " \"lr_weight_decay\": 1e-3,\n",
191
+ " \"ema_beta\": 0.995,\n",
192
+ " \"ema_power\": 0.7,\n",
193
+ " \"model\": model \n",
194
+ "}\n",
195
+ "diffModel = Model(**params)"
196
  ]
197
  },
198
  {
199
  "cell_type": "code",
200
  "execution_count": 7,
201
+ "id": "aa4029a4-efd8-4922-a863-cf7677e86c05",
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "fs = 22050\n",
206
+ "t = 2 ** 18 / fs # 12 seconds\n",
207
+ "\n",
208
+ "class SinDataset(Dataset):\n",
209
+ " def __init__(self, num):\n",
210
+ " self.n = num\n",
211
+ " self.samples = torch.arange(t * fs) / fs\n",
212
+ " def __len__(self):\n",
213
+ " return self.n\n",
214
+ " def __getitem__(self, i): \n",
215
+ " f = 6000 * torch.rand(1) + 300\n",
216
+ " signal = torch.sin(2 * torch.pi * (f*2) * self.samples).unsqueeze(0)\n",
217
+ " return signal"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 8,
223
+ "id": "ae57ad99-fdaf-4720-91b0-ce9338e6a811",
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "data = DataLoader(SinDataset(1000), batch_size=2)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 9,
233
+ "id": "7b131b37-485f-4d4f-8616-6e7afe25beb9",
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "val_data = DataLoader(SinDataset(1000), batch_size=2)"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": 10,
243
+ "id": "4d98c1a0-1763-4d0b-be1d-e84ace68bebb",
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "dataiter = iter(data)\n",
248
+ "x = next(dataiter)"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": 11,
254
+ "id": "c3259082-20d5-415c-8a88-3b97af6615ee",
255
+ "metadata": {},
256
+ "outputs": [
257
+ {
258
+ "data": {
259
+ "text/plain": [
260
+ "torch.Size([2, 1, 262144])"
261
+ ]
262
+ },
263
+ "execution_count": 11,
264
+ "metadata": {},
265
+ "output_type": "execute_result"
266
+ }
267
+ ],
268
+ "source": [
269
+ "x.shape"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 12,
275
+ "id": "d1ec36ea-0f9c-49f6-8f24-a479084ea230",
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "class SampleLogger(Callback):\n",
280
+ " def __init__(\n",
281
+ " self,\n",
282
+ " num_items: int,\n",
283
+ " channels: int,\n",
284
+ " sampling_rate: int,\n",
285
+ " length: int,\n",
286
+ " sampling_steps: List[int],\n",
287
+ " diffusion_schedule: Schedule,\n",
288
+ " diffusion_sampler: Sampler,\n",
289
+ " use_ema_model: bool,\n",
290
+ " ) -> None:\n",
291
+ " self.num_items = num_items\n",
292
+ " self.channels = channels\n",
293
+ " self.sampling_rate = sampling_rate\n",
294
+ " self.length = length\n",
295
+ " self.sampling_steps = sampling_steps\n",
296
+ " self.diffusion_schedule = diffusion_schedule\n",
297
+ " self.diffusion_sampler = diffusion_sampler\n",
298
+ " self.use_ema_model = use_ema_model\n",
299
+ "\n",
300
+ " self.log_next = False\n",
301
+ "\n",
302
+ " def on_validation_epoch_start(self, trainer, pl_module):\n",
303
+ " self.log_next = True\n",
304
+ "\n",
305
+ " def on_validation_batch_start(\n",
306
+ " self, trainer, pl_module, batch, batch_idx, dataloader_idx\n",
307
+ " ):\n",
308
+ " if self.log_next:\n",
309
+ " self.log_sample(trainer, pl_module, batch)\n",
310
+ " self.log_next = False\n",
311
+ "\n",
312
+ " @torch.no_grad()\n",
313
+ " def log_sample(self, trainer, pl_module, batch):\n",
314
+ " is_train = pl_module.training\n",
315
+ " if is_train:\n",
316
+ " pl_module.eval()\n",
317
+ "\n",
318
+ " wandb_logger = get_wandb_logger(trainer).experiment\n",
319
+ "\n",
320
+ " diffusion_model = pl_module.model\n",
321
+ " if self.use_ema_model:\n",
322
+ " diffusion_model = pl_module.model_ema.ema_model\n",
323
+ " # Get start diffusion noise\n",
324
+ " noise = torch.randn(\n",
325
+ " (self.num_items, self.channels, self.length), device=pl_module.device\n",
326
+ " )\n",
327
+ "\n",
328
+ " for steps in self.sampling_steps:\n",
329
+ " samples = diffusion_model.sample(\n",
330
+ " noise=noise,\n",
331
+ " sampler=self.diffusion_sampler,\n",
332
+ " sigma_schedule=self.diffusion_schedule,\n",
333
+ " num_steps=steps,\n",
334
+ " )\n",
335
+ " log_wandb_audio_batch(\n",
336
+ " logger=wandb_logger,\n",
337
+ " id=\"sample\",\n",
338
+ " samples=samples,\n",
339
+ " sampling_rate=self.sampling_rate,\n",
340
+ " caption=f\"Sampled in {steps} steps\",\n",
341
+ " )\n",
342
+ " # log_wandb_audio_spectrogram(\n",
343
+ " # logger=wandb_logger,\n",
344
+ " # id=\"sample\",\n",
345
+ " # samples=samples,\n",
346
+ " # sampling_rate=self.sampling_rate,\n",
347
+ " # caption=f\"Sampled in {steps} steps\",\n",
348
+ " # )\n",
349
+ "\n",
350
+ " if is_train:\n",
351
+ " pl_module.train()\n",
352
+ "\n",
353
+ "def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]:\n",
354
+ " \"\"\"Safely get Weights&Biases logger from Trainer.\"\"\"\n",
355
+ "\n",
356
+ " if isinstance(trainer.logger, WandbLogger):\n",
357
+ " return trainer.logger\n",
358
+ "\n",
359
+ " if isinstance(trainer.logger, LoggerCollection):\n",
360
+ " for logger in trainer.logger:\n",
361
+ " if isinstance(logger, WandbLogger):\n",
362
+ " return logger\n",
363
+ "\n",
364
+ " print(\"WandbLogger not found.\")\n",
365
+ " return None\n",
366
+ "\n",
367
+ "\n",
368
+ "def log_wandb_audio_batch(\n",
369
+ " logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n",
370
+ "):\n",
371
+ " num_items = samples.shape[0]\n",
372
+ " samples = rearrange(samples, \"b c t -> b t c\").detach().cpu().numpy()\n",
373
+ " logger.log(\n",
374
+ " {\n",
375
+ " f\"sample_{idx}_{id}\": wandb.Audio(\n",
376
+ " samples[idx],\n",
377
+ " caption=caption,\n",
378
+ " sample_rate=sampling_rate,\n",
379
+ " )\n",
380
+ " for idx in range(num_items)\n",
381
+ " }\n",
382
+ " )\n",
383
+ "\n",
384
+ "\n",
385
+ "def log_wandb_audio_spectrogram(\n",
386
+ " logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n",
387
+ "):\n",
388
+ " num_items = samples.shape[0]\n",
389
+ " samples = samples.detach().cpu()\n",
390
+ " transform = torchaudio.transforms.MelSpectrogram(\n",
391
+ " sample_rate=sampling_rate,\n",
392
+ " n_fft=1024,\n",
393
+ " hop_length=512,\n",
394
+ " n_mels=80,\n",
395
+ " center=True,\n",
396
+ " norm=\"slaney\",\n",
397
+ " )\n",
398
+ "\n",
399
+ " def get_spectrogram_image(x):\n",
400
+ " spectrogram = transform(x[0])\n",
401
+ " image = librosa.power_to_db(spectrogram)\n",
402
+ " trace = [go.Heatmap(z=image, colorscale=\"viridis\")]\n",
403
+ " layout = go.Layout(\n",
404
+ " yaxis=dict(title=\"Mel Bin (Log Frequency)\"),\n",
405
+ " xaxis=dict(title=\"Frame\"),\n",
406
+ " title_text=caption,\n",
407
+ " title_font_size=10,\n",
408
+ " )\n",
409
+ " fig = go.Figure(data=trace, layout=layout)\n",
410
+ " return fig\n",
411
+ "\n",
412
+ " logger.log(\n",
413
+ " {\n",
414
+ " f\"mel_spectrogram_{idx}_{id}\": get_spectrogram_image(samples[idx])\n",
415
+ " for idx in range(num_items)\n",
416
+ " }\n",
417
+ " )"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": 13,
423
+ "id": "27c038a6-38f1-4a61-a472-2591ae39af3b",
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "vsampler = VSampler()\n",
428
+ "linear_schedule = LinearSchedule()\n",
429
+ "samples_config = {\n",
430
+ " \"num_items\": 3,\n",
431
+ " \"channels\": 1,\n",
432
+ " \"sampling_rate\": fs,\n",
433
+ " \"sampling_steps\": [3,5,10,25,50,100],\n",
434
+ " \"use_ema_model\": True,\n",
435
+ " \"diffusion_sampler\": vsampler,\n",
436
+ " \"length\": 262144,\n",
437
+ " \"diffusion_schedule\": linear_schedule\n",
438
+ "}\n",
439
+ "s = SampleLogger(**samples_config)"
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "code",
444
+ "execution_count": null,
445
+ "id": "ffe84ea2-6e3f-42f0-a261-57649574a601",
446
+ "metadata": {},
447
+ "outputs": [],
448
+ "source": []
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 14,
453
+ "id": "8f8f3cda-da27-477c-b553-bca4eaad69ea",
454
+ "metadata": {},
455
+ "outputs": [
456
+ {
457
+ "name": "stderr",
458
+ "output_type": "stream",
459
+ "text": [
460
+ "GPU available: True (cuda), used: True\n",
461
+ "TPU available: False, using: 0 TPU cores\n",
462
+ "IPU available: False, using: 0 IPUs\n",
463
+ "HPU available: False, using: 0 HPUs\n"
464
+ ]
465
+ }
466
+ ],
467
+ "source": [
468
+ "trainer = pl.Trainer(limit_train_batches=100, max_epochs=100, accelerator='gpu', devices=[1], callbacks=[s], logger=wandb_logger)"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "id": "47b8760a-8ee3-4212-8817-a804fd02fade",
475
  "metadata": {},
476
  "outputs": [
477
  {
478
+ "name": "stderr",
479
+ "output_type": "stream",
480
+ "text": [
481
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
482
+ "\n",
483
+ " | Name | Type | Params\n",
484
+ "--------------------------------------------------\n",
485
+ "0 | model | AudioDiffusionModel | 74.3 M\n",
486
+ "1 | model_ema | EMA | 148 M \n",
487
+ "--------------------------------------------------\n",
488
+ "74.3 M Trainable params\n",
489
+ "74.3 M Non-trainable params\n",
490
+ "148 M Total params\n",
491
+ "594.631 Total estimated model params size (MB)\n"
492
+ ]
493
+ },
494
+ {
495
+ "data": {
496
+ "application/vnd.jupyter.widget-view+json": {
497
+ "model_id": "",
498
+ "version_major": 2,
499
+ "version_minor": 0
500
+ },
501
+ "text/plain": [
502
+ "Sanity Checking: 0it [00:00, ?it/s]"
503
+ ]
504
+ },
505
+ "metadata": {},
506
+ "output_type": "display_data"
507
+ },
508
+ {
509
+ "name": "stderr",
510
  "output_type": "stream",
511
  "text": [
512
+ "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
513
+ " rank_zero_warn(\n",
514
+ "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
515
+ " rank_zero_warn(\n"
516
+ ]
517
+ },
518
+ {
519
+ "data": {
520
+ "application/vnd.jupyter.widget-view+json": {
521
+ "model_id": "5327d73bb6114877adb4e9f991058eea",
522
+ "version_major": 2,
523
+ "version_minor": 0
524
+ },
525
+ "text/plain": [
526
+ "Training: 0it [00:00, ?it/s]"
527
+ ]
528
+ },
529
+ "metadata": {},
530
+ "output_type": "display_data"
531
+ },
532
+ {
533
+ "data": {
534
+ "application/vnd.jupyter.widget-view+json": {
535
+ "model_id": "",
536
+ "version_major": 2,
537
+ "version_minor": 0
538
+ },
539
+ "text/plain": [
540
+ "Validation: 0it [00:00, ?it/s]"
541
+ ]
542
+ },
543
+ "metadata": {},
544
+ "output_type": "display_data"
545
+ },
546
+ {
547
+ "data": {
548
+ "application/vnd.jupyter.widget-view+json": {
549
+ "model_id": "",
550
+ "version_major": 2,
551
+ "version_minor": 0
552
+ },
553
+ "text/plain": [
554
+ "Validation: 0it [00:00, ?it/s]"
555
+ ]
556
+ },
557
+ "metadata": {},
558
+ "output_type": "display_data"
559
+ },
560
+ {
561
+ "data": {
562
+ "application/vnd.jupyter.widget-view+json": {
563
+ "model_id": "",
564
+ "version_major": 2,
565
+ "version_minor": 0
566
+ },
567
+ "text/plain": [
568
+ "Validation: 0it [00:00, ?it/s]"
569
+ ]
570
+ },
571
+ "metadata": {},
572
+ "output_type": "display_data"
573
+ },
574
+ {
575
+ "data": {
576
+ "application/vnd.jupyter.widget-view+json": {
577
+ "model_id": "",
578
+ "version_major": 2,
579
+ "version_minor": 0
580
+ },
581
+ "text/plain": [
582
+ "Validation: 0it [00:00, ?it/s]"
583
+ ]
584
+ },
585
+ "metadata": {},
586
+ "output_type": "display_data"
587
+ },
588
+ {
589
+ "data": {
590
+ "application/vnd.jupyter.widget-view+json": {
591
+ "model_id": "",
592
+ "version_major": 2,
593
+ "version_minor": 0
594
+ },
595
+ "text/plain": [
596
+ "Validation: 0it [00:00, ?it/s]"
597
+ ]
598
+ },
599
+ "metadata": {},
600
+ "output_type": "display_data"
601
+ },
602
+ {
603
+ "data": {
604
+ "application/vnd.jupyter.widget-view+json": {
605
+ "model_id": "",
606
+ "version_major": 2,
607
+ "version_minor": 0
608
+ },
609
+ "text/plain": [
610
+ "Validation: 0it [00:00, ?it/s]"
611
+ ]
612
+ },
613
+ "metadata": {},
614
+ "output_type": "display_data"
615
+ },
616
+ {
617
+ "data": {
618
+ "application/vnd.jupyter.widget-view+json": {
619
+ "model_id": "",
620
+ "version_major": 2,
621
+ "version_minor": 0
622
+ },
623
+ "text/plain": [
624
+ "Validation: 0it [00:00, ?it/s]"
625
+ ]
626
+ },
627
+ "metadata": {},
628
+ "output_type": "display_data"
629
+ },
630
+ {
631
+ "data": {
632
+ "application/vnd.jupyter.widget-view+json": {
633
+ "model_id": "",
634
+ "version_major": 2,
635
+ "version_minor": 0
636
+ },
637
+ "text/plain": [
638
+ "Validation: 0it [00:00, ?it/s]"
639
+ ]
640
+ },
641
+ "metadata": {},
642
+ "output_type": "display_data"
643
+ },
644
+ {
645
+ "data": {
646
+ "application/vnd.jupyter.widget-view+json": {
647
+ "model_id": "",
648
+ "version_major": 2,
649
+ "version_minor": 0
650
+ },
651
+ "text/plain": [
652
+ "Validation: 0it [00:00, ?it/s]"
653
+ ]
654
+ },
655
+ "metadata": {},
656
+ "output_type": "display_data"
657
+ },
658
+ {
659
+ "data": {
660
+ "application/vnd.jupyter.widget-view+json": {
661
+ "model_id": "",
662
+ "version_major": 2,
663
+ "version_minor": 0
664
+ },
665
+ "text/plain": [
666
+ "Validation: 0it [00:00, ?it/s]"
667
+ ]
668
+ },
669
+ "metadata": {},
670
+ "output_type": "display_data"
671
+ },
672
+ {
673
+ "data": {
674
+ "application/vnd.jupyter.widget-view+json": {
675
+ "model_id": "",
676
+ "version_major": 2,
677
+ "version_minor": 0
678
+ },
679
+ "text/plain": [
680
+ "Validation: 0it [00:00, ?it/s]"
681
+ ]
682
+ },
683
+ "metadata": {},
684
+ "output_type": "display_data"
685
+ },
686
+ {
687
+ "data": {
688
+ "application/vnd.jupyter.widget-view+json": {
689
+ "model_id": "",
690
+ "version_major": 2,
691
+ "version_minor": 0
692
+ },
693
+ "text/plain": [
694
+ "Validation: 0it [00:00, ?it/s]"
695
+ ]
696
+ },
697
+ "metadata": {},
698
+ "output_type": "display_data"
699
+ },
700
+ {
701
+ "data": {
702
+ "application/vnd.jupyter.widget-view+json": {
703
+ "model_id": "c6e6b42717824054b576e47f92878ef5",
704
+ "version_major": 2,
705
+ "version_minor": 0
706
+ },
707
+ "text/plain": [
708
+ "Validation: 0it [00:00, ?it/s]"
709
+ ]
710
+ },
711
+ "metadata": {},
712
+ "output_type": "display_data"
713
+ }
714
+ ],
715
+ "source": [
716
+ "trainer.fit(model=diffModel, train_dataloaders=data, val_dataloaders=val_data)"
717
+ ]
718
+ },
719
+ {
720
+ "cell_type": "code",
721
+ "execution_count": null,
722
+ "id": "1f64d981-c9dc-4afa-b783-d017f99633da",
723
+ "metadata": {},
724
+ "outputs": [],
725
+ "source": []
726
+ },
727
+ {
728
+ "cell_type": "code",
729
+ "execution_count": 12,
730
+ "id": "53bba197-83eb-40a2-b748-a4c25e628356",
731
+ "metadata": {},
732
+ "outputs": [],
733
+ "source": []
734
+ },
735
+ {
736
+ "cell_type": "code",
737
+ "execution_count": null,
738
+ "id": "49db25f0-8bda-4693-9872-cbf24c40b575",
739
+ "metadata": {},
740
+ "outputs": [],
741
+ "source": []
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "execution_count": null,
746
+ "id": "29ed502f-2daf-4210-81ff-a90ade519086",
747
+ "metadata": {},
748
+ "outputs": [],
749
+ "source": [
750
+ "# Old code below"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "code",
755
+ "execution_count": 14,
756
+ "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33",
757
+ "metadata": {},
758
+ "outputs": [
759
+ {
760
+ "ename": "NameError",
761
+ "evalue": "name 'device' is not defined",
762
+ "output_type": "error",
763
+ "traceback": [
764
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
765
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
766
+ "Cell \u001b[0;32mIn [14], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m signal2 \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpi \u001b[38;5;241m*\u001b[39m (f\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m*\u001b[39m samples)\n\u001b[1;32m 11\u001b[0m stacked_signal \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack((signal1, signal2))\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m stacked_signal \u001b[38;5;241m=\u001b[39m stacked_signal\u001b[38;5;241m.\u001b[39mto(\u001b[43mdevice\u001b[49m)\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m model(stacked_signal)\n\u001b[1;32m 14\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward() \n",
767
+ "\u001b[0;31mNameError\u001b[0m: name 'device' is not defined"
768
  ]
769
  }
770
  ],
 
835
  },
836
  {
837
  "cell_type": "code",
838
+ "execution_count": 12,
839
  "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8",
840
  "metadata": {},
841
+ "outputs": [
842
+ {
843
+ "ename": "NameError",
844
+ "evalue": "name 'z' is not defined",
845
+ "output_type": "error",
846
+ "traceback": [
847
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
848
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
849
+ "Cell \u001b[0;32mIn [12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mz\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n",
850
+ "\u001b[0;31mNameError\u001b[0m: name 'z' is not defined"
851
+ ]
852
+ }
853
+ ],
854
+ "source": [
855
+ "z.shape"
856
+ ]
857
+ },
858
+ {
859
+ "cell_type": "code",
860
+ "execution_count": null,
861
+ "id": "8a3f582f-a956-4326-872b-416cc13b77ee",
862
+ "metadata": {},
863
  "outputs": [],
864
  "source": []
865
  }
setup.py CHANGED
@@ -38,6 +38,10 @@ setup(
38
  "pytorch-lightning",
39
  "numba",
40
  "wandb",
 
 
 
 
41
  ],
42
  include_package_data=True,
43
  license="Apache License 2.0",
 
38
  "pytorch-lightning",
39
  "numba",
40
  "wandb",
41
+ "audio-diffusion-pytorch",
42
+ "ema_pytorch",
43
+ "einops",
44
+ "librosa",
45
  ],
46
  include_package_data=True,
47
  license="Apache License 2.0",