{ "cells": [ { "cell_type": "code", "execution_count": 27, "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b", "metadata": {}, "outputs": [], "source": [ "from audio_diffusion_pytorch import AudioDiffusionModel\n", "import torch\n", "from IPython.display import Audio\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 28, "id": "a005011f-3019-4d34-bdf2-9a00e5480282", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 29, "id": "1b689f18-375f-4b40-9ddc-a4ced6a5e5e4", "metadata": {}, "outputs": [], "source": [ "model = AudioDiffusionModel(in_channels=1, \n", " patch_size=1,\n", " multipliers=[1, 2, 4, 4, 4, 4, 4],\n", " factors=[2, 2, 2, 2, 2, 2],\n", " num_blocks=[2, 2, 2, 2, 2, 2],\n", " attentions=[0, 0, 0, 0, 0, 0]\n", " )\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 30, "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 32768])\n" ] } ], "source": [ "fs = 22050\n", "t = 32768\n", "fc_min = 220\n", "fc_max = 440\n", "batch_size = 8\n", "samples = torch.arange(t) / fs\n", "n_iters = 1000\n", "\n", "samples = samples.view(1, -1)\n", "print(samples.shape)\n", "\n", "lr = 1e-4\n", "optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3)" ] }, { "cell_type": "code", "execution_count": 31, "id": "01265072", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "999 - loss step: 0.0457 loss mean: 0.1161: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [09:38<00:00, 1.73it/s]\n" ] } ], "source": [ "losses = []\n", "pbar = tqdm(range(n_iters))\n", "for i in pbar:\n", " \n", " optimizer.zero_grad()\n", " \n", " # create a batch of random sine waves\n", " f = torch.randint(fc_min, fc_max, [batch_size,1])\n", " signals = torch.sin(2 * torch.pi * f * samples)\n", " signals = signals.view(batch_size, 1, -1)\n", " signals = signals.to(device)\n", "\n", " loss = model(signals)\n", " loss.backward() \n", " optimizer.step()\n", " \n", " losses.append(loss.item())\n", " pbar.set_description(f\"{i} - loss step: {loss.item():0.4f} loss mean: {np.mean(losses):0.4f}\")" ] }, { "cell_type": "code", "execution_count": 38, "id": "71d17c51-842c-40a1-81a1-a53bf358bc8a", "metadata": {}, "outputs": [], "source": [ "# Sample 2 sources given start noise\n", "noise = torch.randn(1, 1, t)\n", "noise = noise.to(device)\n", "sampled = model.sample(\n", " noise=noise,\n", " num_steps=50 # Suggested range: 2-50\n", ") # [2, 1, 2 ** 18]" ] }, { "cell_type": "code", "execution_count": 39, "id": "59d71efa-05ac-4545-84da-8c09c033dfd7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z = sampled[0]\n", "Audio(z.cpu(), rate=22050)" ] }, { "cell_type": "code", "execution_count": null, "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }