{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"id": "4c52cc1c-91f1-4b79-924b-041d2929ef7b",
"metadata": {},
"outputs": [],
"source": [
"from audio_diffusion_pytorch import AudioDiffusionModel\n",
"import torch\n",
"from IPython.display import Audio\n",
"import matplotlib.pyplot as plt\n",
"from tqdm import tqdm\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "a005011f-3019-4d34-bdf2-9a00e5480282",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1b689f18-375f-4b40-9ddc-a4ced6a5e5e4",
"metadata": {},
"outputs": [],
"source": [
"model = AudioDiffusionModel(in_channels=1, \n",
" patch_size=1,\n",
" multipliers=[1, 2, 4, 4, 4, 4, 4],\n",
" factors=[2, 2, 2, 2, 2, 2],\n",
" num_blocks=[2, 2, 2, 2, 2, 2],\n",
" attentions=[0, 0, 0, 0, 0, 0]\n",
" )\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "bd8a1cb4-42b5-43bc-9a12-f594ce069b33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 32768])\n"
]
}
],
"source": [
"fs = 22050\n",
"t = 32768\n",
"fc_min = 220\n",
"fc_max = 440\n",
"batch_size = 8\n",
"samples = torch.arange(t) / fs\n",
"n_iters = 1000\n",
"\n",
"samples = samples.view(1, -1)\n",
"print(samples.shape)\n",
"\n",
"lr = 1e-4\n",
"optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "01265072",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"999 - loss step: 0.0457 loss mean: 0.1161: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [09:38<00:00, 1.73it/s]\n"
]
}
],
"source": [
"losses = []\n",
"pbar = tqdm(range(n_iters))\n",
"for i in pbar:\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" # create a batch of random sine waves\n",
" f = torch.randint(fc_min, fc_max, [batch_size,1])\n",
" signals = torch.sin(2 * torch.pi * f * samples)\n",
" signals = signals.view(batch_size, 1, -1)\n",
" signals = signals.to(device)\n",
"\n",
" loss = model(signals)\n",
" loss.backward() \n",
" optimizer.step()\n",
" \n",
" losses.append(loss.item())\n",
" pbar.set_description(f\"{i} - loss step: {loss.item():0.4f} loss mean: {np.mean(losses):0.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "71d17c51-842c-40a1-81a1-a53bf358bc8a",
"metadata": {},
"outputs": [],
"source": [
"# Sample 2 sources given start noise\n",
"noise = torch.randn(1, 1, t)\n",
"noise = noise.to(device)\n",
"sampled = model.sample(\n",
" noise=noise,\n",
" num_steps=50 # Suggested range: 2-50\n",
") # [2, 1, 2 ** 18]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "59d71efa-05ac-4545-84da-8c09c033dfd7",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z = sampled[0]\n",
"Audio(z.cpu(), rate=22050)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81eddd71-bba7-4c62-8d50-900b295bb2f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}