jbilcke-hf HF Staff commited on
Commit
243ff9b
·
1 Parent(s): 66eea88

Update demos/musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +14 -102
demos/musicgen_app.py CHANGED
@@ -32,8 +32,7 @@ SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
32
 
33
  MODEL = None # Last used model
34
  SPACE_ID = os.environ.get('SPACE_ID', '')
35
- IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID
36
- print(IS_BATCHED)
37
  MAX_BATCH_SIZE = 12
38
  BATCHED_DURATION = 15
39
  INTERRUPTING = False
@@ -82,17 +81,6 @@ class FileCleaner:
82
 
83
  file_cleaner = FileCleaner()
84
 
85
-
86
- def make_waveform(*args, **kwargs):
87
- # Further remove some warnings.
88
- be = time.time()
89
- with warnings.catch_warnings():
90
- warnings.simplefilter('ignore')
91
- out = gr.make_waveform(*args, **kwargs)
92
- print("Make a video took", time.time() - be)
93
- return out
94
-
95
-
96
  def load_model(version='facebook/musicgen-melody'):
97
  global MODEL
98
  print("Loading model", version)
@@ -153,30 +141,25 @@ def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=N
153
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
154
  outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
155
  outputs = outputs.detach().cpu().float()
156
- pending_videos = []
157
  out_wavs = []
158
  for output in outputs:
