dx2102 commited on
Commit
0d8d838
·
verified ·
1 Parent(s): 3cc65f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -18
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import random
2
  import os
3
  import time
 
 
4
  from queue import Queue
5
  from threading import Thread
6
 
@@ -91,6 +93,8 @@ def postprocess(txt, path):
91
 
92
 
93
 
 
 
94
  with gr.Blocks() as demo:
95
  chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
96
  prefix_box = gr.TextArea(value="Twinkle Twinkle Little Star", label="Score title / text prefix")
@@ -104,9 +108,21 @@ with gr.Blocks() as demo:
104
  audio_box = gr.Audio()
105
  midi_box = gr.File()
106
  piano_roll_box = gr.Image()
 
 
 
 
 
 
 
 
 
 
 
 
107
  example_box = gr.Examples(
108
  [
109
- [example_prefix],
110
  ["Twinkle Twinkle Little Star"], ["Twinkle Twinkle Little Star (Minor Key Version)"],
111
  ["The Entertainer - Scott Joplin (Piano Solo)"], ["Clair de Lune – Debussy"], ["Nocturne | Frederic Chopin"],
112
  ["Fugue I in C major, BWV 846"], ["Beethoven Symphony No. 7 (2nd movement) Piano solo"],
@@ -120,27 +136,47 @@ with gr.Blocks() as demo:
120
  def user_fn(user_message, history: list):
121
  return "", history + [{"role": "user", "content": user_message}]
122
 
123
- def generate_fn(history: list):
 
 
 
 
 
124
  # continue from user input
125
- prefix = history[-1]["content"]
126
  # prevent the model from continuing user's score title
127
  if prefix != '' and '\n' not in prefix:
128
- # prefix is a single line --> prefix is the score title
129
- prefix += '\n'
 
130
 
131
  history.append({"role": "assistant", "content": ""})
132
  # history[-1]["content"] += "Generating with the given prefix...\n"
133
- for history in model_fn(prefix, history):
134
  yield history
135
 
136
- def continue_fn(history: list):
137
  # continue from the last model output
138
  prefix = history[-1]["content"]
139
- for history in model_fn(prefix, history):
140
  yield history
141
-
142
- @spaces.GPU
143
- def model_fn(prefix, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  queue = Queue(maxsize=10)
145
  class MyStreamer:
146
  def put(self, tokens):
@@ -172,22 +208,44 @@ with gr.Blocks() as demo:
172
  history[-1]["content"] += text
173
  yield history
174
 
175
- prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
176
- generate_fn, chatbot_box, chatbot_box
177
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
179
- generate_fn, chatbot_box, chatbot_box
180
  )
181
  continue_event = continue_btn.click(
182
- continue_fn, chatbot_box, chatbot_box
183
  )
184
  clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event, continue_event], queue=False)
185
 
186
  def get_audio_fn(history):
187
  i = random.randint(0, 1000_000_000)
188
  path = f'./temp/{i}.mid'
 
189
  try:
190
- postprocess(history[-1]["content"], path)
191
  except Exception as e:
192
  raise gr.Error(f'Error: {type(e)}, {e}')
193
  # turn midi into audio with timidity
@@ -201,8 +259,9 @@ with gr.Blocks() as demo:
201
  def get_midi_fn(history):
202
  i = random.randint(0, 1000_000_000)
203
  # turn the text into midi
 
204
  try:
205
- postprocess(history[-1]["content"], f'./temp/{i}.mid')
206
  except Exception as e:
207
  raise gr.Error(f'Error: {type(e)}, {e}')
208
  # also render the piano roll
 
1
  import random
2
  import os
3
  import time
4
+ import requests
5
+
6
  from queue import Queue
7
  from threading import Thread
8
 
 
93
 
94
 
95
 
96
+
97
+
98
  with gr.Blocks() as demo:
99
  chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
100
  prefix_box = gr.TextArea(value="Twinkle Twinkle Little Star", label="Score title / text prefix")
 
108
  audio_box = gr.Audio()
109
  midi_box = gr.File()
110
  piano_roll_box = gr.Image()
