File size: 6,138 Bytes
51a05e3
 
 
 
 
 
5521127
 
 
 
51a05e3
 
9e97121
51a05e3
 
 
 
 
 
 
 
 
 
 
9e97121
 
 
 
 
 
 
 
 
 
51a05e3
 
 
 
 
 
43dbd89
51a05e3
 
 
 
 
43dbd89
51a05e3
 
 
5521127
51a05e3
 
 
 
 
 
 
 
 
 
 
 
 
7e81d62
51a05e3
 
 
 
bb201df
51a05e3
bb201df
 
 
 
51a05e3
 
bb201df
51a05e3
 
 
 
 
 
7e81d62
51a05e3
 
 
 
 
 
9e97121
 
51a05e3
e65c3a4
 
51a05e3
 
 
 
 
5521127
db68f64
 
5521127
 
51a05e3
 
 
 
 
9e97121
 
 
 
 
 
 
51a05e3
 
 
 
 
 
 
 
 
 
 
b78ec41
db68f64
5521127
36a252a
b78ec41
51a05e3
7e81d62
36a252a
e65c3a4
 
 
 
 
7e81d62
5521127
 
 
7e81d62
 
 
e65c3a4
 
 
 
5521127
b78ec41
51a05e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#%%
import openai
import numpy as np
import pretty_midi
import re
import numpy as np
import os
import gradio as gr

openai.api_key = os.environ.get("OPENAI_API_KEY")

# sample data
markdown_table_sample = """
|    | 1 | 2 | 3 | 4 | 
|----|---|---|---|---|
| BD |   |   | x |   |
| SD |   |   |   | x |
| CH | x |   | x |   |
| OH |   |   |   | x |  
| LT |   |   |   |   | 
| MT |   | x |   |   |
| HT | x |   |   | x | 
"""

markdown_table_sample2 = """
|    | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16|
|----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| BD | x |   | x |   |   |   | x |   | x |   | x |   | x |   | x |   |
| SD |   |   |   | x |   |   |   | x |   |   | x |   |   |   | x |   |
| CH | x |   | x |   | x |   | x |   | x |   | x |   | x |   | x |   |
| OH |   |   |   | x |   |   | x |   |   |   |   | x |   |   | x |   |
| LT |   |   |   |   |   | x |   |   |   |   |   |   |   | x |   |   |
| MT |   | x |   |   | x |   |   |   |   | x |   |   | x |   |   |   |
| HT | x |   |   | x |   |   |   |   | x |   |   | x |   |   |   |   |
"""

MIDI_NOTENUM = {
    "BD": 36,
    "SD": 38,
    "CH": 42, 
    "HH": 44, 
    "OH": 46,
    "LT": 48,
    "MT": 48,
    "HT": 50,
    "CP": 50,
    "CB": 56,
}
SR = 44100

MAX_QUERY = 5

def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0):
    # convert table to array
    rhythm_pattern = []
    for line in markdown_table.split('\n')[2:]:
        rhythm_pattern.append(line.split('|')[1:-1])
    print(rhythm_pattern)

    # table to MIDI
    pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object
    pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument
    pm.instruments.append(pm_inst)

    note_length = (60. / bpm) * (4.0 / resolution)    # note duration

    for i in range(len(rhythm_pattern)):
        for j in range(1, len(rhythm_pattern[i])):
            inst = rhythm_pattern[i][0].strip().upper()
            velocity = 0
            if 'x' == rhythm_pattern[i][j].strip():
                velocity = 120
            if 'o' == rhythm_pattern[i][j].strip():
                velocity = 65
            if velocity > 0:
                if inst in MIDI_NOTENUM.keys():
                    midinote = MIDI_NOTENUM[inst]
                    note = pretty_midi.Note(velocity=velocity, pitch=midinote, start=note_length * (j-1)+0.05, end=note_length * j)
                    pm_inst.notes.append(note)

    # convert to audio
    audio_data = pm.fluidsynth()

    # cut off the reverb section
    audio_data = audio_data[:int(SR*note_length*resolution)] # for looping, cut the tail
    return audio_data

def get_answer(question):
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a rhythm generator. Time resolution is the 16th note. "},
            {"role": "user", "content": "Please generate a rhythm pattern in a Markdown table. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat."},
            {"role": "assistant", "content": markdown_table_sample2},
            # {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
            # {"role": "assistant", "content": markdown_table_sample},
            {"role": "user", "content": question}
        ]
    )
    return response["choices"][0]["message"]["content"]

def generate_rhythm(query, state):
    print(state)
    if state["gen_count"] > MAX_QUERY and len(state["user_token"]) == 0:
        return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY] 
    state["gen_count"] = state["gen_count"] + 1

    # get respance from ChatGPT
    text_output = get_answer(query)
    
    # Try to use the first row as time resolution
    # resolution_text = text_output.split('|')[0]
    # try:
    #     resolution_text = re.findall(r'\d+', resolution_text)[0]
    #     resolution = int(resolution_text)   
    # except:
    #     resolution = 8 # default 
    resolution = 16 # default 

    # Extract rhythm table
    table = "|" + "|".join(text_output.split('|')[1:-1]) +  "|"
    audio_data = convert_table_to_audio(table, resolution)

    # loop x2
    audio_data = np.tile(audio_data, 4)

    return [(SR, audio_data), text_output] 
# %%

def on_token_change(user_token, state):
    print(user_token)
    openai.api_key = user_token or os.environ.get("OPENAI_API_KEY")
    state["user_token"] = user_token
    return state

with gr.Blocks() as demo:
    state = gr.State({"gen_count": 0, "user_token":""})
    with gr.Row():
        with gr.Column():
        # gr.Markdown("Ask ChatGPT to generate rhythm patterns")
            gr.Markdown("***Hey TR-ChatGPT, give me a drum pattern!***")
            gr.Markdown("You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. Use 'x' for an accented beat, 'o' for a weak beat. ", elem_id="label")
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!")
            btn = gr.Button("Generate")
        with gr.Column():
            out_audio = gr.Audio()
            out_text = gr.Textbox(placeholder="ChatGPT output")
    with gr.Row():
        with gr.Column():
            gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](https://platform.openai.com/account/api-keys).")
            user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False)
    btn.click(fn=generate_rhythm, inputs=[inp, state], outputs=[out_audio, out_text])
    user_token.change(on_token_change, inputs=[user_token, state], outputs=[state])
demo.launch()