Spaces:
Running
Running
#%% | |
import openai | |
import numpy as np | |
import pretty_midi | |
import re | |
import numpy as np | |
import os | |
import gradio as gr | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
# sample data | |
markdown_table_sample = """8th | |
| | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | | |
|----|---|---|---|---|---|---|---|---| | |
| BD | x | | x | | | | x | | | |
| SD | | | | x | | | | x | | |
| CH | x | | x | | x | | x | | | |
| OH | | | | x | | | x | | | |
| LT | | | | | | x | | | | |
| MT | | x | | | x | | | | | |
| HT | x | | | x | | | | | | |
""" | |
markdown_table_sample2 = """16th | |
| | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16| | |
|----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | |
| BD | x | | x | | | | x | | x | | x | | x | | x | | | |
| SD | | | | x | | | | x | | | x | | | | x | | | |
| CH | x | | x | | x | | x | | x | | x | | x | | x | | | |
| OH | | | | x | | | x | | | | | x | | | x | | | |
| LT | | | | | | x | | | | | | | | x | | | | |
| MT | | x | | | x | | | | | x | | | x | | | | | |
| HT | x | | | x | | | | | x | | | x | | | | | | |
""" | |
MIDI_NOTENUM = { | |
"BD": 36, | |
"SD": 38, | |
"CH": 42, | |
"HH": 44, | |
"OH": 46, | |
"LT": 48, | |
"MT": 48, | |
"HT": 50, | |
"CP": 50, | |
"CB": 56, | |
} | |
SR = 44100 | |
MAX_QUERY = 5 | |
def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0): | |
# convert table to array | |
rhythm_pattern = [] | |
for line in markdown_table.split('\n')[2:]: | |
rhythm_pattern.append(line.split('|')[1:-1]) | |
print(rhythm_pattern) | |
# table to MIDI | |
pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object | |
pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument | |
pm.instruments.append(pm_inst) | |
note_length = (60. / bpm) * (4.0 / resolution) # note duration | |
beat_num = resolution | |
for i in range(len(rhythm_pattern)): | |
for j in range(1, len(rhythm_pattern[i])): | |
beat_num = j # for looping | |
inst = rhythm_pattern[i][0].strip().upper() | |
velocity = 0 | |
if 'x' == rhythm_pattern[i][j].strip(): | |
velocity = 120 | |
if 'o' == rhythm_pattern[i][j].strip(): | |
velocity = 65 | |
if velocity > 0: | |
if inst in MIDI_NOTENUM.keys(): | |
midinote = MIDI_NOTENUM[inst] | |
note = pretty_midi.Note(velocity=velocity, pitch=midinote, start=note_length * (j-1)+0.0001, end=note_length * j) | |
pm_inst.notes.append(note) | |
# convert to audio | |
audio_data = pm.fluidsynth() | |
# cut off the reverb section | |
audio_data = audio_data[:int(SR*note_length*beat_num)] # for looping, cut the tail | |
return audio_data | |
def get_answer(question): | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a rhythm generator. "}, | |
{"role": "user", "content": "Please generate a rhythm pattern in a Markdown table. Time resolution is the 8th note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."}, | |
{"role": "assistant", "content": markdown_table_sample}, | |
# {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."}, | |
# {"role": "assistant", "content": markdown_table_sample}, | |
{"role": "user", "content": question} | |
] | |
) | |
return response["choices"][0]["message"]["content"] | |
def generate_rhythm(query, state): | |
print(state) | |
if state["gen_count"] > MAX_QUERY and len(state["user_token"]) == 0: | |
return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY] | |
state["gen_count"] = state["gen_count"] + 1 | |
# get respance from ChatGPT | |
text_output = get_answer(query) | |
# Try to use the first row as time resolution | |
resolution_text = text_output.split('|')[0] | |
try: | |
resolution_text = re.findall(r'\d+', resolution_text)[0] | |
resolution = int(resolution_text) | |
except: | |
resolution = 8 # default | |
# Extract rhythm table | |
table = "|" + "|".join(text_output.split('|')[1:-1]) + "|" | |
audio_data = convert_table_to_audio(table, resolution) | |
# loop x2 | |
audio_data = np.tile(audio_data, 4) | |
return [(SR, audio_data), text_output] | |
# %% | |
def on_token_change(user_token, state): | |
print(user_token) | |
openai.api_key = user_token or os.environ.get("OPENAI_API_KEY") | |
state["user_token"] = user_token | |
return state | |
with gr.Blocks() as demo: | |
state = gr.State({"gen_count": 0, "user_token":""}) | |
with gr.Row(): | |
with gr.Column(): | |
# gr.Markdown("Ask ChatGPT to generate rhythm patterns") | |
gr.Markdown("***Hey TR-ChatGPT, give me a drum pattern!***") | |
gr.Markdown("You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. Use 'x' for an accented beat, 'o' for a weak beat. ", elem_id="label") | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!") | |
btn = gr.Button("Generate") | |
with gr.Column(): | |
out_audio = gr.Audio() | |
out_text = gr.Textbox(placeholder="ChatGPT output") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](https://platform.openai.com/account/api-keys).") | |
user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False) | |
btn.click(fn=generate_rhythm, inputs=[inp, state], outputs=[out_audio, out_text]) | |
user_token.change(on_token_change, inputs=[user_token, state], outputs=[state]) | |
demo.launch() | |