TR-ChatGPT /
naotokui's picture
add api key text input
history blame
5.18 kB
import openai
import numpy as np
import pretty_midi
import re
import numpy as np
import os
import gradio as gr
openai.api_key = os.environ.get("OPENAI_API_KEY")
# sample data
markdown_table_sample = """4
| | 1 | 2 | 3 | 4 |
| BD | | | x | |
| SD | | | | x |
| CH | x | | x | |
| OH | | | | x |
| LT | | | | |
| MT | | x | | |
| HT | x | | | x |
markdown_table_sample2 = """8
| | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
| BD | | | x | | | | x | |
| SD | | | | x | | | | x |
| CH | x | | x | | x | | x | |
| OH | | | | x | | | x | |
| LT | | | | | | x | | |
| MT | | x | | | x | | | |
| HT | x | | | x | | | | |
"BD": 36,
"SD": 38,
"CH": 42,
"HH": 42,
"OH": 46,
"LT": 48,
"MT": 48,
"HT": 50,
"CP": 50,
SR = 44100
def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0):
# convert table to array
rhythm_pattern = []
for line in markdown_table.split('\n')[2:]:
# table to MIDI
pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object
pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument
note_length = (60. / bpm) * (4.0 / resolution) # note duration
for i in range(len(rhythm_pattern)):
for j in range(1, len(rhythm_pattern[i])):
inst = rhythm_pattern[i][0].strip().upper()
if 'x' == rhythm_pattern[i][j].strip():
if inst in MIDI_NOTENUM.keys():
midinote = MIDI_NOTENUM[inst]
note = pretty_midi.Note(velocity=80, pitch=midinote, start=note_length * (j-1)+0.05, end=note_length * j)
# convert to audio
audio_data = pm.fluidsynth()
# cut off the reverb section
audio_data = audio_data[:int(SR*note_length*resolution)] # for looping, cut the tail
return audio_data
def get_answer(question):
response = openai.ChatCompletion.create(
{"role": "system", "content": "You are a rhythm generator. You generate rhythm patterns with the resolution of the 8th note"},
{"role": "user", "content": "Please generate a rhythm pattern. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT。You need to write the time resolution first."},
{"role": "assistant", "content": markdown_table_sample2},
# {"role": "user", "content": "4分音符単位で生成して下さい. ドラムはBD, SD, CH, OH, LT, MT, HTを使います。最初に時間解像度の逆数を書いて下さい"},
# {"role": "assistant", "content": markdown_table_sample},
{"role": "user", "content": question}
return response["choices"][0]["message"]["content"]
def generate_rhythm(query, state):
if state["gen_count"] > MAX_QUERY:
return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY]
state["gen_count"] = state["gen_count"] + 1
# get respance from ChatGPT
text_output = get_answer(query)
# Try to use the first row as time resolution
resolution_text = text_output.split('|')[0]
resolution_text = re.findall(r'\d+', resolution_text)[0]
resolution = int(resolution_text)
resolution = 8 # default
# Extract rhythm table
table = "|" + "|".join(text_output.split('|')[1:-1]) + "|"
audio_data = convert_table_to_audio(table, resolution)
# loop x2
audio_data = np.tile(audio_data, 4)
return [(SR, audio_data), text_output]
# %%
def on_token_change(user_token):
openai.api_key = user_token or os.environ.get("OPENAI_API_KEY")
with gr.Blocks() as demo:
state = gr.State({"gen_count": 0})
gr.Markdown("Ask ChatGPT to generate rhythm patterns")
with gr.Row():
with gr.Column():
inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!")
btn = gr.Button("Generate")
with gr.Column():
out_audio = gr.Audio()
out_text = gr.Textbox(placeholder="ChatGPT output")
gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](", elem_id="label")
user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False), inputs=[inp, state], outputs=[out_audio, out_text])
user_token.change(on_token_change, inputs=[user_token], outputs=[])
# demo = gr.Interface(
# fn=generate_rhythm,
# inputs=gr.Textbox(label="command",show_label=True, placeholder="Give me a dope beat!", visible=True).style(container=False),
# outputs=["audio", "text"]
# )
# demo.launch()
# %%
# %%