{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c2807819", "metadata": {}, "outputs": [], "source": [ "from audio_diffusion_pytorch import AudioDiffusionModel\n", "import torch\n", "from tqdm import tqdm\n", "from IPython.display import Audio\n", "from pathlib import Path\n", "import torchaudio\n", "import torchaudio.transforms as T\n", "import pytorch_lightning as pl\n", "from torch.utils.data import random_split, DataLoader, Dataset\n", "import torch.nn.functional as F\n", "import wandb\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "469edd04", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.13.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /Users/matthewrice/Developer/remfx/wandb/run-20221203_231820-3tdw8zp6" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run ruby-oath-2 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wandb.init(project=\"RemFX\", entity=\"mattricesound\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "8d7eacfc", "metadata": {}, "outputs": [], "source": [ "SAMPLE_RATE = 22050\n", "LENGTH = 2**17#round(5 * SAMPLE_RATE) 6 seconds" ] }, { "cell_type": "code", "execution_count": 19, "id": "a70e9942", "metadata": {}, "outputs": [], "source": [ "model = AudioDiffusionModel(in_channels=1)" ] }, { "cell_type": "code", "execution_count": 20, "id": "cdc0fb64", "metadata": {}, "outputs": [], "source": [ "class GuitarDataset(Dataset):\n", " def __init__(self, root, length=LENGTH):\n", " self.files = list(Path().glob(f\"{root}/**/*.wav\"))\n", " self.resampler = T.Resample(48000, SAMPLE_RATE)\n", " \n", " def __len__(self):\n", " return len(self.files)\n", " \n", " def __getitem__(self, idx):\n", " x, sr = torchaudio.load(self.files[idx])\n", "# x = x.view() # Duplicate channel\n", " resampled_x = self.resampler(x)\n", " if resampled_x.shape[1] < LENGTH:\n", " resampled_x = F.pad(resampled_x, (0, LENGTH - resampled_x.shape[1]))\n", " elif resampled_x.shape[1] > LENGTH:\n", " resampled_x = resampled_x[:, :LENGTH]\n", " return resampled_x" ] }, { "cell_type": "code", "execution_count": 21, "id": "148c2a96", "metadata": {}, "outputs": [], "source": [ "g = GuitarDataset(Path(\"Clean\"))" ] }, { "cell_type": "code", "execution_count": 237, "id": "670c94a5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 131072])\n" ] } ], "source": [ "x = g[10]\n", "print(x.shape)" ] }, { "cell_type": "code", "execution_count": 22, "id": "f031f044", "metadata": {}, "outputs": [], "source": [ "class RemFXDataModule(pl.LightningDataModule):\n", " def __init__(self, data_dir: str = \"path/to/dir\", batch_size: int = 32):\n", " super().__init__()\n", " self.data_dir = Path(data_dir)\n", " self.batch_size = batch_size\n", "\n", " def setup(self, stage: str):\n", "# self.guitar_test = GuitarDataset(self.data_dir, train=False)\n", "# self.guitar_predict = GuitarDataset(self.data_dir, train=False)\n", " guitar_full = GuitarDataset(self.data_dir)\n", " self.guitar_train, self.guitar_val = random_split(guitar_full, [55000, 5000])\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.guitar_train, batch_size=self.batch_size)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.guitar_val, batch_size=self.batch_size)\n", "\n", "# def test_dataloader(self):\n", "# return DataLoader(self.guitar_test, batch_size=self.batch_size)\n", "\n", "# def predict_dataloader(self):\n", "# return DataLoader(self.guitar_predict, batch_size=self.batch_size)\n", "\n", " def teardown(self, stage: str):\n", " pass" ] }, { "cell_type": "code", "execution_count": 23, "id": "e1c83600", "metadata": {}, "outputs": [], "source": [ "data = DataLoader(GuitarDataset(Path(\"Clean\")), batch_size=32)" ] }, { "cell_type": "code", "execution_count": 24, "id": "4d46f992", "metadata": {}, "outputs": [], "source": [ "dataiter = iter(data)\n", "x = next(dataiter)" ] }, { "cell_type": "code", "execution_count": 25, "id": "1103e520", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 131072])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[0].shape" ] }, { "cell_type": "code", "execution_count": 26, "id": "6b0f1575", "metadata": {}, "outputs": [], "source": [ "# wandb.log({\"Audio\": wandb.Audio(x[0].view(-1).numpy(), sample_rate=SAMPLE_RATE)})" ] }, { "cell_type": "code", "execution_count": 28, "id": "a6a2bb97", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 14%|█████████▏ | 7/50 [7:29:41<59:56:20, 5018.16s/it]wandb: Network error (ConnectionError), entering retry loop.\n", " 14%|█████████▏ | 7/50 [8:13:48<50:33:21, 4232.58s/it]\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[28], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(epochs)):\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m data:\n\u001b[0;32m----> 4\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m5\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/model.py:36\u001b[0m, in \u001b[0;36mModel1d.forward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m---> 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:674\u001b[0m, in \u001b[0;36mXDiffusion.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 674\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:161\u001b[0m, in \u001b[0;36mVDiffusion.forward\u001b[0;34m(self, x, noise, **kwargs)\u001b[0m\n\u001b[1;32m 158\u001b[0m x_target \u001b[38;5;241m=\u001b[39m noise \u001b[38;5;241m*\u001b[39m alpha \u001b[38;5;241m-\u001b[39m x \u001b[38;5;241m*\u001b[39m beta\n\u001b[1;32m 160\u001b[0m \u001b[38;5;66;03m# Denoise and return loss\u001b[39;00m\n\u001b[0;32m--> 161\u001b[0m x_denoised \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdenoise_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_noisy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigmas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mmse_loss(x_denoised, x_target)\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:143\u001b[0m, in \u001b[0;36mVDiffusion.denoise_fn\u001b[0;34m(self, x_noisy, sigmas, sigma, **kwargs)\u001b[0m\n\u001b[1;32m 141\u001b[0m batch_size, device \u001b[38;5;241m=\u001b[39m x_noisy\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], x_noisy\u001b[38;5;241m.\u001b[39mdevice\n\u001b[1;32m 142\u001b[0m sigmas \u001b[38;5;241m=\u001b[39m to_batch(x\u001b[38;5;241m=\u001b[39msigma, xs\u001b[38;5;241m=\u001b[39msigmas, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[0;32m--> 143\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_noisy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigmas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:1099\u001b[0m, in \u001b[0;36mUNet1d.forward\u001b[0;34m(self, x, time, features, channels_list, embedding)\u001b[0m\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, downsample \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownsamples):\n\u001b[1;32m 1098\u001b[0m channels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_channels(channels_list, layer\u001b[38;5;241m=\u001b[39mi \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m-> 1099\u001b[0m x, skips \u001b[38;5;241m=\u001b[39m \u001b[43mdownsample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1100\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchannels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchannels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membedding\u001b[49m\n\u001b[1;32m 1101\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1102\u001b[0m skips_list \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m [skips]\n\u001b[1;32m 1104\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbottleneck(x, mapping\u001b[38;5;241m=\u001b[39mmapping, embedding\u001b[38;5;241m=\u001b[39membedding)\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:670\u001b[0m, in \u001b[0;36mDownsampleBlock1d.forward\u001b[0;34m(self, x, mapping, channels, embedding)\u001b[0m\n\u001b[1;32m 668\u001b[0m skips \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m block \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mblocks:\n\u001b[0;32m--> 670\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmapping\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 671\u001b[0m skips \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m [x] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_skip \u001b[38;5;28;01melse\u001b[39;00m []\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_transformer:\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:199\u001b[0m, in \u001b[0;36mResnetBlock1d.forward\u001b[0;34m(self, x, mapping)\u001b[0m\n\u001b[1;32m 196\u001b[0m assert_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontext mapping required if context_mapping_features > 0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mapping \u001b[38;5;241m^\u001b[39m exists(mapping)), assert_message\n\u001b[0;32m--> 199\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblock1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 201\u001b[0m scale_shift \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mapping:\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:124\u001b[0m, in \u001b[0;36mConvBlock1d.forward\u001b[0;34m(self, x, scale_shift)\u001b[0m\n\u001b[1;32m 122\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m*\u001b[39m (scale \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m+\u001b[39m shift\n\u001b[1;32m 123\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation(x)\n\u001b[0;32m--> 124\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproject\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/conv.py:313\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 313\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/conv.py:309\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 307\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 308\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 310\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "epochs = 50\n", "for i in tqdm(range(epochs)):\n", " for batch in data:\n", " loss = model(batch)\n", " loss.backward()\n", " if i % 5 == 0:\n", " wandb.log({\"loss\": loss})\n", " with torch.no_grad():\n", " noise = torch.randn(1, 1, 2**17)\n", " sampled = model.sample(noise=noise, num_steps=40)\n", " z = sampled.view(-1)\n", " wandb.log({f\"Audio_{i}\": wandb.Audio(z.numpy(), sample_rate=SAMPLE_RATE)})\n", " \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 259, "id": "d18e4816", "metadata": {}, "outputs": [], "source": [ "noise = torch.randn(1, 1, 2**17)\n", "sampled = model.sample(noise=noise, num_steps=50)" ] }, { "cell_type": "code", "execution_count": 260, "id": "054e708f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 1, 131072]) tensor([[[-0.4879, -0.4534, -0.4094, ..., -1.0000, 0.8554, -0.9605]]])\n" ] } ], "source": [ "print(sampled.shape, sampled)" ] }, { "cell_type": "code", "execution_count": 261, "id": "fc8becc0", "metadata": {}, "outputs": [], "source": [ "z = sampled.view(-1)\n", "# z = z.mean(axis=0)" ] }, { "cell_type": "code", "execution_count": 262, "id": "7731b0e7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 262, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Audio(z, rate=22050)" ] }, { "cell_type": "code", "execution_count": 258, "id": "907c7dfb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.6213, grad_fn=)\n" ] } ], "source": [] }, { "cell_type": "code", "execution_count": 164, "id": "698bc12e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6890.625" ] }, "execution_count": 164, "metadata": {}, "output_type": "execute_result" } ], "source": [ "110250 / 16" ] }, { "cell_type": "code", "execution_count": 165, "id": "d4d8bb49", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "264600" ] }, "execution_count": 165, "metadata": {}, "output_type": "execute_result" } ], "source": [ "12 * 22050" ] }, { "cell_type": "code", "execution_count": 166, "id": "ecda7e0e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16537.5" ] }, "execution_count": 166, "metadata": {}, "output_type": "execute_result" } ], "source": [ "264600 / 16" ] }, { "cell_type": "code", "execution_count": null, "id": "a9980287", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }