File size: 4,542 Bytes
51a05e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e81d62
51a05e3
 
 
 
 
 
 
 
 
 
 
 
 
 
7e81d62
51a05e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e81d62
 
 
 
 
 
 
 
 
51a05e3
7e81d62
 
 
 
 
 
51a05e3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#%%
import openai
import numpy as np
import pretty_midi
import re
import numpy as np

# sample data
markdown_table_sample = """4

|    | 1 | 2 | 3 | 4 | 
|----|---|---|---|---|
| BD |   |   | x |   |
| SD |   |   |   | x |
| CH | x |   | x |   |
| OH |   |   |   | x |  
| LT |   |   |   |   | 
| MT |   | x |   |   |
| HT | x |   |   | x | 
"""

markdown_table_sample2 = """8

|    | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
|----|---|---|---|---|---|---|---|---|
| BD |   |   | x |   |   |   | x |   |
| SD |   |   |   | x |   |   |   | x |
| CH | x |   | x |   | x |   | x |   |
| OH |   |   |   | x |   |   | x |   |
| LT |   |   |   |   |   | x |   |   |
| MT |   | x |   |   | x |   |   |   |
| HT | x |   |   | x |   |   |   |   |
"""

MIDI_NOTENUM = {
    "BD": 36,
    "SD": 38,
    "CH": 42, 
    "HH": 42, 
    "OH": 46,
    "LT": 48,
    "MT": 48,
    "HT": 50,
    "CP": 50,
}
SR = 44100

count = 0
MAX_QUERY = 100

def convert_table_to_audio(markdown_table, resolution=8, bpm = 120.0):
    # convert table to array
    rhythm_pattern = []
    for line in markdown_table.split('\n')[2:]:
        rhythm_pattern.append(line.split('|')[1:-1])
    print(rhythm_pattern)

    # table to MIDI
    pm = pretty_midi.PrettyMIDI(initial_tempo=bpm) # midi object
    pm_inst = pretty_midi.Instrument(0, is_drum=True) # midi instrument
    pm.instruments.append(pm_inst)

    note_length = (60. / bpm) * (4.0 / resolution)    # note duration

    for i in range(len(rhythm_pattern)):
        for j in range(1, len(rhythm_pattern[i])):
            inst = rhythm_pattern[i][0].strip().upper()
            if 'x' == rhythm_pattern[i][j].strip():
                if inst in MIDI_NOTENUM.keys():
                    midinote = MIDI_NOTENUM[inst]
                    note = pretty_midi.Note(velocity=80, pitch=midinote, start=note_length * (j-1)+0.05, end=note_length * j)
                    pm_inst.notes.append(note)

    # convert to audio
    audio_data = pm.fluidsynth()

    # cut off the reverb section
    audio_data = audio_data[:int(SR*note_length*resolution)] # for looping, cut the tail
    return audio_data

def get_answer(question):
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a rhythm generator. You generate rhythm patterns with the resolution of the 8th note"},
            {"role": "user", "content": "Please generate a rhythm pattern. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT。You need to write the time resolution first."},
            {"role": "assistant", "content": markdown_table_sample2},
            # {"role": "user", "content": "4分音符単位で生成して下さい. ドラムはBD, SD, CH, OH, LT, MT, HTを使います。最初に時間解像度の逆数を書いて下さい"},
            # {"role": "assistant", "content": markdown_table_sample},
            {"role": "user", "content": question}
        ]
    )
    return response["choices"][0]["message"]["content"]

def generate_rhythm(query):
    global count 
    count += 1
    if count > MAX_QUERY:
        return [None, "Now you can try up to %d times" % MAX_QUERY] 

    # get respance from ChatGPT
    text_output = get_answer(query)
    
    # Try to use the first row as time resolution
    resolution_text = text_output.split('|')[0]
    try:
        resolution_text = re.findall(r'\d+', resolution_text)[0]
        resolution = int(resolution_text)   
    except:
        resolution = 8 # default 

    # Extract rhythm table
    table = "|" + "|".join(text_output.split('|')[1:-1]) +  "|"
    audio_data = convert_table_to_audio(table, resolution)

    # loop x2
    audio_data = np.tile(audio_data, 4)

    return [(SR, audio_data), text_output] 
# %%

import gradio as gr

with gr.Blocks() as demo:
    gr.Markdown("Ask ChatGPT to generate rhythm patterns")
    with gr.Row():
        inp = gr.Textbox(placeholder="Give me a Hiphop rhythm pattern with some reggae twist!")
        with gr.Column():
            out_audio = gr.Audio()
            out_text = gr.Textbox(placeholder="ChatGPT output")
    btn = gr.Button("Generate")
    btn.click(fn=generate_rhythm, inputs=inp, outputs=[out_audio, out_text])
demo.launch()
# demo = gr.Interface(
#     fn=generate_rhythm,
# inputs=gr.Textbox(label="command",show_label=True, placeholder="Give me a dope beat!", visible=True).style(container=False),
#     outputs=["audio", "text"]
# )
# demo.launch()
# %%

# %%