diff --git "a/diffusion_test.ipynb" "b/diffusion_test.ipynb"
new file mode 100644--- /dev/null
+++ "b/diffusion_test.ipynb"
@@ -0,0 +1,141 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from audio_diffusion_pytorch import AudioDiffusionModel\n",
+ "import torch\n",
+ "from IPython.display import Audio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "a005011f-3019-4d34-bdf2-9a00e5480282",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "1b689f18-375f-4b40-9ddc-a4ced6a5e5e4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = AudioDiffusionModel(in_channels=1, \n",
+ " patch_size=1,\n",
+ " multipliers=[1, 2, 4, 4, 4, 4, 4],\n",
+ " factors=[2, 2, 2, 2, 2, 2],\n",
+ " num_blocks=[2, 2, 2, 2, 2, 2],\n",
+ " attentions=[0, 0, 0, 0, 0, 0]\n",
+ " )\n",
+ "model = model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Step 300\n",
+ "Step 310\n",
+ "Step 320\n"
+ ]
+ }
+ ],
+ "source": [
+ "fs = 22050\n",
+ "t = 2 ** 18 / 22050\n",
+ "samples = torch.arange(t * fs) / fs\n",
+ "\n",
+ "for i in range(300, 8000):\n",
+ " f = i\n",
+ " # Create 2 sine waves (one at f=step, other is octave up) \n",
+ " # There is aliasing at higher freq, but since it is sinusoids, that doesn't matter too much\n",
+ " signal1 = torch.sin(2 * torch.pi * f * samples)\n",
+ " signal2 = torch.sin(2 * torch.pi * (f*2) * samples)\n",
+ " stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)\n",
+ " stacked_signal = stacked_signal.to(device)\n",
+ " loss = model(stacked_signal)\n",
+ " loss.backward() \n",
+ " if i % 10 == 0:\n",
+ " print(\"Step\", i)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "71d17c51-842c-40a1-81a1-a53bf358bc8a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Sample 2 sources given start noise\n",
+ "noise = torch.randn(2, 1, 2 ** 18)\n",
+ "noise = noise.to(device)\n",
+ "sampled = model.sample(\n",
+ " noise=noise,\n",
+ " num_steps=10 # Suggested range: 2-50\n",
+ ") # [2, 1, 2 ** 18]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "59d71efa-05ac-4545-84da-8c09c033dfd7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "z = sampled[1]\n",
+ "Audio(z.cpu(), rate=22050)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "81eddd71-bba7-4c62-8d50-900b295bb2f8",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}