mattricesound commited on
Commit
e316ba7
1 Parent(s): c2a47f9

Remove unecessary components. Add back in examples. Add ability to swap local model for using musicgen gradio api

Browse files
Files changed (1) hide show
  1. app.py +113 -108
app.py CHANGED
@@ -16,29 +16,35 @@ from tempfile import NamedTemporaryFile
16
  import time
17
  import typing as tp
18
  import warnings
 
19
 
20
  import torch
21
  import gradio as gr
22
-
23
  from audiocraft.data.audio_utils import convert_audio
24
- from audiocraft.data.audio import audio_write
25
  from audiocraft.models import MusicGen
26
 
27
  from demucs import pretrained
28
  from demucs.apply import apply_model
29
  from demucs.audio import convert_audio
 
 
 
30
 
31
 
32
  MODEL = None # Last used model
33
  DEMUCS_MODEL = None
34
  MAX_BATCH_SIZE = 12
35
  INTERRUPTING = False
 
36
  # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
37
  _old_call = sp.call
38
 
39
  stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
40
  stem_idx = torch.LongTensor([stem2idx['vocal'], stem2idx['other'], stem2idx['bass']])
41
 
 
42
 
43
 
44
  def _call_nostderr(*args, **kwargs):
@@ -94,14 +100,19 @@ def make_waveform(*args, **kwargs):
94
 
95
  def load_model(version='melody'):
96
  global MODEL, DEMUCS_MODEL
97
- print("Loading model", version)
98
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
99
- if MODEL is None or MODEL.name != version:
100
- # If gpu is not available, we'll use cpu.
101
- MODEL = MusicGen.get_pretrained(version, device=device)
 
 
102
  if DEMUCS_MODEL is None:
103
  DEMUCS_MODEL = pretrained.get_model('htdemucs').to(device)
104
 
 
 
 
 
105
 
106
  def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
107
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
@@ -149,31 +160,19 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
149
  demucs_output = demucs_output.cpu()
150
 
151
  # Naming
152
- filename = f"temp/{texts[0][:10]}.wav"
153
- d_filename = f"temp/{texts[0][:10]}_no_drums.wav"
154
 
155
  # If path exists, add number. If number exists, update number.
156
  i = 1
157
- while Path(filename).exists():
158
- filename = f"{texts[0][:10]}_{i}.wav"
159
- d_filename = f"{texts[0][:10]}_{i}_no_drums.wav"
160
  i += 1
161
 
