{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b", "metadata": {}, "outputs": [], "source": [ "from audio_diffusion_pytorch import AudioDiffusionModel, Sampler, Schedule, VSampler, LinearSchedule, AudioDiffusionAE\n", "import torch\n", "from torch import Tensor, nn, optim\n", "from IPython.display import Audio\n", "import pytorch_lightning as pl\n", "from torch.utils.data import random_split, DataLoader, Dataset\n", "\n", "from einops import rearrange\n", "from ema_pytorch import EMA\n", "from pytorch_lightning import Callback, Trainer\n", "from typing import Any, Callable, Dict, List, Optional, Sequence, Union\n", "from pytorch_lightning.loggers import WandbLogger\n", "import wandb\n", "import torchaudio\n", "import librosa\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "a005011f-3019-4d34-bdf2-9a00e5480282", "metadata": {}, "outputs": [], "source": [ "# device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "6349ed8e-f418-436f-860e-62a51e48f79a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.13.7" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in ./wandb/run-20230107_213018-192gzo2n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run laced-bush-17 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "wandb_logger = WandbLogger(project=\"RemFX\", save_dir=\"./\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "1b689f18-375f-4b40-9ddc-a4ced6a5e5e4", "metadata": {}, "outputs": [], "source": [ "#AudioDiffusionModel\n", "#AudioDiffusionAE\n", "model = AudioDiffusionModel(in_channels=1, \n", " patch_size=1,\n", " multipliers=[1, 2, 4, 4, 4, 4, 4],\n", " factors=[2, 2, 2, 2, 2, 2],\n", " num_blocks=[2, 2, 2, 2, 2, 2],\n", " attentions=[0, 0, 0, 0, 0, 0]\n", " )\n", "\n", "\n", "# model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "950711d4-9e8a-4af1-8d56-204e4ce0a19b", "metadata": {}, "outputs": [], "source": [ "class Model(pl.LightningModule):\n", " def __init__(\n", " self,\n", " lr: float,\n", " lr_eps: float,\n", " lr_beta1: float,\n", " lr_beta2: float,\n", " lr_weight_decay: float,\n", " ema_beta: float,\n", " ema_power: float,\n", " model: nn.Module,\n", " ):\n", " super().__init__()\n", " self.lr = lr\n", " self.lr_eps = lr_eps\n", " self.lr_beta1 = lr_beta1\n", " self.lr_beta2 = lr_beta2\n", " self.lr_weight_decay = lr_weight_decay\n", " self.model = model\n", " self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)\n", "\n", " @property\n", " def device(self):\n", " return next(self.model.parameters()).device\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.AdamW(\n", " list(self.parameters()),\n", " lr=self.lr,\n", " betas=(self.lr_beta1, self.lr_beta2),\n", " eps=self.lr_eps,\n", " weight_decay=self.lr_weight_decay,\n", " )\n", " return optimizer\n", "\n", " def training_step(self, batch, batch_idx):\n", " waveforms = batch\n", " loss = self.model(waveforms)\n", " self.log(\"train_loss\", loss)\n", " self.model_ema.update()\n", " self.log(\"ema_decay\", self.model_ema.get_current_decay())\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " waveforms = batch\n", " loss = self.model_ema(waveforms)\n", " self.log(\"valid_loss\", loss)\n", " return loss" ] }, { "cell_type": "code", "execution_count": null, "id": "7ce9b20b-d163-425a-a92d-8ddb1a92b905", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 6, "id": "cfa42700-f190-485d-84b9-d9203f8275d7", "metadata": {}, "outputs": [], "source": [ "params = {\n", " \"lr\": 1e-4,\n", " \"lr_beta1\": 0.95,\n", " \"lr_beta2\": 0.999,\n", " \"lr_eps\": 1e-6,\n", " \"lr_weight_decay\": 1e-3,\n", " \"ema_beta\": 0.995,\n", " \"ema_power\": 0.7,\n", " \"model\": model \n", "}\n", "diffModel = Model(**params)" ] }, { "cell_type": "code", "execution_count": 7, "id": "aa4029a4-efd8-4922-a863-cf7677e86c05", "metadata": {}, "outputs": [], "source": [ "fs = 22050\n", "t = 2 ** 18 / fs # 12 seconds\n", "\n", "class SinDataset(Dataset):\n", " def __init__(self, num):\n", " self.n = num\n", " self.samples = torch.arange(t * fs) / fs\n", " def __len__(self):\n", " return self.n\n", " def __getitem__(self, i): \n", " f = 6000 * torch.rand(1) + 300\n", " signal = torch.sin(2 * torch.pi * (f*2) * self.samples).unsqueeze(0)\n", " return signal" ] }, { "cell_type": "code", "execution_count": 8, "id": "ae57ad99-fdaf-4720-91b0-ce9338e6a811", "metadata": {}, "outputs": [], "source": [ "data = DataLoader(SinDataset(1000), batch_size=2)" ] }, { "cell_type": "code", "execution_count": 9, "id": "7b131b37-485f-4d4f-8616-6e7afe25beb9", "metadata": {}, "outputs": [], "source": [ "val_data = DataLoader(SinDataset(1000), batch_size=2)" ] }, { "cell_type": "code", "execution_count": 10, "id": "4d98c1a0-1763-4d0b-be1d-e84ace68bebb", "metadata": {}, "outputs": [], "source": [ "dataiter = iter(data)\n", "x = next(dataiter)" ] }, { "cell_type": "code", "execution_count": 11, "id": "c3259082-20d5-415c-8a88-3b97af6615ee", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 262144])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": 12, "id": "d1ec36ea-0f9c-49f6-8f24-a479084ea230", "metadata": {}, "outputs": [], "source": [ "class SampleLogger(Callback):\n", " def __init__(\n", " self,\n", " num_items: int,\n", " channels: int,\n", " sampling_rate: int,\n", " length: int,\n", " sampling_steps: List[int],\n", " diffusion_schedule: Schedule,\n", " diffusion_sampler: Sampler,\n", " use_ema_model: bool,\n", " ) -> None:\n", " self.num_items = num_items\n", " self.channels = channels\n", " self.sampling_rate = sampling_rate\n", " self.length = length\n", " self.sampling_steps = sampling_steps\n", " self.diffusion_schedule = diffusion_schedule\n", " self.diffusion_sampler = diffusion_sampler\n", " self.use_ema_model = use_ema_model\n", "\n", " self.log_next = False\n", "\n", " def on_validation_epoch_start(self, trainer, pl_module):\n", " self.log_next = True\n", "\n", " def on_validation_batch_start(\n", " self, trainer, pl_module, batch, batch_idx, dataloader_idx\n", " ):\n", " if self.log_next:\n", " self.log_sample(trainer, pl_module, batch)\n", " self.log_next = False\n", "\n", " @torch.no_grad()\n", " def log_sample(self, trainer, pl_module, batch):\n", " is_train = pl_module.training\n", " if is_train:\n", " pl_module.eval()\n", "\n", " wandb_logger = get_wandb_logger(trainer).experiment\n", "\n", " diffusion_model = pl_module.model\n", " if self.use_ema_model:\n", " diffusion_model = pl_module.model_ema.ema_model\n", " # Get start diffusion noise\n", " noise = torch.randn(\n", " (self.num_items, self.channels, self.length), device=pl_module.device\n", " )\n", "\n", " for steps in self.sampling_steps:\n", " samples = diffusion_model.sample(\n", " noise=noise,\n", " sampler=self.diffusion_sampler,\n", " sigma_schedule=self.diffusion_schedule,\n", " num_steps=steps,\n", " )\n", " log_wandb_audio_batch(\n", " logger=wandb_logger,\n", " id=\"sample\",\n", " samples=samples,\n", " sampling_rate=self.sampling_rate,\n", " caption=f\"Sampled in {steps} steps\",\n", " )\n", " # log_wandb_audio_spectrogram(\n", " # logger=wandb_logger,\n", " # id=\"sample\",\n", " # samples=samples,\n", " # sampling_rate=self.sampling_rate,\n", " # caption=f\"Sampled in {steps} steps\",\n", " # )\n", "\n", " if is_train:\n", " pl_module.train()\n", "\n", "def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]:\n", " \"\"\"Safely get Weights&Biases logger from Trainer.\"\"\"\n", "\n", " if isinstance(trainer.logger, WandbLogger):\n", " return trainer.logger\n", "\n", " if isinstance(trainer.logger, LoggerCollection):\n", " for logger in trainer.logger:\n", " if isinstance(logger, WandbLogger):\n", " return logger\n", "\n", " print(\"WandbLogger not found.\")\n", " return None\n", "\n", "\n", "def log_wandb_audio_batch(\n", " logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n", "):\n", " num_items = samples.shape[0]\n", " samples = rearrange(samples, \"b c t -> b t c\").detach().cpu().numpy()\n", " logger.log(\n", " {\n", " f\"sample_{idx}_{id}\": wandb.Audio(\n", " samples[idx],\n", " caption=caption,\n", " sample_rate=sampling_rate,\n", " )\n", " for idx in range(num_items)\n", " }\n", " )\n", "\n", "\n", "def log_wandb_audio_spectrogram(\n", " logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n", "):\n", " num_items = samples.shape[0]\n", " samples = samples.detach().cpu()\n", " transform = torchaudio.transforms.MelSpectrogram(\n", " sample_rate=sampling_rate,\n", " n_fft=1024,\n", " hop_length=512,\n", " n_mels=80,\n", " center=True,\n", " norm=\"slaney\",\n", " )\n", "\n", " def get_spectrogram_image(x):\n", " spectrogram = transform(x[0])\n", " image = librosa.power_to_db(spectrogram)\n", " trace = [go.Heatmap(z=image, colorscale=\"viridis\")]\n", " layout = go.Layout(\n", " yaxis=dict(title=\"Mel Bin (Log Frequency)\"),\n", " xaxis=dict(title=\"Frame\"),\n", " title_text=caption,\n", " title_font_size=10,\n", " )\n", " fig = go.Figure(data=trace, layout=layout)\n", " return fig\n", "\n", " logger.log(\n", " {\n", " f\"mel_spectrogram_{idx}_{id}\": get_spectrogram_image(samples[idx])\n", " for idx in range(num_items)\n", " }\n", " )" ] }, { "cell_type": "code", "execution_count": 13, "id": "27c038a6-38f1-4a61-a472-2591ae39af3b", "metadata": {}, "outputs": [], "source": [ "vsampler = VSampler()\n", "linear_schedule = LinearSchedule()\n", "samples_config = {\n", " \"num_items\": 3,\n", " \"channels\": 1,\n", " \"sampling_rate\": fs,\n", " \"sampling_steps\": [3,5,10,25,50,100],\n", " \"use_ema_model\": True,\n", " \"diffusion_sampler\": vsampler,\n", " \"length\": 262144,\n", " \"diffusion_schedule\": linear_schedule\n", "}\n", "s = SampleLogger(**samples_config)" ] }, { "cell_type": "code", "execution_count": null, "id": "ffe84ea2-6e3f-42f0-a261-57649574a601", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 14, "id": "8f8f3cda-da27-477c-b553-bca4eaad69ea", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "trainer = pl.Trainer(limit_train_batches=100, max_epochs=100, accelerator='gpu', devices=[1], callbacks=[s], logger=wandb_logger)" ] }, { "cell_type": "code", "execution_count": null, "id": "47b8760a-8ee3-4212-8817-a804fd02fade", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", "\n", " | Name | Type | Params\n", "--------------------------------------------------\n", "0 | model | AudioDiffusionModel | 74.3 M\n", "1 | model_ema | EMA | 148 M \n", "--------------------------------------------------\n", "74.3 M Trainable params\n", "74.3 M Non-trainable params\n", "148 M Total params\n", "594.631 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n", "/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5327d73bb6114877adb4e9f991058eea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c6e6b42717824054b576e47f92878ef5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.fit(model=diffModel, train_dataloaders=data, val_dataloaders=val_data)" ] }, { "cell_type": "code", "execution_count": null, "id": "1f64d981-c9dc-4afa-b783-d017f99633da", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 12, "id": "53bba197-83eb-40a2-b748-a4c25e628356", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "49db25f0-8bda-4693-9872-cbf24c40b575", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "29ed502f-2daf-4210-81ff-a90ade519086", "metadata": {}, "outputs": [], "source": [ "# Old code below" ] }, { "cell_type": "code", "execution_count": 14, "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'device' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [14], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m signal2 \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpi \u001b[38;5;241m*\u001b[39m (f\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m*\u001b[39m samples)\n\u001b[1;32m 11\u001b[0m stacked_signal \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack((signal1, signal2))\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 12\u001b[0m stacked_signal \u001b[38;5;241m=\u001b[39m stacked_signal\u001b[38;5;241m.\u001b[39mto(\u001b[43mdevice\u001b[49m)\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m model(stacked_signal)\n\u001b[1;32m 14\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward() \n", "\u001b[0;31mNameError\u001b[0m: name 'device' is not defined" ] } ], "source": [ "fs = 22050\n", "t = 2 ** 18 / 22050\n", "samples = torch.arange(t * fs) / fs\n", "\n", "for i in range(300, 8000):\n", " f = i\n", " # Create 2 sine waves (one at f=step, other is octave up) \n", " # There is aliasing at higher freq, but since it is sinusoids, that doesn't matter too much\n", " signal1 = torch.sin(2 * torch.pi * f * samples)\n", " signal2 = torch.sin(2 * torch.pi * (f*2) * samples)\n", " stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)\n", " stacked_signal = stacked_signal.to(device)\n", " loss = model(stacked_signal)\n", " loss.backward() \n", " if i % 10 == 0:\n", " print(\"Step\", i)" ] }, { "cell_type": "code", "execution_count": 8, "id": "71d17c51-842c-40a1-81a1-a53bf358bc8a", "metadata": {}, "outputs": [], "source": [ "# Sample 2 sources given start noise\n", "noise = torch.randn(2, 1, 2 ** 18)\n", "noise = noise.to(device)\n", "sampled = model.sample(\n", " noise=noise,\n", " num_steps=10 # Suggested range: 2-50\n", ") # [2, 1, 2 ** 18]" ] }, { "cell_type": "code", "execution_count": 9, "id": "59d71efa-05ac-4545-84da-8c09c033dfd7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z = sampled[1]\n", "Audio(z.cpu(), rate=22050)" ] }, { "cell_type": "code", "execution_count": 12, "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'z' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mz\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n", "\u001b[0;31mNameError\u001b[0m: name 'z' is not defined" ] } ], "source": [ "z.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "8a3f582f-a956-4326-872b-416cc13b77ee", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }