naotokui commited on
Commit
9e97121
1 Parent(s): 61159bb
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -10,8 +10,7 @@ import gradio as gr
10
  openai.api_key = os.environ.get("OPENAI_API_KEY")
11
 
12
  # sample data
13
- markdown_table_sample = """4
14
-
15
  | | 1 | 2 | 3 | 4 |
16
  |----|---|---|---|---|
17
  | BD | | | x | |
@@ -23,17 +22,16 @@ markdown_table_sample = """4
23
  | HT | x | | | x |
24
  """
25
 
26
- markdown_table_sample2 = """8
27
-
28
- | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
29
- |----|---|---|---|---|---|---|---|---|
30
- | BD | | | x | | | | x | |
31
- | SD | | | | x | | | | x |
32
- | CH | x | | x | | x | | x | |
33
- | OH | | | | x | | | x | |
34
- | LT | | | | | | x | | |
35
- | MT | | x | | | x | | | |
36
- | HT | x | | | x | | | | |
37
  """
38
 
39
  MIDI_NOTENUM = {
@@ -91,8 +89,8 @@ def get_answer(question):
91
  response = openai.ChatCompletion.create(
92
  model="gpt-3.5-turbo",
93
  messages=[
94
- {"role": "system", "content": "You are a rhythm generator."},
95
- {"role": "user", "content": "Please generate a rhythm pattern in a Markdown table. Time resolution is the 8th note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
96
  {"role": "assistant", "content": markdown_table_sample2},
97
  # {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
98
  # {"role": "assistant", "content": markdown_table_sample},
@@ -111,12 +109,13 @@ def generate_rhythm(query, state):
111
  text_output = get_answer(query)
112
 
113
  # Try to use the first row as time resolution
114
- resolution_text = text_output.split('|')[0]
115
- try:
116
- resolution_text = re.findall(r'\d+', resolution_text)[0]
117
- resolution = int(resolution_text)
118
- except:
119
- resolution = 8 # default
 
120
 
121
  # Extract rhythm table
122
  table = "|" + "|".join(text_output.split('|')[1:-1]) + "|"
 
10
  openai.api_key = os.environ.get("OPENAI_API_KEY")
11
 
12
  # sample data
13
+ markdown_table_sample = """
 
14
  | | 1 | 2 | 3 | 4 |
15
  |----|---|---|---|---|
16
  | BD | | | x | |
 
22
  | HT | x | | | x |
23
  """
24
 
25
+ markdown_table_sample2 = """
26
+ | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16|
27
+ |----|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
28
+ | BD | x | | x | | | | x | | x | | x | | x | | x | |
29
+ | SD | | | | x | | | | x | | | x | | | | x | |
30
+ | CH | x | | x | | x | | x | | x | | x | | x | | x | |
31
+ | OH | | | | x | | | x | | | | | x | | | x | |
32
+ | LT | | | | | | x | | | | | | | | x | | |
33
+ | MT | | x | | | x | | | | | x | | | x | | | |
34
+ | HT | x | | | x | | | | | x | | | x | | | | |
 
35
  """
36
 
37
  MIDI_NOTENUM = {
 
89
  response = openai.ChatCompletion.create(
90
  model="gpt-3.5-turbo",
91
  messages=[
92
+ {"role": "system", "content": "You are a rhythm generator. Time resolution is the 16th note. "},
93
+ {"role": "user", "content": "Please generate a rhythm pattern in a Markdown table. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat."},
94
  {"role": "assistant", "content": markdown_table_sample2},
95
  # {"role": "user", "content": "Please generate a rhythm pattern. The resolution is the fourth note. You use the following drums. Kick drum:BD, Snare drum:SD, Closed-hihat:CH, Open-hihat:OH, Low-tom:LT, Mid-tom:MT, High-tom:HT. use 'x' for an accented beat, 'o' for a weak beat. You need to write the time resolution first."},
96
  # {"role": "assistant", "content": markdown_table_sample},
 
109
  text_output = get_answer(query)
110
 
111
  # Try to use the first row as time resolution
112
+ # resolution_text = text_output.split('|')[0]
113
+ # try:
114
+ # resolution_text = re.findall(r'\d+', resolution_text)[0]
115
+ # resolution = int(resolution_text)
116
+ # except:
117
+ # resolution = 8 # default
118
+ resolution = 16 # default
119
 
120
  # Extract rhythm table
121
  table = "|" + "|".join(text_output.split('|')[1:-1]) + "|"