162
- # with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
163
- audio_write(
164
- filename, output, MODEL.sample_rate, strategy="loudness",
165
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
166
- # out_files.append(pool.submit(make_waveform, filename))
167
- out_files.append(filename)
168
- file_cleaner.add(filename)
169
- # with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
170
  audio_write(
171
  d_filename, demucs_output, MODEL.sample_rate, strategy="loudness",
172
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
173
  out_files.append(d_filename)
174
- # out_files.append(pool.submit(make_waveform, d_filename))
175
  file_cleaner.add(d_filename)
176
- # res = [out_file.result() for out_file in out_files]
177
  res = [out_file for out_file in out_files]
178
  for file in res:
179
  file_cleaner.add(file)
@@ -183,18 +182,10 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
183
 
184
 
185
 
186
- def predict_full(text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
187
  global INTERRUPTING
188
  INTERRUPTING = False
189
- if temperature < 0:
190
- raise gr.Error("Temperature must be >= 0.")
191
- if topk < 0:
192
- raise gr.Error("Topk must be non-negative.")
193
- if topp < 0:
194
- raise gr.Error("Topp must be non-negative.")
195
-
196
- topk = int(topk)
197
-
198
  def _progress(generated, to_generate):
199
  progress((generated, to_generate))
200
  if INTERRUPTING:
@@ -202,77 +193,114 @@ def predict_full(text, melody, duration, topk, topp, temperature, cfg_coef, prog
202
  MODEL.set_custom_progress_callback(_progress)
203
 
204
  outs = _do_predictions(
205
- [text], [melody], duration, progress=True,
206
- top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
207
 
208
- return outs[0], outs[1], outs[0], outs[1]
209
 
210
 
211
- def toggle_audio_src(choice):
212
- if choice == "mic":
213
- return gr.update(source="microphone", value=None, label="Microphone")
214
- else:
215
- return gr.update(source="upload", value=None, label="File")
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  def ui_full(launch_kwargs):
219
  with gr.Blocks() as interface:
 
 
 
 
 
220
  with gr.Row():
221
  with gr.Column():
222
  with gr.Row():
223
  text = gr.Text(label="Input Text", interactive=True)
224
  with gr.Column():
225
- radio = gr.Radio(["file", "mic"], value="file",
226
- label="Condition on a melody (optional) File or Mic")
227
- melody = gr.Audio(source="upload", type="numpy", label="File",
228
- interactive=True, elem_id="melody-input")
 
 
 
 
229
  with gr.Row():
230
  submit = gr.Button("Submit")
231
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
232
- _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
233
- with gr.Row():
234
- duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
235
- with gr.Row():
236
- topk = gr.Number(label="Top-k", value=250, interactive=True)
237
- topp = gr.Number(label="Top-p", value=0, interactive=True)
238
- temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
239
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
240
  with gr.Column():
241
- with gr.Row():
242
- # output_normal = gr.Video(label="Generated Music")
243
- output_normal = gr.Audio(label="Generated Music")
244
- with gr.Row():
245
- file_download = gr.File(label="Download")
246
- with gr.Row():
247
- # output_without_drum = gr.Video(label="Removed drums")
248
- output_without_drum = gr.Audio(label="Removed drums")
249
- with gr.Row():
250
- file_download_no_drum = gr.File(label="Download")
251
- with gr.Row():
252
  gr.Markdown(
253
  """
254
  Note that the files will be deleted after 10 minutes, so make sure to download!
255
  """
256
  )
257
- submit.click(predict_full,
258
- inputs=[text, melody, duration, topk, topp, temperature, cfg_coef],
259
- outputs=[output_normal, output_without_drum, file_download, file_download_no_drum])
260
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
261
- gr.Markdown(
262
- """
263
- ### More details
264
-
265
- The model will generate a short music extract based on the description you provided.
266
- The model can generate up to 30 seconds of audio in one pass. It is now possible
267
- to extend the generation by feeding back the end of the previous chunk of audio.
268
- This can take a long time, and the model might lose consistency. The model might also
269
- decide at arbitrary positions that the song ends.
270
-
271
- **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
272
- An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
273
- are generated each time.
274
-
275
- """
 
 
 
 
276
  )
277
 
278
  interface.queue().launch(**launch_kwargs)
@@ -286,41 +314,18 @@ if __name__ == "__main__":
286
  default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
287
  help='IP to listen on for connections to Gradio',
288
  )
289
- parser.add_argument(
290
- '--username', type=str, default='', help='Username for authentication'
291
- )
292
- parser.add_argument(
293
- '--password', type=str, default='', help='Password for authentication'
294
- )
295
- parser.add_argument(
296
- '--server_port',
297
- type=int,
298
- default=0,
299
- help='Port to run the server listener on',
300
- )
301
- parser.add_argument(
302
- '--inbrowser', action='store_true', help='Open in browser'
303
- )
304
- parser.add_argument(
305
- '--share', action='store_true', help='Share the gradio UI'
306
- )
307
 
308
  args = parser.parse_args()
309
 
310
  launch_kwargs = {}
311
  launch_kwargs['server_name'] = args.listen
312
 
313
- if args.username and args.password:
314
- launch_kwargs['auth'] = (args.username, args.password)
315
- if args.server_port:
316
- launch_kwargs['server_port'] = args.server_port
317
- if args.inbrowser:
318
- launch_kwargs['inbrowser'] = args.inbrowser
319
- if args.share:
320
- launch_kwargs['share'] = args.share
321
-
322
  # Load melody model
323
  load_model()
 
 
324
  if not os.path.exists("temp"):
325
  os.mkdir("temp")
326
  # Show the interface
 
16
  import time
17
  import typing as tp
18
  import warnings
19
+ import glob
20
 
21
  import torch
22
  import gradio as gr
23
+ import numpy as np
24
  from audiocraft.data.audio_utils import convert_audio
25
+ from audiocraft.data.audio import audio_write, audio_read
26
  from audiocraft.models import MusicGen
27
 
28
  from demucs import pretrained
29
  from demucs.apply import apply_model
30
  from demucs.audio import convert_audio
31
+ from gradio_client import Client
32
+
33
+ LOCAL = False
34
 
35
 
36
  MODEL = None # Last used model
37
  DEMUCS_MODEL = None
38
  MAX_BATCH_SIZE = 12
39
  INTERRUPTING = False
40
+ client = None
41
  # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
42
  _old_call = sp.call
43
 
44
  stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
45
  stem_idx = torch.LongTensor([stem2idx['vocal'], stem2idx['other'], stem2idx['bass']])
46
 
47
+ melody_files = glob.glob('clips/**/*.mp3', recursive=True)
48
 
49
 
50
  def _call_nostderr(*args, **kwargs):
 
100
 
101
  def load_model(version='melody'):
102
  global MODEL, DEMUCS_MODEL
 
103
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
104
+ if LOCAL:
105
+ if MODEL is None or MODEL.name != version:
106
+ print("Loading model", version)
107
+ # If gpu is not available, we'll use cpu.
108
+ MODEL = MusicGen.get_pretrained(version, device=device)
109
  if DEMUCS_MODEL is None:
110
  DEMUCS_MODEL = pretrained.get_model('htdemucs').to(device)
111
 
112
+ def connect_to_endpoint():
113
+ global client
114
+ client = Client("https://facebook-musicgen--44zzp.hf.space/")
115
+
116
 
117
  def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
118
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
 
160
  demucs_output = demucs_output.cpu()
161
 
162
  # Naming
163
+ d_filename = f"temp/{texts[0][:10]}.wav"
 
164
 
165
  # If path exists, add number. If number exists, update number.
166
  i = 1
167
+ while Path(d_filename).exists():
168
+ d_filename = f"temp/{texts[0][:10]}_{i}.wav"
 
169
  i += 1
170
 
 
 
 
 
 
 
 
 
171
  audio_write(
172
  d_filename, demucs_output, MODEL.sample_rate, strategy="loudness",
173
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
174
  out_files.append(d_filename)
 
175
  file_cleaner.add(d_filename)
 
176
  res = [out_file for out_file in out_files]
177
  for file in res:
178
  file_cleaner.add(file)
 
182
 
183
 
184
 
185
+ def predict_full(text, melody, progress=gr.Progress()):
186
  global INTERRUPTING
187
  INTERRUPTING = False
188
+ print("Running local model")
 
 
 
 
 
 
 
 
189
  def _progress(generated, to_generate):
190
  progress((generated, to_generate))
191
  if INTERRUPTING:
 
193
  MODEL.set_custom_progress_callback(_progress)
194
 
195
  outs = _do_predictions(
196
+ [text], [melody], duration=10, progress=True)
 
197
 
198
+ return outs[0], gr.File.update(value=outs[0], visible=True)
199
 
200
 
 
 
 
 
 
201
 
202
+ def select_new_melody():
203
+ new_melody_file = np.random.choice(melody_files)
204
+ return gr.update(source="upload", value=new_melody_file)
205
+
206
+ def run_remote_model(text, melody):
207
+ print("Running Audiocraft API model with text", text, "and melody", melody)
208
+ result = client.predict(
209
+ text, # str in 'Describe your music' Textbox component
210
+ melody, # str (filepath or URL to file) in 'File' Audio component
211
+ fn_index=0
212
+ )
213
+ # Naming
214
+ d_filename = os.path.join("temp", f"{text[:10]}.wav")
215
+ # If path exists, add number. If number exists, update number.
216
+ i = 1
217
+ while Path(d_filename).exists():
218
+ d_filename = os.path.join("temp", f"{text[:10]}_{i}.wav")
219
+ i += 1
220
+
221
+ # Convert mp4 to wav, using ffmpeg
222
+ # ffmpeg -i input.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 output.wav
223
+ sp.run(["ffmpeg", "-i", result, "-vn", "-acodec", "pcm_s16le", "-ar", "32000", "-ac", "1", d_filename])
224
+ # Load wav file
225
+ output, sr = audio_read(d_filename)
226
+ # Demucs
227
+ print("Running demucs")
228
+ wav = convert_audio(output, sr, DEMUCS_MODEL.samplerate, DEMUCS_MODEL.audio_channels)
229
+ wav = wav.unsqueeze(0)
230
+ stems = apply_model(DEMUCS_MODEL, wav)
231
+ stems = stems[:, stem_idx] # extract stem
232
+ stems = stems.sum(1) # merge extracted stems
233
+ stems = convert_audio(stems, DEMUCS_MODEL.samplerate, 32000, 1)
234
+ demucs_output = stems[0]
235
+
236
+ output = output.cpu()
237
+ demucs_output = demucs_output.cpu()
238
+
239
+
240
+ audio_write(
241
+ d_filename, demucs_output, 32000, strategy="loudness",
242
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
243
+ file_cleaner.add(d_filename)
244
+ print("Finished", text)
245
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
246
+ return d_filename, gr.File.update(value=d_filename, visible=True)
247
 
248
  def ui_full(launch_kwargs):
249
  with gr.Blocks() as interface:
250
+ gr.Markdown(
251
+ """
252
+ # Soundsauce Melody Playground
253
+ """
254
+ )
255
  with gr.Row():
256
  with gr.Column():
257
  with gr.Row():
258
  text = gr.Text(label="Input Text", interactive=True)
259
  with gr.Column():
260
+ # previously, type="numpy"
261
+ if LOCAL:
262
+ audio_type="numpy"
263
+ else:
264
+ audio_type="filepath"
265
+ melody = gr.Audio(type=audio_type, label="File",
266
+ interactive=True, elem_id="melody-input", value="clips/chipmunk.wav")
267
+ new_melody = gr.Button("New Melody", interactive=True)
268
  with gr.Row():
269
  submit = gr.Button("Submit")
270
  # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
271
+ # _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
272
+
 
 
 
 
 
 
273
  with gr.Column():
274
+ output_without_drum = gr.Audio(label="Output")
275
+ file_download_no_drum = gr.File(label="Download", visible=False)
 
 
 
 
 
 
 
 
 
276
  gr.Markdown(
277
  """
278
  Note that the files will be deleted after 10 minutes, so make sure to download!
279
  """
280
  )
281
+ if LOCAL:
282
+ submit.click(predict_full,
283
+ inputs=[text, melody],
284
+ outputs=[output_without_drum, file_download_no_drum])
285
+ else:
286
+ submit.click(run_remote_model, inputs=[text, melody], outputs=[output_without_drum, file_download_no_drum])
287
+ new_melody.click(select_new_melody, outputs=[melody])
288
+ gr.Examples(
289
+ fn=predict_full,
290
+ examples=[
291
+ ["Enchanting Flute Trills amidst Misty String Section"],
292
+ ["Gliding Mellotron Strings over Vibrant Phrases"],
293
+ ["Synth Brass Melody Floating over Airy Wind Chimes"],
294
+ ["Echoing Electric Guitar Licks with Ethereal Vocal Chops"],
295
+ ["Rhythmic Acoustic Guitar Licks with Echoing Layers"],
296
+ ["Whimsical Flute Flourishes in a Mystical Forest Glade"],
297
+ ["Airy Piccolo Trills accompanied by Floating Harp Arpeggios"],
298
+ ["Dreamy Harp Glissandos accompanied by Distant Celesta"],
299
+ ["Hypnotic Synth Pads layered with Enigmatic Guitar Progressions"],
300
+ ["Enchanting Kalimba Melodies atop Mystical Atmosphere"],
301
+ ],
302
+ inputs=[text],
303
+ outputs=[output_without_drum, file_download_no_drum]
304
  )
305
 
306
  interface.queue().launch(**launch_kwargs)
 
314
  default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
315
  help='IP to listen on for connections to Gradio',
316
  )
317
+ parser.add_argument("--local", action="store_true", help="Run locally instead of using API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  args = parser.parse_args()
320
 
321
  launch_kwargs = {}
322
  launch_kwargs['server_name'] = args.listen
323
 
324
+ LOCAL = args.local
 
 
 
 
 
 
 
 
325
  # Load melody model
326
  load_model()
327
+ if not LOCAL:
328
+ connect_to_endpoint()
329
  if not os.path.exists("temp"):
330
  os.mkdir("temp")
331
  # Show the interface