{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "c2807819", "metadata": {}, "outputs": [], "source": [ "from audio_diffusion_pytorch import AudioDiffusionModel\n", "import torch\n", "from tqdm import tqdm\n", "from IPython.display import Audio\n", "from pathlib import Path\n", "import torchaudio\n", "import torchaudio.transforms as T\n", "import pytorch_lightning as pl\n", "from torch.utils.data import random_split, DataLoader, Dataset\n", "import torch.nn.functional as F\n", "import wandb\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "cdd08230-c057-4a6e-83b9-435b2c0fbaaf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'1.13.0+cu117'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.__version__" ] }, { "cell_type": "code", "execution_count": 2, "id": "469edd04", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.13.6" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /home/jovyan/RemFx/wandb/run-20221209_160820-9wzgwfl3" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run fast-snowflake-6 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wandb.init(project=\"RemFX\", entity=\"mattricesound\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "8d7eacfc", "metadata": {}, "outputs": [], "source": [ "SAMPLE_RATE = 22050\n", "LENGTH = 2**17#round(5 * SAMPLE_RATE) 6 seconds" ] }, { "cell_type": "code", "execution_count": 6, "id": "d8f78b50-b8f5-4008-b986-fb02590a9cd1", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "cdc0fb64", "metadata": {}, "outputs": [], "source": [ "class GuitarDataset(Dataset):\n", " def __init__(self, root, length=LENGTH):\n", " self.files = list(Path().glob(f\"{root}/**/*.wav\"))\n", " self.resampler = T.Resample(48000, SAMPLE_RATE)\n", " \n", " def __len__(self):\n", " return len(self.files)\n", " \n", " def __getitem__(self, idx):\n", " x, sr = torchaudio.load(self.files[idx])\n", "# x = x.view() # Duplicate channel\n", " resampled_x = self.resampler(x)\n", " if resampled_x.shape[1] < LENGTH:\n", " resampled_x = F.pad(resampled_x, (0, LENGTH - resampled_x.shape[1]))\n", " elif resampled_x.shape[1] > LENGTH:\n", " resampled_x = resampled_x[:, :LENGTH]\n", " return resampled_x.to(device)" ] }, { "cell_type": "code", "execution_count": 10, "id": "148c2a96", "metadata": {}, "outputs": [], "source": [ "g = GuitarDataset(Path(\"Clean\"))" ] }, { "cell_type": "code", "execution_count": 11, "id": "670c94a5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 131072])\n" ] } ], "source": [ "x = g[10]\n", "print(x.shape)" ] }, { "cell_type": "code", "execution_count": 12, "id": "e1c83600", "metadata": {}, "outputs": [], "source": [ "data = DataLoader(GuitarDataset(Path(\"Clean\")), batch_size=32)" ] }, { "cell_type": "code", "execution_count": 13, "id": "4d46f992", "metadata": {}, "outputs": [], "source": [ "dataiter = iter(data)\n", "x = next(dataiter)" ] }, { "cell_type": "code", "execution_count": 14, "id": "1103e520", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 131072])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[0].shape" ] }, { "cell_type": "code", "execution_count": 15, "id": "6b0f1575", "metadata": {}, "outputs": [], "source": [ "# wandb.log({\"Audio\": wandb.Audio(x[0].view(-1).numpy(), sample_rate=SAMPLE_RATE)})" ] }, { "cell_type": "code", "execution_count": 39, "id": "314fd8af-a813-436e-9ca5-29dc3a5ad460", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 7, "id": "eff19abd-304c-449e-9fb5-4e9ce4d4b19c", "metadata": {}, "outputs": [], "source": [ "model = AudioDiffusionModel(in_channels=1, \n", " patch_size=1,\n", " multipliers=[1, 2, 4, 4, 4, 4, 4],\n", " factors=[2, 2, 2, 2, 2, 2],\n", " num_blocks=[2, 2, 2, 2, 2, 2],\n", " attentions=[0, 0, 0, 0, 0, 0]\n", " )\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 9, "id": "75dd6e95-5e31-43f5-a0f8-05c7e13e7a14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "300\n", "310\n", "320\n", "330\n", "340\n", "350\n", "360\n", "370\n", "380\n", "390\n", "400\n", "410\n", "420\n", "430\n", "440\n", "450\n", "460\n", "470\n", "480\n", "490\n", "500\n", "510\n", "520\n", "530\n", "540\n", "550\n", "560\n", "570\n", "580\n", "590\n", "600\n", "610\n", "620\n", "630\n", "640\n", "650\n", "660\n", "670\n", "680\n", "690\n", "700\n", "710\n", "720\n", "730\n", "740\n", "750\n", "760\n", "770\n", "780\n", "790\n", "800\n", "810\n", "820\n", "830\n", "840\n", "850\n", "860\n", "870\n", "880\n", "890\n", "900\n", "910\n", "920\n", "930\n", "940\n", "950\n", "960\n", "970\n", "980\n", "990\n", "1000\n", "1010\n", "1020\n", "1030\n", "1040\n", "1050\n", "1060\n", "1070\n", "1080\n", "1090\n", "1100\n", "1110\n", "1120\n", "1130\n", "1140\n", "1150\n", "1160\n", "1170\n", "1180\n", "1190\n", "1200\n", "1210\n", "1220\n", "1230\n", "1240\n", "1250\n", "1260\n", "1270\n", "1280\n", "1290\n", "1300\n", "1310\n", "1320\n", "1330\n", "1340\n", "1350\n", "1360\n", "1370\n", "1380\n", "1390\n", "1400\n", "1410\n", "1420\n", "1430\n", "1440\n", "1450\n", "1460\n", "1470\n", "1480\n", "1490\n", "1500\n", "1510\n", "1520\n", "1530\n", "1540\n", "1550\n", "1560\n", "1570\n", "1580\n", "1590\n", "1600\n", "1610\n", "1620\n", "1630\n", "1640\n", "1650\n", "1660\n", "1670\n", "1680\n", "1690\n", "1700\n", "1710\n", "1720\n", "1730\n", "1740\n", "1750\n", "1760\n", "1770\n", "1780\n", "1790\n", "1800\n", "1810\n", "1820\n", "1830\n", "1840\n", "1850\n", "1860\n", "1870\n", "1880\n", "1890\n", "1900\n", "1910\n", "1920\n", "1930\n", "1940\n", "1950\n", "1960\n", "1970\n", "1980\n", "1990\n", "2000\n", "2010\n", "2020\n", "2030\n", "2040\n", "2050\n", "2060\n", "2070\n", "2080\n", "2090\n", "2100\n", "2110\n", "2120\n", "2130\n", "2140\n", "2150\n", "2160\n", "2170\n", "2180\n", "2190\n", "2200\n", "2210\n", "2220\n", "2230\n", "2240\n", "2250\n", "2260\n", "2270\n", "2280\n", "2290\n", "2300\n", "2310\n", "2320\n", "2330\n", "2340\n", "2350\n", "2360\n", "2370\n", "2380\n", "2390\n", "2400\n", "2410\n", "2420\n", "2430\n", "2440\n", "2450\n", "2460\n", "2470\n", "2480\n", "2490\n", "2500\n", "2510\n", "2520\n", "2530\n", "2540\n", "2550\n", "2560\n", "2570\n", "2580\n", "2590\n", "2600\n", "2610\n", "2620\n", "2630\n", "2640\n", "2650\n", "2660\n", "2670\n", "2680\n", "2690\n", "2700\n", "2710\n", "2720\n", "2730\n", "2740\n", "2750\n", "2760\n", "2770\n", "2780\n", "2790\n", "2800\n", "2810\n", "2820\n", "2830\n", "2840\n", "2850\n", "2860\n", "2870\n", "2880\n", "2890\n", "2900\n", "2910\n", "2920\n", "2930\n", "2940\n", "2950\n", "2960\n", "2970\n", "2980\n", "2990\n", "3000\n", "3010\n", "3020\n", "3030\n", "3040\n", "3050\n", "3060\n", "3070\n", "3080\n", "3090\n", "3100\n", "3110\n", "3120\n", "3130\n", "3140\n", "3150\n", "3160\n", "3170\n", "3180\n", "3190\n", "3200\n", "3210\n", "3220\n", "3230\n", "3240\n", "3250\n", "3260\n", "3270\n", "3280\n", "3290\n", "3300\n", "3310\n", "3320\n", "3330\n", "3340\n", "3350\n", "3360\n", "3370\n", "3380\n", "3390\n", "3400\n", "3410\n", "3420\n", "3430\n", "3440\n", "3450\n", "3460\n", "3470\n", "3480\n", "3490\n", "3500\n", "3510\n", "3520\n", "3530\n", "3540\n", "3550\n", "3560\n", "3570\n", "3580\n", "3590\n", "3600\n", "3610\n", "3620\n", "3630\n", "3640\n", "3650\n", "3660\n", "3670\n", "3680\n", "3690\n", "3700\n", "3710\n", "3720\n", "3730\n", "3740\n", "3750\n", "3760\n", "3770\n", "3780\n", "3790\n", "3800\n", "3810\n", "3820\n", "3830\n", "3840\n", "3850\n", "3860\n", "3870\n", "3880\n", "3890\n", "3900\n", "3910\n", "3920\n", "3930\n", "3940\n", "3950\n", "3960\n", "3970\n", "3980\n", "3990\n", "4000\n", "4010\n", "4020\n", "4030\n", "4040\n", "4050\n", "4060\n", "4070\n", "4080\n", "4090\n", "4100\n", "4110\n", "4120\n", "4130\n", "4140\n", "4150\n", "4160\n", "4170\n", "4180\n", "4190\n", "4200\n", "4210\n", "4220\n", "4230\n", "4240\n", "4250\n", "4260\n", "4270\n", "4280\n", "4290\n", "4300\n", "4310\n", "4320\n", "4330\n", "4340\n", "4350\n", "4360\n", "4370\n", "4380\n", "4390\n", "4400\n", "4410\n", "4420\n", "4430\n", "4440\n", "4450\n", "4460\n", "4470\n", "4480\n", "4490\n", "4500\n", "4510\n", "4520\n", "4530\n", "4540\n", "4550\n", "4560\n", "4570\n", "4580\n", "4590\n", "4600\n", "4610\n", "4620\n", "4630\n", "4640\n", "4650\n", "4660\n", "4670\n", "4680\n", "4690\n", "4700\n", "4710\n", "4720\n", "4730\n", "4740\n", "4750\n", "4760\n", "4770\n", "4780\n", "4790\n", "4800\n", "4810\n", "4820\n", "4830\n", "4840\n", "4850\n", "4860\n", "4870\n", "4880\n", "4890\n", "4900\n", "4910\n", "4920\n", "4930\n", "4940\n", "4950\n", "4960\n", "4970\n", "4980\n", "4990\n", "5000\n", "5010\n", "5020\n", "5030\n", "5040\n", "5050\n", "5060\n", "5070\n", "5080\n", "5090\n", "5100\n", "5110\n", "5120\n", "5130\n", "5140\n", "5150\n", "5160\n", "5170\n", "5180\n", "5190\n", "5200\n", "5210\n", "5220\n", "5230\n", "5240\n", "5250\n", "5260\n", "5270\n", "5280\n", "5290\n", "5300\n", "5310\n", "5320\n", "5330\n", "5340\n", "5350\n", "5360\n", "5370\n", "5380\n", "5390\n", "5400\n", "5410\n", "5420\n", "5430\n", "5440\n", "5450\n", "5460\n", "5470\n", "5480\n", "5490\n", "5500\n", "5510\n", "5520\n", "5530\n", "5540\n", "5550\n", "5560\n", "5570\n", "5580\n", "5590\n", "5600\n", "5610\n", "5620\n", "5630\n", "5640\n", "5650\n", "5660\n", "5670\n", "5680\n", "5690\n", "5700\n", "5710\n", "5720\n", "5730\n", "5740\n", "5750\n", "5760\n", "5770\n", "5780\n", "5790\n", "5800\n", "5810\n", "5820\n", "5830\n", "5840\n", "5850\n", "5860\n", "5870\n", "5880\n", "5890\n", "5900\n", "5910\n", "5920\n", "5930\n", "5940\n", "5950\n", "5960\n", "5970\n", "5980\n", "5990\n", "6000\n", "6010\n", "6020\n", "6030\n", "6040\n", "6050\n", "6060\n", "6070\n", "6080\n", "6090\n", "6100\n", "6110\n", "6120\n", "6130\n", "6140\n", "6150\n", "6160\n", "6170\n", "6180\n", "6190\n", "6200\n", "6210\n", "6220\n", "6230\n", "6240\n", "6250\n", "6260\n", "6270\n", "6280\n", "6290\n", "6300\n", "6310\n", "6320\n", "6330\n", "6340\n", "6350\n", "6360\n", "6370\n", "6380\n", "6390\n", "6400\n", "6410\n", "6420\n", "6430\n", "6440\n", "6450\n", "6460\n", "6470\n", "6480\n", "6490\n", "6500\n", "6510\n", "6520\n", "6530\n", "6540\n", "6550\n", "6560\n", "6570\n", "6580\n", "6590\n", "6600\n", "6610\n", "6620\n", "6630\n", "6640\n", "6650\n", "6660\n", "6670\n", "6680\n", "6690\n", "6700\n", "6710\n", "6720\n", "6730\n", "6740\n", "6750\n", "6760\n", "6770\n", "6780\n", "6790\n", "6800\n", "6810\n", "6820\n", "6830\n", "6840\n", "6850\n", "6860\n", "6870\n", "6880\n", "6890\n", "6900\n", "6910\n", "6920\n", "6930\n", "6940\n", "6950\n", "6960\n", "6970\n", "6980\n", "6990\n", "7000\n", "7010\n", "7020\n", "7030\n", "7040\n", "7050\n", "7060\n", "7070\n", "7080\n", "7090\n", "7100\n", "7110\n", "7120\n", "7130\n", "7140\n", "7150\n", "7160\n", "7170\n", "7180\n", "7190\n", "7200\n", "7210\n", "7220\n", "7230\n", "7240\n", "7250\n", "7260\n", "7270\n", "7280\n", "7290\n", "7300\n", "7310\n", "7320\n", "7330\n", "7340\n", "7350\n", "7360\n", "7370\n", "7380\n", "7390\n", "7400\n", "7410\n", "7420\n", "7430\n", "7440\n", "7450\n", "7460\n", "7470\n", "7480\n", "7490\n", "7500\n", "7510\n", "7520\n", "7530\n", "7540\n", "7550\n", "7560\n", "7570\n", "7580\n", "7590\n", "7600\n", "7610\n", "7620\n", "7630\n", "7640\n", "7650\n", "7660\n", "7670\n", "7680\n", "7690\n", "7700\n", "7710\n", "7720\n", "7730\n", "7740\n", "7750\n", "7760\n", "7770\n", "7780\n", "7790\n", "7800\n", "7810\n", "7820\n", "7830\n", "7840\n", "7850\n", "7860\n", "7870\n", "7880\n", "7890\n", "7900\n", "7910\n", "7920\n", "7930\n", "7940\n", "7950\n", "7960\n", "7970\n", "7980\n", "7990\n", "8000\n", "8010\n", "8020\n", "8030\n", "8040\n", "8050\n", "8060\n", "8070\n", "8080\n", "8090\n", "8100\n", "8110\n", "8120\n", "8130\n", "8140\n", "8150\n", "8160\n", "8170\n", "8180\n", "8190\n", "8200\n", "8210\n", "8220\n", "8230\n", "8240\n", "8250\n", "8260\n", "8270\n", "8280\n", "8290\n", "8300\n", "8310\n", "8320\n", "8330\n", "8340\n", "8350\n", "8360\n", "8370\n", "8380\n", "8390\n", "8400\n", "8410\n", "8420\n", "8430\n", "8440\n", "8450\n", "8460\n", "8470\n", "8480\n", "8490\n", "8500\n", "8510\n", "8520\n", "8530\n", "8540\n", "8550\n", "8560\n", "8570\n", "8580\n", "8590\n", "8600\n", "8610\n", "8620\n", "8630\n", "8640\n", "8650\n", "8660\n", "8670\n", "8680\n", "8690\n", "8700\n", "8710\n", "8720\n", "8730\n", "8740\n", "8750\n", "8760\n", "8770\n", "8780\n", "8790\n", "8800\n", "8810\n", "8820\n", "8830\n", "8840\n", "8850\n", "8860\n", "8870\n", "8880\n", "8890\n", "8900\n", "8910\n", "8920\n", "8930\n", "8940\n", "8950\n", "8960\n", "8970\n", "8980\n", "8990\n" ] } ], "source": [ "fs = 22050\n", "t = 2 ** 18 / 22050\n", "samples = torch.arange(t * fs) / fs\n", "\n", "for i in range(300, 8000):\n", " f = i\n", " signal1 = torch.sin(2 * torch.pi * f * samples)\n", " signal2 = torch.sin(2 * torch.pi * (f*2) * samples)\n", " stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)\n", " stacked_signal = stacked_signal.to(device)\n", " loss = model(stacked_signal)\n", " loss.backward() \n", " if i % 10 == 0:\n", " print(i)\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "bda06495-0546-4474-ba5c-bf55e4887329", "metadata": {}, "outputs": [], "source": [ "# Sample 2 sources given start noise\n", "noise = torch.randn(2, 1, 2 ** 18)\n", "noise = noise.to(device)\n", "sampled = model.sample(\n", " noise=noise,\n", " num_steps=10 # Suggested range: 2-50\n", ") # [2, 1, 2 ** 18]" ] }, { "cell_type": "code", "execution_count": 11, "id": "2d025c1e-3618-4801-9b9b-b4e50e41dcf7", "metadata": {}, "outputs": [], "source": [ "z = sampled[1]" ] }, { "cell_type": "code", "execution_count": 12, "id": "583d4d28-7b1b-463b-8642-4975b36f38f2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 262144])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z.shape" ] }, { "cell_type": "code", "execution_count": 13, "id": "eeec47b7-4b99-4239-9c61-fd36ad881876", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Audio(z.cpu(), rate=22050)" ] }, { "cell_type": "code", "execution_count": 14, "id": "4d87215c-4f2d-410b-ac33-7cc1d9f73fac", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'sig' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m Audio(\u001b[43msig\u001b[49m[\u001b[38;5;241m0\u001b[39m], rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m22050\u001b[39m)\n", "\u001b[0;31mNameError\u001b[0m: name 'sig' is not defined" ] } ], "source": [ "Audio(sig[0], rate=22050)" ] }, { "cell_type": "code", "execution_count": 18, "id": "2ccb733d-706a-4535-93b6-73ae2469de8a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Audio(stacked_signal[1].cpu(), rate=22050)" ] }, { "cell_type": "code", "execution_count": 47, "id": "0377cc63-846b-4acf-8fa9-f1d4a2b07be4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7999" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "i" ] }, { "cell_type": "code", "execution_count": 70, "id": "dcf6a106-7967-470a-932e-156b00e46ab2", "metadata": {}, "outputs": [], "source": [ "f = 4000\n", "signal1 = torch.sin(2 * torch.pi * f * samples)\n", "signal2 = torch.sin(2 * torch.pi * (f*2) * samples)" ] }, { "cell_type": "code", "execution_count": 72, "id": "fac2d679-9e68-4bcc-8119-745435d128ed", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Audio(signal1.cpu(), rate=22050)" ] }, { "cell_type": "code", "execution_count": 66, "id": "ddf58e57-4660-4e1a-83e3-5909da3b42fe", "metadata": {}, "outputs": [], "source": [ "fs = 22050\n", "f = 440\n", "t = 2 ** 18 / 22050\n", "samples = torch.arange(t * fs) / fs\n", "signal = torch.sin(2 * torch.pi * f * samples)\n", "sig = torch.stack((signal, signal)).unsqueeze(1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "faef7cc2-94b0-4b85-919f-0339542570c7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1, 262144])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sig.shape" ] }, { "cell_type": "code", "execution_count": 17, "id": "6cd94fea-3d4c-4a5b-bcba-2220fb3e9414", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16384.0" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "262144 / 16" ] }, { "cell_type": "code", "execution_count": 89, "id": "a62143ce-e47b-49e8-979f-e9241068d744", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([264600])" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "signal.shape" ] }, { "cell_type": "code", "execution_count": 17, "id": "e79b1b33-1905-4ae6-9dbe-73b68eec1dc5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Audio(sig[0], rate=22050)" ] }, { "cell_type": "code", "execution_count": 24, "id": "a6a2bb97", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 500/500 [15:47:21<00:00, 113.68s/it] \n" ] } ], "source": [ "epochs = 500\n", "for i in tqdm(range(epochs)):\n", " for batch in data:\n", " loss = model(batch)\n", " loss.backward()\n", " if i % 10 == 0:\n", " wandb.log({\"loss\": loss})\n", " with torch.no_grad():\n", " noise = torch.randn(1, 1, 2**17).to(device)\n", " sampled = model.sample(noise=noise, num_steps=40)\n", " z = sampled.view(-1)\n", " wandb.log({f\"Audio_{i}\": wandb.Audio(z.cpu().numpy(), sample_rate=SAMPLE_RATE)})\n", " \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 259, "id": "d18e4816", "metadata": {}, "outputs": [], "source": [ "noise = torch.randn(1, 1, 2**17)\n", "sampled = model.sample(noise=noise, num_steps=50)" ] }, { "cell_type": "code", "execution_count": 260, "id": "054e708f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 1, 131072]) tensor([[[-0.4879, -0.4534, -0.4094, ..., -1.0000, 0.8554, -0.9605]]])\n" ] } ], "source": [ "print(sampled.shape, sampled)" ] }, { "cell_type": "code", "execution_count": 32, "id": "fc8becc0", "metadata": {}, "outputs": [], "source": [ "z = sampled.view(-1)\n", "# z = z.mean(axis=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "2c2296ba-7e43-4155-a754-349a7ee5f519", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "887fc2c1-de1a-4847-86ca-88b7c59f45fb", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "55e3555b-3f88-4a33-9fc8-a47bf5f28df7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }