nakas commited on
Commit
2e3a6e1
1 Parent(s): 98269fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -57
app.py CHANGED
@@ -6,7 +6,7 @@
6
 
7
  # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
  # also released under the MIT license.
9
- import numpy as np
10
  import argparse
11
  from concurrent.futures import ProcessPoolExecutor
12
  import os
@@ -82,7 +82,6 @@ def make_waveform(*args, **kwargs):
82
  warnings.simplefilter('ignore')
83
  out = gr.make_waveform(*args, **kwargs)
84
  print("Make a video took", time.time() - be)
85
- print("Returning from make_waveform")
86
  return out
87
 
88
 
@@ -107,61 +106,293 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
107
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
108
  if melody.dim() == 1:
109
  melody = melody[None]
110
- melody = melody.to(target_ac).to(MODEL.device).float()
111
- if melody.size(0) != target_sr:
112
- melody = convert_audio(melody, sr, target_sr)
113
- processed_melodies.append(melody[None])
114
- try:
115
- outputs, infos = MODEL.generate_multiple(texts, processed_melodies, progress=progress)
116
- except RuntimeError as e:
117
- print(f'Runtime error in _do_predictions: {e}')
118
- return []
119
- print(f'Generation took {time.time() - be} seconds.')
120
- return outputs, infos
121
-
122
-
123
- def _postprocess(output):
124
- be = time.time()
125
- audio_path = NamedTemporaryFile(delete=False, suffix=".mp3").name
126
- file_cleaner.add(audio_path)
127
- audio_write(output, audio_path)
128
- print(f'Audio write took {time.time() - be} seconds.')
129
- print("Returning from _postprocess")
130
- return audio_path
131
-
132
-
133
- def _predict_single(text: str, melody: tp.Tuple[tp.Optional[int], tp.Optional[np.ndarray]], duration: float, **gen_kwargs):
134
- load_model()
135
- print(f'_predict_single called with text: {text}, melody: {melody}, duration: {duration}, gen_kwargs: {gen_kwargs}')
136
- outputs, infos = _do_predictions([text], [melody], duration, **gen_kwargs)
137
- if not outputs:
138
- print("No outputs in _predict_single")
139
- return None
140
- output = outputs[0]
141
- return _postprocess(output)
142
-
143
-
144
- def _predict_batch(texts: tp.List[str], melodies: tp.List[tp.Tuple[tp.Optional[int], tp.Optional[np.ndarray]]], duration: float, **gen_kwargs):
145
- load_model()
146
- print(f'_predict_batch called with texts: {texts}, melodies: {melodies}, duration: {duration}, gen_kwargs: {gen_kwargs}')
147
- outputs, infos = _do_predictions(texts, melodies, duration, **gen_kwargs)
148
- if not outputs:
149
- print("No outputs in _predict_batch")
150
- return [None] * len(texts)
151
- return [_postprocess(output) for output in outputs]
152
-
153
-
154
- def launch_app():
155
- launch_kwargs = dict(verbose=False, debug=False, inline=False)
156
- if 'PORT' in os.environ:
157
- launch_kwargs['port'] = os.environ['PORT']
158
- if 'HOST' in os.environ:
159
- launch_kwargs['host'] = os.environ['HOST']
160
-
161
- # lastly launch the app with the set parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  if IS_BATCHED:
163
- print("Launching batched UI.")
164
  ui_batched(launch_kwargs)
165
  else:
166
- print("Launching full UI.")
167
- ui_full(launch_kwargs)
 
6
 
7
  # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
  # also released under the MIT license.
9
+
10
  import argparse
11
  from concurrent.futures import ProcessPoolExecutor
12
  import os
 
82
  warnings.simplefilter('ignore')
83
  out = gr.make_waveform(*args, **kwargs)
84
  print("Make a video took", time.time() - be)
 
85
  return out
86
 
87
 
 
106
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
107
  if melody.dim() == 1:
108
  melody = melody[None]
