mattricesound commited on
Commit
cbe698e
1 Parent(s): 260c2a0

Added gradio app.py

Browse files
Files changed (2) hide show
  1. app.py +399 -0
  2. setup.py +2 -1
app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
+ # also released under the MIT license.
9
+
10
+ import argparse
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ import os
13
+ from pathlib import Path
14
+ import subprocess as sp
15
+ from tempfile import NamedTemporaryFile
16
+ import time
17
+ import typing as tp
18
+ import warnings
19
+
20
+ import torch
21
+ import gradio as gr
22
+
23
+ from audiocraft.data.audio_utils import convert_audio
24
+ from audiocraft.data.audio import audio_write
25
+ from audiocraft.models import MusicGen
26
+
27
+ from demucs import pretrained
28
+ from demucs.apply import apply_model
29
+ from demucs.audio import convert_audio
30
+
31
+
32
+ MODEL = None # Last used model
33
+ IS_BATCHED = False
34
+ MAX_BATCH_SIZE = 12
35
+ BATCHED_DURATION = 15
36
+ INTERRUPTING = False
37
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
38
+ _old_call = sp.call
39
+
40
+ stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
41
+ stem_idx = torch.LongTensor([stem2idx['vocal'], stem2idx['other'], stem2idx['bass']])
42
+ demucs_model = pretrained.get_model('htdemucs')
43
+
44
+
45
+ def _call_nostderr(*args, **kwargs):
46
+ # Avoid ffmpeg vomitting on the logs.
47
+ kwargs['stderr'] = sp.DEVNULL
48
+ kwargs['stdout'] = sp.DEVNULL
49
+ _old_call(*args, **kwargs)
50
+
51
+
52
+ sp.call = _call_nostderr
53
+ # Preallocating the pool of processes.
54
+ pool = ProcessPoolExecutor(4)
55
+ pool.__enter__()
56
+
57
+
58
+ def interrupt():
59
+ global INTERRUPTING
60
+ INTERRUPTING = True
61
+
62
+
63
+ class FileCleaner:
64
+ def __init__(self, file_lifetime: float = 3600):
65
+ self.file_lifetime = file_lifetime
66
+ self.files = []
67
+
68
+ def add(self, path: tp.Union[str, Path]):
69
+ self._cleanup()
70
+ self.files.append((time.time(), Path(path)))
71
+
72
+ def _cleanup(self):
73
+ now = time.time()
74
+ for time_added, path in list(self.files):
75
+ if now - time_added > self.file_lifetime:
76
+ if path.exists():
77
+ path.unlink()
78
+ self.files.pop(0)
79
+ else:
80
+ break
81
+
82
+
83
+ file_cleaner = FileCleaner()
84
+
85
+
86
+ def make_waveform(*args, **kwargs):
87
+ # Further remove some warnings.
88
+ be = time.time()
89
+ with warnings.catch_warnings():
90
+ warnings.simplefilter('ignore')
91
+ out = gr.make_waveform(*args, **kwargs)
92
+ print("Make a video took", time.time() - be)
93
+ return out
94
+
95
+
96
+ def load_model(version='melody'):
97
+ global MODEL
98
+ print("Loading model", version)
99
+ if MODEL is None or MODEL.name != version:
100
+ MODEL = MusicGen.get_pretrained(version)
101
+
102
+
103
+ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
104
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
105
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
106
+ be = time.time()
107
+ processed_melodies = []
108
+ target_sr = 32000
109
+ target_ac = 1
110
+ for melody in melodies:
111
+ if melody is None:
112
+ processed_melodies.append(None)
113
+ else:
114
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
115
+ if melody.dim() == 1:
116
+ melody = melody[None]
117
+ melody = melody[..., :int(sr * duration)]
118
+ melody = convert_audio(melody, sr, target_sr, target_ac)
119
+ processed_melodies.append(melody)
120
+
121
+ if any(m is not None for m in processed_melodies):
122
+ outputs = MODEL.generate_with_chroma(
123
+ descriptions=texts,
124
+ melody_wavs=processed_melodies,
125
+ melody_sample_rate=target_sr,
126
+ progress=progress,
127
+ )
128
+ else:
129
+ outputs = MODEL.generate(texts, progress=progress)
130
+
131
+ outputs = outputs.detach().cpu().float()
132
+
133
+ out_files = []
134
+ for output in outputs:
135
+ # Demucs
136
+ print("Running demucs")
137
+ wav = convert_audio(output, MODEL.sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
138
+ wav = wav.unsqueeze(0)
139
+ stems = apply_model(demucs_model, wav)
140
+ stems = stems[:, stem_idx] # extract stem
141
+ stems = stems.sum(1) # merge extracted stems
142
+ stems = convert_audio(stems, demucs_model.samplerate, MODEL.sample_rate, 1)
143
+ output = stems[0]
144
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
145
+ audio_write(
146
+ file.name, output, MODEL.sample_rate, strategy="loudness",
147
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
148
+
149
+
150
+
151
+ out_files.append(pool.submit(make_waveform, file.name))
152
+ file_cleaner.add(file.name)
153
+ res = [out_file.result() for out_file in out_files]
154
+ for file in res:
155
+ file_cleaner.add(file)
156
+ print("batch finished", len(texts), time.time() - be)
157
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
158
+ return res
159
+
160
+
161
+ def predict_batched(texts, melodies):
162
+ max_text_length = 512
163
+ texts = [text[:max_text_length] for text in texts]
164
+ load_model('melody')
165
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
166
+ return [res]
167
+
168
+
169
+ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
170
+ global INTERRUPTING
171
+ INTERRUPTING = False
172
+ if temperature < 0:
173
+ raise gr.Error("Temperature must be >= 0.")
174
+ if topk < 0:
175
+ raise gr.Error("Topk must be non-negative.")
176
+ if topp < 0:
177
+ raise gr.Error("Topp must be non-negative.")
178
+
179
+ topk = int(topk)
180
+ load_model(model)
181
+
182
+ def _progress(generated, to_generate):
183
+ progress((generated, to_generate))
184
+ if INTERRUPTING:
185
+ raise gr.Error("Interrupted.")
186
+ MODEL.set_custom_progress_callback(_progress)
187
+
188
+ outs = _do_predictions(
189
+ [text], [melody], duration, progress=True,
190
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
191
+
192
+
193
+
194
+ return outs[0]
195
+
196
+
197
+ def toggle_audio_src(choice):
198
+ if choice == "mic":
199
+ return gr.update(source="microphone", value=None, label="Microphone")
200
+ else:
201
+ return gr.update(source="upload", value=None, label="File")
202
+
203
+
204
+ def ui_full(launch_kwargs):
205
+ with gr.Blocks() as interface:
206
+ gr.Markdown(
207
+ """
208
+ # MusicGen
209
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
210
+ a simple and controllable model for music generation
211
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
212
+ """
213
+ )
214
+ with gr.Row():
215
+ with gr.Column():
216
+ with gr.Row():
217
+ text = gr.Text(label="Input Text", interactive=True)
218
+ with gr.Column():
219
+ radio = gr.Radio(["file", "mic"], value="file",
220
+ label="Condition on a melody (optional) File or Mic")
221
+ melody = gr.Audio(source="upload", type="numpy", label="File",
222
+ interactive=True, elem_id="melody-input")
223
+ with gr.Row():
224
+ submit = gr.Button("Submit")
225
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
226
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
227
+ with gr.Row():
228
+ model = gr.Radio(["melody", "medium", "small", "large"],
229
+ label="Model", value="melody", interactive=True)
230
+ with gr.Row():
231
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
232
+ with gr.Row():
233
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
234
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
235
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
236
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
237
+ with gr.Column():
238
+ output = gr.Video(label="Generated Music")
239
+ submit.click(predict_full,
240
+ inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef],
241
+ outputs=[output])
242
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
243
+ gr.Markdown(
244
+ """
245
+ ### More details
246
+
247
+ The model will generate a short music extract based on the description you provided.
248
+ The model can generate up to 30 seconds of audio in one pass. It is now possible
249
+ to extend the generation by feeding back the end of the previous chunk of audio.
250
+ This can take a long time, and the model might lose consistency. The model might also
251
+ decide at arbitrary positions that the song ends.
252
+
253
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
254
+ An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
255
+ are generated each time.
256
+
257
+ We present 4 model variations:
258
+ 1. Melody -- a music generation model capable of generating music condition
259
+ on text and melody inputs. **Note**, you can also use text only.
260
+ 2. Small -- a 300M transformer decoder conditioned on text only.
261
+ 3. Medium -- a 1.5B transformer decoder conditioned on text only.
262
+ 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
263
+
264
+ When using `melody`, ou can optionaly provide a reference audio from
265
+ which a broad melody will be extracted. The model will then try to follow both
266
+ the description and melody provided.
267
+
268
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
269
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
270
+ for more details.
271
+ """
272
+ )
273
+
274
+ interface.queue().launch(**launch_kwargs)
275
+
276
+
277
+ def ui_batched(launch_kwargs):
278
+ with gr.Blocks() as demo:
279
+ gr.Markdown(
280
+ """
281
+ # MusicGen
282
+
283
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
284
+ a simple and controllable model for music generation
285
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
286
+ <br/>
287
+ <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
288
+ style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
289
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
290
+ src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
291
+ for longer sequences, more control and no queue.</p>
292
+ """
293
+ )
294
+ with gr.Row():
295
+ with gr.Column():
296
+ with gr.Row():
297
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
298
+ with gr.Column():
299
+ radio = gr.Radio(["file", "mic"], value="file",
300
+ label="Condition on a melody (optional) File or Mic")
301
+ melody = gr.Audio(source="upload", type="numpy", label="File",
302
+ interactive=True, elem_id="melody-input")
303
+ with gr.Row():
304
+ submit = gr.Button("Generate")
305
+ with gr.Column():
306
+ output = gr.Video(label="Generated Music")
307
+ submit.click(predict_batched, inputs=[text, melody],
308
+ outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
309
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
310
+ # gr.Examples(
311
+ # fn=predict_batched,
312
+ # examples=[
313
+ # [
314
+ # "An 80s driving pop song with heavy drums and synth pads in the background",
315
+ # "./assets/bach.mp3",
316
+ # ],
317
+ # [
318
+ # "A cheerful country song with acoustic guitars",
319
+ # "./assets/bolero_ravel.mp3",
320
+ # ],
321
+ # [
322
+ # "90s rock song with electric guitar and heavy drums",
323
+ # None,
324
+ # ],
325
+ # [
326
+ # "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
327
+ # "./assets/bach.mp3",
328
+ # ],
329
+ # [
330
+ # "lofi slow bpm electro chill with organic samples",
331
+ # None,
332
+ # ],
333
+ # ],
334
+ # inputs=[text, melody],
335
+ # outputs=[output]
336
+ # )
337
+ gr.Markdown("""
338
+ ### More details
339
+
340
+ The model will generate 12 seconds of audio based on the description you provided.
341
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
342
+ The model will then try to follow both the description and melody provided.
343
+ All samples are generated with the `melody` model.
344
+
345
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
346
+
347
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
348
+ for more details.
349
+ """)
350
+
351
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
352
+
353
+
354
+ if __name__ == "__main__":
355
+ parser = argparse.ArgumentParser()
356
+ parser.add_argument(
357
+ '--listen',
358
+ type=str,
359
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
360
+ help='IP to listen on for connections to Gradio',
361
+ )
362
+ parser.add_argument(
363
+ '--username', type=str, default='', help='Username for authentication'
364
+ )
365
+ parser.add_argument(
366
+ '--password', type=str, default='', help='Password for authentication'
367
+ )
368
+ parser.add_argument(
369
+ '--server_port',
370
+ type=int,
371
+ default=0,
372
+ help='Port to run the server listener on',
373
+ )
374
+ parser.add_argument(
375
+ '--inbrowser', action='store_true', help='Open in browser'
376
+ )
377
+ parser.add_argument(
378
+ '--share', action='store_true', help='Share the gradio UI'
379
+ )
380
+
381
+ args = parser.parse_args()
382
+
383
+ launch_kwargs = {}
384
+ launch_kwargs['server_name'] = args.listen
385
+
386
+ if args.username and args.password:
387
+ launch_kwargs['auth'] = (args.username, args.password)
388
+ if args.server_port:
389
+ launch_kwargs['server_port'] = args.server_port
390
+ if args.inbrowser:
391
+ launch_kwargs['inbrowser'] = args.inbrowser
392
+ if args.share:
393
+ launch_kwargs['share'] = args.share
394
+
395
+ # Show the interface
396
+ if IS_BATCHED:
397
+ ui_batched(launch_kwargs)
398
+ else:
399
+ ui_full(launch_kwargs)
setup.py CHANGED
@@ -34,7 +34,8 @@ setup(
34
  "soundfile",
35
  "flask",
36
  "flask-socketio",
37
- "audiocraft@git+https://github.com/facebookresearch/audiocraft"
 
38
  ],
39
  include_package_data=True,
40
  )
 
34
  "soundfile",
35
  "flask",
36
  "flask-socketio",
37
+ "audiocraft@git+https://github.com/facebookresearch/audiocraft",
38
+ "gradio"
39
  ],
40
  include_package_data=True,
41
  )