Manjushri commited on
Commit
6f38fd2
1 Parent(s): e5c5aed

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +408 -0
app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
+ # also released under the MIT license.
9
+
10
+ import argparse
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ import os
13
+ from pathlib import Path
14
+ import subprocess as sp
15
+ from tempfile import NamedTemporaryFile
16
+ import time
17
+ import typing as tp
18
+ import warnings
19
+
20
+ import torch
21
+ import gradio as gr
22
+
23
+ from audiocraft.data.audio_utils import convert_audio
24
+ from audiocraft.data.audio import audio_write
25
+ from audiocraft.models import MusicGen
26
+
27
+
28
+ MODEL = None # Last used model
29
+ IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
30
+ MAX_BATCH_SIZE = 6
31
+ BATCHED_DURATION = 15
32
+ INTERRUPTING = False
33
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
34
+ _old_call = sp.call
35
+
36
+
37
+ def _call_nostderr(*args, **kwargs):
38
+ # Avoid ffmpeg vomitting on the logs.
39
+ kwargs['stderr'] = sp.DEVNULL
40
+ kwargs['stdout'] = sp.DEVNULL
41
+ _old_call(*args, **kwargs)
42
+
43
+
44
+ sp.call = _call_nostderr
45
+ # Preallocating the pool of processes.
46
+ pool = ProcessPoolExecutor(3)
47
+ pool.__enter__()
48
+
49
+
50
+ def interrupt():
51
+ global INTERRUPTING
52
+ INTERRUPTING = True
53
+
54
+
55
+ class FileCleaner:
56
+ def __init__(self, file_lifetime: float = 3600):
57
+ self.file_lifetime = file_lifetime
58
+ self.files = []
59
+
60
+ def add(self, path: tp.Union[str, Path]):
61
+ self._cleanup()
62
+ self.files.append((time.time(), Path(path)))
63
+
64
+ def _cleanup(self):
65
+ now = time.time()
66
+ for time_added, path in list(self.files):
67
+ if now - time_added > self.file_lifetime:
68
+ if path.exists():
69
+ path.unlink()
70
+ self.files.pop(0)
71
+ else:
72
+ break
73
+
74
+
75
+ file_cleaner = FileCleaner()
76
+
77
+
78
+ def make_waveform(*args, **kwargs):
79
+ # Further remove some warnings.
80
+ be = time.time()
81
+ with warnings.catch_warnings():
82
+ warnings.simplefilter('ignore')
83
+ out = gr.make_waveform(*args, **kwargs)
84
+ print("Make a video took", time.time() - be)
85
+ return out
86
+
87
+
88
+ def load_model(version='melody'):
89
+ global MODEL
90
+ print("Loading model", version)
91
+ if MODEL is None or MODEL.name != version:
92
+ MODEL = MusicGen.get_pretrained(version)
93
+
94
+
95
+ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
96
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
97
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
98
+ be = time.time()
99
+ processed_melodies = []
100
+ target_sr = 32000
101
+ target_ac = 1
102
+ for melody in melodies:
103
+ if melody is None:
104
+ processed_melodies.append(None)
105
+ else:
106
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
107
+ if melody.dim() == 1:
108
+ melody = melody[None]
109
+ melody = melody[..., :int(sr * duration)]
110
+ melody = convert_audio(melody, sr, target_sr, target_ac)
111
+ processed_melodies.append(melody)
112
+
113
+ if any(m is not None for m in processed_melodies):
114
+ outputs = MODEL.generate_with_chroma(
115
+ descriptions=texts,
116
+ melody_wavs=processed_melodies,
117
+ melody_sample_rate=target_sr,
118
+ progress=progress,
119
+ )
120
+ else:
121
+ outputs = MODEL.generate(texts, progress=progress)
122
+
123
+ outputs = outputs.detach().cpu().float()
124
+ out_files = []
125
+ for output in outputs:
126
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
127
+ audio_write(
128
+ file.name, output, MODEL.sample_rate, strategy="loudness",
129
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
130
+ out_files.append(pool.submit(make_waveform, file.name))
131
+ file_cleaner.add(file.name)
132
+ res = [out_file.result() for out_file in out_files]
133
+ for file in res:
134
+ file_cleaner.add(file)
135
+ print("batch finished", len(texts), time.time() - be)
136
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
137
+ return res
138
+
139
+
140
+ def predict_batched(texts, melodies):
141
+ max_text_length = 512
142
+ texts = [text[:max_text_length] for text in texts]
143
+ load_model('melody')
144
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
145
+ return [res]
146
+
147
+
148
+ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
149
+ global INTERRUPTING
150
+ INTERRUPTING = False
151
+ if temperature < 0:
152
+ raise gr.Error("Temperature must be >= 0.")
153
+ if topk < 0:
154
+ raise gr.Error("Topk must be non-negative.")
155
+ if topp < 0:
156
+ raise gr.Error("Topp must be non-negative.")
157
+
158
+ topk = int(topk)
159
+ load_model(model)
160
+
161
+ def _progress(generated, to_generate):
162
+ progress((generated, to_generate))
163
+ if INTERRUPTING:
164
+ raise gr.Error("Interrupted.")
165
+ MODEL.set_custom_progress_callback(_progress)
166
+
167
+ outs = _do_predictions(
168
+ [text], [melody], duration, progress=True,
169
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
170
+ return outs[0]
171
+
172
+
173
+ def toggle_audio_src(choice):
174
+ if choice == "mic":
175
+ return gr.update(source="microphone", value=None, label="Microphone")
176
+ else:
177
+ return gr.update(source="upload", value=None, label="File")
178
+
179
+
180
+ def ui_full(launch_kwargs):
181
+ with gr.Blocks() as interface:
182
+ gr.Markdown(
183
+ """
184
+ # MusicGen
185
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
186
+ a simple and controllable model for music generation
187
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
188
+ """
189
+ )
190
+ with gr.Row():
191
+ with gr.Column():
192
+ with gr.Row():
193
+ text = gr.Text(label="Input Text", interactive=True)
194
+ with gr.Column():
195
+ radio = gr.Radio(["file", "mic"], value="file",
196
+ label="Condition on a melody (optional) File or Mic")
197
+ melody = gr.Audio(source="upload", type="numpy", label="File",
198
+ interactive=True, elem_id="melody-input")
199
+ with gr.Row():
200
+ submit = gr.Button("Submit")
201
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
202
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
203
+ with gr.Row():
204
+ model = gr.Radio(["melody", "medium", "small", "large"],
205
+ label="Model", value="melody", interactive=True)
206
+ with gr.Row():
207
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
208
+ with gr.Row():
209
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
210
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
211
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
212
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
213
+ with gr.Column():
214
+ output = gr.Video(label="Generated Music")
215
+ submit.click(predict_full,
216
+ inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef],
217
+ outputs=[output])
218
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
219
+ gr.Examples(
220
+ fn=predict_full,
221
+ examples=[
222
+ [
223
+ "An 80s driving pop song with heavy drums and synth pads in the background",
224
+ "./bach.mp3",
225
+ "melody"
226
+ ],
227
+ [
228
+ "A cheerful country song with acoustic guitars",
229
+ "./bolero_ravel.mp3",
230
+ "melody"
231
+ ],
232
+ [
233
+ "90s rock song with electric guitar and heavy drums",
234
+ None,
235
+ "medium"
236
+ ],
237
+ [
238
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
239
+ "./bach.mp3",
240
+ "melody"
241
+ ],
242
+ [
243
+ "lofi slow bpm electro chill with organic samples",
244
+ None,
245
+ "medium",
246
+ ],
247
+ ],
248
+ inputs=[text, melody, model],
249
+ outputs=[output]
250
+ )
251
+ gr.Markdown(
252
+ """
253
+ ### More details
254
+
255
+ The model will generate a short music extract based on the description you provided.
256
+ The model can generate up to 30 seconds of audio in one pass. It is now possible
257
+ to extend the generation by feeding back the end of the previous chunk of audio.
258
+ This can take a long time, and the model might lose consistency. The model might also
259
+ decide at arbitrary positions that the song ends.
260
+
261
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
262
+ An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
263
+ are generated each time.
264
+
265
+ We present 4 model variations:
266
+ 1. Melody -- a music generation model capable of generating music condition
267
+ on text and melody inputs. **Note**, you can also use text only.
268
+ 2. Small -- a 300M transformer decoder conditioned on text only.
269
+ 3. Medium -- a 1.5B transformer decoder conditioned on text only.
270
+ 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
271
+
272
+ When using `melody`, ou can optionaly provide a reference audio from
273
+ which a broad melody will be extracted. The model will then try to follow both
274
+ the description and melody provided.
275
+
276
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
277
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
278
+ for more details.
279
+ """
280
+ )
281
+
282
+ interface.queue().launch(**launch_kwargs)
283
+
284
+
285
+ def ui_batched(launch_kwargs):
286
+ with gr.Blocks() as demo:
287
+ gr.Markdown(
288
+ """
289
+ # MusicGen
290
+
291
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
292
+ a simple and controllable model for music generation
293
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
294
+ <br/>
295
+ <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
296
+ style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
297
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
298
+ src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
299
+ for longer sequences, more control and no queue.</p>
300
+ """
301
+ )
302
+ with gr.Row():
303
+ with gr.Column():
304
+ with gr.Row():
305
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
306
+ with gr.Column():
307
+ radio = gr.Radio(["file", "mic"], value="file",
308
+ label="Condition on a melody (optional) File or Mic")
309
+ melody = gr.Audio(source="upload", type="numpy", label="File",
310
+ interactive=True, elem_id="melody-input")
311
+ with gr.Row():
312
+ submit = gr.Button("Generate")
313
+ with gr.Column():
314
+ output = gr.Video(label="Generated Music")
315
+ submit.click(predict_batched, inputs=[text, melody],
316
+ outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
317
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
318
+ gr.Examples(
319
+ fn=predict_batched,
320
+ examples=[
321
+ [
322
+ "An 80s driving pop song with heavy drums and synth pads in the background",
323
+ "./assets/bach.mp3",
324
+ ],
325
+ [
326
+ "A cheerful country song with acoustic guitars",
327
+ "./assets/bolero_ravel.mp3",
328
+ ],
329
+ [
330
+ "90s rock song with electric guitar and heavy drums",
331
+ None,
332
+ ],
333
+ [
334
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
335
+ "./assets/bach.mp3",
336
+ ],
337
+ [
338
+ "lofi slow bpm electro chill with organic samples",
339
+ None,
340
+ ],
341
+ ],
342
+ inputs=[text, melody],
343
+ outputs=[output]
344
+ )
345
+ gr.Markdown("""
346
+ ### More details
347
+
348
+ The model will generate 12 seconds of audio based on the description you provided.
349
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
350
+ The model will then try to follow both the description and melody provided.
351
+ All samples are generated with the `melody` model.
352
+
353
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
354
+
355
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
356
+ for more details.
357
+ """)
358
+
359
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
360
+
361
+
362
+ if __name__ == "__main__":
363
+ parser = argparse.ArgumentParser()
364
+ parser.add_argument(
365
+ '--listen',
366
+ type=str,
367
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
368
+ help='IP to listen on for connections to Gradio',
369
+ )
370
+ parser.add_argument(
371
+ '--username', type=str, default='', help='Username for authentication'
372
+ )
373
+ parser.add_argument(
374
+ '--password', type=str, default='', help='Password for authentication'
375
+ )
376
+ parser.add_argument(
377
+ '--server_port',
378
+ type=int,
379
+ default=0,
380
+ help='Port to run the server listener on',
381
+ )
382
+ parser.add_argument(
383
+ '--inbrowser', action='store_true', help='Open in browser'
384
+ )
385
+ parser.add_argument(
386
+ '--share', action='store_true', help='Share the gradio UI'
387
+ )
388
+
389
+ args = parser.parse_args()
390
+
391
+ launch_kwargs = {}
392
+ launch_kwargs['server_name'] = args.listen
393
+
394
+ if args.username and args.password:
395
+ launch_kwargs['auth'] = (args.username, args.password)
396
+ if args.server_port:
397
+ launch_kwargs['server_port'] = args.server_port
398
+ if args.inbrowser:
399
+ launch_kwargs['inbrowser'] = args.inbrowser
400
+ if args.share:
401
+ launch_kwargs['share'] = args.share
402
+
403
+ # Show the interface
404
+ if IS_BATCHED:
405
+ ui_batched(launch_kwargs)
406
+ else:
407
+ ui_full(launch_kwargs)
408
+