109
+ melody = melody[..., :int(sr * duration)]
110
+ melody = convert_audio(melody, sr, target_sr, target_ac)
111
+ processed_melodies.append(melody)
112
+
113
+ if any(m is not None for m in processed_melodies):
114
+ outputs = MODEL.generate_with_chroma(
115
+ descriptions=texts,
116
+ melody_wavs=processed_melodies,
117
+ melody_sample_rate=target_sr,
118
+ progress=progress,
119
+ )
120
+ else:
121
+ outputs = MODEL.generate(texts, progress=progress)
122
+
123
+ outputs = outputs.detach().cpu().float()
124
+ out_files = []
125
+ for output in outputs:
126
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
127
+ audio_write(
128
+ file.name, output, MODEL.sample_rate, strategy="loudness",
129
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
130
+ out_files.append(pool.submit(make_waveform, file.name))
131
+ file_cleaner.add(file.name)
132
+ res = [out_file.result() for out_file in out_files]
133
+ for file in res:
134
+ file_cleaner.add(file)
135
+ print("batch finished", len(texts), time.time() - be)
136
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
137
+ return res
138
+
139
+
140
+ def predict_batched(texts, melodies):
141
+ max_text_length = 512
142
+ texts = [text[:max_text_length] for text in texts]
143
+ load_model('melody')
144
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
145
+ return [res]
146
+
147
+
148
+ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
149
+ global INTERRUPTING
150
+ INTERRUPTING = False
151
+ if temperature < 0:
152
+ raise gr.Error("Temperature must be >= 0.")
153
+ if topk < 0:
154
+ raise gr.Error("Topk must be non-negative.")
155
+ if topp < 0:
156
+ raise gr.Error("Topp must be non-negative.")
157
+
158
+ topk = int(topk)
159
+ load_model(model)
160
+
161
+ def _progress(generated, to_generate):
162
+ progress((generated, to_generate))
163
+ if INTERRUPTING:
164
+ raise gr.Error("Interrupted.")
165
+ MODEL.set_custom_progress_callback(_progress)
166
+
167
+ outs = _do_predictions(
168
+ [text], [melody], duration, progress=True,
169
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
170
+ return outs[0]
171
+
172
+
173
+ def toggle_audio_src(choice):
174
+ if choice == "mic":
175
+ return gr.update(source="microphone", value=None, label="Microphone")
176
+ else:
177
+ return gr.update(source="upload", value=None, label="File")
178
+
179
+
180
+ def ui_full(launch_kwargs):
181
+ with gr.Blocks() as interface:
182
+ gr.Markdown(
183
+ """
184
+ # MusicGen
185
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
186
+ a simple and controllable model for music generation
187
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
188
+ """
189
+ )
190
+ with gr.Row():
191
+ with gr.Column():
192
+ with gr.Row():
193
+ text = gr.Text(label="Input Text", interactive=True)
194
+ with gr.Column():
195
+ radio = gr.Radio(["file", "mic"], value="file",
196
+ label="Condition on a melody (optional) File or Mic")
197
+ melody = gr.Audio(source="upload", type="numpy", label="File",
198
+ interactive=True, elem_id="melody-input")
199
+ with gr.Row():
200
+ submit = gr.Button("Submit")
201
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
202
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
203
+ with gr.Row():
204
+ model = gr.Radio(["melody", "medium", "small", "large"],
205
+ label="Model", value="melody", interactive=True)
206
+ with gr.Row():
207
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
208
+ with gr.Row():
209
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
210
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
211
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
212
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
213
+ with gr.Column():
214
+ output = gr.Video(label="Generated Music")
215
+ submit.click(predict_full,
216
+ inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef],
217
+ outputs=[output])
218
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
219
+ gr.Examples(
220
+ fn=predict_full,
221
+ examples=[
222
+ [
223
+ "An 80s driving pop song with heavy drums and synth pads in the background",
224
+ "./assets/bach.mp3",
225
+ "melody"
226
+ ],
227
+ [
228
+ "A cheerful country song with acoustic guitars",
229
+ "./assets/bolero_ravel.mp3",
230
+ "melody"
231
+ ],
232
+ [
233
+ "90s rock song with electric guitar and heavy drums",
234
+ None,
235
+ "medium"
236
+ ],
237
+ [
238
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
239
+ "./assets/bach.mp3",
240
+ "melody"
241
+ ],
242
+ [
243
+ "lofi slow bpm electro chill with organic samples",
244
+ None,
245
+ "medium",
246
+ ],
247
+ ],
248
+ inputs=[text, melody, model],
249
+ outputs=[output]
250
+ )
251
+ gr.Markdown(
252
+ """
253
+ ### More details
254
+ The model will generate a short music extract based on the description you provided.
255
+ The model can generate up to 30 seconds of audio in one pass. It is now possible
256
+ to extend the generation by feeding back the end of the previous chunk of audio.
257
+ This can take a long time, and the model might lose consistency. The model might also
258
+ decide at arbitrary positions that the song ends.
259
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
260
+ An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
261
+ are generated each time.
262
+ We present 4 model variations:
263
+ 1. Melody -- a music generation model capable of generating music condition
264
+ on text and melody inputs. **Note**, you can also use text only.
265
+ 2. Small -- a 300M transformer decoder conditioned on text only.
266
+ 3. Medium -- a 1.5B transformer decoder conditioned on text only.
267
+ 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
268
+ When using `melody`, ou can optionaly provide a reference audio from
269
+ which a broad melody will be extracted. The model will then try to follow both
270
+ the description and melody provided.
271
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
272
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
273
+ for more details.
274
+ """
275
+ )
276
+
277
+ interface.queue().launch(**launch_kwargs)
278
+
279
+
280
+ def ui_batched(launch_kwargs):
281
+ with gr.Blocks() as demo:
282
+ gr.Markdown(
283
+ """
284
+ # MusicGen
285
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
286
+ a simple and controllable model for music generation
287
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
288
+ <br/>
289
+ <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
290
+ style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
291
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
292
+ src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
293
+ for longer sequences, more control and no queue.</p>
294
+ """
295
+ )
296
+ with gr.Row():
297
+ with gr.Column():
298
+ with gr.Row():
299
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
300
+ with gr.Column():
301
+ radio = gr.Radio(["file", "mic"], value="file",
302
+ label="Condition on a melody (optional) File or Mic")
303
+ melody = gr.Audio(source="upload", type="numpy", label="File",
304
+ interactive=True, elem_id="melody-input")
305
+ with gr.Row():
306
+ submit = gr.Button("Generate")
307
+ with gr.Column():
308
+ output = gr.Video(label="Generated Music")
309
+ submit.click(predict_batched, inputs=[text, melody],
310
+ outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
311
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
312
+ gr.Examples(
313
+ fn=predict_batched,
314
+ examples=[
315
+ [
316
+ "An 80s driving pop song with heavy drums and synth pads in the background",
317
+ "./assets/bach.mp3",
318
+ ],
319
+ [
320
+ "A cheerful country song with acoustic guitars",
321
+ "./assets/bolero_ravel.mp3",
322
+ ],
323
+ [
324
+ "90s rock song with electric guitar and heavy drums",
325
+ None,
326
+ ],
327
+ [
328
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
329
+ "./assets/bach.mp3",
330
+ ],
331
+ [
332
+ "lofi slow bpm electro chill with organic samples",
333
+ None,
334
+ ],
335
+ ],
336
+ inputs=[text, melody],
337
+ outputs=[output]
338
+ )
339
+ gr.Markdown("""
340
+ ### More details
341
+ The model will generate 12 seconds of audio based on the description you provided.
342
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
343
+ The model will then try to follow both the description and melody provided.
344
+ All samples are generated with the `melody` model.
345
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
346
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
347
+ for more details.
348
+ """)
349
+
350
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ parser = argparse.ArgumentParser()
355
+ parser.add_argument(
356
+ '--listen',
357
+ type=str,
358
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
359
+ help='IP to listen on for connections to Gradio',
360
+ )
361
+ parser.add_argument(
362
+ '--username', type=str, default='', help='Username for authentication'
363
+ )
364
+ parser.add_argument(
365
+ '--password', type=str, default='', help='Password for authentication'
366
+ )
367
+ parser.add_argument(
368
+ '--server_port',
369
+ type=int,
370
+ default=0,
371
+ help='Port to run the server listener on',
372
+ )
373
+ parser.add_argument(
374
+ '--inbrowser', action='store_true', help='Open in browser'
375
+ )
376
+ parser.add_argument(
377
+ '--share', action='store_true', help='Share the gradio UI'
378
+ )
379
+
380
+ args = parser.parse_args()
381
+
382
+ launch_kwargs = {}
383
+ launch_kwargs['server_name'] = args.listen
384
+
385
+ if args.username and args.password:
386
+ launch_kwargs['auth'] = (args.username, args.password)
387
+ if args.server_port:
388
+ launch_kwargs['server_port'] = args.server_port
389
+ if args.inbrowser:
390
+ launch_kwargs['inbrowser'] = args.inbrowser
391
+ if args.share:
392
+ launch_kwargs['share'] = args.share
393
+
394
+ # Show the interface
395
  if IS_BATCHED:
 
396
  ui_batched(launch_kwargs)
397
  else:
398
+ ui_full(launch_kwargs)