File size: 6,501 Bytes
51a05e3
 
 
 
 
 
5521127
 
e95329a
5521127
 
51a05e3
 
5458994
 
0c63b51
685c69b
0c63b51
 
 
 
 
 
 
51a05e3
 
5458994
 
9e97121
 
 
 
 
 
 
 
 
51a05e3
 
 
 
 
 
43dbd89
51a05e3
 
 
 
 
43dbd89
51a05e3
 
 
5521127
51a05e3
 
 
 
 
 
 
 
 
 
 
 
 
7e81d62
acc0f0d
 
51a05e3
 
acc0f0d
51a05e3
bb201df
51a05e3
bb201df
 
 
 
51a05e3
 
257684d
51a05e3
 
 
 
 
 
acc0f0d
51a05e3
 
 
 
 
 
0c63b51
5458994
0c63b51
e65c3a4
 
51a05e3
 
 
 
 
5521127
db68f64
 
5521127
 
51a05e3
 
 
 
 
5458994
 
 
 
 
 
51a05e3
 
58a268c
 
 
51a05e3
58a268c
 
0264f32
 
58a268c
0264f32
51a05e3
 
 
 
b78ec41
db68f64
5521127
36a252a
b78ec41
51a05e3
7e81d62
36a252a
e65c3a4
 
 
 
58a268c
7e81d62
5521127
 
 
7e81d62
 
 
e65c3a4
 
 
 
5521127
b78ec41
51a05e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#%%
import openai
import numpy as np
import pretty_midi
import re
import numpy as np
import os
import gradio as gr
import librosa

openai.api_key = os.environ.get("OPENAI_API_KEY")

# sample data
markdown_table_sample = """8th

|    | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
|----|---|---|---|---|---|---|---|---|
| BD | x |   | x |   |   |   | x |   | 
| SD |   |   |   | x |   |   |   | x |
| CH | x |   | x |   | x |   | x |   |
| OH |   |   |   | x |   |   | x |   |
| LT |   |   |   |   |   | x |   |   | 
| MT |   | x |   |   | x |   |   |   | 
| HT | x |   |   | x |   |   |   |   | 
"""

markdown_table_sample2 = """16th

|    | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16|
|----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| BD | x |   | x |   |   |   | x |   | x |   | x |   | x |   | x |   |
| SD |   |   |   | x |   |   |   | x |   |   | x |   |   |   | x |   |
| CH | x |   | x |   | x |   | x |   | x |   | x |   | x |   | x |   |
| OH |   |   |   | x |   |   | x |   |   |   |   | x |   |   | x |   |
| LT |   |   |   |   |   | x |   |   |   |   |   |   |   | x |   |   |
| MT |   | x |   |   | x |   |   |   |   | x |   |   | x |   |   |   |
| HT | x |   |   | x |   |   |   |   | x |   |   | x |   |   |   |   |
"""

MIDI_NOTENUM = {
    "BD": 36,
    "SD": 38,
    "CH": 42, 
    "HH": 44, 
    "OH": 46,
    "LT": 48,
    "MT": 48,
    "HT": 50,
    "CP": 50,
    "CB": 56,
}
SR = 44100

MAX_QUERY = 5

def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0):
    # convert table to array
    rhythm_pattern = []
    for line in markdown_table.split('\n')[2:]:
        rhythm_pattern.append(line.split('|')[1:-1])
    print(rhythm_pattern)

    # table to MIDI
    pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object
    pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument
    pm.instruments.append(pm_inst)

    note_length = (60. / bpm) * (4.0 / resolution)    # note duration
   
    beat_num = resolution
    for i in range(len(rhythm_pattern)):
        for j in range(1, len(rhythm_pattern[i])):
            beat_num = j # for looping
            inst = rhythm_pattern[i][0].strip().upper()
            velocity = 0
            if 'x' == rhythm_pattern[i][j].strip():
                velocity = 120
            if 'o' == rhythm_pattern[i][j].strip():
                velocity = 65
            if velocity > 0:
                if inst in MIDI_NOTENUM.keys():
                    midinote = MIDI_NOTENUM[inst]
                    note = pretty_midi.Note(velocity=velocity, pitch=midinote, start=note_length * (j-1)+0.0001, end=note_length * j)
                    pm_inst.notes.append(note)

    # convert to audio
    audio_data = pm.fluidsynth()

    # cut off the reverb section
    audio_data = audio_data[:int(SR*note_length*beat_num)] # for looping, cut the tail
    return audio_data

def get_answer(question):
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a rhythm generator. "},
            {"role": "user", "content": "Please generate a rhythm pattern in a Markdown table.  Time resolution is the 8th note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
            {"role": "assistant", "content": markdown_table_sample},
            # {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
            # {"role": "assistant", "content": markdown_table_sample},
            {"role": "user", "content": question}
        ]
    )
    return response["choices"][0]["message"]["content"]

def generate_rhythm(query, state):
    print(state)
    if state["gen_count"] > MAX_QUERY and len(state["user_token"]) == 0:
        return [None, "You need to set your ChatGPT API Key to try more than %d times" % MAX_QUERY] 
    state["gen_count"] = state["gen_count"] + 1

    # get respance from ChatGPT
    text_output = get_answer(query)
    
    # Try to use the first row as time resolution
    resolution_text = text_output.split('|')[0]
    try:
        resolution_text = re.findall(r'\d+', resolution_text)[0]
        resolution = int(resolution_text)   
    except:
        resolution = 8 # default 

    # Extract rhythm table
    try:
        table = "|" + "|".join(text_output.split('|')[1:-1]) +  "|"
        audio_data = convert_table_to_audio(table, resolution)

        # loop x4
        audio_data = np.tile(audio_data, 4)
        if np.max(audio_data) == 0.0:
            audio_data = np.ones(1)       
    except:
        audio_data = np.ones(1)

    return [(SR, audio_data), text_output] 
# %%

def on_token_change(user_token, state):
    print(user_token)
    openai.api_key = user_token or os.environ.get("OPENAI_API_KEY")
    state["user_token"] = user_token
    return state

with gr.Blocks() as demo:
    state = gr.State({"gen_count": 0, "user_token":""})
    with gr.Row():
        with gr.Column():
        # gr.Markdown("Ask ChatGPT to generate rhythm patterns")
            gr.Markdown("***Hey TR-ChatGPT, give me a drum pattern!***")
            gr.Markdown("Use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT and 'x' for an accented beat, 'o' for a weak beat!")
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!")
            btn = gr.Button("Generate")
        with gr.Column():
            out_audio = gr.Audio()
            out_text = gr.Textbox(placeholder="ChatGPT output")
    with gr.Row():
        with gr.Column():
            gr.Markdown("Enter your own OpenAI API Key to try out more than 5 times. You can get it [here](https://platform.openai.com/account/api-keys).")
            user_token = gr.Textbox(placeholder="OpenAI API Key", type="password", show_label=False)
    btn.click(fn=generate_rhythm, inputs=[inp, state], outputs=[out_audio, out_text])
    user_token.change(on_token_change, inputs=[user_token, state], outputs=[state])
demo.launch()