{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "8UdXU6Bs9FQ6" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "P80evX4s4BhX", "outputId": "c388da56-06e1-45dd-982c-594f030dd2e6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading package lists... Done\n", "Building dependency tree... Done\n", "Reading state information... Done\n", "The following additional packages will be installed:\n", " fluid-soundfont-gm libevdev2 libfluidsynth3 libgudev-1.0-0 libinput-bin\n", " libinput10 libinstpatch-1.0-2 libmd4c0 libmtdev1 libqt5core5a libqt5dbus5\n", " libqt5gui5 libqt5network5 libqt5svg5 libqt5widgets5 libwacom-bin\n", " libwacom-common libwacom9 libxcb-icccm4 libxcb-image0 libxcb-keysyms1\n", " libxcb-render-util0 libxcb-util1 libxcb-xinerama0 libxcb-xinput0 libxcb-xkb1\n", " libxkbcommon-x11-0 qsynth qt5-gtk-platformtheme qttranslations5-l10n\n", " timgm6mb-soundfont\n", "Suggested packages:\n", " fluid-soundfont-gs qt5-image-formats-plugins qtwayland5 jackd\n", "The following NEW packages will be installed:\n", " fluid-soundfont-gm fluidsynth libevdev2 libfluidsynth3 libgudev-1.0-0\n", " libinput-bin libinput10 libinstpatch-1.0-2 libmd4c0 libmtdev1 libqt5core5a\n", " libqt5dbus5 libqt5gui5 libqt5network5 libqt5svg5 libqt5widgets5 libwacom-bin\n", " libwacom-common libwacom9 libxcb-icccm4 libxcb-image0 libxcb-keysyms1\n", " libxcb-render-util0 libxcb-util1 libxcb-xinerama0 libxcb-xinput0 libxcb-xkb1\n", " libxkbcommon-x11-0 qsynth qt5-gtk-platformtheme qttranslations5-l10n\n", " timgm6mb-soundfont\n", "0 upgraded, 32 newly installed, 0 to remove and 18 not upgraded.\n", "Need to get 148 MB of archives.\n", "After this operation, 207 MB of additional disk space will be used.\n", "Get:1 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 libqt5core5a amd64 5.15.3+dfsg-2ubuntu0.2 [2,006 kB]\n", "Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libevdev2 amd64 1.12.1+dfsg-1 [39.5 kB]\n", "Get:3 http://archive.ubuntu.com/ubuntu jammy/main amd64 libmtdev1 amd64 1.1.6-1build4 [14.5 kB]\n", "Get:4 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgudev-1.0-0 amd64 1:237-2build1 [16.3 kB]\n", "Get:5 http://archive.ubuntu.com/ubuntu jammy/main amd64 libwacom-common all 2.2.0-1 [54.3 kB]\n", "Get:6 http://archive.ubuntu.com/ubuntu jammy/main amd64 libwacom9 amd64 2.2.0-1 [22.0 kB]\n", "Get:7 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 libinput-bin amd64 1.20.0-1ubuntu0.3 [19.9 kB]\n", "Get:8 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 libinput10 amd64 1.20.0-1ubuntu0.3 [131 kB]\n", "Get:9 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libmd4c0 amd64 0.4.8-1 [42.0 kB]\n", "Get:10 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 libqt5dbus5 amd64 5.15.3+dfsg-2ubuntu0.2 [222 kB]\n", "Get:11 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 libqt5network5 amd64 5.15.3+dfsg-2ubuntu0.2 [731 kB]\n", "Get:12 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-icccm4 amd64 0.4.1-1.1build2 [11.5 kB]\n", "Get:13 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-util1 amd64 0.4.0-1build2 [11.4 kB]\n", "Get:14 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-image0 amd64 0.4.0-2 [11.5 kB]\n", "Get:15 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-keysyms1 amd64 0.4.0-1build3 [8,746 B]\n", "Get:16 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-render-util0 amd64 0.3.9-1build3 [10.3 kB]\n", "Get:17 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-xinerama0 amd64 1.14-3ubuntu3 [5,414 B]\n", "Get:18 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-xinput0 amd64 1.14-3ubuntu3 [34.3 kB]\n", "Get:19 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxcb-xkb1 amd64 1.14-3ubuntu3 [32.8 kB]\n", "Get:20 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxkbcommon-x11-0 amd64 1.4.0-1 [14.4 kB]\n", "Get:21 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 libqt5gui5 amd64 5.15.3+dfsg-2ubuntu0.2 [3,722 kB]\n", "Get:22 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 libqt5widgets5 amd64 5.15.3+dfsg-2ubuntu0.2 [2,561 kB]\n", "Get:23 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libqt5svg5 amd64 5.15.3-1 [149 kB]\n", "Get:24 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fluid-soundfont-gm all 3.1-5.3 [130 MB]\n", "Get:25 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libinstpatch-1.0-2 amd64 1.1.6-1 [240 kB]\n", "Get:26 http://archive.ubuntu.com/ubuntu jammy/universe amd64 timgm6mb-soundfont all 1.3-5 [5,427 kB]\n", "Get:27 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libfluidsynth3 amd64 2.2.5-1 [246 kB]\n", "Get:28 http://archive.ubuntu.com/ubuntu jammy/universe amd64 fluidsynth amd64 2.2.5-1 [27.4 kB]\n", "Get:29 http://archive.ubuntu.com/ubuntu jammy/main amd64 libwacom-bin amd64 2.2.0-1 [13.6 kB]\n", "Get:30 http://archive.ubuntu.com/ubuntu jammy/universe amd64 qsynth amd64 0.9.6-1 [305 kB]\n", "Get:31 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 qt5-gtk-platformtheme amd64 5.15.3+dfsg-2ubuntu0.2 [130 kB]\n", "Get:32 http://archive.ubuntu.com/ubuntu jammy/universe amd64 qttranslations5-l10n all 5.15.3-1 [1,983 kB]\n", "Fetched 148 MB in 8s (19.4 MB/s)\n", "debconf: unable to initialize frontend: Dialog\n", "debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 32.)\n", "debconf: falling back to frontend: Readline\n", "debconf: unable to initialize frontend: Readline\n", "debconf: (This frontend requires a controlling tty.)\n", "debconf: falling back to frontend: Teletype\n", "dpkg-preconfigure: unable to re-open stdin: \n", "Selecting previously unselected package libqt5core5a:amd64.\n", "(Reading database ... 120895 files and directories currently installed.)\n", "Preparing to unpack .../00-libqt5core5a_5.15.3+dfsg-2ubuntu0.2_amd64.deb ...\n", "Unpacking libqt5core5a:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Selecting previously unselected package libevdev2:amd64.\n", "Preparing to unpack .../01-libevdev2_1.12.1+dfsg-1_amd64.deb ...\n", "Unpacking libevdev2:amd64 (1.12.1+dfsg-1) ...\n", "Selecting previously unselected package libmtdev1:amd64.\n", "Preparing to unpack .../02-libmtdev1_1.1.6-1build4_amd64.deb ...\n", "Unpacking libmtdev1:amd64 (1.1.6-1build4) ...\n", "Selecting previously unselected package libgudev-1.0-0:amd64.\n", "Preparing to unpack .../03-libgudev-1.0-0_1%3a237-2build1_amd64.deb ...\n", "Unpacking libgudev-1.0-0:amd64 (1:237-2build1) ...\n", "Selecting previously unselected package libwacom-common.\n", "Preparing to unpack .../04-libwacom-common_2.2.0-1_all.deb ...\n", "Unpacking libwacom-common (2.2.0-1) ...\n", "Selecting previously unselected package libwacom9:amd64.\n", "Preparing to unpack .../05-libwacom9_2.2.0-1_amd64.deb ...\n", "Unpacking libwacom9:amd64 (2.2.0-1) ...\n", "Selecting previously unselected package libinput-bin.\n", "Preparing to unpack .../06-libinput-bin_1.20.0-1ubuntu0.3_amd64.deb ...\n", "Unpacking libinput-bin (1.20.0-1ubuntu0.3) ...\n", "Selecting previously unselected package libinput10:amd64.\n", "Preparing to unpack .../07-libinput10_1.20.0-1ubuntu0.3_amd64.deb ...\n", "Unpacking libinput10:amd64 (1.20.0-1ubuntu0.3) ...\n", "Selecting previously unselected package libmd4c0:amd64.\n", "Preparing to unpack .../08-libmd4c0_0.4.8-1_amd64.deb ...\n", "Unpacking libmd4c0:amd64 (0.4.8-1) ...\n", "Selecting previously unselected package libqt5dbus5:amd64.\n", "Preparing to unpack .../09-libqt5dbus5_5.15.3+dfsg-2ubuntu0.2_amd64.deb ...\n", "Unpacking libqt5dbus5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Selecting previously unselected package libqt5network5:amd64.\n", "Preparing to unpack .../10-libqt5network5_5.15.3+dfsg-2ubuntu0.2_amd64.deb ...\n", "Unpacking libqt5network5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Selecting previously unselected package libxcb-icccm4:amd64.\n", "Preparing to unpack .../11-libxcb-icccm4_0.4.1-1.1build2_amd64.deb ...\n", "Unpacking libxcb-icccm4:amd64 (0.4.1-1.1build2) ...\n", "Selecting previously unselected package libxcb-util1:amd64.\n", "Preparing to unpack .../12-libxcb-util1_0.4.0-1build2_amd64.deb ...\n", "Unpacking libxcb-util1:amd64 (0.4.0-1build2) ...\n", "Selecting previously unselected package libxcb-image0:amd64.\n", "Preparing to unpack .../13-libxcb-image0_0.4.0-2_amd64.deb ...\n", "Unpacking libxcb-image0:amd64 (0.4.0-2) ...\n", "Selecting previously unselected package libxcb-keysyms1:amd64.\n", "Preparing to unpack .../14-libxcb-keysyms1_0.4.0-1build3_amd64.deb ...\n", "Unpacking libxcb-keysyms1:amd64 (0.4.0-1build3) ...\n", "Selecting previously unselected package libxcb-render-util0:amd64.\n", "Preparing to unpack .../15-libxcb-render-util0_0.3.9-1build3_amd64.deb ...\n", "Unpacking libxcb-render-util0:amd64 (0.3.9-1build3) ...\n", "Selecting previously unselected package libxcb-xinerama0:amd64.\n", "Preparing to unpack .../16-libxcb-xinerama0_1.14-3ubuntu3_amd64.deb ...\n", "Unpacking libxcb-xinerama0:amd64 (1.14-3ubuntu3) ...\n", "Selecting previously unselected package libxcb-xinput0:amd64.\n", "Preparing to unpack .../17-libxcb-xinput0_1.14-3ubuntu3_amd64.deb ...\n", "Unpacking libxcb-xinput0:amd64 (1.14-3ubuntu3) ...\n", "Selecting previously unselected package libxcb-xkb1:amd64.\n", "Preparing to unpack .../18-libxcb-xkb1_1.14-3ubuntu3_amd64.deb ...\n", "Unpacking libxcb-xkb1:amd64 (1.14-3ubuntu3) ...\n", "Selecting previously unselected package libxkbcommon-x11-0:amd64.\n", "Preparing to unpack .../19-libxkbcommon-x11-0_1.4.0-1_amd64.deb ...\n", "Unpacking libxkbcommon-x11-0:amd64 (1.4.0-1) ...\n", "Selecting previously unselected package libqt5gui5:amd64.\n", "Preparing to unpack .../20-libqt5gui5_5.15.3+dfsg-2ubuntu0.2_amd64.deb ...\n", "Unpacking libqt5gui5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Selecting previously unselected package libqt5widgets5:amd64.\n", "Preparing to unpack .../21-libqt5widgets5_5.15.3+dfsg-2ubuntu0.2_amd64.deb ...\n", "Unpacking libqt5widgets5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Selecting previously unselected package libqt5svg5:amd64.\n", "Preparing to unpack .../22-libqt5svg5_5.15.3-1_amd64.deb ...\n", "Unpacking libqt5svg5:amd64 (5.15.3-1) ...\n", "Selecting previously unselected package fluid-soundfont-gm.\n", "Preparing to unpack .../23-fluid-soundfont-gm_3.1-5.3_all.deb ...\n", "Unpacking fluid-soundfont-gm (3.1-5.3) ...\n", "Selecting previously unselected package libinstpatch-1.0-2:amd64.\n", "Preparing to unpack .../24-libinstpatch-1.0-2_1.1.6-1_amd64.deb ...\n", "Unpacking libinstpatch-1.0-2:amd64 (1.1.6-1) ...\n", "Selecting previously unselected package timgm6mb-soundfont.\n", "Preparing to unpack .../25-timgm6mb-soundfont_1.3-5_all.deb ...\n", "Unpacking timgm6mb-soundfont (1.3-5) ...\n", "Selecting previously unselected package libfluidsynth3:amd64.\n", "Preparing to unpack .../26-libfluidsynth3_2.2.5-1_amd64.deb ...\n", "Unpacking libfluidsynth3:amd64 (2.2.5-1) ...\n", "Selecting previously unselected package fluidsynth.\n", "Preparing to unpack .../27-fluidsynth_2.2.5-1_amd64.deb ...\n", "Unpacking fluidsynth (2.2.5-1) ...\n", "Selecting previously unselected package libwacom-bin.\n", "Preparing to unpack .../28-libwacom-bin_2.2.0-1_amd64.deb ...\n", "Unpacking libwacom-bin (2.2.0-1) ...\n", "Selecting previously unselected package qsynth.\n", "Preparing to unpack .../29-qsynth_0.9.6-1_amd64.deb ...\n", "Unpacking qsynth (0.9.6-1) ...\n", "Selecting previously unselected package qt5-gtk-platformtheme:amd64.\n", "Preparing to unpack .../30-qt5-gtk-platformtheme_5.15.3+dfsg-2ubuntu0.2_amd64.deb ...\n", "Unpacking qt5-gtk-platformtheme:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Selecting previously unselected package qttranslations5-l10n.\n", "Preparing to unpack .../31-qttranslations5-l10n_5.15.3-1_all.deb ...\n", "Unpacking qttranslations5-l10n (5.15.3-1) ...\n", "Setting up libxcb-xinput0:amd64 (1.14-3ubuntu3) ...\n", "Setting up libxcb-keysyms1:amd64 (0.4.0-1build3) ...\n", "Setting up libxcb-render-util0:amd64 (0.3.9-1build3) ...\n", "Setting up libxcb-icccm4:amd64 (0.4.1-1.1build2) ...\n", "Setting up libxcb-util1:amd64 (0.4.0-1build2) ...\n", "Setting up libxcb-xkb1:amd64 (1.14-3ubuntu3) ...\n", "Setting up libxcb-image0:amd64 (0.4.0-2) ...\n", "Setting up libxcb-xinerama0:amd64 (1.14-3ubuntu3) ...\n", "Setting up qttranslations5-l10n (5.15.3-1) ...\n", "Setting up libxkbcommon-x11-0:amd64 (1.4.0-1) ...\n", "Setting up libqt5core5a:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Setting up libmtdev1:amd64 (1.1.6-1build4) ...\n", "Setting up libqt5dbus5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Setting up libmd4c0:amd64 (0.4.8-1) ...\n", "Setting up fluid-soundfont-gm (3.1-5.3) ...\n", "update-alternatives: using /usr/share/sounds/sf2/FluidR3_GM.sf2 to provide /usr/share/sounds/sf2/default-GM.sf2 (default-GM.sf2) in auto mode\n", "update-alternatives: using /usr/share/sounds/sf2/FluidR3_GM.sf2 to provide /usr/share/sounds/sf3/default-GM.sf3 (default-GM.sf3) in auto mode\n", "Setting up timgm6mb-soundfont (1.3-5) ...\n", "Setting up libevdev2:amd64 (1.12.1+dfsg-1) ...\n", "Setting up libinstpatch-1.0-2:amd64 (1.1.6-1) ...\n", "Setting up libgudev-1.0-0:amd64 (1:237-2build1) ...\n", "Setting up libfluidsynth3:amd64 (2.2.5-1) ...\n", "Setting up libwacom-common (2.2.0-1) ...\n", "Setting up libwacom9:amd64 (2.2.0-1) ...\n", "Setting up libqt5network5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Setting up libinput-bin (1.20.0-1ubuntu0.3) ...\n", "Setting up fluidsynth (2.2.5-1) ...\n", "Created symlink /etc/systemd/user/default.target.wants/fluidsynth.service → /usr/lib/systemd/user/fluidsynth.service.\n", "Setting up libwacom-bin (2.2.0-1) ...\n", "Setting up libinput10:amd64 (1.20.0-1ubuntu0.3) ...\n", "Setting up libqt5gui5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Setting up libqt5widgets5:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Setting up qt5-gtk-platformtheme:amd64 (5.15.3+dfsg-2ubuntu0.2) ...\n", "Setting up libqt5svg5:amd64 (5.15.3-1) ...\n", "Setting up qsynth (0.9.6-1) ...\n", "Processing triggers for hicolor-icon-theme (0.17-2) ...\n", "Processing triggers for libc-bin (2.35-0ubuntu3.1) ...\n", "/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n", "\n", "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n", "\n", "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n", "\n", "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n", "\n", "/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n", "\n", "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n", "\n", "Processing triggers for man-db (2.10.2-1) ...\n", "Collecting pyfluidsynth\n", " Downloading pyFluidSynth-1.3.2-py3-none-any.whl (19 kB)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from pyfluidsynth) (1.23.5)\n", "Installing collected packages: pyfluidsynth\n", "Successfully installed pyfluidsynth-1.3.2\n", "Collecting pretty_midi\n", " Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: numpy>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.23.5)\n", "Collecting mido>=1.1.16 (from pretty_midi)\n", " Downloading mido-1.3.0-py3-none-any.whl (50 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.3/50.3 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from pretty_midi) (1.16.0)\n", "Requirement already satisfied: packaging~=23.1 in /usr/local/lib/python3.10/dist-packages (from mido>=1.1.16->pretty_midi) (23.1)\n", "Building wheels for collected packages: pretty_midi\n", " Building wheel for pretty_midi (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592287 sha256=22bff5ef5eb5356c3b24982736b3095092aa128de352268d39c316086e3b2f8d\n", " Stored in directory: /root/.cache/pip/wheels/cd/a5/30/7b8b7f58709f5150f67f98fde4b891ebf0be9ef07a8af49f25\n", "Successfully built pretty_midi\n", "Installing collected packages: mido, pretty_midi\n", "Successfully installed mido-1.3.0 pretty_midi-0.2.10\n" ] } ], "source": [ "!sudo apt install -y fluidsynth\n", "!pip install --upgrade pyfluidsynth\n", "!pip install pretty_midi\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "GTyF4vV17Pwg" }, "outputs": [], "source": [ "import collections\n", "import datetime\n", "import fluidsynth\n", "import glob\n", "import numpy as np\n", "import pathlib\n", "import pandas as pd\n", "import pretty_midi\n", "import seaborn as sns\n", "import tensorflow as tf\n", "\n", "from IPython import display\n", "from matplotlib import pyplot as plt\n", "from typing import Dict, List, Optional, Sequence, Tuple" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "GLJoYrD-7to-" }, "outputs": [], "source": [ "seed = 42\n", "tf.random.set_seed(seed)\n", "np.random.seed(seed)\n", "\n", "# Sampling rate for audio playback\n", "_SAMPLING_RATE = 16000" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wcxrruVN7Qac", "outputId": "825a8d11-ee23-47a9-ce8b-7b7ec2a799c8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip\n", "59243107/59243107 [==============================] - 1s 0us/step\n" ] } ], "source": [ "data_dir = pathlib.Path('data/maestro-v2.0.0')\n", "if not data_dir.exists():\n", " tf.keras.utils.get_file(\n", " 'maestro-v2.0.0-midi.zip',\n", " origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',\n", " extract=True,\n", " cache_dir='.', cache_subdir='data',\n", " )" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1eysVVig7SyP", "outputId": "ae6cf3d6-3d64-4b34-dc16-420f16733283" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of files: 1282\n" ] } ], "source": [ "filenames = glob.glob(str(data_dir/'**/*.mid*'))\n", "print('Number of files:', len(filenames))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5LJ9kc0p7UwI", "outputId": "95af9282-32fd-40d1-9adc-bbab1a1d11a2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data/maestro-v2.0.0/2017/MIDI-Unprocessed_056_PIANO056_MID--AUDIO-split_07-07-17_Piano-e_1-05_wav--1.midi\n" ] } ], "source": [ "sample_file = filenames[1]\n", "print(sample_file)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "XjRO1oat71lx" }, "outputs": [], "source": [ "pm = pretty_midi.PrettyMIDI(sample_file)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "aazQeknS7huI" }, "outputs": [], "source": [ "def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):\n", " waveform = pm.fluidsynth(fs=_SAMPLING_RATE)\n", " # Take a sample of the generated waveform to mitigate kernel resets\n", " waveform_short = waveform[:seconds*_SAMPLING_RATE]\n", " return display.Audio(waveform_short, rate=_SAMPLING_RATE)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 74 }, "id": "6h8TARy873u_", "outputId": "ad5be01f-496f-4247-dee1-0a4a30b3c4cc" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "display_audio(pm)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jn49Z6iz74yA", "outputId": "b6ce5b68-de67-496b-d9ef-e25bd1eaedf7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of instruments: 1\n", "Instrument name: Acoustic Grand Piano\n" ] } ], "source": [ "print('Number of instruments:', len(pm.instruments))\n", "instrument = pm.instruments[0]\n", "instrument_name = pretty_midi.program_to_instrument_name(instrument.program)\n", "print('Instrument name:', instrument_name)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZKr6Owtd74v5", "outputId": "51c53dcf-03bc-43e6-b9b0-ec1c8f91c859" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0: pitch=68, note_name=G#4, duration=0.4052\n", "1: pitch=70, note_name=A#4, duration=0.2490\n", "2: pitch=59, note_name=B3, duration=0.6010\n", "3: pitch=71, note_name=B4, duration=0.2448\n", "4: pitch=68, note_name=G#4, duration=0.3365\n", "5: pitch=63, note_name=D#4, duration=0.5562\n", "6: pitch=70, note_name=A#4, duration=0.5792\n", "7: pitch=73, note_name=C#5, duration=0.3823\n", "8: pitch=61, note_name=C#4, duration=0.6135\n", "9: pitch=56, note_name=G#3, duration=1.5927\n" ] } ], "source": [ "for i, note in enumerate(instrument.notes[:10]):\n", " note_name = pretty_midi.note_number_to_name(note.pitch)\n", " duration = note.end - note.start\n", " print(f'{i}: pitch={note.pitch}, note_name={note_name},'\n", " f' duration={duration:.4f}')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "zL6Ur9oI74td" }, "outputs": [], "source": [ "def midi_to_notes(midi_file: str) -> pd.DataFrame:\n", " pm = pretty_midi.PrettyMIDI(midi_file)\n", " instrument = pm.instruments[0]\n", " notes = collections.defaultdict(list)\n", "\n", " # Sort the notes by start time\n", " sorted_notes = sorted(instrument.notes, key=lambda note: note.start)\n", " prev_start = sorted_notes[0].start\n", "\n", " for note in sorted_notes:\n", " start = note.start\n", " end = note.end\n", " notes['pitch'].append(note.pitch)\n", " notes['start'].append(start)\n", " notes['end'].append(end)\n", " notes['step'].append(start - prev_start)\n", " notes['duration'].append(end - start)\n", " prev_start = start\n", "\n", " return pd.DataFrame({name: np.array(value) for name, value in notes.items()})" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "1Ic9oYGZ74q4", "outputId": "a416dd87-d7db-41b1-ac23-a617047f0237" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pitchstartendstepduration
0681.0083331.4135420.0000000.405208
1561.0395832.6322920.0312501.592708
2591.0531251.6541670.0135420.601042
3701.2750001.5239580.2218750.248958
4711.5083331.7531250.2333330.244792
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ], "text/plain": [ " pitch start end step duration\n", "0 68 1.008333 1.413542 0.000000 0.405208\n", "1 56 1.039583 2.632292 0.031250 1.592708\n", "2 59 1.053125 1.654167 0.013542 0.601042\n", "3 70 1.275000 1.523958 0.221875 0.248958\n", "4 71 1.508333 1.753125 0.233333 0.244792" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_notes = midi_to_notes(sample_file)\n", "raw_notes.head()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZwtHRgq374oR", "outputId": "07c37b0c-b597-4090-948f-f83fffb0c1bc" }, "outputs": [ { "data": { "text/plain": [ "array(['G#4', 'G#3', 'B3', 'A#4', 'B4', 'D#4', 'G#4', 'A#4', 'C#4', 'C#5'],\n", " dtype='" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_piano_roll(raw_notes, count=100)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 289 }, "id": "2OrEm8Gv8CXS", "outputId": "b8189e78-9bcd-49a2-8c6d-ff436eb86e63" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_piano_roll(raw_notes)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "6WFFdXGB8CUe" }, "outputs": [], "source": [ "def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):\n", " plt.figure(figsize=[15, 5])\n", " plt.subplot(1, 3, 1)\n", " sns.histplot(notes, x=\"pitch\", bins=20)\n", "\n", " plt.subplot(1, 3, 2)\n", " max_step = np.percentile(notes['step'], 100 - drop_percentile)\n", " sns.histplot(notes, x=\"step\", bins=np.linspace(0, max_step, 21))\n", "\n", " plt.subplot(1, 3, 3)\n", " max_duration = np.percentile(notes['duration'], 100 - drop_percentile)\n", " sns.histplot(notes, x=\"duration\", bins=np.linspace(0, max_duration, 21))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 420 }, "id": "cQjtdi1d8JCz", "outputId": "d5f3b3f2-9757-4b95-82f2-c2a9412098ae" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_distributions(raw_notes)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "_HrSph2l8I_E" }, "outputs": [], "source": [ "def notes_to_midi(\n", " notes: pd.DataFrame,\n", " out_file: str,\n", " instrument_name: str,\n", " velocity: int = 100, # note loudness\n", ") -> pretty_midi.PrettyMIDI:\n", "\n", " pm = pretty_midi.PrettyMIDI()\n", " instrument = pretty_midi.Instrument(\n", " program=pretty_midi.instrument_name_to_program(\n", " instrument_name))\n", "\n", " prev_start = 0\n", " for i, note in notes.iterrows():\n", " start = float(prev_start + note['step'])\n", " end = float(start + note['duration'])\n", " note = pretty_midi.Note(\n", " velocity=velocity,\n", " pitch=int(note['pitch']),\n", " start=start,\n", " end=end,\n", " )\n", " instrument.notes.append(note)\n", " prev_start = start\n", "\n", " pm.instruments.append(instrument)\n", " pm.write(out_file)\n", " return pm" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "36PDwmdM8I8W" }, "outputs": [], "source": [ "example_file = 'example.midi'\n", "example_pm = notes_to_midi(\n", " raw_notes, out_file=example_file, instrument_name=instrument_name)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 74 }, "id": "wbdtcurB8I5k", "outputId": "adff79bc-21dd-4c3e-cbee-ea2309ac7a06" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "display_audio(example_pm)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "tmusgxE08PvY" }, "outputs": [], "source": [ "num_files = 5\n", "all_notes = []\n", "for f in filenames[:num_files]:\n", " notes = midi_to_notes(f)\n", " all_notes.append(notes)\n", "\n", "all_notes = pd.concat(all_notes)\n", "\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IpjIKp0W8PsI", "outputId": "be20eada-4dc4-40ad-bbc3-0d2ae2dfc414" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of notes parsed: 17744\n" ] } ], "source": [ "n_notes = len(all_notes)\n", "print('Number of notes parsed:', n_notes)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "tcXGxo-I8Pov" }, "outputs": [], "source": [ "key_order = ['pitch', 'step', 'duration']\n", "train_notes = np.stack([all_notes[key] for key in key_order], axis=1)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Z-HzraBt8PkR", "outputId": "4abbc353-a2cc-49e5-ddfe-d2a05d380100" }, "outputs": [ { "data": { "text/plain": [ "TensorSpec(shape=(3,), dtype=tf.float64, name=None)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "notes_ds = tf.data.Dataset.from_tensor_slices(train_notes)\n", "notes_ds.element_spec" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "d4KNRvOD8WBm" }, "outputs": [], "source": [ "def create_sequences(\n", " dataset: tf.data.Dataset,\n", " seq_length: int,\n", " vocab_size = 128,\n", ") -> tf.data.Dataset:\n", " \"\"\"Returns TF Dataset of sequence and label examples.\"\"\"\n", " seq_length = seq_length+1\n", "\n", " # Take 1 extra for the labels\n", " windows = dataset.window(seq_length, shift=1, stride=1,\n", " drop_remainder=True)\n", "\n", " # `flat_map` flattens the\" dataset of datasets\" into a dataset of tensors\n", " flatten = lambda x: x.batch(seq_length, drop_remainder=True)\n", " sequences = windows.flat_map(flatten)\n", "\n", " # Normalize note pitch\n", " def scale_pitch(x):\n", " x = x/[vocab_size,1.0,1.0]\n", " return x\n", "\n", " # Split the labels\n", " def split_labels(sequences):\n", " inputs = sequences[:-1]\n", " labels_dense = sequences[-1]\n", " labels = {key:labels_dense[i] for i,key in enumerate(key_order)}\n", "\n", " return scale_pitch(inputs), labels\n", "\n", " return sequences.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ifjazhkn8V9G", "outputId": "59b844b6-f92f-4e3d-b82e-61e11b99a7a5" }, "outputs": [ { "data": { "text/plain": [ "(TensorSpec(shape=(25, 3), dtype=tf.float64, name=None),\n", " {'pitch': TensorSpec(shape=(), dtype=tf.float64, name=None),\n", " 'step': TensorSpec(shape=(), dtype=tf.float64, name=None),\n", " 'duration': TensorSpec(shape=(), dtype=tf.float64, name=None)})" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seq_length = 25\n", "vocab_size = 128\n", "seq_ds = create_sequences(notes_ds, seq_length, vocab_size)\n", "seq_ds.element_spec" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DqpYgaX18V5N", "outputId": "075044bf-8f63-4935-979e-157b9ffb9bf5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sequence shape: (25, 3)\n", "sequence elements (first 10): tf.Tensor(\n", "[[0.59375 0. 0.071875 ]\n", " [0.5390625 0.00104167 0.05208333]\n", " [0.5 0.00520833 0.053125 ]\n", " [0.4453125 0.00833333 0.046875 ]\n", " [0.5390625 0.21354167 0.034375 ]\n", " [0.5 0. 0.05 ]\n", " [0.59375 0. 0.05104167]\n", " [0.4453125 0.00104167 0.053125 ]\n", " [0.5 0.371875 0.053125 ]\n", " [0.59375 0.00104167 0.0625 ]], shape=(10, 3), dtype=float64)\n", "\n", "target: {'pitch': , 'step': , 'duration': }\n" ] } ], "source": [ "for seq, target in seq_ds.take(1):\n", " print('sequence shape:', seq.shape)\n", " print('sequence elements (first 10):', seq[0: 10])\n", " print()\n", " print('target:', target)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "LjsBrZXe8cMi" }, "outputs": [], "source": [ "batch_size = 64\n", "buffer_size = n_notes - seq_length # the number of items in the dataset\n", "train_ds = (seq_ds\n", " .shuffle(buffer_size)\n", " .batch(batch_size, drop_remainder=True)\n", " .cache()\n", " .prefetch(tf.data.experimental.AUTOTUNE))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EWrmjg-v8cKK", "outputId": "feca9d1b-ca9e-4f08-9197-427effc6fb80" }, "outputs": [ { "data": { "text/plain": [ "(TensorSpec(shape=(64, 25, 3), dtype=tf.float64, name=None),\n", " {'pitch': TensorSpec(shape=(64,), dtype=tf.float64, name=None),\n", " 'step': TensorSpec(shape=(64,), dtype=tf.float64, name=None),\n", " 'duration': TensorSpec(shape=(64,), dtype=tf.float64, name=None)})" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ds.element_spec" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "QgKQs8Cp8cH1" }, "outputs": [], "source": [ "def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):\n", " mse = (y_true - y_pred) ** 2\n", " positive_pressure = 10 * tf.maximum(-y_pred, 0.0)\n", " return tf.reduce_mean(mse + positive_pressure)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KXKMqe7z8cEq", "outputId": "a2f50821-cc23-4d3c-b9bd-4bd017ee2439" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_1 (InputLayer) [(None, 25, 3)] 0 [] \n", " \n", " lstm (LSTM) (None, 128) 67584 ['input_1[0][0]'] \n", " \n", " duration (Dense) (None, 1) 129 ['lstm[0][0]'] \n", " \n", " pitch (Dense) (None, 128) 16512 ['lstm[0][0]'] \n", " \n", " step (Dense) (None, 1) 129 ['lstm[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 84354 (329.51 KB)\n", "Trainable params: 84354 (329.51 KB)\n", "Non-trainable params: 0 (0.00 Byte)\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "input_shape = (seq_length, 3)\n", "learning_rate = 0.005\n", "\n", "inputs = tf.keras.Input(input_shape)\n", "x = tf.keras.layers.LSTM(128)(inputs)\n", "\n", "outputs = {\n", " 'pitch': tf.keras.layers.Dense(128, name='pitch')(x),\n", " 'step': tf.keras.layers.Dense(1, name='step')(x),\n", " 'duration': tf.keras.layers.Dense(1, name='duration')(x),\n", "}\n", "\n", "model = tf.keras.Model(inputs, outputs)\n", "\n", "loss = {\n", " 'pitch': tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True),\n", " 'step': mse_with_positive_pressure,\n", " 'duration': mse_with_positive_pressure,\n", "}\n", "\n", "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "model.compile(loss=loss, optimizer=optimizer)\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Zx05T_Rp8j4Q", "outputId": "eda780ed-0f70-40a5-8603-0ebeb90f2b39" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "276/276 [==============================] - 17s 24ms/step - loss: 5.0208 - duration_loss: 0.1338 - pitch_loss: 4.8559 - step_loss: 0.0311\n" ] }, { "data": { "text/plain": [ "{'loss': 5.020809173583984,\n", " 'duration_loss': 0.13375774025917053,\n", " 'pitch_loss': 4.8559160232543945,\n", " 'step_loss': 0.031137656420469284}" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "losses = model.evaluate(train_ds, return_dict=True)\n", "losses" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "HlFGhQ6M8j1r" }, "outputs": [], "source": [ "model.compile(\n", " loss=loss,\n", " loss_weights={\n", " 'pitch': 0.05,\n", " 'step': 1.0,\n", " 'duration':1.0,\n", " },\n", " optimizer=optimizer,\n", ")" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gyOsQbNn8jz0", "outputId": "dd5932e8-05a3-4a53-ff39-3fae3c431308" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "276/276 [==============================] - 6s 17ms/step - loss: 0.4077 - duration_loss: 0.1338 - pitch_loss: 4.8559 - step_loss: 0.0311\n" ] }, { "data": { "text/plain": [ "{'loss': 0.4076911509037018,\n", " 'duration_loss': 0.13375774025917053,\n", " 'pitch_loss': 4.8559160232543945,\n", " 'step_loss': 0.031137656420469284}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(train_ds, return_dict=True)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "_CX7wBuZ8jxy" }, "outputs": [], "source": [ "callbacks = [\n", " tf.keras.callbacks.ModelCheckpoint(\n", " filepath='./training_checkpoints/ckpt_{epoch}',\n", " save_weights_only=True),\n", " tf.keras.callbacks.EarlyStopping(\n", " monitor='loss',\n", " patience=5,\n", " verbose=1,\n", " restore_best_weights=True),\n", "]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WsAE82108jua", "outputId": "1eb06560-c252-4103-b727-9910a2186e52" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/50\n", "276/276 [==============================] - 14s 41ms/step - loss: 0.3227 - duration_loss: 0.0862 - pitch_loss: 4.1751 - step_loss: 0.0278\n", "Epoch 2/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.3047 - duration_loss: 0.0823 - pitch_loss: 3.9569 - step_loss: 0.0246\n", "Epoch 3/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2992 - duration_loss: 0.0808 - pitch_loss: 3.8803 - step_loss: 0.0244\n", "Epoch 4/50\n", "276/276 [==============================] - 13s 48ms/step - loss: 0.2957 - duration_loss: 0.0805 - pitch_loss: 3.8175 - step_loss: 0.0244\n", "Epoch 5/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2929 - duration_loss: 0.0795 - pitch_loss: 3.7917 - step_loss: 0.0239\n", "Epoch 6/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2908 - duration_loss: 0.0788 - pitch_loss: 3.7624 - step_loss: 0.0239\n", "Epoch 7/50\n", "276/276 [==============================] - 14s 51ms/step - loss: 0.2883 - duration_loss: 0.0779 - pitch_loss: 3.7422 - step_loss: 0.0232\n", "Epoch 8/50\n", "276/276 [==============================] - 13s 46ms/step - loss: 0.2862 - duration_loss: 0.0776 - pitch_loss: 3.7249 - step_loss: 0.0224\n", "Epoch 9/50\n", "276/276 [==============================] - 19s 69ms/step - loss: 0.2870 - duration_loss: 0.0774 - pitch_loss: 3.7197 - step_loss: 0.0236\n", "Epoch 10/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2842 - duration_loss: 0.0766 - pitch_loss: 3.6995 - step_loss: 0.0226\n", "Epoch 11/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2826 - duration_loss: 0.0757 - pitch_loss: 3.6895 - step_loss: 0.0224\n", "Epoch 12/50\n", "276/276 [==============================] - 12s 42ms/step - loss: 0.2818 - duration_loss: 0.0761 - pitch_loss: 3.6835 - step_loss: 0.0215\n", "Epoch 13/50\n", "276/276 [==============================] - 14s 50ms/step - loss: 0.2795 - duration_loss: 0.0751 - pitch_loss: 3.6725 - step_loss: 0.0208\n", "Epoch 14/50\n", "276/276 [==============================] - 13s 48ms/step - loss: 0.2793 - duration_loss: 0.0746 - pitch_loss: 3.6657 - step_loss: 0.0214\n", "Epoch 15/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2771 - duration_loss: 0.0740 - pitch_loss: 3.6701 - step_loss: 0.0196\n", "Epoch 16/50\n", "276/276 [==============================] - 14s 52ms/step - loss: 0.2771 - duration_loss: 0.0738 - pitch_loss: 3.6571 - step_loss: 0.0205\n", "Epoch 17/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2759 - duration_loss: 0.0736 - pitch_loss: 3.6453 - step_loss: 0.0200\n", "Epoch 18/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2741 - duration_loss: 0.0727 - pitch_loss: 3.6368 - step_loss: 0.0196\n", "Epoch 19/50\n", "276/276 [==============================] - 11s 41ms/step - loss: 0.2720 - duration_loss: 0.0722 - pitch_loss: 3.6271 - step_loss: 0.0184\n", "Epoch 20/50\n", "276/276 [==============================] - 14s 51ms/step - loss: 0.2722 - duration_loss: 0.0722 - pitch_loss: 3.6205 - step_loss: 0.0190\n", "Epoch 21/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2708 - duration_loss: 0.0719 - pitch_loss: 3.6173 - step_loss: 0.0180\n", "Epoch 22/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2677 - duration_loss: 0.0709 - pitch_loss: 3.6021 - step_loss: 0.0167\n", "Epoch 23/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2651 - duration_loss: 0.0697 - pitch_loss: 3.5895 - step_loss: 0.0160\n", "Epoch 24/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2636 - duration_loss: 0.0687 - pitch_loss: 3.5760 - step_loss: 0.0161\n", "Epoch 25/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2600 - duration_loss: 0.0676 - pitch_loss: 3.5618 - step_loss: 0.0143\n", "Epoch 26/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2587 - duration_loss: 0.0667 - pitch_loss: 3.5527 - step_loss: 0.0144\n", "Epoch 27/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2562 - duration_loss: 0.0649 - pitch_loss: 3.5443 - step_loss: 0.0141\n", "Epoch 28/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2550 - duration_loss: 0.0641 - pitch_loss: 3.5399 - step_loss: 0.0139\n", "Epoch 29/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2603 - duration_loss: 0.0667 - pitch_loss: 3.5656 - step_loss: 0.0154\n", "Epoch 30/50\n", "276/276 [==============================] - 11s 42ms/step - loss: 0.2555 - duration_loss: 0.0644 - pitch_loss: 3.5318 - step_loss: 0.0145\n", "Epoch 31/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2549 - duration_loss: 0.0639 - pitch_loss: 3.5218 - step_loss: 0.0149\n", "Epoch 32/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2511 - duration_loss: 0.0618 - pitch_loss: 3.5122 - step_loss: 0.0137\n", "Epoch 33/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2558 - duration_loss: 0.0647 - pitch_loss: 3.5160 - step_loss: 0.0153\n", "Epoch 34/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2512 - duration_loss: 0.0630 - pitch_loss: 3.4960 - step_loss: 0.0135\n", "Epoch 35/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2439 - duration_loss: 0.0577 - pitch_loss: 3.4700 - step_loss: 0.0126\n", "Epoch 36/50\n", "276/276 [==============================] - 14s 50ms/step - loss: 0.2490 - duration_loss: 0.0600 - pitch_loss: 3.4660 - step_loss: 0.0157\n", "Epoch 37/50\n", "276/276 [==============================] - 12s 43ms/step - loss: 0.2441 - duration_loss: 0.0587 - pitch_loss: 3.4499 - step_loss: 0.0130\n", "Epoch 38/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2455 - duration_loss: 0.0592 - pitch_loss: 3.4704 - step_loss: 0.0128\n", "Epoch 39/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2444 - duration_loss: 0.0592 - pitch_loss: 3.4366 - step_loss: 0.0134\n", "Epoch 40/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2364 - duration_loss: 0.0532 - pitch_loss: 3.4152 - step_loss: 0.0124\n", "Epoch 41/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2349 - duration_loss: 0.0517 - pitch_loss: 3.4171 - step_loss: 0.0123\n", "Epoch 42/50\n", "276/276 [==============================] - 13s 46ms/step - loss: 0.2489 - duration_loss: 0.0618 - pitch_loss: 3.4786 - step_loss: 0.0132\n", "Epoch 43/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2396 - duration_loss: 0.0567 - pitch_loss: 3.4111 - step_loss: 0.0123\n", "Epoch 44/50\n", "276/276 [==============================] - 12s 44ms/step - loss: 0.2323 - duration_loss: 0.0503 - pitch_loss: 3.3994 - step_loss: 0.0120\n", "Epoch 45/50\n", "276/276 [==============================] - 14s 49ms/step - loss: 0.2330 - duration_loss: 0.0516 - pitch_loss: 3.3868 - step_loss: 0.0121\n", "Epoch 46/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2301 - duration_loss: 0.0484 - pitch_loss: 3.3779 - step_loss: 0.0128\n", "Epoch 47/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2290 - duration_loss: 0.0481 - pitch_loss: 3.3812 - step_loss: 0.0118\n", "Epoch 48/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2509 - duration_loss: 0.0626 - pitch_loss: 3.5011 - step_loss: 0.0132\n", "Epoch 49/50\n", "276/276 [==============================] - 12s 45ms/step - loss: 0.2470 - duration_loss: 0.0603 - pitch_loss: 3.4742 - step_loss: 0.0130\n", "Epoch 50/50\n", "276/276 [==============================] - 12s 43ms/step - loss: 0.2404 - duration_loss: 0.0570 - pitch_loss: 3.4128 - step_loss: 0.0128\n", "CPU times: user 13min 28s, sys: 35.6 s, total: 14min 3s\n", "Wall time: 14min 57s\n" ] } ], "source": [ "%%time\n", "epochs = 50\n", "\n", "history = model.fit(\n", " train_ds,\n", " epochs=epochs,\n", " callbacks=callbacks,\n", ")" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 430 }, "id": "KZ6u433y8jrf", "outputId": "8fb59db5-b55b-4a66-a901-2a2ae81e3433" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(history.epoch, history.history['loss'], label='total loss')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "r93tXyG18s-u" }, "outputs": [], "source": [ "def predict_next_note(\n", " notes: np.ndarray,\n", " keras_model: tf.keras.Model,\n", " temperature: float = 1.0) -> int:\n", " \"\"\"Generates a note IDs using a trained sequence model.\"\"\"\n", "\n", " assert temperature > 0\n", "\n", " # Add batch dimension\n", " inputs = tf.expand_dims(notes, 0)\n", "\n", " predictions = model.predict(inputs)\n", " pitch_logits = predictions['pitch']\n", " step = predictions['step']\n", " duration = predictions['duration']\n", "\n", " pitch_logits /= temperature\n", " pitch = tf.random.categorical(pitch_logits, num_samples=1)\n", " pitch = tf.squeeze(pitch, axis=-1)\n", " duration = tf.squeeze(duration, axis=-1)\n", " step = tf.squeeze(step, axis=-1)\n", "\n", " # `step` and `duration` values should be non-negative\n", " step = tf.maximum(0, step)\n", " duration = tf.maximum(0, duration)\n", "\n", " return int(pitch), float(step), float(duration)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TjCLDNuI8s7K", "outputId": "a706df0d-0a43-4fe4-a20e-332ad02c56c8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 [==============================] - 1s 558ms/step\n", "1/1 [==============================] - 0s 25ms/step\n", "1/1 [==============================] - 0s 21ms/step\n", "1/1 [==============================] - 0s 24ms/step\n", "1/1 [==============================] - 0s 22ms/step\n", "1/1 [==============================] - 0s 22ms/step\n", "1/1 [==============================] - 0s 41ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 34ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 36ms/step\n", "1/1 [==============================] - 0s 34ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 34ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 34ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 37ms/step\n", "1/1 [==============================] - 0s 37ms/step\n", "1/1 [==============================] - 0s 41ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 34ms/step\n", "1/1 [==============================] - 0s 34ms/step\n", "1/1 [==============================] - 0s 37ms/step\n", "1/1 [==============================] - 0s 36ms/step\n", "1/1 [==============================] - 0s 37ms/step\n", "1/1 [==============================] - 0s 23ms/step\n", "1/1 [==============================] - 0s 22ms/step\n", "1/1 [==============================] - 0s 49ms/step\n", "1/1 [==============================] - 0s 25ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 29ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 30ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 25ms/step\n", "1/1 [==============================] - 0s 29ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 24ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 24ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 25ms/step\n", "1/1 [==============================] - 0s 30ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 29ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 39ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 25ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 25ms/step\n", "1/1 [==============================] - 0s 25ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 24ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 24ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 39ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 30ms/step\n", "1/1 [==============================] - 0s 36ms/step\n", "1/1 [==============================] - 0s 34ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 28ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 31ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 69ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 36ms/step\n", "1/1 [==============================] - 0s 32ms/step\n", "1/1 [==============================] - 0s 48ms/step\n", "1/1 [==============================] - 0s 36ms/step\n", "1/1 [==============================] - 0s 35ms/step\n", "1/1 [==============================] - 0s 26ms/step\n", "1/1 [==============================] - 0s 27ms/step\n", "1/1 [==============================] - 0s 30ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 41ms/step\n", "1/1 [==============================] - 0s 39ms/step\n", "1/1 [==============================] - 0s 37ms/step\n", "1/1 [==============================] - 0s 41ms/step\n", "1/1 [==============================] - 0s 38ms/step\n", "1/1 [==============================] - 0s 38ms/step\n", "1/1 [==============================] - 0s 39ms/step\n", "1/1 [==============================] - 0s 35ms/step\n" ] } ], "source": [ "temperature = 2.0\n", "num_predictions = 120\n", "\n", "sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)\n", "\n", "# The initial sequence of notes; pitch is normalized similar to training\n", "# sequences\n", "input_notes = (\n", " sample_notes[:seq_length] / np.array([vocab_size, 1, 1]))\n", "\n", "generated_notes = []\n", "prev_start = 0\n", "for _ in range(num_predictions):\n", " pitch, step, duration = predict_next_note(input_notes, model, temperature)\n", " start = prev_start + step\n", " end = start + duration\n", " input_note = (pitch, step, duration)\n", " generated_notes.append((*input_note, start, end))\n", " input_notes = np.delete(input_notes, 0, axis=0)\n", " input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)\n", " prev_start = start\n", "\n", "generated_notes = pd.DataFrame(\n", " generated_notes, columns=(*key_order, 'start', 'end'))" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 363 }, "id": "_VNguuhU8s5F", "outputId": "b7a1d8fc-c26c-47c4-b573-a9e83056cbc0" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pitchstepdurationstartend
0610.1374860.4473030.1374860.584789
1960.2098480.3520040.3473340.699338
2960.2287300.3957370.5760640.971800
3810.2349880.4108940.8110521.221946
4910.2348530.4063171.0459051.452222
5860.2360360.4075661.2819411.689507
6650.2367970.4113491.5187391.930088
7890.2363230.4060151.7550622.161077
8510.2353290.3926451.9903912.383036
9890.2392920.3974352.2296832.627118
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ], "text/plain": [ " pitch step duration start end\n", "0 61 0.137486 0.447303 0.137486 0.584789\n", "1 96 0.209848 0.352004 0.347334 0.699338\n", "2 96 0.228730 0.395737 0.576064 0.971800\n", "3 81 0.234988 0.410894 0.811052 1.221946\n", "4 91 0.234853 0.406317 1.045905 1.452222\n", "5 86 0.236036 0.407566 1.281941 1.689507\n", "6 65 0.236797 0.411349 1.518739 1.930088\n", "7 89 0.236323 0.406015 1.755062 2.161077\n", "8 51 0.235329 0.392645 1.990391 2.383036\n", "9 89 0.239292 0.397435 2.229683 2.627118" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generated_notes.head(10)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 74 }, "id": "CbsIztKt8yvG", "outputId": "40ae49f7-3c44-44c2-961a-9080d93371af" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out_file = 'output.mid'\n", "out_pm = notes_to_midi(\n", " generated_notes, out_file=out_file, instrument_name=instrument_name)\n", "display_audio(out_pm)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "id": "xa5WXi-_8ysp", "outputId": "17a08ca7-dc5d-4aa0-f721-6888f917c96d" }, "outputs": [ { "data": { "application/javascript": "\n async function download(id, filename, size) {\n if (!google.colab.kernel.accessAllowed) {\n return;\n }\n const div = document.createElement('div');\n const label = document.createElement('label');\n label.textContent = `Downloading \"${filename}\": `;\n div.appendChild(label);\n const progress = document.createElement('progress');\n progress.max = size;\n div.appendChild(progress);\n document.body.appendChild(div);\n\n const buffers = [];\n let downloaded = 0;\n\n const channel = await google.colab.kernel.comms.open(id);\n // Send a message to notify the kernel that we're ready.\n channel.send({})\n\n for await (const message of channel.messages) {\n // Send a message to notify the kernel that we're ready.\n channel.send({})\n if (message.buffers) {\n for (const buffer of message.buffers) {\n buffers.push(buffer);\n downloaded += buffer.byteLength;\n progress.value = downloaded;\n }\n }\n }\n const blob = new Blob(buffers, {type: 'application/binary'});\n const a = document.createElement('a');\n a.href = window.URL.createObjectURL(blob);\n a.download = filename;\n div.appendChild(a);\n a.click();\n div.remove();\n }\n ", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": "download(\"download_4812e287-f583-4ff6-bd47-35a60e941b6c\", \"output.mid\", 777)", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from google.colab import files\n", "files.download(out_file)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 288 }, "id": "wO0PpWLZ8yqN", "outputId": "5dba0290-c2ae-460f-e744-c98b80491c8b" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_piano_roll(generated_notes)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 425 }, "id": "tEhO4RN_83Kb", "outputId": "f64cd26e-54cc-46a7-e442-3448c0736131" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_distributions(generated_notes)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "id": "LrjpEOq683Hj" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 47, "metadata": { "id": "_e8Nqtyj83EZ" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "WjqliKpkjqoc" }, "source": [ "### モデルの保存" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "id": "WVlagcDD8s2J" }, "outputs": [], "source": [ "# モデルの保存\n", "model.save('model/')" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "id": "_aSythGz8jpD" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "2JHPGiX4js9m" }, "source": [ "### モデルのファインチューニング" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H8DoQca5juaN" }, "outputs": [], "source": [ "# モデルのロード\n", "model = load_model('model/')\n", "\n", "# モデルの再コンパイル\n", "model.compile(\n", " loss=loss,\n", " optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", ")\n", "\n", "# ファインチューニングのための訓練\n", "# ここでは、既存のtrain_dsを使用していますが、新しいデータセットを使用することもできます。\n", "model.fit(train_ds, epochs=additional_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yEBbjZhpjvT6" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g77RjW0ujvPl" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4skPIJo7jvJ5" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }