{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "DWLOSBkp0A2U" }, "source": [ "# GPT-2 for music.\n", "\n", "This notebook shows you how to generate music with GPT-2\n", "\n", "---\n", "\n", "## Install depencencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6J_AnhV8D5p6" }, "outputs": [], "source": [ "!pip install transformers\n", "!pip install note_seq" ] }, { "cell_type": "markdown", "metadata": { "id": "RzhHhFll0JVl" }, "source": [ "## Load the tokenizer and the model from 🤗 Hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g3ih12FMD7bs" }, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"TristanBehrens/js-fakes-4bars\")\n", "model = AutoModelForCausalLM.from_pretrained(\"TristanBehrens/js-fakes-4bars\")" ] }, { "cell_type": "markdown", "metadata": { "id": "GxRBk--Q0P1q" }, "source": [ "## How to generate." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZZSLX96ID7t8" }, "outputs": [], "source": [ "# Encode the conditioning tokens.\n", "input_ids = tokenizer.encode(\"PIECE_START STYLE=JSFAKES GENRE=JSFAKES TRACK_START INST=48 BAR_START NOTE_ON=60\", return_tensors=\"pt\")\n", "print(input_ids)\n", "\n", "# Generate more tokens.\n", "generated_ids = model.generate(input_ids, max_length=500)\n", "generated_sequence = tokenizer.decode(generated_ids[0])\n", "print(generated_sequence)" ] }, { "cell_type": "markdown", "metadata": { "id": "YfHXFugA0WdI" }, "source": [ "## Convert the generated tokens to music that you can listen to." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L3QMj8NyEBqs" }, "outputs": [], "source": [ "import note_seq\n", "\n", "NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / 120\n", "BAR_LENGTH_120BPM = 4.0 * 60 / 120\n", "\n", "def token_sequence_to_note_sequence(token_sequence, use_program=True, use_drums=True, instrument_mapper=None, only_piano=False):\n", "\n", " if isinstance(token_sequence, str):\n", " token_sequence = token_sequence.split()\n", "\n", " note_sequence = empty_note_sequence()\n", "\n", " # Render all notes.\n", " current_program = 1\n", " current_is_drum = False\n", " current_instrument = 0\n", " track_count = 0\n", " for token_index, token in enumerate(token_sequence):\n", "\n", " if token == \"PIECE_START\":\n", " pass\n", " elif token == \"PIECE_END\":\n", " print(\"The end.\")\n", " break\n", " elif token == \"TRACK_START\":\n", " current_bar_index = 0\n", " track_count += 1\n", " pass\n", " elif token == \"TRACK_END\":\n", " pass\n", " elif token == \"KEYS_START\":\n", " pass\n", " elif token == \"KEYS_END\":\n", " pass\n", " elif token.startswith(\"KEY=\"):\n", " pass\n", " elif token.startswith(\"INST\"):\n", " instrument = token.split(\"=\")[-1]\n", " if instrument != \"DRUMS\" and use_program:\n", " if instrument_mapper is not None:\n", " if instrument in instrument_mapper:\n", " instrument = instrument_mapper[instrument]\n", " current_program = int(instrument)\n", " current_instrument = track_count\n", " current_is_drum = False\n", " if instrument == \"DRUMS\" and use_drums:\n", " current_instrument = 0\n", " current_program = 0\n", " current_is_drum = True\n", " elif token == \"BAR_START\":\n", " current_time = current_bar_index * BAR_LENGTH_120BPM\n", " current_notes = {}\n", " elif token == \"BAR_END\":\n", " current_bar_index += 1\n", " pass\n", " elif token.startswith(\"NOTE_ON\"):\n", " pitch = int(token.split(\"=\")[-1])\n", " note = note_sequence.notes.add()\n", " note.start_time = current_time\n", " note.end_time = current_time + 4 * NOTE_LENGTH_16TH_120BPM\n", " note.pitch = pitch\n", " note.instrument = current_instrument\n", " note.program = current_program\n", " note.velocity = 80\n", " note.is_drum = current_is_drum\n", " current_notes[pitch] = note\n", " elif token.startswith(\"NOTE_OFF\"):\n", " pitch = int(token.split(\"=\")[-1])\n", " if pitch in current_notes:\n", " note = current_notes[pitch]\n", " note.end_time = current_time\n", " elif token.startswith(\"TIME_DELTA\"):\n", " delta = float(token.split(\"=\")[-1]) * NOTE_LENGTH_16TH_120BPM\n", " current_time += delta\n", " elif token.startswith(\"DENSITY=\"):\n", " pass\n", " elif token == \"[PAD]\":\n", " pass\n", " else:\n", " #print(f\"Ignored token {token}.\")\n", " pass\n", "\n", " # Make the instruments right.\n", " instruments_drums = []\n", " for note in note_sequence.notes:\n", " pair = [note.program, note.is_drum]\n", " if pair not in instruments_drums:\n", " instruments_drums += [pair]\n", " note.instrument = instruments_drums.index(pair)\n", "\n", " if only_piano:\n", " for note in note_sequence.notes:\n", " if not note.is_drum:\n", " note.instrument = 0\n", " note.program = 0\n", "\n", " return note_sequence\n", "\n", "def empty_note_sequence(qpm=120.0, total_time=0.0):\n", " note_sequence = note_seq.protobuf.music_pb2.NoteSequence()\n", " note_sequence.tempos.add().qpm = qpm\n", " note_sequence.ticks_per_quarter = note_seq.constants.STANDARD_PPQ\n", " note_sequence.total_time = total_time\n", " return note_sequence" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZYpukydNESDF" }, "outputs": [], "source": [ "input_ids = tokenizer.encode(\"PIECE_START STYLE=JSFAKES GENRE=JSFAKES TRACK_START INST=48 BAR_START NOTE_ON=61\", return_tensors=\"pt\")\n", "generated_ids = model.generate(input_ids, max_length=500, temperature=1.0)\n", "generated_sequence = tokenizer.decode(generated_ids[0])\n", "\n", "note_sequence = token_sequence_to_note_sequence(generated_sequence)\n", "\n", "synth = note_seq.midi_synth.synthesize\n", "note_seq.plot_sequence(note_sequence)\n", "note_seq.play_sequence(note_sequence, synth)" ] }, { "cell_type": "markdown", "metadata": { "id": "d1x6HeF90kkO" }, "source": [ "# Thank you!" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "colab_jsfakes_generation.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 1 }