mattricesound commited on
Commit
d9755fb
1 Parent(s): 8180c66

Add demucs output, load melody model on launch

Browse files
Files changed (1) hide show
  1. app.py +29 -68
app.py CHANGED
@@ -102,7 +102,7 @@ def load_model(version='melody'):
102
  MODEL = MusicGen.get_pretrained(version, device=device)
103
 
104
 
105
- def _do_predictions(texts, melodies, duration, progress=False, drums=True, **gen_kwargs):
106
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
107
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
108
  be = time.time()
@@ -135,22 +135,25 @@ def _do_predictions(texts, melodies, duration, progress=False, drums=True, **gen
135
  out_files = []
136
  for output in outputs:
137
  # Demucs
138
- if not drums:
139
- print("Running demucs")
140
- wav = convert_audio(output, MODEL.sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
141
- wav = wav.unsqueeze(0)
142
- stems = apply_model(demucs_model, wav)
143
- stems = stems[:, stem_idx] # extract stem
144
- stems = stems.sum(1) # merge extracted stems
145
- stems = convert_audio(stems, demucs_model.samplerate, MODEL.sample_rate, 1)
146
- output = stems[0]
147
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
148
  audio_write(
149
  file.name, output, MODEL.sample_rate, strategy="loudness",
150
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
151
-
152
-
153
-
 
 
 
154
  out_files.append(pool.submit(make_waveform, file.name))
155
  file_cleaner.add(file.name)
156
  res = [out_file.result() for out_file in out_files]
@@ -169,7 +172,7 @@ def predict_batched(texts, melodies):
169
  return [res]
170
 
171
 
172
- def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, drums, progress=gr.Progress()):
173
  global INTERRUPTING
174
  INTERRUPTING = False
175
  if temperature < 0:
@@ -180,7 +183,7 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe
180
  raise gr.Error("Topp must be non-negative.")
181
 
182
  topk = int(topk)
183
- load_model(model)
184
 
185
  def _progress(generated, to_generate):
186
  progress((generated, to_generate))
@@ -190,11 +193,9 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe
190
 
191
  outs = _do_predictions(
192
  [text], [melody], duration, progress=True,
193
- top_k=topk, top_p=topp, temperature=temperature, drums=drums, cfg_coef=cfg_coef)
194
-
195
 
196
-
197
- return outs[0]
198
 
199
 
200
  def toggle_audio_src(choice):
@@ -219,9 +220,6 @@ def ui_full(launch_kwargs):
219
  submit = gr.Button("Submit")
220
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
221
  _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
222
- with gr.Row():
223
- model = gr.Radio(["melody", "medium", "small", "large"],
224
- label="Model", value="melody", interactive=True)
225
  with gr.Row():
226
  duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
227
  with gr.Row():
@@ -229,13 +227,15 @@ def ui_full(launch_kwargs):
229
  topp = gr.Number(label="Top-p", value=0, interactive=True)
230
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
231
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
232
- with gr.Row():
233
- drums = gr.Checkbox(label="Drums", value=True, interactive=True)
234
  with gr.Column():
235
- output = gr.Video(label="Generated Music")
 
 
 
 
236
  submit.click(predict_full,
237
- inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef, drums],
238
- outputs=[output])
239
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
240
  gr.Markdown(
241
  """
@@ -251,20 +251,6 @@ def ui_full(launch_kwargs):
251
  An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
252
  are generated each time.
253
 
254
- We present 4 model variations:
255
- 1. Melody -- a music generation model capable of generating music condition
256
- on text and melody inputs. **Note**, you can also use text only.
257
- 2. Small -- a 300M transformer decoder conditioned on text only.
258
- 3. Medium -- a 1.5B transformer decoder conditioned on text only.
259
- 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
260
-
261
- When using `melody`, ou can optionaly provide a reference audio from
262
- which a broad melody will be extracted. The model will then try to follow both
263
- the description and melody provided.
264
-
265
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
266
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
267
- for more details.
268
  """
269
  )
270
 
@@ -304,33 +290,6 @@ def ui_batched(launch_kwargs):
304
  submit.click(predict_batched, inputs=[text, melody],
305
  outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
306
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
307
- # gr.Examples(
308
- # fn=predict_batched,
309
- # examples=[
310
- # [
311
- # "An 80s driving pop song with heavy drums and synth pads in the background",
312
- # "./assets/bach.mp3",
313
- # ],
314
- # [
315
- # "A cheerful country song with acoustic guitars",
316
- # "./assets/bolero_ravel.mp3",
317
- # ],
318
- # [
319
- # "90s rock song with electric guitar and heavy drums",
320
- # None,
321
- # ],
322
- # [
323
- # "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
324
- # "./assets/bach.mp3",
325
- # ],
326
- # [
327
- # "lofi slow bpm electro chill with organic samples",
328
- # None,
329
- # ],
330
- # ],
331
- # inputs=[text, melody],
332
- # outputs=[output]
333
- # )
334
  gr.Markdown("""
335
  ### More details
336
 
@@ -389,6 +348,8 @@ if __name__ == "__main__":
389
  if args.share:
390
  launch_kwargs['share'] = args.share
391
 
 
 
392
  # Show the interface
393
  if IS_BATCHED:
394
  ui_batched(launch_kwargs)
 
102
  MODEL = MusicGen.get_pretrained(version, device=device)
103
 
104
 
105
+ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
106
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
107
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
108
  be = time.time()
 
135
  out_files = []
136
  for output in outputs:
137
  # Demucs
138
+ print("Running demucs")
139
+ wav = convert_audio(output, MODEL.sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
140
+ wav = wav.unsqueeze(0)
141
+ stems = apply_model(demucs_model, wav)
142
+ stems = stems[:, stem_idx] # extract stem
143
+ stems = stems.sum(1) # merge extracted stems
144
+ stems = convert_audio(stems, demucs_model.samplerate, MODEL.sample_rate, 1)
145
+ demucs_output = stems[0]
146
+
147
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
148
  audio_write(
149
  file.name, output, MODEL.sample_rate, strategy="loudness",
150
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
151
+ out_files.append(pool.submit(make_waveform, file.name))
152
+ file_cleaner.add(file.name)
153
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
154
+ audio_write(
155
+ file.name, demucs_output, MODEL.sample_rate, strategy="loudness",
156
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
157
  out_files.append(pool.submit(make_waveform, file.name))
158
  file_cleaner.add(file.name)
159
  res = [out_file.result() for out_file in out_files]
 
172
  return [res]
173
 
174
 
175
+ def predict_full(text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
176
  global INTERRUPTING
177
  INTERRUPTING = False
178
  if temperature < 0:
 
183
  raise gr.Error("Topp must be non-negative.")
184
 
185
  topk = int(topk)
186
+ # load_model(model)
187
 
188
  def _progress(generated, to_generate):
189
  progress((generated, to_generate))
 
193
 
194
  outs = _do_predictions(
195
  [text], [melody], duration, progress=True,
196
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
 
197
 
198
+ return outs[0], outs[1]
 
199
 
200
 
201
  def toggle_audio_src(choice):
 
220
  submit = gr.Button("Submit")
221
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
222
  _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
 
 
 
223
  with gr.Row():
224
  duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
225
  with gr.Row():
 
227
  topp = gr.Number(label="Top-p", value=0, interactive=True)
228
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
229
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
 
 
230
  with gr.Column():
231
+ with gr.Row():
232
+ output_normal = gr.Video(label="Generated Music")
233
+ with gr.Row():
234
+ output_without_drum = gr.Video(label="Removed drums")
235
+
236
  submit.click(predict_full,
237
+ inputs=[text, melody, duration, topk, topp, temperature, cfg_coef],
238
+ outputs=[output_normal, output_without_drum])
239
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
240
  gr.Markdown(
241
  """
 
251
  An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
252
  are generated each time.
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  """
255
  )
256
 
 
290
  submit.click(predict_batched, inputs=[text, melody],
291
  outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
292
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  gr.Markdown("""
294
  ### More details
295
 
 
348
  if args.share:
349
  launch_kwargs['share'] = args.share
350
 
351
+ # Load melody model
352
+ load_model()
353
  # Show the interface
354
  if IS_BATCHED:
355
  ui_batched(launch_kwargs)