159
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
160
  audio_write(
161
  file.name, output, MODEL.sample_rate, strategy="loudness",
162
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
163
- pending_videos.append(pool.submit(make_waveform, file.name))
164
  out_wavs.append(file.name)
165
  file_cleaner.add(file.name)
166
- out_videos = [pending_video.result() for pending_video in pending_videos]
167
- for video in out_videos:
168
- file_cleaner.add(video)
169
  print("batch finished", len(texts), time.time() - be)
170
  print("Tempfiles currently stored: ", len(file_cleaner.files))
171
- return out_videos, out_wavs
172
 
173
 
174
  def predict_batched(texts, melodies):
175
  max_text_length = 512
176
  texts = [text[:max_text_length] for text in texts]
177
  load_model('facebook/musicgen-stereo-melody')
178
- res = _do_predictions(texts, melodies, BATCHED_DURATION)
179
- return res
180
 
181
 
182
  def predict_full(secret_token, model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
@@ -222,14 +205,13 @@ def predict_full(secret_token, model, model_path, decoder, text, melody, duratio
222
  raise gr.Error("Interrupted.")
223
  MODEL.set_custom_progress_callback(_progress)
224
 
225
- videos, wavs = _do_predictions(
226
  [text], [melody], duration, progress=True,
227
  top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef,
228
  gradio_progress=progress)
229
  if USE_DIFFUSION:
230
- return videos[0], wavs[0], videos[1], wavs[1]
231
- return videos[0], wavs[0], None, None
232
-
233
 
234
  def toggle_audio_src(choice):
235
  if choice == "mic":
@@ -240,9 +222,9 @@ def toggle_audio_src(choice):
240
 
241
  def toggle_diffusion(choice):
242
  if choice == "MultiBand_Diffusion":
243
- return [gr.update(visible=True)] * 2
244
  else:
245
- return [gr.update(visible=False)] * 2
246
 
247
 
248
  def ui_full(launch_kwargs):
@@ -292,14 +274,12 @@ def ui_full(launch_kwargs):
292
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
293
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
294
  with gr.Column():
295
- output = gr.Video(label="Generated Music")
296
  audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
297
- diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
298
  audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
299
- submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
300
  show_progress=False).then(predict_full, inputs=[secret_token, model, model_path, decoder, text, melody, duration, topk, topp,
301
  temperature, cfg_coef],
302
- outputs=[output, audio_output, diffusion_output, audio_diffusion])
303
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
304
 
305
  gr.Markdown(
@@ -349,71 +329,6 @@ def ui_full(launch_kwargs):
349
 
350
  interface.queue().launch(**launch_kwargs)
351
 
352
-
353
- def ui_batched(launch_kwargs):
354
- with gr.Blocks() as demo:
355
- gr.Markdown(
356
- """
357
- # MusicGen
358
-
359
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md),
360
- a simple and controllable model for music generation
361
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
362
- <br/>
363
- <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
364
- style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
365
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
366
- src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
367
- for longer sequences, more control and no queue.</p>
368
- """
369
- )
370
- with gr.Row():
371
- with gr.Column():
372
- with gr.Row():
373
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
374
- with gr.Column():
375
- radio = gr.Radio(["file", "mic"], value="file",
376
- label="Condition on a melody (optional) File or Mic")
377
- melody = gr.Audio(source="upload", type="numpy", label="File",
378
- interactive=True, elem_id="melody-input")
379
- with gr.Row():
380
- submit = gr.Button("Generate")
381
- with gr.Column():
382
- output = gr.Video(label="Generated Music")
383
- audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
384
- submit.click(predict_batched, inputs=[text, melody],
385
- outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
386
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
387
-
388
- gr.Markdown("""
389
- ### More details
390
-
391
- The model will generate 15 seconds of audio based on the description you provided.
392
- The model was trained with description from a stock music catalog, descriptions that will work best
393
- should include some level of details on the instruments present, along with some intended use case
394
- (e.g. adding "perfect for a commercial" can somehow help).
395
-
396
- You can optionally provide a reference audio from which a broad melody will be extracted.
397
- The model will then try to follow both the description and melody provided.
398
- For best results, the melody should be 30 seconds long (I know, the samples we provide are not...)
399
-
400
- You can access more control (longer generation, more models etc.) by clicking
401
- the <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
402
- style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
403
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
404
- src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
405
- (you will then need a paid GPU from HuggingFace).
406
- If you have a GPU, you can run the gradio demo locally (click the link to our repo below for more info).
407
- Finally, you can get a GPU for free from Google
408
- and run the demo in [a Google Colab.](https://ai.honu.io/red/musicgen-colab).
409
-
410
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
411
- for more details. All samples are generated with the `stereo-melody` model.
412
- """)
413
-
414
- demo.queue(max_size=8 * 4).launch(**launch_kwargs)
415
-
416
-
417
  if __name__ == "__main__":
418
  parser = argparse.ArgumentParser()
419
  parser.add_argument(
@@ -458,9 +373,6 @@ if __name__ == "__main__":
458
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
459
 
460
  # Show the interface
461
- if IS_BATCHED:
462
- global USE_DIFFUSION
463
- USE_DIFFUSION = False
464
- ui_batched(launch_kwargs)
465
- else:
466
- ui_full(launch_kwargs)
 
32
 
33
  MODEL = None # Last used model
34
  SPACE_ID = os.environ.get('SPACE_ID', '')
35
+ IS_BATCHED = False # <- we hardcode it
 
36
  MAX_BATCH_SIZE = 12
37
  BATCHED_DURATION = 15
38
  INTERRUPTING = False
 
81
 
82
  file_cleaner = FileCleaner()
83
 
 
 
 
 
 
 
 
 
 
 
 
84
  def load_model(version='facebook/musicgen-melody'):
85
  global MODEL
86
  print("Loading model", version)
 
141
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
142
  outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
143
  outputs = outputs.detach().cpu().float()
 
144
  out_wavs = []
145
  for output in outputs:
146
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
147
  audio_write(
148
  file.name, output, MODEL.sample_rate, strategy="loudness",
149
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
 
150
  out_wavs.append(file.name)
151
  file_cleaner.add(file.name)
152
+
 
 
153
  print("batch finished", len(texts), time.time() - be)
154
  print("Tempfiles currently stored: ", len(file_cleaner.files))
155
+ return out_wavs
156
 
157
 
158
  def predict_batched(texts, melodies):
159
  max_text_length = 512
160
  texts = [text[:max_text_length] for text in texts]
161
  load_model('facebook/musicgen-stereo-melody')
162
+ return _do_predictions(texts, melodies, BATCHED_DURATION)
 
163
 
164
 
165
  def predict_full(secret_token, model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
 
205
  raise gr.Error("Interrupted.")
206
  MODEL.set_custom_progress_callback(_progress)
207
 
208
+ wavs = _do_predictions(
209
  [text], [melody], duration, progress=True,
210
  top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef,
211
  gradio_progress=progress)
212
  if USE_DIFFUSION:
213
+ return wavs[1]
214
+ return wavs[0]
 
215
 
216
  def toggle_audio_src(choice):
217
  if choice == "mic":
 
222
 
223
  def toggle_diffusion(choice):
224
  if choice == "MultiBand_Diffusion":
225
+ return [gr.update(visible=True)]
226
  else:
227
+ return [gr.update(visible=False)]
228
 
229
 
230
  def ui_full(launch_kwargs):
 
274
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
275
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
276
  with gr.Column():
 
277
  audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
 
278
  audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
279
+ submit.click(toggle_diffusion, decoder, [audio_diffusion], queue=False,
280
  show_progress=False).then(predict_full, inputs=[secret_token, model, model_path, decoder, text, melody, duration, topk, topp,
281
  temperature, cfg_coef],
282
+ outputs=[audio_output])
283
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
284
 
285
  gr.Markdown(
 
329
 
330
  interface.queue().launch(**launch_kwargs)
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  if __name__ == "__main__":
333
  parser = argparse.ArgumentParser()
334
  parser.add_argument(
 
373
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
374
 
375
  # Show the interface
376
+ # we preload the model to avoid a timeout on the first request
377
+ load_model('facebook/musicgen-stereo-large')
378
+ ui_full(launch_kwargs)