111
+ server_box = gr.Dropdown(
112
+ choices=["Huggingface ZeroGPU", "CPU"],
113
+ label="GPU Server",
114
+ )
115
+ gr.Markdown('''
116
+ ZeroGPU comes with a time limit currently:
117
+ - 3 minutes (not logged in)
118
+ - 5 minutes (logged in)
119
+ - 25 minutes (Pro user)
120
+
121
+ CPUs will be slower but there is no time limit.
122
+ '''.strip())
123
  example_box = gr.Examples(
124
  [
125
+ # [example_prefix],
126
  ["Twinkle Twinkle Little Star"], ["Twinkle Twinkle Little Star (Minor Key Version)"],
127
  ["The Entertainer - Scott Joplin (Piano Solo)"], ["Clair de Lune – Debussy"], ["Nocturne | Frederic Chopin"],
128
  ["Fugue I in C major, BWV 846"], ["Beethoven Symphony No. 7 (2nd movement) Piano solo"],
 
136
  def user_fn(user_message, history: list):
137
  return "", history + [{"role": "user", "content": user_message}]
138
 
139
+ def get_last(history: list):
140
+ if len(history) == 0:
141
+ raise gr.Error('''No messages to read yet. Try the "Generate" button first!''')
142
+ return history[-1]["content"]
143
+
144
+ def generate_fn(history, server):
145
  # continue from user input
146
+ prefix = get_last(history)
147
  # prevent the model from continuing user's score title
148
  if prefix != '' and '\n' not in prefix:
149
+ # prefix is a single line => prefix is the score title
150
+ # add '\n' to prevent model from continuing the title
151
+ prefix += '\n'
152
 
153
  history.append({"role": "assistant", "content": ""})
154
  # history[-1]["content"] += "Generating with the given prefix...\n"
155
+ for history in model_fn(prefix, history, server):
156
  yield history
157
 
158
+ def continue_fn(history, server):
159
  # continue from the last model output
160
  prefix = history[-1]["content"]
161
+ for history in model_fn(prefix, history, server):
162
  yield history
163
+
164
+
165
+
166
+
167
+ def model_fn(prefix, history, server):
168
+ if server == "Huggingface ZeroGPU":
169
+ generator = zerogpu_model_fn(prefix, history)
170
+ elif server == "CPU":
171
+ generator = cpu_model_fn(prefix, history)
172
+ # elif server == "RunPod":
173
+ # generator = runpod_model_fn(prefix, history)
174
+ else:
175
+ raise gr.Error(f"Unknown server: {server}")
176
+ for history in generator:
177
+ yield history
178
+
179
+ def cpu_model_fn(prefix, history):
180
  queue = Queue(maxsize=10)
181
  class MyStreamer:
182
  def put(self, tokens):
 
208
  history[-1]["content"] += text
209
  yield history
210
 
211
+ zerogpu_model_fn = spaces.GPU(cpu_model_fn)
212
+
213
+ def runpod_model_fn(prefix, history):
214
+ # NOTE
215
+ runpod_api_key = os.getenv('RUNPOD_API_KEY')
216
+ runpod_endpoint = os.getenv('RUNPOD_ENDPOINT')
217
+
218
+ # synchronized request
219
+ response = requests.post(
220
+ f"https://api.runpod.ai/v2/{runpod_endpoint}/runsync",
221
+ headers={"Authorization": f"Bearer {runpod_api_key}"},
222
+ json={"input": {"prompt": prefix}}
223
+ ).json()['output'][0]['choices'][0]['tokens'][0]
224
+ # yield just once
225
+ history[-1]["content"] += response
226
+ yield history
227
+
228
+
229
+
230
+
231
+
232
+
233
+
234
+
235
  submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
236
+ generate_fn, [chatbot_box, server_box], chatbot_box
237
  )
238
  continue_event = continue_btn.click(
239
+ continue_fn, [chatbot_box, server_box], chatbot_box
240
  )
241
  clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event, continue_event], queue=False)
242
 
243
  def get_audio_fn(history):
244
  i = random.randint(0, 1000_000_000)
245
  path = f'./temp/{i}.mid'
246
+ text = get_last(history)
247
  try:
248
+ postprocess(text, path)
249
  except Exception as e:
250
  raise gr.Error(f'Error: {type(e)}, {e}')
251
  # turn midi into audio with timidity
 
259
  def get_midi_fn(history):
260
  i = random.randint(0, 1000_000_000)
261
  # turn the text into midi
262
+ text = get_last(history)
263
  try:
264
+ postprocess(text, f'./temp/{i}.mid')
265
  except Exception as e:
266
  raise gr.Error(f'Error: {type(e)}, {e}')
267
  # also render the piano roll