anthonyrusso commited on
Commit
66051a3
1 Parent(s): 5186d69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +454 -0
app.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from concurrent.futures import ProcessPoolExecutor
3
+ import os
4
+ from pathlib import Path
5
+ import subprocess as sp
6
+ from tempfile import NamedTemporaryFile
7
+ import time
8
+ import typing as tp
9
+ import warnings
10
+
11
+ import torch
12
+ import gradio as gr
13
+
14
+ from audiocraft.data.audio_utils import convert_audio
15
+ from audiocraft.data.audio import audio_write
16
+ from audiocraft.models import MusicGen, MultiBandDiffusion
17
+
18
+
19
+ MODEL = None # Last used model
20
+ IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
21
+ print(IS_BATCHED)
22
+ MAX_BATCH_SIZE = 12
23
+ BATCHED_DURATION = 15
24
+ INTERRUPTING = False
25
+ MBD = None
26
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
27
+ _old_call = sp.call
28
+
29
+
30
+ def _call_nostderr(*args, **kwargs):
31
+ # Avoid ffmpeg vomiting on the logs.
32
+ kwargs['stderr'] = sp.DEVNULL
33
+ kwargs['stdout'] = sp.DEVNULL
34
+ _old_call(*args, **kwargs)
35
+
36
+
37
+ sp.call = _call_nostderr
38
+ # Preallocating the pool of processes.
39
+ pool = ProcessPoolExecutor(4)
40
+ pool.__enter__()
41
+
42
+
43
+ def interrupt():
44
+ global INTERRUPTING
45
+ INTERRUPTING = True
46
+
47
+
48
+ class FileCleaner:
49
+ def __init__(self, file_lifetime: float = 3600):
50
+ self.file_lifetime = file_lifetime
51
+ self.files = []
52
+
53
+ def add(self, path: tp.Union[str, Path]):
54
+ self._cleanup()
55
+ self.files.append((time.time(), Path(path)))
56
+
57
+ def _cleanup(self):
58
+ now = time.time()
59
+ for time_added, path in list(self.files):
60
+ if now - time_added > self.file_lifetime:
61
+ if path.exists():
62
+ path.unlink()
63
+ self.files.pop(0)
64
+ else:
65
+ break
66
+
67
+
68
+ file_cleaner = FileCleaner()
69
+
70
+
71
+ def make_waveform(*args, **kwargs):
72
+ # Further remove some warnings.
73
+ be = time.time()
74
+ with warnings.catch_warnings():
75
+ warnings.simplefilter('ignore')
76
+ out = gr.make_waveform(*args, **kwargs)
77
+ print("Make a video took", time.time() - be)
78
+ return out
79
+
80
+
81
+ def load_model(version='facebook/musicgen-melody'):
82
+ global MODEL
83
+ print("Loading model", version)
84
+ if MODEL is None or MODEL.name != version:
85
+ MODEL = MusicGen.get_pretrained(version)
86
+
87
+
88
+ def load_diffusion():
89
+ global MBD
90
+ if MBD is None:
91
+ print("loading MBD")
92
+ MBD = MultiBandDiffusion.get_mbd_musicgen()
93
+
94
+
95
+ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
96
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
97
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
98
+ be = time.time()
99
+ processed_melodies = []
100
+ target_sr = 32000
101
+ target_ac = 1
102
+ for melody in melodies:
103
+ if melody is None:
104
+ processed_melodies.append(None)
105
+ else:
106
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
107
+ if melody.dim() == 1:
108
+ melody = melody[None]
109
+ melody = melody[..., :int(sr * duration)]
110
+ melody = convert_audio(melody, sr, target_sr, target_ac)
111
+ processed_melodies.append(melody)
112
+
113
+ if any(m is not None for m in processed_melodies):
114
+ outputs = MODEL.generate_with_chroma(
115
+ descriptions=texts,
116
+ melody_wavs=processed_melodies,
117
+ melody_sample_rate=target_sr,
118
+ progress=progress,
119
+ return_tokens=USE_DIFFUSION
120
+ )
121
+ else:
122
+ outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
123
+ if USE_DIFFUSION:
124
+ outputs_diffusion = MBD.tokens_to_wav(outputs[1])
125
+ outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
126
+ outputs = outputs.detach().cpu().float()
127
+ pending_videos = []
128
+ out_wavs = []
129
+ for output in outputs:
130
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
131
+ audio_write(
132
+ file.name, output, MODEL.sample_rate, strategy="loudness",
133
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
134
+ pending_videos.append(pool.submit(make_waveform, file.name))
135
+ out_wavs.append(file.name)
136
+ file_cleaner.add(file.name)
137
+ out_videos = [pending_video.result() for pending_video in pending_videos]
138
+ for video in out_videos:
139
+ file_cleaner.add(video)
140
+ print("batch finished", len(texts), time.time() - be)
141
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
142
+ return out_videos, out_wavs
143
+
144
+
145
+ def predict_batched(texts, melodies):
146
+ max_text_length = 512
147
+ texts = [text[:max_text_length] for text in texts]
148
+ load_model('facebook/musicgen-melody')
149
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
150
+ return res
151
+
152
+
153
+ def predict_full(model, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
154
+ global INTERRUPTING
155
+ global USE_DIFFUSION
156
+ INTERRUPTING = False
157
+ if temperature < 0:
158
+ raise gr.Error("Temperature must be >= 0.")
159
+ if topk < 0:
160
+ raise gr.Error("Topk must be non-negative.")
161
+ if topp < 0:
162
+ raise gr.Error("Topp must be non-negative.")
163
+
164
+ topk = int(topk)
165
+ if decoder == "MultiBand_Diffusion":
166
+ USE_DIFFUSION = True
167
+ load_diffusion()
168
+ else:
169
+ USE_DIFFUSION = False
170
+ load_model(model)
171
+
172
+ def _progress(generated, to_generate):
173
+ progress((min(generated, to_generate), to_generate))
174
+ if INTERRUPTING:
175
+ raise gr.Error("Interrupted.")
176
+ MODEL.set_custom_progress_callback(_progress)
177
+
178
+ videos, wavs = _do_predictions(
179
+ [text], [melody], duration, progress=True,
180
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
181
+ if USE_DIFFUSION:
182
+ return videos[0], wavs[0], videos[1], wavs[1]
183
+ return videos[0], wavs[0], None, None
184
+
185
+
186
+ def toggle_audio_src(choice):
187
+ if choice == "mic":
188
+ return gr.update(source="microphone", value=None, label="Microphone")
189
+ else:
190
+ return gr.update(source="upload", value=None, label="File")
191
+
192
+
193
+ def toggle_diffusion(choice):
194
+ if choice == "MultiBand_Diffusion":
195
+ return [gr.update(visible=True)] * 2
196
+ else:
197
+ return [gr.update(visible=False)] * 2
198
+
199
+
200
+ def ui_full(launch_kwargs):
201
+ with gr.Blocks() as interface:
202
+ gr.Markdown(
203
+ """
204
+ # MusicGen
205
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
206
+ a simple and controllable model for music generation
207
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
208
+ """
209
+ )
210
+ with gr.Row():
211
+ with gr.Column():
212
+ with gr.Row():
213
+ text = gr.Text(label="Input Text", interactive=True)
214
+ with gr.Column():
215
+ radio = gr.Radio(["file", "mic"], value="file",
216
+ label="Condition on a melody (optional) File or Mic")
217
+ melody = gr.Audio(source="upload", type="numpy", label="File",
218
+ interactive=True, elem_id="melody-input")
219
+ with gr.Row():
220
+ submit = gr.Button("Submit")
221
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
222
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
223
+ with gr.Row():
224
+ model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
225
+ "facebook/musicgen-large"],
226
+ label="Model", value="facebook/musicgen-melody", interactive=True)
227
+ with gr.Row():
228
+ decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
229
+ label="Decoder", value="Default", interactive=True)
230
+ with gr.Row():
231
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
232
+ with gr.Row():
233
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
234
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
235
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
236
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
237
+ with gr.Column():
238
+ output = gr.Video(label="Generated Music")
239
+ audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
240
+ diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
241
+ audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
242
+ submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
243
+ show_progress=False).then(predict_full, inputs=[model, decoder, text, melody, duration, topk, topp,
244
+ temperature, cfg_coef],
245
+ outputs=[output, audio_output, diffusion_output, audio_diffusion])
246
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
247
+
248
+ gr.Examples(
249
+ fn=predict_full,
250
+ examples=[
251
+ [
252
+ "An 80s driving pop song with heavy drums and synth pads in the background",
253
+ "./assets/bach.mp3",
254
+ "facebook/musicgen-melody",
255
+ "Default"
256
+ ],
257
+ [
258
+ "A cheerful country song with acoustic guitars",
259
+ "./assets/bolero_ravel.mp3",
260
+ "facebook/musicgen-melody",
261
+ "Default"
262
+ ],
263
+ [
264
+ "90s rock song with electric guitar and heavy drums",
265
+ None,
266
+ "facebook/musicgen-medium",
267
+ "Default"
268
+ ],
269
+ [
270
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
271
+ "./assets/bach.mp3",
272
+ "facebook/musicgen-melody",
273
+ "Default"
274
+ ],
275
+ [
276
+ "lofi slow bpm electro chill with organic samples",
277
+ None,
278
+ "facebook/musicgen-medium",
279
+ "Default"
280
+ ],
281
+ [
282
+ "Punk rock with loud drum and power guitar",
283
+ None,
284
+ "facebook/musicgen-medium",
285
+ "MultiBand_Diffusion"
286
+ ],
287
+ ],
288
+ inputs=[text, melody, model, decoder],
289
+ outputs=[output]
290
+ )
291
+ gr.Markdown(
292
+ """
293
+ ### More details
294
+
295
+ The model will generate a short music extract based on the description you provided.
296
+ The model can generate up to 30 seconds of audio in one pass. It is now possible
297
+ to extend the generation by feeding back the end of the previous chunk of audio.
298
+ This can take a long time, and the model might lose consistency. The model might also
299
+ decide at arbitrary positions that the song ends.
300
+
301
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
302
+ An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
303
+ are generated each time.
304
+
305
+ We present 4 model variations:
306
+ 1. facebook/musicgen-melody -- a music generation model capable of generating music condition
307
+ on text and melody inputs. **Note**, you can also use text only.
308
+ 2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only.
309
+ 3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only.
310
+ 4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only.
311
+
312
+ We also present two way of decoding the audio tokens
313
+ 1. Use the default GAN based compression model
314
+ 2. Use MultiBand Diffusion from (paper linknano )
315
+
316
+ When using `facebook/musicgen-melody`, you can optionally provide a reference audio from
317
+ which a broad melody will be extracted. The model will then try to follow both
318
+ the description and melody provided.
319
+
320
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
321
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
322
+ for more details.
323
+ """
324
+ )
325
+
326
+ interface.queue().launch(**launch_kwargs)
327
+
328
+
329
+ def ui_batched(launch_kwargs):
330
+ with gr.Blocks() as demo:
331
+ gr.Markdown(
332
+ """
333
+ # MusicGen
334
+
335
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
336
+ a simple and controllable model for music generation
337
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
338
+ <br/>
339
+ <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
340
+ style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
341
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
342
+ src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
343
+ for longer sequences, more control and no queue.</p>
344
+ """
345
+ )
346
+ with gr.Row():
347
+ with gr.Column():
348
+ with gr.Row():
349
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
350
+ with gr.Column():
351
+ radio = gr.Radio(["file", "mic"], value="file",
352
+ label="Condition on a melody (optional) File or Mic")
353
+ melody = gr.Audio(source="upload", type="numpy", label="File",
354
+ interactive=True, elem_id="melody-input")
355
+ with gr.Row():
356
+ submit = gr.Button("Generate")
357
+ with gr.Column():
358
+ output = gr.Video(label="Generated Music")
359
+ audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
360
+ submit.click(predict_batched, inputs=[text, melody],
361
+ outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
362
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
363
+ gr.Examples(
364
+ fn=predict_batched,
365
+ examples=[
366
+ [
367
+ "An 80s driving pop song with heavy drums and synth pads in the background",
368
+ "./assets/bach.mp3",
369
+ ],
370
+ [
371
+ "A cheerful country song with acoustic guitars",
372
+ "./assets/bolero_ravel.mp3",
373
+ ],
374
+ [
375
+ "90s rock song with electric guitar and heavy drums",
376
+ None,
377
+ ],
378
+ [
379
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
380
+ "./assets/bach.mp3",
381
+ ],
382
+ [
383
+ "lofi slow bpm electro chill with organic samples",
384
+ None,
385
+ ],
386
+ ],
387
+ inputs=[text, melody],
388
+ outputs=[output]
389
+ )
390
+ gr.Markdown("""
391
+ ### More details
392
+
393
+ The model will generate 12 seconds of audio based on the description you provided.
394
+ You can optionally provide a reference audio from which a broad melody will be extracted.
395
+ The model will then try to follow both the description and melody provided.
396
+ All samples are generated with the `melody` model.
397
+
398
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
399
+
400
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
401
+ for more details.
402
+ """)
403
+
404
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
405
+
406
+
407
+ if __name__ == "__main__":
408
+ parser = argparse.ArgumentParser()
409
+ parser.add_argument(
410
+ '--listen',
411
+ type=str,
412
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
413
+ help='IP to listen on for connections to Gradio',
414
+ )
415
+ parser.add_argument(
416
+ '--username', type=str, default='', help='Username for authentication'
417
+ )
418
+ parser.add_argument(
419
+ '--password', type=str, default='', help='Password for authentication'
420
+ )
421
+ parser.add_argument(
422
+ '--server_port',
423
+ type=int,
424
+ default=0,
425
+ help='Port to run the server listener on',
426
+ )
427
+ parser.add_argument(
428
+ '--inbrowser', action='store_true', help='Open in browser'
429
+ )
430
+ parser.add_argument(
431
+ '--share', action='store_true', help='Share the gradio UI'
432
+ )
433
+
434
+ args = parser.parse_args()
435
+
436
+ launch_kwargs = {}
437
+ launch_kwargs['server_name'] = args.listen
438
+
439
+ if args.username and args.password:
440
+ launch_kwargs['auth'] = (args.username, args.password)
441
+ if args.server_port:
442
+ launch_kwargs['server_port'] = args.server_port
443
+ if args.inbrowser:
444
+ launch_kwargs['inbrowser'] = args.inbrowser
445
+ if args.share:
446
+ launch_kwargs['share'] = args.share
447
+
448
+ # Show the interface
449
+ if IS_BATCHED:
450
+ global USE_DIFFUSION
451
+ USE_DIFFUSION = False
452
+ ui_batched(launch_kwargs)
453
+ else:
454
+ ui_full(launch_kwargs)