{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "b451ab22", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import random\n", "import numpy as np\n", "from PIL import Image\n", "from datasets import load_dataset\n", "from IPython.display import Audio\n", "from diffusers import AutoencoderKL, AudioDiffusionPipeline, Mel" ] }, { "cell_type": "code", "execution_count": null, "id": "324cef44", "metadata": {}, "outputs": [], "source": [ "mel = Mel()\n", "vae = AutoencoderKL.from_pretrained('../models/autoencoder-kl')" ] }, { "cell_type": "code", "execution_count": null, "id": "da55ce79", "metadata": {}, "outputs": [], "source": [ "vae.config" ] }, { "cell_type": "code", "execution_count": null, "id": "5fea99ff", "metadata": {}, "outputs": [], "source": [ "ds = load_dataset('teticio/audio-diffusion-256')" ] }, { "cell_type": "markdown", "id": "3a65ec4d", "metadata": {}, "source": [ "### Reconstruct audio" ] }, { "cell_type": "code", "execution_count": null, "id": "426c6edd", "metadata": {}, "outputs": [], "source": [ "image = random.choice(ds['train'])['image']\n", "display(image)\n", "Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate())" ] }, { "cell_type": "code", "execution_count": null, "id": "29c9011d", "metadata": {}, "outputs": [], "source": [ "# encode\n", "input_image = np.frombuffer(image.tobytes(), dtype=\"uint8\").reshape(\n", " (image.height, image.width, 1))\n", "input_image = ((input_image / 255) * 2 - 1).transpose(2, 0, 1)\n", "posterior = vae.encode(torch.tensor([input_image],\n", " dtype=torch.float32)).latent_dist\n", "latents = posterior.sample()" ] }, { "cell_type": "code", "execution_count": null, "id": "323ba46d", "metadata": {}, "outputs": [], "source": [ "# reconstruct\n", "output_image = vae.decode(latents)['sample']\n", "output_image = torch.clamp(output_image, -1., 1.)\n", "output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n", "output_image = (output_image.detach().cpu().numpy() *\n", " 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0, :, :, 0]\n", "output_image = Image.fromarray(output_image)\n", "display(output_image)\n", "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())" ] }, { "cell_type": "markdown", "id": "00ff2ffa", "metadata": {}, "source": [ "### Random sample from latent space\n", "(Don't expect interesting results!)" ] }, { "cell_type": "code", "execution_count": null, "id": "156a06a2", "metadata": {}, "outputs": [], "source": [ "# sample\n", "output_image = vae.decode(torch.randn_like(latents))['sample']\n", "output_image = torch.clamp(output_image, -1., 1.)\n", "output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n", "output_image = (output_image.detach().cpu().numpy() *\n", " 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0, :, :, 0]\n", "output_image = Image.fromarray(output_image)\n", "display(output_image)\n", "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())" ] }, { "cell_type": "markdown", "id": "ee3997cf", "metadata": {}, "source": [ "### Interpolate between two audios in latent space" ] }, { "cell_type": "code", "execution_count": null, "id": "46019770", "metadata": {}, "outputs": [], "source": [ "image2 = random.choice(ds['train'])['image']\n", "display(image2)\n", "Audio(data=mel.image_to_audio(image2), rate=mel.get_sample_rate())" ] }, { "cell_type": "code", "execution_count": null, "id": "e6552b19", "metadata": {}, "outputs": [], "source": [ "# encode\n", "input_image2 = np.frombuffer(image2.tobytes(), dtype=\"uint8\").reshape(\n", " (image2.height, image2.width, 1))\n", "input_image2 = ((input_image2 / 255) * 2 - 1).transpose(2, 0, 1)\n", "posterior2 = vae.encode(torch.tensor([input_image2],\n", " dtype=torch.float32)).latent_dist\n", "latents2 = posterior2.sample()" ] }, { "cell_type": "code", "execution_count": null, "id": "060a0b25", "metadata": {}, "outputs": [], "source": [ "# interpolate\n", "alpha = 0.5 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", "output_image = vae.decode(\n", " AudioDiffusionPipeline.slerp(latents, latents2, alpha))['sample']\n", "output_image = torch.clamp(output_image, -1., 1.)\n", "output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n", "output_image = (output_image.detach().cpu().numpy() *\n", " 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0, :, :, 0]\n", "output_image = Image.fromarray(output_image)\n", "display(output_image)\n", "display(Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate()))\n", "display(Audio(data=mel.image_to_audio(image2), rate=mel.get_sample_rate()))\n", "display(\n", " Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate()))" ] }, { "cell_type": "code", "execution_count": null, "id": "d6c74105", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "huggingface", "language": "python", "name": "huggingface" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }