{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MusicGen\n",
    "Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.\n",
    "\n",
    "First, we start by initializing MusicGen, you can choose a model from the following selection:\n",
    "1. `small` - 300M transformer decoder.\n",
    "2. `medium` - 1.5B transformer decoder.\n",
    "3. `melody` - 1.5B transformer decoder also supporting melody conditioning.\n",
    "4. `large` - 3.3B transformer decoder.\n",
    "\n",
    "We will use the `small` variant for the purpose of this demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.models import MusicGen\n",
    "\n",
    "# Using small model, better results would be obtained with `medium` or `large`.\n",
    "model = MusicGen.get_pretrained('small')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
    "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
    "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
    "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
    "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
    "* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n",
    "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
    "\n",
    "When left unchanged, MusicGen will revert to its default parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_generation_params(\n",
    "    use_sampling=True,\n",
    "    top_k=250,\n",
    "    duration=5\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we can go ahead and start generating music using one of the following modes:\n",
    "* Unconditional samples using `model.generate_unconditional`\n",
    "* Music continuation using `model.generate_continuation`\n",
    "* Text-conditional samples using `model.generate`\n",
    "* Melody-conditional samples using `model.generate_with_chroma`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Unconditional Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "output = model.generate_unconditional(num_samples=2, progress=True)\n",
    "display_audio(output, sample_rate=32000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Music Continuation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torchaudio\n",
    "import torch\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "def get_bip_bip(bip_duration=0.125, frequency=440,\n",
    "                duration=0.5, sample_rate=32000, device=\"cuda\"):\n",
    "    \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
    "    t = torch.arange(\n",
    "        int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
    "    wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
    "    tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
    "    envelope = (tp >= 0.5).float()\n",
    "    return wav * envelope\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we use a synthetic signal to prompt both the tonality and the BPM\n",
    "# of the generated audio.\n",
    "res = model.generate_continuation(\n",
    "    get_bip_bip(0.125).expand(2, -1, -1), \n",
    "    32000, ['Jazz jazz and only jazz', \n",
    "            'Heartful EDM with beautiful synths and chords'], \n",
    "    progress=True)\n",
    "display_audio(res, 32000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n",
    "prompt_waveform, prompt_sr = torchaudio.load(\"./assets/bach.mp3\")\n",
    "prompt_duration = 2\n",
    "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n",
    "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n",
    "display_audio(output, sample_rate=32000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-conditional Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "output = model.generate(\n",
    "    descriptions=[\n",
    "        '80s pop track with bassy drums and synth',\n",
    "        '90s rock song with loud guitars and heavy drums',\n",
    "    ],\n",
    "    progress=True\n",
    ")\n",
    "display_audio(output, sample_rate=32000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Melody-conditional Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchaudio\n",
    "from audiocraft.utils.notebook import display_audio\n",
    "\n",
    "model = MusicGen.get_pretrained('melody')\n",
    "model.set_generation_params(duration=8)\n",
    "\n",
    "melody_waveform, sr = torchaudio.load(\"assets/bach.mp3\")\n",
    "melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n",
    "output = model.generate_with_chroma(\n",
    "    descriptions=[\n",
    "        '80s pop track with bassy drums and synth',\n",
    "        '90s rock song with loud guitars and heavy drums',\n",
    "    ],\n",
    "    melody_wavs=melody_waveform,\n",
    "    melody_sample_rate=sr,\n",
    "    progress=True\n",
    ")\n",
    "display_audio(output, sample_rate=32000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}