diff --git "a/diffusion_test.ipynb" "b/diffusion_test.ipynb" new file mode 100644--- /dev/null +++ "b/diffusion_test.ipynb" @@ -0,0 +1,141 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b", + "metadata": {}, + "outputs": [], + "source": [ + "from audio_diffusion_pytorch import AudioDiffusionModel\n", + "import torch\n", + "from IPython.display import Audio" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a005011f-3019-4d34-bdf2-9a00e5480282", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1b689f18-375f-4b40-9ddc-a4ced6a5e5e4", + "metadata": {}, + "outputs": [], + "source": [ + "model = AudioDiffusionModel(in_channels=1, \n", + " patch_size=1,\n", + " multipliers=[1, 2, 4, 4, 4, 4, 4],\n", + " factors=[2, 2, 2, 2, 2, 2],\n", + " num_blocks=[2, 2, 2, 2, 2, 2],\n", + " attentions=[0, 0, 0, 0, 0, 0]\n", + " )\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 300\n", + "Step 310\n", + "Step 320\n" + ] + } + ], + "source": [ + "fs = 22050\n", + "t = 2 ** 18 / 22050\n", + "samples = torch.arange(t * fs) / fs\n", + "\n", + "for i in range(300, 8000):\n", + " f = i\n", + " # Create 2 sine waves (one at f=step, other is octave up) \n", + " # There is aliasing at higher freq, but since it is sinusoids, that doesn't matter too much\n", + " signal1 = torch.sin(2 * torch.pi * f * samples)\n", + " signal2 = torch.sin(2 * torch.pi * (f*2) * samples)\n", + " stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)\n", + " stacked_signal = stacked_signal.to(device)\n", + " loss = model(stacked_signal)\n", + " loss.backward() \n", + " if i % 10 == 0:\n", + " print(\"Step\", i)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "71d17c51-842c-40a1-81a1-a53bf358bc8a", + "metadata": {}, + "outputs": [], + "source": [ + "# Sample 2 sources given start noise\n", + "noise = torch.randn(2, 1, 2 ** 18)\n", + "noise = noise.to(device)\n", + "sampled = model.sample(\n", + " noise=noise,\n", + " num_steps=10 # Suggested range: 2-50\n", + ") # [2, 1, 2 ** 18]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "59d71efa-05ac-4545-84da-8c09c033dfd7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "z = sampled[1]\n", + "Audio(z.cpu(), rate=22050)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}