diff --git "a/Experiments.ipynb" "b/Experiments.ipynb" new file mode 100644--- /dev/null +++ "b/Experiments.ipynb" @@ -0,0 +1,493 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c2807819", + "metadata": {}, + "outputs": [], + "source": [ + "from audio_diffusion_pytorch import AudioDiffusionModel\n", + "import torch\n", + "from tqdm import tqdm\n", + "from IPython.display import Audio\n", + "from pathlib import Path\n", + "import torchaudio\n", + "import torchaudio.transforms as T\n", + "import pytorch_lightning as pl\n", + "from torch.utils.data import random_split, DataLoader, Dataset\n", + "import torch.nn.functional as F\n", + "import wandb\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "469edd04", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.13.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /Users/matthewrice/Developer/remfx/wandb/run-20221203_231820-3tdw8zp6" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run ruby-oath-2 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wandb.init(project=\"RemFX\", entity=\"mattricesound\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "8d7eacfc", + "metadata": {}, + "outputs": [], + "source": [ + "SAMPLE_RATE = 22050\n", + "LENGTH = 2**17#round(5 * SAMPLE_RATE) 6 seconds" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a70e9942", + "metadata": {}, + "outputs": [], + "source": [ + "model = AudioDiffusionModel(in_channels=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "cdc0fb64", + "metadata": {}, + "outputs": [], + "source": [ + "class GuitarDataset(Dataset):\n", + " def __init__(self, root, length=LENGTH):\n", + " self.files = list(Path().glob(f\"{root}/**/*.wav\"))\n", + " self.resampler = T.Resample(48000, SAMPLE_RATE)\n", + " \n", + " def __len__(self):\n", + " return len(self.files)\n", + " \n", + " def __getitem__(self, idx):\n", + " x, sr = torchaudio.load(self.files[idx])\n", + "# x = x.view() # Duplicate channel\n", + " resampled_x = self.resampler(x)\n", + " if resampled_x.shape[1] < LENGTH:\n", + " resampled_x = F.pad(resampled_x, (0, LENGTH - resampled_x.shape[1]))\n", + " elif resampled_x.shape[1] > LENGTH:\n", + " resampled_x = resampled_x[:, :LENGTH]\n", + " return resampled_x" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "148c2a96", + "metadata": {}, + "outputs": [], + "source": [ + "g = GuitarDataset(Path(\"Clean\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 237, + "id": "670c94a5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 131072])\n" + ] + } + ], + "source": [ + "x = g[10]\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "f031f044", + "metadata": {}, + "outputs": [], + "source": [ + "class RemFXDataModule(pl.LightningDataModule):\n", + " def __init__(self, data_dir: str = \"path/to/dir\", batch_size: int = 32):\n", + " super().__init__()\n", + " self.data_dir = Path(data_dir)\n", + " self.batch_size = batch_size\n", + "\n", + " def setup(self, stage: str):\n", + "# self.guitar_test = GuitarDataset(self.data_dir, train=False)\n", + "# self.guitar_predict = GuitarDataset(self.data_dir, train=False)\n", + " guitar_full = GuitarDataset(self.data_dir)\n", + " self.guitar_train, self.guitar_val = random_split(guitar_full, [55000, 5000])\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.guitar_train, batch_size=self.batch_size)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.guitar_val, batch_size=self.batch_size)\n", + "\n", + "# def test_dataloader(self):\n", + "# return DataLoader(self.guitar_test, batch_size=self.batch_size)\n", + "\n", + "# def predict_dataloader(self):\n", + "# return DataLoader(self.guitar_predict, batch_size=self.batch_size)\n", + "\n", + " def teardown(self, stage: str):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "e1c83600", + "metadata": {}, + "outputs": [], + "source": [ + "data = DataLoader(GuitarDataset(Path(\"Clean\")), batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4d46f992", + "metadata": {}, + "outputs": [], + "source": [ + "dataiter = iter(data)\n", + "x = next(dataiter)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1103e520", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 131072])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "6b0f1575", + "metadata": {}, + "outputs": [], + "source": [ + "# wandb.log({\"Audio\": wandb.Audio(x[0].view(-1).numpy(), sample_rate=SAMPLE_RATE)})" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "a6a2bb97", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 14%|█████████▏ | 7/50 [7:29:41<59:56:20, 5018.16s/it]wandb: Network error (ConnectionError), entering retry loop.\n", + " 14%|█████████▏ | 7/50 [8:13:48<50:33:21, 4232.58s/it]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[28], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(epochs)):\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m data:\n\u001b[0;32m----> 4\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m5\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/model.py:36\u001b[0m, in \u001b[0;36mModel1d.forward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m---> 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:674\u001b[0m, in \u001b[0;36mXDiffusion.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 674\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:161\u001b[0m, in \u001b[0;36mVDiffusion.forward\u001b[0;34m(self, x, noise, **kwargs)\u001b[0m\n\u001b[1;32m 158\u001b[0m x_target \u001b[38;5;241m=\u001b[39m noise \u001b[38;5;241m*\u001b[39m alpha \u001b[38;5;241m-\u001b[39m x \u001b[38;5;241m*\u001b[39m beta\n\u001b[1;32m 160\u001b[0m \u001b[38;5;66;03m# Denoise and return loss\u001b[39;00m\n\u001b[0;32m--> 161\u001b[0m x_denoised \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdenoise_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_noisy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigmas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mmse_loss(x_denoised, x_target)\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:143\u001b[0m, in \u001b[0;36mVDiffusion.denoise_fn\u001b[0;34m(self, x_noisy, sigmas, sigma, **kwargs)\u001b[0m\n\u001b[1;32m 141\u001b[0m batch_size, device \u001b[38;5;241m=\u001b[39m x_noisy\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], x_noisy\u001b[38;5;241m.\u001b[39mdevice\n\u001b[1;32m 142\u001b[0m sigmas \u001b[38;5;241m=\u001b[39m to_batch(x\u001b[38;5;241m=\u001b[39msigma, xs\u001b[38;5;241m=\u001b[39msigmas, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[0;32m--> 143\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_noisy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigmas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:1099\u001b[0m, in \u001b[0;36mUNet1d.forward\u001b[0;34m(self, x, time, features, channels_list, embedding)\u001b[0m\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, downsample \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownsamples):\n\u001b[1;32m 1098\u001b[0m channels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_channels(channels_list, layer\u001b[38;5;241m=\u001b[39mi \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m-> 1099\u001b[0m x, skips \u001b[38;5;241m=\u001b[39m \u001b[43mdownsample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1100\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchannels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchannels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membedding\u001b[49m\n\u001b[1;32m 1101\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1102\u001b[0m skips_list \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m [skips]\n\u001b[1;32m 1104\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbottleneck(x, mapping\u001b[38;5;241m=\u001b[39mmapping, embedding\u001b[38;5;241m=\u001b[39membedding)\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:670\u001b[0m, in \u001b[0;36mDownsampleBlock1d.forward\u001b[0;34m(self, x, mapping, channels, embedding)\u001b[0m\n\u001b[1;32m 668\u001b[0m skips \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m block \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mblocks:\n\u001b[0;32m--> 670\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmapping\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 671\u001b[0m skips \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m [x] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_skip \u001b[38;5;28;01melse\u001b[39;00m []\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_transformer:\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:199\u001b[0m, in \u001b[0;36mResnetBlock1d.forward\u001b[0;34m(self, x, mapping)\u001b[0m\n\u001b[1;32m 196\u001b[0m assert_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontext mapping required if context_mapping_features > 0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mapping \u001b[38;5;241m^\u001b[39m exists(mapping)), assert_message\n\u001b[0;32m--> 199\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblock1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 201\u001b[0m scale_shift \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mapping:\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:124\u001b[0m, in \u001b[0;36mConvBlock1d.forward\u001b[0;34m(self, x, scale_shift)\u001b[0m\n\u001b[1;32m 122\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m*\u001b[39m (scale \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m+\u001b[39m shift\n\u001b[1;32m 123\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation(x)\n\u001b[0;32m--> 124\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproject\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/conv.py:313\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 313\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/conv.py:309\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 307\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 308\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 310\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "epochs = 50\n", + "for i in tqdm(range(epochs)):\n", + " for batch in data:\n", + " loss = model(batch)\n", + " loss.backward()\n", + " if i % 5 == 0:\n", + " wandb.log({\"loss\": loss})\n", + " with torch.no_grad():\n", + " noise = torch.randn(1, 1, 2**17)\n", + " sampled = model.sample(noise=noise, num_steps=40)\n", + " z = sampled.view(-1)\n", + " wandb.log({f\"Audio_{i}\": wandb.Audio(z.numpy(), sample_rate=SAMPLE_RATE)})\n", + " \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 259, + "id": "d18e4816", + "metadata": {}, + "outputs": [], + "source": [ + "noise = torch.randn(1, 1, 2**17)\n", + "sampled = model.sample(noise=noise, num_steps=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 260, + "id": "054e708f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 1, 131072]) tensor([[[-0.4879, -0.4534, -0.4094, ..., -1.0000, 0.8554, -0.9605]]])\n" + ] + } + ], + "source": [ + "print(sampled.shape, sampled)" + ] + }, + { + "cell_type": "code", + "execution_count": 261, + "id": "fc8becc0", + "metadata": {}, + "outputs": [], + "source": [ + "z = sampled.view(-1)\n", + "# z = z.mean(axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 262, + "id": "7731b0e7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 262, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(z, rate=22050)" + ] + }, + { + "cell_type": "code", + "execution_count": 258, + "id": "907c7dfb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0.6213, grad_fn=)\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 164, + "id": "698bc12e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6890.625" + ] + }, + "execution_count": 164, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "110250 / 16" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "d4d8bb49", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "264600" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "12 * 22050" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "ecda7e0e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "16537.5" + ] + }, + "execution_count": 166, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "264600 / 16" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9980287", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}