TR-ChatGPT / app.py
naotokui's picture
fixe
0264f32
#%%
import openai
import numpy as np
import pretty_midi
import re
import numpy as np
import os
import gradio as gr
import librosa
openai.api_key = os.environ.get("OPENAI_API_KEY")
# sample data
markdown_table_sample = """8th
| | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
|----|---|---|---|---|---|---|---|---|
| BD | x | | x | | | | x | |
| SD | | | | x | | | | x |
| CH | x | | x | | x | | x | |
| OH | | | | x | | | x | |
| LT | | | | | | x | | |
| MT | | x | | | x | | | |
| HT | x | | | x | | | | |
"""
markdown_table_sample2 = """16th
| | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16|
|----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| BD | x | | x | | | | x | | x | | x | | x | | x | |
| SD | | | | x | | | | x | | | x | | | | x | |
| CH | x | | x | | x | | x | | x | | x | | x | | x | |
| OH | | | | x | | | x | | | | | x | | | x | |
| LT | | | | | | x | | | | | | | | x | | |
| MT | | x | | | x | | | | | x | | | x | | | |
| HT | x | | | x | | | | | x | | | x | | | | |
"""
MIDI_NOTENUM = {
"BD": 36,
"SD": 38,
"CH": 42,
"HH": 44,
"OH": 46,
"LT": 48,
"MT": 48,
"HT": 50,
"CP": 50,
"CB": 56,
}
SR = 44100
MAX_QUERY = 5
def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0):
# convert table to array
rhythm_pattern = []
for line in markdown_table.split('\n')[2:]:
rhythm_pattern.append(line.split('|')[1:-1])
print(rhythm_pattern)
# table to MIDI
pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object
pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument
pm.instruments.append(pm_inst)
note_length = (60. / bpm) * (4.0 / resolution) # note duration
beat_num = resolution
for i in range(len(rhythm_pattern)):
for j in range(1, len(rhythm_pattern[i])):
beat_num = j # for looping
inst = rhythm_pattern[i][0].strip().upper()
velocity = 0
if 'x' == rhythm_pattern[i][j].strip():
velocity = 120
if 'o' == rhythm_pattern[i][j].strip():
velocity = 65
if velocity > 0:
if inst in MIDI_NOTENUM.keys():
midinote = MIDI_NOTENUM[inst]
note = pretty_midi.Note(velocity=velocity, pitch=midinote, start=note_length * (j-1)+0.0001, end=note_length * j)
pm_inst.notes.append(note)
# convert to audio
audio_data = pm.fluidsynth()
# cut off the reverb section
audio_data = audio_data[:int(SR*note_length*beat_num)] # for looping, cut the tail
return audio_data
def get_answer(question):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a rhythm generator. "},
{"role": "user", "content": "Please generate a rhythm pattern in a Markdown table. Time resolution is the 8th note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
{"role": "assistant", "content": markdown_table_sample},
# {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
# {"role": "assistant", "content": markdown_table_sample},
{"role": "user", "content": question}
]
)
return response["choices"][0]["message"]["content"]
def generate_rhythm(query, state):
print(state)
if state["gen_count"] > MAX_QUERY and len(state["user_token"]) == 0:
return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY]
state["gen_count"] = state["gen_count"] + 1
# get respance from ChatGPT
text_output = get_answer(query)
# Try to use the first row as time resolution
resolution_text = text_output.split('|')[0]
try:
resolution_text = re.findall(r'\d+', resolution_text)[0]
resolution = int(resolution_text)
except:
resolution = 8 # default
# Extract rhythm table
try:
table = "|" + "|".join(text_output.split('|')[1:-1]) + "|"
audio_data = convert_table_to_audio(table, resolution)
# loop x4
audio_data = np.tile(audio_data, 4)
if np.max(audio_data) == 0.0:
audio_data = np.ones(1)
except:
audio_data = np.ones(1)
return [(SR, audio_data), text_output]
# %%
def on_token_change(user_token, state):
print(user_token)
openai.api_key = user_token or os.environ.get("OPENAI_API_KEY")
state["user_token"] = user_token
return state
with gr.Blocks() as demo:
state = gr.State({"gen_count": 0, "user_token":""})
with gr.Row():
with gr.Column():
# gr.Markdown("Ask ChatGPT to generate rhythm patterns")
gr.Markdown("***Hey TR-ChatGPT, give me a drum pattern!***")
gr.Markdown("Use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT and 'x' for an accented beat, 'o' for a weak beat!")
with gr.Row():
with gr.Column():
inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!")
btn = gr.Button("Generate")
with gr.Column():
out_audio = gr.Audio()
out_text = gr.Textbox(placeholder="ChatGPT output")
with gr.Row():
with gr.Column():
gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](https://platform.openai.com/account/api-keys).")
user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False)
btn.click(fn=generate_rhythm, inputs=[inp, state], outputs=[out_audio, out_text])
user_token.change(on_token_change, inputs=[user_token, state], outputs=[state])
demo.launch()