{ "cells": [ { "metadata": {}, "execution_count": null, "outputs": [], "cell_type": "code", "source": [ "from google.colab import drive\n", "mount_drive = True #@param {type:\"boolean\"}\n", "if mount_drive:\n", " drive.mount('/content/drive')\n", "\n", "requirements_txt = \"git+https://github.com/ArthurZucker/transformers.git@jukebox\\naccelerate\\nbitsandbytes==0.31.8\\ngradio\"\n", "\n", "# Save the requirements.txt file\n", "with open('requirements.txt', 'w') as f:\n", " f.write(requirements_txt)\n", "\n", "# Install the dependencies\n", "%pip install -r requirements.txt" ] }, { "metadata": {}, "execution_count": null, "outputs": [], "cell_type": "code", "source": [ "# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI\n", "\n", "import sys\n", "\n", "import torch as t\n", "from transformers import JukeboxVQVAE\n", "import gradio as gr\n", "\n", "model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']\n", "\n", "if 'google.colab' in sys.modules:\n", "\n", " cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:\"string\"}\n", " # Connect to your Google Drive\n", " from google.colab import drive\n", " drive.mount('/content/drive')\n", "\n", "else:\n", "\n", " cache_path = '~/.cache/'\n", "\n", "class Convert:\n", "\n", " class TokenList:\n", "\n", " def to_tokens_file(tokens_list):\n", " # temporary random file name\n", " filename = f\"tmp/{t.randint(0, 1000000)}.jt\"\n", " t.save(validate_tokens_list(tokens_list), filename)\n", " return filename\n", "\n", " def to_audio(tokens_list):\n", " return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)\n", " # TODO: Implement converting other levels besides 2\n", "\n", " class TokensFile:\n", "\n", " def to_tokens_list(file):\n", " return validate_tokens_list(t.load(file))\n", "\n", " def to_audio(file):\n", " return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))\n", "\n", " class Audio:\n", "\n", " def to_tokens_list(audio):\n", " return model.encode(audio.unsqueeze(0), start_level=2)\n", " # (TODO: Generated by copilot, check if it works)\n", "\n", " def to_tokens_file(audio):\n", " return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))\n", "\n", "def init():\n", " global model\n", "\n", " try:\n", " model\n", " print(\"Model already initialized\")\n", " except NameError:\n", " model = JukeboxVQVAE.from_pretrained(\n", " model_id,\n", " torch_dtype = t.float16,\n", " cache_dir = f\"{cache_path}/jukebox/models\"\n", " )\n", "\n", "def validate_tokens_list(tokens_list):\n", " # Make sure that:\n", " # - tokens_list is a list of exactly 3 torch tensors\n", " assert len(tokens_list) == 3, \"Invalid file format: expecting a list of 3 tensors\"\n", "\n", " # - each has the same number of dimensions\n", " assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), \"Invalid file format: each tensor in the list should have the same number of dimensions\"\n", "\n", " # - the shape along dimension 0 is the same\n", " assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], \"Invalid file format: the shape along dimension 0 should be the same for all tensors in the list\"\n", "\n", " # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2\n", " assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], \"Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2\"\n", "\n", " return tokens_list\n", "\n", "\n", "with gr.Blocks() as ui:\n", "\n", " # File input to upload or download the music tokens file\n", " tokens = gr.File(label='music_tokens_file')\n", "\n", " # Audio output to play or upload the generated audio\n", " audio = gr.Audio(label='audio')\n", "\n", " # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)\n", " gr.Button(\"Convert tokens to audio\", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)\n", " gr.Button(\"Convert audio to tokens\", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)\n", "\n", "if __name__ == '__main__':\n", " init()\n", " ui.launch()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.7.5" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }