{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import sys,time\n", "\n", "import midi2audio\n", "import transformers\n", "from transformers import AutoModelForCausalLM, BertTokenizer, GPT2LMHeadModel\n", "\n", "from IPython.display import Audio\n", "\n", "from anticipation import ops\n", "from anticipation.sample import generate\n", "from anticipation.tokenize import extract_instruments\n", "from anticipation.convert import events_to_midi,midi_to_events\n", "from anticipation.visuals import visualize\n", "from anticipation.config import *\n", "from anticipation.vocab import *\n", "\n", "import tqdm" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at stanford-crfm/music-small-800k were not used when initializing GPT2LMHeadModel: ['token_out_embeddings']\n", "- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Downloading model.safetensors: 100%|██████████| 166M/166M [00:23<00:00, 7.18MB/s]]\n", "Downloading generation_config.json: 100%|██████████| 144/144 [00:00<00:00, 47.9kB/s]\n" ] } ], "source": [ "small_john_path = 'stanford-crfm/music-small-800k' \n", "john_model = AutoModelForCausalLM.from_pretrained(small_john_path)\n", "\n", "# john_model.device(\"\")\n", "\n", "# our_model = GPT2LMHeadModel.from_pretrained(\"beat-goes-on/wiki-bert-tiny\")\n", "our_model = GPT2LMHeadModel.from_pretrained(\"beat-goes-on/HF_wiki_kmeans\")\n", "\n", "# our_model.device(\"\")\n", "\n", "# a MIDI synthesizer\n", "fs = midi2audio.FluidSynth('word2house\\8MBGMSFX.SF2')\n", "\n", "# sem_tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-base-uncased\", padding='max_length', max_length=64)\n", "\n", "# the MIDI synthesis script\n", "def synthesize(fs, tokens, filename):\n", " mid = events_to_midi(tokens)\n", " mid.save(f'{filename}.mid')\n", " fs.midi_to_audio(f'{filename}.mid', f'{filename}.wav')\n", " return f'{filename}.wav'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 34%|███▎ | 337/1000 [05:38<11:06, 1.00s/it]\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20572\\166056788.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;31m# Generate Unconditional John\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mlength\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m10\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0munconditional_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mjohn_model\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlength\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m.98\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0mAudio\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msynthesize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munconditional_tokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'unconditional_default'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\anticipation\\sample.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(model, start_time, end_time, inputs, controls, top_p, debug, delta, text_prompt, sem_tokenizer)\u001b[0m\n\u001b[0;32m 174\u001b[0m \u001b[0manticipated_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minf\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 175\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 176\u001b[1;33m \u001b[0mnew_token\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0madd_token\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mz\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstart_time\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcurrent_time\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 177\u001b[0m \u001b[0mnew_time\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_token\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mTIME_OFFSET\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 178\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mnew_time\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\anticipation\\sample.py\u001b[0m in \u001b[0;36madd_token\u001b[1;34m(model, z, tokens, top_p, current_time, debug)\u001b[0m\n\u001b[0;32m 87\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 88\u001b[0m \u001b[0minput_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mz\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mhistory\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mnew_token\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 89\u001b[1;33m \u001b[0mlogits\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_tokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlogits\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 90\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 91\u001b[0m \u001b[0midx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_tokens\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 1074\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0muse_return_dict\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1075\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1076\u001b[1;33m transformer_outputs = self.transformer(\n\u001b[0m\u001b[0;32m 1077\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1078\u001b[0m \u001b[0mpast_key_values\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpast_key_values\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 898\u001b[0m )\n\u001b[0;32m 899\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 900\u001b[1;33m outputs = block(\n\u001b[0m\u001b[0;32m 901\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 902\u001b[0m \u001b[0mlayer_past\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlayer_past\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[0;32m 388\u001b[0m \u001b[0mresidual\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 389\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mln_1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 390\u001b[1;33m attn_outputs = self.attn(\n\u001b[0m\u001b[0;32m 391\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 392\u001b[0m \u001b[0mlayer_past\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlayer_past\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\gpt2\\modeling_gpt2.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[0;32m 332\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 333\u001b[0m \u001b[0mattn_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_merge_heads\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattn_output\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_heads\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhead_dim\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 334\u001b[1;33m \u001b[0mattn_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mc_proj\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattn_output\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 335\u001b[0m \u001b[0mattn_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mresid_dropout\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattn_output\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 336\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# type: ignore[misc]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1510\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1511\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1512\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1513\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1518\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1519\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1521\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\pytorch_utils.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[0msize_out\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnf\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msize_out\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# Generate Unconditional John\n", "length = 10\n", "unconditional_tokens = generate(john_model, start_time=0, end_time=length, top_p=.98, debug=False)\n", "Audio(synthesize(fs, unconditional_tokens, 'unconditional_default'))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "API functions for sampling from anticipatory infilling models.\n", "\"\"\"\n", "\n", "import math\n", "import random\n", "import torch\n", "import torch.nn.functional as F\n", "\n", "from tqdm import tqdm\n", "\n", "from anticipation import ops\n", "from anticipation.config import *\n", "from anticipation.vocab import *\n", "\n", "\n", "def safe_logits(logits, idx):\n", " logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls\n", " logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens\n", "\n", " # don't generate stuff in the wrong time slot\n", " if idx % 3 == 0:\n", " logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')\n", " logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')\n", " elif idx % 3 == 1:\n", " logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')\n", " logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')\n", " elif idx % 3 == 2:\n", " logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')\n", " logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')\n", "\n", " return logits\n", "\n", "\n", "def nucleus(logits, top_p):\n", " # from HF implementation\n", " if top_p < 1.0:\n", " sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n", " cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n", "\n", " # Remove tokens with cumulative probability above the threshold (token with 0 are kept)\n", " sorted_indices_to_remove = cumulative_probs > top_p\n", "\n", " # Shift the indices to the right to keep also the first token above the threshold\n", " sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()\n", " sorted_indices_to_remove[..., 0] = 0\n", "\n", " # scatter sorted tensors to original indexing\n", " indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)\n", " logits[indices_to_remove] = -float(\"inf\")\n", "\n", " return logits\n", "\n", "\n", "def future_logits(logits, curtime):\n", " \"\"\" don't sample events in the past \"\"\"\n", " if curtime > 0:\n", " logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf')\n", "\n", " return logits\n", "\n", "\n", "def instr_logits(logits, full_history):\n", " \"\"\" don't sample more than 16 instruments \"\"\"\n", " instrs = ops.get_instruments(full_history)\n", " if len(instrs) < 16:\n", " return logits\n", " # print(instrs)\n", " for instr in range(MAX_INSTR):\n", " if instr not in instrs:\n", " # print(instr)\n", " logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')\n", "\n", " return logits\n", "\n", "\n", "def our_add_token(model, z, tokens, top_p, current_time, debug=False):\n", " assert len(tokens) % 3 == 0\n", "\n", " history = tokens.copy()\n", " lookback = max(len(tokens) - 1017, 0)\n", " history = history[lookback:] # Markov window\n", " offset = ops.min_time(history, seconds=False)\n", " history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer\n", "\n", " new_token = []\n", " with torch.no_grad():\n", " for i in range(3):\n", " input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)\n", " logits = model(input_tokens).logits[0,-1]\n", "\n", " # print(logits)\n", " idx = input_tokens.shape[1]-1\n", " logits = safe_logits(logits, idx)\n", " if i == 0:\n", " logits = future_logits(logits, current_time - offset)\n", " elif i == 2:\n", " logits = instr_logits(logits, tokens)\n", " logits = nucleus(logits, top_p)\n", "\n", " probs = F.softmax(logits, dim=-1)\n", " token = torch.multinomial(probs, 1)\n", " instrs = ops.get_instruments(tokens)\n", " if i == 2 and len(instrs.keys()) >= 16:\n", " # pitch = token % MAX_PITCH\n", " instr_probs = {}\n", " for instr in instrs:\n", " instr_probs[probs[instr]] = instr\n", " token = instr_probs[max(instr_probs.keys())]*MAX_PITCH + NOTE_OFFSET\n", " new_token.append(int(token))\n", "\n", " new_token[0] += offset # revert to full sequence timing\n", " if debug:\n", " print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')\n", " \n", " # print(new_token)\n", " return new_token\n", "\n", "\n", "def our_generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION, text_prompt=None, sem_tokenizer=None):\n", " if inputs is None:\n", " inputs = []\n", "\n", " if controls is None:\n", " controls = []\n", "\n", " start_time = int(TIME_RESOLUTION*start_time)\n", " end_time = int(TIME_RESOLUTION*end_time)\n", "\n", " # prompt is events up to start_time\n", " prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)\n", "\n", " # treat events beyond start_time as controls\n", " future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)\n", " if debug:\n", " print('Future')\n", " ops.print_tokens(future)\n", "\n", " # clip controls that preceed the sequence\n", " controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False)\n", "\n", " if debug:\n", " print('Controls')\n", " ops.print_tokens(controls)\n", "\n", " # z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 or text_prompt == None else sem_tokenizer.encode(text_prompt)\n", " z = \"\" # FIND A CLUSTER\n", " if debug:\n", " print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode')\n", "\n", " # interleave the controls with the events\n", " tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future]))\n", " # tokens = prompt\n", " if debug:\n", " print('Prompt')\n", " ops.print_tokens(tokens)\n", "\n", " current_time = ops.max_time(prompt, seconds=False)\n", " if debug:\n", " print('Current time:', current_time)\n", "\n", " with tqdm(range(end_time-start_time)) as progress:\n", " if controls:\n", " atime, adur, anote = controls[0:3]\n", " anticipated_tokens = controls[3:]\n", " anticipated_time = atime - ATIME_OFFSET\n", " else:\n", " # nothing to anticipate\n", " anticipated_time = math.inf\n", " \n", " while True:\n", " while current_time >= anticipated_time - delta:\n", " tokens.extend([atime, adur, anote])\n", " if debug:\n", " note = anote - ANOTE_OFFSET\n", " instr = note//2**7\n", " print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr)\n", "\n", " if len(anticipated_tokens) > 0:\n", " atime, adur, anote = anticipated_tokens[0:3]\n", " anticipated_tokens = anticipated_tokens[3:]\n", " anticipated_time = atime - ATIME_OFFSET\n", " else:\n", " # nothing more to anticipate\n", " anticipated_time = math.inf\n", " new_token = our_add_token(model, z, tokens, top_p, max(start_time,current_time))\n", " new_time = new_token[0] - TIME_OFFSET\n", " # We generated one note!\n", " # (note_start, duration, note_value)\n", " # new_time = current_time + random.randint(1, 5)\n", "\n", " # Check to make sure not too long. Yup.\n", " # new_token[1] = new_token[1] % MAX_DUR + DUR_OFFSET\n", " \n", " # new_token[2] = new_token[2] % MAX_NOTE + NOTE_OFFSET\n", "\n", "\n", " \n", " if new_time >= end_time:\n", " break\n", "\n", " if debug:\n", " new_note = new_token[2] - NOTE_OFFSET\n", " new_instr = new_note//2**7\n", " new_pitch = new_note - (2**7)*new_instr\n", " print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)\n", "\n", " tokens.extend(new_token)\n", " dt = new_time - current_time\n", " assert dt >= 0\n", " current_time = new_time\n", " progress.update(dt)\n", "\n", " events, _ = ops.split(tokens)\n", " return ops.sort(ops.unpad(events) + future)\n", "\n", "def our_generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION, text_prompt=None, sem_tokenizer=None):\n", " if inputs is None:\n", " inputs = []\n", "\n", " if controls is None:\n", " controls = []\n", " else:\n", " # treat controls as ordinary tokens\n", " controls = [token-CONTROL_OFFSET for token in controls]\n", "\n", " start_time = int(TIME_RESOLUTION*start_time)\n", " end_time = int(TIME_RESOLUTION*end_time)\n", "\n", " inputs = ops.sort(inputs + controls)\n", "\n", " # prompt is events up to start_time\n", " prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)\n", " if debug:\n", " print('Prompt')\n", " ops.print_tokens(prompt)\n", "\n", " # treat events beyond start_time as controls\n", " controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)\n", " if debug:\n", " print('Future')\n", " ops.print_tokens(controls)\n", "\n", " z = sem_tokenizer.encode(text_prompt)\n", " if debug:\n", " print('AR Mode')\n", "\n", " current_time = ops.max_time(prompt, seconds=False)\n", " if debug:\n", " print('Current time:', current_time)\n", "\n", " tokens = prompt\n", " with tqdm(range(end_time-start_time)) as progress:\n", " if controls:\n", " atime, adur, anote = controls[0:3]\n", " anticipated_tokens = controls[3:]\n", " anticipated_time = atime - TIME_OFFSET\n", " else:\n", " # nothing to anticipate\n", " anticipated_time = math.inf\n", " \n", " while True:\n", " new_token = our_add_token(model, z, tokens, top_p, max(start_time,current_time))\n", " # new_time = new_token[0] - TIME_OFFSET\n", " new_time = current_time + 5\n", " if new_time >= end_time:\n", " print(tokens)\n", " normed_tokens = [TIME_OFFSET]\n", " normed_tokens.extend(tokens[1:3])\n", " instrs = set()\n", " instrs.add((tokens[2] - NOTE_OFFSET) / 2 ** 7)\n", " for i in range(3, len(tokens), 3):\n", " note = tokens[i + 2] - NOTE_OFFSET\n", " instr = note//2**7\n", " if instr in instrs or len(instrs) < 16:\n", " normed_tokens.extend([normed_tokens[-3] + 1, tokens[i + 1], tokens[i + 2]])\n", " instrs.add(instr)\n", " tokens = normed_tokens\n", " break\n", "\n", " dt = new_time - current_time\n", " assert dt >= 0\n", " current_time = new_time\n", "\n", " # backfill anything that should have come before the new token\n", " while current_time >= anticipated_time:\n", " tokens.extend([atime, adur, anote])\n", " if debug:\n", " note = anote - NOTE_OFFSET\n", " instr = note//2**7\n", " print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)\n", "\n", " if len(anticipated_tokens) > 0:\n", " atime, adur, anote = anticipated_tokens[0:3]\n", " anticipated_tokens = anticipated_tokens[3:]\n", " anticipated_time = atime - TIME_OFFSET\n", " else:\n", " # nothing more to anticipate\n", " anticipated_time = math.inf\n", "\n", " if debug:\n", " new_note = new_token[2] - NOTE_OFFSET\n", " new_instr = new_note//2**7\n", " new_pitch = new_note - (2**7)*new_instr\n", " print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)\n", "\n", " tokens.extend(new_token)\n", " progress.update(dt)\n", "\n", " if anticipated_time != math.inf:\n", " tokens.extend([atime, adur, anote])\n", "\n", " return ops.sort(ops.unpad(tokens) + controls)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "ename": "UnpicklingError", "evalue": "pickle stream refers to out-of-band data but no *buffers* argument was given", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mUnpicklingError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20572\\4172336658.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mcluster_inference\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mget_cluster\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mgenerate_kmeans\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtextPrompt\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcontrols\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdelta\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mDELTA\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mTIME_RESOLUTION\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\OneDrive\\Desktop\\CS224N\\finalproj\\word2house\\cluster_inference.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[1;31m# load the kmeans model\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 18\u001b[1;33m \u001b[0mkmeans_model\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mjoblib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'GPT_128k_means_model.joblib'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 19\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mget_cluster\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprompt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\joblib\\numpy_pickle.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(filename, mmap_mode)\u001b[0m\n\u001b[0;32m 583\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mload_compatibility\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfobj\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 584\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 585\u001b[1;33m \u001b[0mobj\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_unpickle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfobj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmmap_mode\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 586\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mobj\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\joblib\\numpy_pickle.py\u001b[0m in \u001b[0;36m_unpickle\u001b[1;34m(fobj, filename, mmap_mode)\u001b[0m\n\u001b[0;32m 502\u001b[0m \u001b[0mobj\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 503\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 504\u001b[1;33m \u001b[0mobj\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 505\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcompat_mode\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 506\u001b[0m warnings.warn(\"The file '%s' has been generated with a \"\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\pickle.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1210\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mEOFError\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1211\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbytes_types\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1212\u001b[1;33m \u001b[0mdispatch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1213\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0m_Stop\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mstopinst\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1214\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mstopinst\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\Users\\phili\\AppData\\Local\\Programs\\Python\\Python39\\lib\\pickle.py\u001b[0m in \u001b[0;36mload_next_buffer\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1395\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mload_next_buffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1396\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_buffers\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1397\u001b[1;33m raise UnpicklingError(\"pickle stream refers to out-of-band data \"\n\u001b[0m\u001b[0;32m 1398\u001b[0m \"but no *buffers* argument was given\")\n\u001b[0;32m 1399\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mUnpicklingError\u001b[0m: pickle stream refers to out-of-band data but no *buffers* argument was given" ] } ], "source": [ "\n", "\n", "from cluster_inference import get_cluster\n", "\n", "def generate_kmeans(model, start_time, end_time, textPrompt=None, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):\n", " if inputs is None:\n", " inputs = []\n", "\n", " if controls is None:\n", " controls = []\n", " else:\n", " # treat controls as ordinary tokens\n", " controls = [token-CONTROL_OFFSET for token in controls]\n", "\n", " start_time = int(TIME_RESOLUTION*start_time)\n", " end_time = int(TIME_RESOLUTION*end_time)\n", "\n", " inputs = ops.sort(inputs + controls)\n", "\n", " # prompt is events up to start_time\n", " prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)\n", " if debug:\n", " print('Prompt')\n", " ops.print_tokens(prompt)\n", "\n", " # treat events beyond start_time as controls\n", " controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)\n", " if debug:\n", " print('Future')\n", " ops.print_tokens(controls)\n", "\n", " z = [get_cluster(textPrompt)] if textPrompt != None else [AUTOREGRESS]\n", " if debug:\n", " print('AR Mode')\n", "\n", " current_time = ops.max_time(prompt, seconds=False)\n", " if debug:\n", " print('Current time:', current_time)\n", "\n", " tokens = prompt\n", " with tqdm(range(end_time-start_time)) as progress:\n", " if controls:\n", " atime, adur, anote = controls[0:3]\n", " anticipated_tokens = controls[3:]\n", " anticipated_time = atime - TIME_OFFSET\n", " else:\n", " # nothing to anticipate\n", " anticipated_time = math.inf\n", "\n", " while True:\n", " new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))\n", " new_time = new_token[0] - TIME_OFFSET\n", " if new_time >= end_time:\n", " break\n", "\n", " dt = new_time - current_time\n", " assert dt >= 0\n", " current_time = new_time\n", "\n", " # backfill anything that should have come before the new token\n", " while current_time >= anticipated_time:\n", " tokens.extend([atime, adur, anote])\n", " if debug:\n", " note = anote - NOTE_OFFSET\n", " instr = note//2**7\n", " print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)\n", "\n", " if len(anticipated_tokens) > 0:\n", " atime, adur, anote = anticipated_tokens[0:3]\n", " anticipated_tokens = anticipated_tokens[3:]\n", " anticipated_time = atime - TIME_OFFSET\n", " else:\n", " # nothing more to anticipate\n", " anticipated_time = math.inf\n", "\n", " if debug:\n", " new_note = new_token[2] - NOTE_OFFSET\n", " new_instr = new_note//2**7\n", " new_pitch = new_note - (2**7)*new_instr\n", " print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)\n", "\n", " tokens.extend(new_token)\n", " progress.update(dt)\n", "\n", " if anticipated_time != math.inf:\n", " tokens.extend([atime, adur, anote])\n", "\n", " return ops.sort(ops.unpad(tokens) + controls)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 95%|█████████▌| 95/100 [00:02<00:00, 33.39it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[1050, 10022, 14255, 2098, 10022, 21199, 1093, 10230, 20274, 1150, 10107, 14745, 1150, 10716, 17460, 2582, 10043, 15385, 5608, 10011, 17221, 3507, 10022, 14264, 1150, 10508, 12498, 1108, 10011, 14520, 3590, 10013, 14513, 1109, 10613, 21334, 5976, 10754, 27426, 2894, 10015, 27459, 2106, 10022, 14264, 2073, 10067, 14252, 1111, 10125, 14503, 3477, 10006, 14757, 1222, 10029, 14252]\n", "[0, 10022, 14255, 1, 10022, 21199, 2, 10230, 20274, 3, 10107, 14745, 4, 10716, 17460, 5, 10043, 15385, 6, 10011, 17221, 7, 10022, 14264, 8, 10508, 12498, 9, 10011, 14520, 10, 10013, 14513, 11, 10613, 21334, 12, 10754, 27426, 13, 10015, 27459, 14, 10022, 14264, 15, 10067, 14252, 16, 10125, 14503, 17, 10006, 14757, 18, 10029, 14252]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "ename": "NameError", "evalue": "name 'generate_kmeans' is not defined", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20572\\1479155271.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mlength\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;31m# unconditional_tokens = our_generate(our_model, debug=False, start_time=0, end_time=length, top_p=.98, text_prompt=\"make me a silly song for my friends\", sem_tokenizer=sem_tokenizer)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0munconditional_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgenerate_kmeans\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mour_model\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlength\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtop_p\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m.98\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtext_prompt\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"make me a silly song for my friends\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0munconditional_tokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mAudio\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msynthesize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munconditional_tokens\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"unconditional_new_kmeans\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mNameError\u001b[0m: name 'generate_kmeans' is not defined" ] }, { "name": "stderr", "output_type": "stream", "text": [] } ], "source": [ "# Generate Unconditional Our Model\n", "length = 1\n", "# unconditional_tokens = our_generate(our_model, debug=False, start_time=0, end_time=length, top_p=.98, text_prompt=\"make me a silly song for my friends\", sem_tokenizer=sem_tokenizer)\n", "unconditional_tokens = generate_kmeans(our_model, debug=False, start_time=0, end_time=length, top_p=.98, text_prompt=\"make me a silly song for my friends\")\n", "print(unconditional_tokens)\n", "Audio(synthesize(fs, unconditional_tokens, \"unconditional_new_kmeans\"))" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/20 [00:00