diff --git "a/gen_with_our_model.ipynb" "b/gen_with_our_model.ipynb" new file mode 100644--- /dev/null +++ "b/gen_with_our_model.ipynb" @@ -0,0 +1,806 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import sys,time\n", + "\n", + "import midi2audio\n", + "import transformers\n", + "from transformers import AutoModelForCausalLM, BertTokenizer, GPT2LMHeadModel\n", + "\n", + "from IPython.display import Audio\n", + "\n", + "from anticipation import ops\n", + "from anticipation.sample import generate\n", + "from anticipation.tokenize import extract_instruments\n", + "from anticipation.convert import events_to_midi,midi_to_events\n", + "from anticipation.visuals import visualize\n", + "from anticipation.config import *\n", + "from anticipation.vocab import *\n", + "\n", + "import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at stanford-crfm/music-small-800k were not used when initializing GPT2LMHeadModel: ['token_out_embeddings']\n", + "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Downloading model.safetensors: 100%|██████████| 166M/166M [00:23<00:00, 7.18MB/s]]\n", + "Downloading generation_config.json: 100%|██████████| 144/144 [00:00<00:00, 47.9kB/s]\n" + ] + } + ], + "source": [ + "small_john_path = 'stanford-crfm/music-small-800k' \n", + "john_model = AutoModelForCausalLM.from_pretrained(small_john_path)\n", + "\n", + "# john_model.device(\"\")\n", + "\n", + "# our_model = GPT2LMHeadModel.from_pretrained(\"beat-goes-on/wiki-bert-tiny\")\n", + "our_model = GPT2LMHeadModel.from_pretrained(\"beat-goes-on/HF_wiki_kmeans\")\n", + "\n", + "# our_model.device(\"\")\n", + "\n", + "# a MIDI synthesizer\n", + "fs = midi2audio.FluidSynth('word2house\\8MBGMSFX.SF2')\n", + "\n", + "# sem_tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-base-uncased\", padding='max_length', max_length=64)\n", + "\n", + "# the MIDI synthesis script\n", + "def synthesize(fs, tokens, filename):\n", + " mid = events_to_midi(tokens)\n", + " mid.save(f'{filename}.mid')\n", + " fs.midi_to_audio(f'{filename}.mid', f'{filename}.wav')\n", + " return f'{filename}.wav'" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 34%|███▎ | 337/1000 [05:38<11:06, 1.00s/it]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20572\\166056788.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;31m# Generate Unconditional John\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mlength\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m10\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0munconditional_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mjohn_model\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlength\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m.98\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0mAudio\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msynthesize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munconditional_tokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'unconditional_default'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\anticipation\\sample.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(model, start_time, end_time, inputs, controls, top_p, debug, delta, text_prompt, sem_tokenizer)\u001b[0m\n\u001b[0;32m 174\u001b[0m \u001b[0manticipated_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minf\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 175\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 176\u001b[1;33m \u001b[0mnew_token\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0madd_token\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mz\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstart_time\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcurrent_time\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 177\u001b[0m \u001b[0mnew_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_token\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mTIME_OFFSET\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 178\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mnew_time\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\anticipation\\sample.py\u001b[0m in \u001b[0;36madd_token\u001b[1;34m(model, z, tokens, top_p, current_time, debug)\u001b[0m\n\u001b[0;32m 87\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 88\u001b[0m \u001b[0minput_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mz\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mhistory\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mnew_token\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 89\u001b[1;33m \u001b[0mlogits\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_tokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlogits\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 90\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 91\u001b[0m \u001b[0midx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_tokens\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 1074\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0muse_return_dict\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1075\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1076\u001b[1;33m transformer_outputs = self.transformer(\n\u001b[0m\u001b[0;32m 1077\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1078\u001b[0m \u001b[0mpast_key_values\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpast_key_values\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 898\u001b[0m )\n\u001b[0;32m 899\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 900\u001b[1;33m outputs = block(\n\u001b[0m\u001b[0;32m 901\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 902\u001b[0m \u001b[0mlayer_past\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlayer_past\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[0;32m 388\u001b[0m \u001b[0mresidual\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 389\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mln_1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 390\u001b[1;33m attn_outputs = self.attn(\n\u001b[0m\u001b[0;32m 391\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 392\u001b[0m \u001b[0mlayer_past\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlayer_past\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[0;32m 332\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 333\u001b[0m \u001b[0mattn_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_merge_heads\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattn_output\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_heads\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhead_dim\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 334\u001b[1;33m \u001b[0mattn_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mc_proj\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattn_output\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 335\u001b[0m \u001b[0mattn_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mresid_dropout\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattn_output\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 336\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\pytorch_utils.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[0msize_out\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnf\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msize_out\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Generate Unconditional John\n", + "length = 10\n", + "unconditional_tokens = generate(john_model, start_time=0, end_time=length, top_p=.98, debug=False)\n", + "Audio(synthesize(fs, unconditional_tokens, 'unconditional_default'))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "API functions for sampling from anticipatory infilling models.\n", + "\"\"\"\n", + "\n", + "import math\n", + "import random\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "from tqdm import tqdm\n", + "\n", + "from anticipation import ops\n", + "from anticipation.config import *\n", + "from anticipation.vocab import *\n", + "\n", + "\n", + "def safe_logits(logits, idx):\n", + " logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls\n", + " logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens\n", + "\n", + " # don't generate stuff in the wrong time slot\n", + " if idx % 3 == 0:\n", + " logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')\n", + " logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')\n", + " elif idx % 3 == 1:\n", + " logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')\n", + " logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')\n", + " elif idx % 3 == 2:\n", + " logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')\n", + " logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')\n", + "\n", + " return logits\n", + "\n", + "\n", + "def nucleus(logits, top_p):\n", + " # from HF implementation\n", + " if top_p < 1.0:\n", + " sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n", + " cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n", + "\n", + " # Remove tokens with cumulative probability above the threshold (token with 0 are kept)\n", + " sorted_indices_to_remove = cumulative_probs > top_p\n", + "\n", + " # Shift the indices to the right to keep also the first token above the threshold\n", + " sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()\n", + " sorted_indices_to_remove[..., 0] = 0\n", + "\n", + " # scatter sorted tensors to original indexing\n", + " indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)\n", + " logits[indices_to_remove] = -float(\"inf\")\n", + "\n", + " return logits\n", + "\n", + "\n", + "def future_logits(logits, curtime):\n", + " \"\"\" don't sample events in the past \"\"\"\n", + " if curtime > 0:\n", + " logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf')\n", + "\n", + " return logits\n", + "\n", + "\n", + "def instr_logits(logits, full_history):\n", + " \"\"\" don't sample more than 16 instruments \"\"\"\n", + " instrs = ops.get_instruments(full_history)\n", + " if len(instrs) < 16:\n", + " return logits\n", + " # print(instrs)\n", + " for instr in range(MAX_INSTR):\n", + " if instr not in instrs:\n", + " # print(instr)\n", + " logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')\n", + "\n", + " return logits\n", + "\n", + "\n", + "def our_add_token(model, z, tokens, top_p, current_time, debug=False):\n", + " assert len(tokens) % 3 == 0\n", + "\n", + " history = tokens.copy()\n", + " lookback = max(len(tokens) - 1017, 0)\n", + " history = history[lookback:] # Markov window\n", + " offset = ops.min_time(history, seconds=False)\n", + " history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer\n", + "\n", + " new_token = []\n", + " with torch.no_grad():\n", + " for i in range(3):\n", + " input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)\n", + " logits = model(input_tokens).logits[0,-1]\n", + "\n", + " # print(logits)\n", + " idx = input_tokens.shape[1]-1\n", + " logits = safe_logits(logits, idx)\n", + " if i == 0:\n", + " logits = future_logits(logits, current_time - offset)\n", + " elif i == 2:\n", + " logits = instr_logits(logits, tokens)\n", + " logits = nucleus(logits, top_p)\n", + "\n", + " probs = F.softmax(logits, dim=-1)\n", + " token = torch.multinomial(probs, 1)\n", + " instrs = ops.get_instruments(tokens)\n", + " if i == 2 and len(instrs.keys()) >= 16:\n", + " # pitch = token % MAX_PITCH\n", + " instr_probs = {}\n", + " for instr in instrs:\n", + " instr_probs[probs[instr]] = instr\n", + " token = instr_probs[max(instr_probs.keys())]*MAX_PITCH + NOTE_OFFSET\n", + " new_token.append(int(token))\n", + "\n", + " new_token[0] += offset # revert to full sequence timing\n", + " if debug:\n", + " print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')\n", + " \n", + " # print(new_token)\n", + " return new_token\n", + "\n", + "\n", + "def our_generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION, text_prompt=None, sem_tokenizer=None):\n", + " if inputs is None:\n", + " inputs = []\n", + "\n", + " if controls is None:\n", + " controls = []\n", + "\n", + " start_time = int(TIME_RESOLUTION*start_time)\n", + " end_time = int(TIME_RESOLUTION*end_time)\n", + "\n", + " # prompt is events up to start_time\n", + " prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)\n", + "\n", + " # treat events beyond start_time as controls\n", + " future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)\n", + " if debug:\n", + " print('Future')\n", + " ops.print_tokens(future)\n", + "\n", + " # clip controls that preceed the sequence\n", + " controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False)\n", + "\n", + " if debug:\n", + " print('Controls')\n", + " ops.print_tokens(controls)\n", + "\n", + " # z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 or text_prompt == None else sem_tokenizer.encode(text_prompt)\n", + " z = \"\" # FIND A CLUSTER\n", + " if debug:\n", + " print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode')\n", + "\n", + " # interleave the controls with the events\n", + " tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future]))\n", + " # tokens = prompt\n", + " if debug:\n", + " print('Prompt')\n", + " ops.print_tokens(tokens)\n", + "\n", + " current_time = ops.max_time(prompt, seconds=False)\n", + " if debug:\n", + " print('Current time:', current_time)\n", + "\n", + " with tqdm(range(end_time-start_time)) as progress:\n", + " if controls:\n", + " atime, adur, anote = controls[0:3]\n", + " anticipated_tokens = controls[3:]\n", + " anticipated_time = atime - ATIME_OFFSET\n", + " else:\n", + " # nothing to anticipate\n", + " anticipated_time = math.inf\n", + " \n", + " while True:\n", + " while current_time >= anticipated_time - delta:\n", + " tokens.extend([atime, adur, anote])\n", + " if debug:\n", + " note = anote - ANOTE_OFFSET\n", + " instr = note//2**7\n", + " print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr)\n", + "\n", + " if len(anticipated_tokens) > 0:\n", + " atime, adur, anote = anticipated_tokens[0:3]\n", + " anticipated_tokens = anticipated_tokens[3:]\n", + " anticipated_time = atime - ATIME_OFFSET\n", + " else:\n", + " # nothing more to anticipate\n", + " anticipated_time = math.inf\n", + " new_token = our_add_token(model, z, tokens, top_p, max(start_time,current_time))\n", + " new_time = new_token[0] - TIME_OFFSET\n", + " # We generated one note!\n", + " # (note_start, duration, note_value)\n", + " # new_time = current_time + random.randint(1, 5)\n", + "\n", + " # Check to make sure not too long. Yup.\n", + " # new_token[1] = new_token[1] % MAX_DUR + DUR_OFFSET\n", + " \n", + " # new_token[2] = new_token[2] % MAX_NOTE + NOTE_OFFSET\n", + "\n", + "\n", + " \n", + " if new_time >= end_time:\n", + " break\n", + "\n", + " if debug:\n", + " new_note = new_token[2] - NOTE_OFFSET\n", + " new_instr = new_note//2**7\n", + " new_pitch = new_note - (2**7)*new_instr\n", + " print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)\n", + "\n", + " tokens.extend(new_token)\n", + " dt = new_time - current_time\n", + " assert dt >= 0\n", + " current_time = new_time\n", + " progress.update(dt)\n", + "\n", + " events, _ = ops.split(tokens)\n", + " return ops.sort(ops.unpad(events) + future)\n", + "\n", + "def our_generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION, text_prompt=None, sem_tokenizer=None):\n", + " if inputs is None:\n", + " inputs = []\n", + "\n", + " if controls is None:\n", + " controls = []\n", + " else:\n", + " # treat controls as ordinary tokens\n", + " controls = [token-CONTROL_OFFSET for token in controls]\n", + "\n", + " start_time = int(TIME_RESOLUTION*start_time)\n", + " end_time = int(TIME_RESOLUTION*end_time)\n", + "\n", + " inputs = ops.sort(inputs + controls)\n", + "\n", + " # prompt is events up to start_time\n", + " prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)\n", + " if debug:\n", + " print('Prompt')\n", + " ops.print_tokens(prompt)\n", + "\n", + " # treat events beyond start_time as controls\n", + " controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)\n", + " if debug:\n", + " print('Future')\n", + " ops.print_tokens(controls)\n", + "\n", + " z = sem_tokenizer.encode(text_prompt)\n", + " if debug:\n", + " print('AR Mode')\n", + "\n", + " current_time = ops.max_time(prompt, seconds=False)\n", + " if debug:\n", + " print('Current time:', current_time)\n", + "\n", + " tokens = prompt\n", + " with tqdm(range(end_time-start_time)) as progress:\n", + " if controls:\n", + " atime, adur, anote = controls[0:3]\n", + " anticipated_tokens = controls[3:]\n", + " anticipated_time = atime - TIME_OFFSET\n", + " else:\n", + " # nothing to anticipate\n", + " anticipated_time = math.inf\n", + " \n", + " while True:\n", + " new_token = our_add_token(model, z, tokens, top_p, max(start_time,current_time))\n", + " # new_time = new_token[0] - TIME_OFFSET\n", + " new_time = current_time + 5\n", + " if new_time >= end_time:\n", + " print(tokens)\n", + " normed_tokens = [TIME_OFFSET]\n", + " normed_tokens.extend(tokens[1:3])\n", + " instrs = set()\n", + " instrs.add((tokens[2] - NOTE_OFFSET) / 2 ** 7)\n", + " for i in range(3, len(tokens), 3):\n", + " note = tokens[i + 2] - NOTE_OFFSET\n", + " instr = note//2**7\n", + " if instr in instrs or len(instrs) < 16:\n", + " normed_tokens.extend([normed_tokens[-3] + 1, tokens[i + 1], tokens[i + 2]])\n", + " instrs.add(instr)\n", + " tokens = normed_tokens\n", + " break\n", + "\n", + " dt = new_time - current_time\n", + " assert dt >= 0\n", + " current_time = new_time\n", + "\n", + " # backfill anything that should have come before the new token\n", + " while current_time >= anticipated_time:\n", + " tokens.extend([atime, adur, anote])\n", + " if debug:\n", + " note = anote - NOTE_OFFSET\n", + " instr = note//2**7\n", + " print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)\n", + "\n", + " if len(anticipated_tokens) > 0:\n", + " atime, adur, anote = anticipated_tokens[0:3]\n", + " anticipated_tokens = anticipated_tokens[3:]\n", + " anticipated_time = atime - TIME_OFFSET\n", + " else:\n", + " # nothing more to anticipate\n", + " anticipated_time = math.inf\n", + "\n", + " if debug:\n", + " new_note = new_token[2] - NOTE_OFFSET\n", + " new_instr = new_note//2**7\n", + " new_pitch = new_note - (2**7)*new_instr\n", + " print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)\n", + "\n", + " tokens.extend(new_token)\n", + " progress.update(dt)\n", + "\n", + " if anticipated_time != math.inf:\n", + " tokens.extend([atime, adur, anote])\n", + "\n", + " return ops.sort(ops.unpad(tokens) + controls)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "UnpicklingError", + "evalue": "pickle stream refers to out-of-band data but no *buffers* argument was given", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mUnpicklingError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20572\\4172336658.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mcluster_inference\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mget_cluster\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mgenerate_kmeans\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtextPrompt\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcontrols\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mDELTA\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mTIME_RESOLUTION\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\OneDrive\\Desktop\\CS224N\\finalproj\\word2house\\cluster_inference.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[1;31m# load the kmeans model\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 18\u001b[1;33m \u001b[0mkmeans_model\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mjoblib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'GPT_128k_means_model.joblib'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 19\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mget_cluster\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\joblib\\numpy_pickle.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(filename, mmap_mode)\u001b[0m\n\u001b[0;32m 583\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mload_compatibility\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfobj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 584\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 585\u001b[1;33m \u001b[0mobj\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_unpickle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfobj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmmap_mode\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 586\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mobj\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\joblib\\numpy_pickle.py\u001b[0m in \u001b[0;36m_unpickle\u001b[1;34m(fobj, filename, mmap_mode)\u001b[0m\n\u001b[0;32m 502\u001b[0m \u001b[0mobj\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 503\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 504\u001b[1;33m \u001b[0mobj\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 505\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcompat_mode\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 506\u001b[0m warnings.warn(\"The file '%s' has been generated with a \"\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\pickle.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1210\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mEOFError\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1211\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbytes_types\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1212\u001b[1;33m \u001b[0mdispatch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1213\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0m_Stop\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mstopinst\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1214\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mstopinst\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\pickle.py\u001b[0m in \u001b[0;36mload_next_buffer\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1395\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mload_next_buffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1396\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_buffers\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1397\u001b[1;33m raise UnpicklingError(\"pickle stream refers to out-of-band data \"\n\u001b[0m\u001b[0;32m 1398\u001b[0m \"but no *buffers* argument was given\")\n\u001b[0;32m 1399\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mUnpicklingError\u001b[0m: pickle stream refers to out-of-band data but no *buffers* argument was given" + ] + } + ], + "source": [ + "\n", + "\n", + "from cluster_inference import get_cluster\n", + "\n", + "def generate_kmeans(model, start_time, end_time, textPrompt=None, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):\n", + " if inputs is None:\n", + " inputs = []\n", + "\n", + " if controls is None:\n", + " controls = []\n", + " else:\n", + " # treat controls as ordinary tokens\n", + " controls = [token-CONTROL_OFFSET for token in controls]\n", + "\n", + " start_time = int(TIME_RESOLUTION*start_time)\n", + " end_time = int(TIME_RESOLUTION*end_time)\n", + "\n", + " inputs = ops.sort(inputs + controls)\n", + "\n", + " # prompt is events up to start_time\n", + " prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)\n", + " if debug:\n", + " print('Prompt')\n", + " ops.print_tokens(prompt)\n", + "\n", + " # treat events beyond start_time as controls\n", + " controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)\n", + " if debug:\n", + " print('Future')\n", + " ops.print_tokens(controls)\n", + "\n", + " z = [get_cluster(textPrompt)] if textPrompt != None else [AUTOREGRESS]\n", + " if debug:\n", + " print('AR Mode')\n", + "\n", + " current_time = ops.max_time(prompt, seconds=False)\n", + " if debug:\n", + " print('Current time:', current_time)\n", + "\n", + " tokens = prompt\n", + " with tqdm(range(end_time-start_time)) as progress:\n", + " if controls:\n", + " atime, adur, anote = controls[0:3]\n", + " anticipated_tokens = controls[3:]\n", + " anticipated_time = atime - TIME_OFFSET\n", + " else:\n", + " # nothing to anticipate\n", + " anticipated_time = math.inf\n", + "\n", + " while True:\n", + " new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))\n", + " new_time = new_token[0] - TIME_OFFSET\n", + " if new_time >= end_time:\n", + " break\n", + "\n", + " dt = new_time - current_time\n", + " assert dt >= 0\n", + " current_time = new_time\n", + "\n", + " # backfill anything that should have come before the new token\n", + " while current_time >= anticipated_time:\n", + " tokens.extend([atime, adur, anote])\n", + " if debug:\n", + " note = anote - NOTE_OFFSET\n", + " instr = note//2**7\n", + " print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)\n", + "\n", + " if len(anticipated_tokens) > 0:\n", + " atime, adur, anote = anticipated_tokens[0:3]\n", + " anticipated_tokens = anticipated_tokens[3:]\n", + " anticipated_time = atime - TIME_OFFSET\n", + " else:\n", + " # nothing more to anticipate\n", + " anticipated_time = math.inf\n", + "\n", + " if debug:\n", + " new_note = new_token[2] - NOTE_OFFSET\n", + " new_instr = new_note//2**7\n", + " new_pitch = new_note - (2**7)*new_instr\n", + " print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)\n", + "\n", + " tokens.extend(new_token)\n", + " progress.update(dt)\n", + "\n", + " if anticipated_time != math.inf:\n", + " tokens.extend([atime, adur, anote])\n", + "\n", + " return ops.sort(ops.unpad(tokens) + controls)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 95%|█████████▌| 95/100 [00:02<00:00, 33.39it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1050, 10022, 14255, 2098, 10022, 21199, 1093, 10230, 20274, 1150, 10107, 14745, 1150, 10716, 17460, 2582, 10043, 15385, 5608, 10011, 17221, 3507, 10022, 14264, 1150, 10508, 12498, 1108, 10011, 14520, 3590, 10013, 14513, 1109, 10613, 21334, 5976, 10754, 27426, 2894, 10015, 27459, 2106, 10022, 14264, 2073, 10067, 14252, 1111, 10125, 14503, 3477, 10006, 14757, 1222, 10029, 14252]\n", + "[0, 10022, 14255, 1, 10022, 21199, 2, 10230, 20274, 3, 10107, 14745, 4, 10716, 17460, 5, 10043, 15385, 6, 10011, 17221, 7, 10022, 14264, 8, 10508, 12498, 9, 10011, 14520, 10, 10013, 14513, 11, 10613, 21334, 12, 10754, 27426, 13, 10015, 27459, 14, 10022, 14264, 15, 10067, 14252, 16, 10125, 14503, 17, 10006, 14757, 18, 10029, 14252]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "ename": "NameError", + "evalue": "name 'generate_kmeans' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20572\\1479155271.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mlength\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;31m# unconditional_tokens = our_generate(our_model, debug=False, start_time=0, end_time=length, top_p=.98, text_prompt=\"make me a silly song for my friends\", sem_tokenizer=sem_tokenizer)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0munconditional_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgenerate_kmeans\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mour_model\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlength\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m.98\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtext_prompt\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"make me a silly song for my friends\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0munconditional_tokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mAudio\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msynthesize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munconditional_tokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"unconditional_new_kmeans\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mNameError\u001b[0m: name 'generate_kmeans' is not defined" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [] + } + ], + "source": [ + "# Generate Unconditional Our Model\n", + "length = 1\n", + "# unconditional_tokens = our_generate(our_model, debug=False, start_time=0, end_time=length, top_p=.98, text_prompt=\"make me a silly song for my friends\", sem_tokenizer=sem_tokenizer)\n", + "unconditional_tokens = generate_kmeans(our_model, debug=False, start_time=0, end_time=length, top_p=.98, text_prompt=\"make me a silly song for my friends\")\n", + "print(unconditional_tokens)\n", + "Audio(synthesize(fs, unconditional_tokens, \"unconditional_new_kmeans\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 169, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/20 [00:00