{ "cells": [ { "cell_type": "markdown", "source": [ "## Prerequisites" ], "metadata": { "id": "w4LtdMb23tZ4" } }, { "cell_type": "markdown", "metadata": { "id": "JgJLBIh3fm-W" }, "source": [ "### Install Dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EodUpreufqD-" }, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bOn11huvfuXc" }, "outputs": [], "source": [ "!pip install --upgrade --quiet pip\n", "!pip install --quiet git+https://github.com/huggingface/transformers.git" ] }, { "cell_type": "code", "source": [ "!pip install typing-extensions==4.5.0\n", "!pip install python-multipart\n", "!pip install kaleido\n", "!pip install notebook>=6.5.5\n", "!pip install click>=8.0\n", "!pip install fastapi\n", "!pip install \"uvicorn[standard]\"\n", "!pip install pyngrok" ], "metadata": { "id": "Nl0CQxwHCrFd" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ROxnljVbf6_o" }, "source": [ "### Load the models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ToxW1gbLf6tr" }, "outputs": [], "source": [ "from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed\n", "\n", "model = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\")\n", "processor = MusicgenProcessor.from_pretrained(\"facebook/musicgen-small\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "omP9Hg1ajUKM" }, "outputs": [], "source": [ "import torch\n", "from IPython.display import Audio\n", "\n", "sampling_rate = model.config.audio_encoder.sampling_rate\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "model.to(device)\n", "None" ] }, { "cell_type": "markdown", "metadata": { "id": "852jZKSqiKoT" }, "source": [ "## Music Generation functionality" ] }, { "cell_type": "markdown", "source": [ "#### Model Class" ], "metadata": { "id": "8nydshMdxKab" } }, { "cell_type": "code", "source": [ "import numpy as np\n", "import typing\n", "\n", "class AudioPalette:\n", " def __init__(self):\n", " pass\n", "\n", " def set_prompt(self, caption: str | typing.List[str]):\n", " self.caption = caption\n", "\n", " def generate(self):\n", " if isinstance(self.caption, str):\n", " return self.generate_single(max_new_tokens=1024)\n", " else:\n", " return self.generate_multiple()\n", "\n", " def generate_single(self, prompt=None, max_new_tokens=512):\n", " if not prompt:\n", " prompt = self.caption\n", " inputs = processor(\n", " text=[prompt],\n", " padding=True,\n", " return_tensors=\"pt\",\n", " sampling_rate=sampling_rate\n", " )\n", "\n", " audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=max_new_tokens)\n", " return audio_values\n", "\n", " def generate_audio_with_melody_conditioning(self, prompt, melody, max_new_tokens=256):\n", " inputs = processor(\n", " text=[prompt],\n", " audio=melody[0, 0].cpu().numpy(),\n", " padding=True,\n", " return_tensors=\"pt\",\n", " sampling_rate=sampling_rate\n", " )\n", "\n", " # set_seed(1)\n", " audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=max_new_tokens)\n", " return audio_values\n", "\n", " def generate_multiple(self):\n", " for idx, prompt in enumerate(self.caption):\n", " if idx == 0:\n", " audio = self.generate_single(prompt, 256)\n", " else:\n", " audio = self.generate_audio_with_melody_conditioning(prompt, audio)\n", " return audio" ], "metadata": { "id": "4V49E7xpxNPu" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "audiopalette = AudioPalette()" ], "metadata": { "id": "qW65Q68o-R7f" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "P3OmxnaBA9E-" }, "source": [ "#### API Creation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dxlw3ODsTtuB" }, "outputs": [], "source": [ "from fastapi import FastAPI\n", "from pydantic import BaseModel, Field\n", "from fastapi.middleware.cors import CORSMiddleware\n", "\n", "app = FastAPI()\n", "\n", "app.add_middleware(\n", " CORSMiddleware,\n", " allow_origins=['*'],\n", " allow_credentials=True,\n", " allow_methods=['*'],\n", " allow_headers=['*'],\n", ")" ] }, { "cell_type": "code", "source": [ "import typing\n", "import numpy as np\n", "\n", "class Prompt(BaseModel):\n", " caption: str | typing.List[str]\n", "\n", "class FileData(BaseModel):\n", " file_path: str\n", "\n", "# class Melody(BaseModel):\n", "# audio: np.ndarray\n", "\n", "# class Config:\n", "# arbitrary_types_allowed = True" ], "metadata": { "id": "iYUH3-GpfbN8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTQCdon0A9FA" }, "outputs": [], "source": [ "import tempfile\n", "import scipy\n", "\n", "from fastapi.responses import FileResponse\n", "\n", "@app.get('/')\n", "async def root():\n", " return {\"message\": \"Hello World\"}\n", "\n", "@app.post('/download')\n", "async def download(file_data: FileData):\n", " file_path = file_data.file_path\n", " return FileResponse(file_path)\n", "\n", "@app.post('/generate')\n", "async def gen_music(prompt: Prompt):\n", " audiopalette.set_prompt(prompt.caption)\n", " audio = audiopalette.generate()\n", "\n", " file_path = None\n", " with tempfile.NamedTemporaryFile(delete=False) as f:\n", " scipy.io.wavfile.write(f, rate=sampling_rate, data=audio[0, 0].cpu().numpy())\n", " file_path = f.name\n", "\n", " if not file_path:\n", " return {\"error\": \"There has been an error\"}\n", " return {\"file_path\": f\"{file_path}\"}\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ALpNtVpHA9FA" }, "source": [ "#### Run the API" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w3eG8rfRA9FB" }, "outputs": [], "source": [ "from getpass import getpass\n", "\n", "import nest_asyncio\n", "import uvicorn\n", "from pyngrok import ngrok" ] }, { "cell_type": "code", "source": [ "ngrok_auth_token = getpass(prompt=\"Enter ngrok auth token: \")\n", "ngrok.set_auth_token(ngrok_auth_token)" ], "metadata": { "id": "QFDDncCJEs4f" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yOhMLxA5A9FB" }, "outputs": [], "source": [ "ngrok_tunnel = ngrok.connect(8000)\n", "print(\"Public URL:\", ngrok_tunnel.public_url)\n", "nest_asyncio.apply()\n", "uvicorn.run(app, port=8000)" ] }, { "cell_type": "markdown", "metadata": { "id": "606pRql4A9FC" }, "source": [ "#### Kill ngrok Connection" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k7Tbq8w-A9FC" }, "outputs": [], "source": [ "ngrok.kill()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "w4LtdMb23tZ4" ], "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 }