#%% import openai import numpy as np import pretty_midi import re import numpy as np import os import gradio as gr openai.api_key = os.environ.get("OPENAI_API_KEY") # sample data markdown_table_sample = """ | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | |----|---|---|---|---|---|---|---|---| | BD | x | | x | | | | x | | | SD | | | | x | | | | x | | CH | x | | x | | x | | x | | | OH | | | | x | | | x | | | LT | | | | | | x | | | | MT | | x | | | x | | | | | HT | x | | | x | | | | | """ markdown_table_sample2 = """ | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16| |----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | BD | x | | x | | | | x | | x | | x | | x | | x | | | SD | | | | x | | | | x | | | x | | | | x | | | CH | x | | x | | x | | x | | x | | x | | x | | x | | | OH | | | | x | | | x | | | | | x | | | x | | | LT | | | | | | x | | | | | | | | x | | | | MT | | x | | | x | | | | | x | | | x | | | | | HT | x | | | x | | | | | x | | | x | | | | | """ MIDI_NOTENUM = { "BD": 36, "SD": 38, "CH": 42, "HH": 44, "OH": 46, "LT": 48, "MT": 48, "HT": 50, "CP": 50, "CB": 56, } SR = 44100 MAX_QUERY = 5 def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0): # convert table to array rhythm_pattern = [] for line in markdown_table.split('\n')[2:]: rhythm_pattern.append(line.split('|')[1:-1]) print(rhythm_pattern) # table to MIDI pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument pm.instruments.append(pm_inst) note_length = (60. / bpm) * (4.0 / resolution) # note duration for i in range(len(rhythm_pattern)): for j in range(1, len(rhythm_pattern[i])): inst = rhythm_pattern[i][0].strip().upper() velocity = 0 if 'x' == rhythm_pattern[i][j].strip(): velocity = 120 if 'o' == rhythm_pattern[i][j].strip(): velocity = 65 if velocity > 0: if inst in MIDI_NOTENUM.keys(): midinote = MIDI_NOTENUM[inst] note = pretty_midi.Note(velocity=velocity, pitch=midinote, start=note_length * (j-1)+0.05, end=note_length * j) pm_inst.notes.append(note) # convert to audio audio_data = pm.fluidsynth() # cut off the reverb section audio_data = audio_data[:int(SR*note_length*resolution)] # for looping, cut the tail return audio_data def get_answer(question): response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a rhythm generator. "}, {"role": "user", "content": "Please generate a rhythm pattern in a Markdown table. Time resolution is the 8th note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat."}, {"role": "assistant", "content": markdown_table_sample}, # {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."}, # {"role": "assistant", "content": markdown_table_sample}, {"role": "user", "content": question} ] ) return response["choices"][0]["message"]["content"] def generate_rhythm(query, state): print(state) if state["gen_count"] > MAX_QUERY and len(state["user_token"]) == 0: return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY] state["gen_count"] = state["gen_count"] + 1 # get respance from ChatGPT text_output = get_answer(query) # Try to use the first row as time resolution # resolution_text = text_output.split('|')[0] # try: # resolution_text = re.findall(r'\d+', resolution_text)[0] # resolution = int(resolution_text) # except: # resolution = 8 # default resolution = 8 # default # Extract rhythm table table = "|" + "|".join(text_output.split('|')[1:-1]) + "|" audio_data = convert_table_to_audio(table, resolution) # loop x2 audio_data = np.tile(audio_data, 4) return [(SR, audio_data), text_output] # %% def on_token_change(user_token, state): print(user_token) openai.api_key = user_token or os.environ.get("OPENAI_API_KEY") state["user_token"] = user_token return state with gr.Blocks() as demo: state = gr.State({"gen_count": 0, "user_token":""}) with gr.Row(): with gr.Column(): # gr.Markdown("Ask ChatGPT to generate rhythm patterns") gr.Markdown("***Hey TR-ChatGPT, give me a drum pattern!***") gr.Markdown("You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. Use 'x' for an accented beat, 'o' for a weak beat. ", elem_id="label") with gr.Row(): with gr.Column(): inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!") btn = gr.Button("Generate") with gr.Column(): out_audio = gr.Audio() out_text = gr.Textbox(placeholder="ChatGPT output") with gr.Row(): with gr.Column(): gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](https://platform.openai.com/account/api-keys).") user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False) btn.click(fn=generate_rhythm, inputs=[inp, state], outputs=[out_audio, out_text]) user_token.change(on_token_change, inputs=[user_token, state], outputs=[state]) demo.launch()