ford442 commited on
Commit
054d10f
·
verified ·
1 Parent(s): efa2898

Update demos/musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +6 -5
demos/musicgen_app.py CHANGED
@@ -186,7 +186,7 @@ class Predictor:
186
  tokens = torch.cat([left, right])
187
  outputs_diffusion = self.mbd.tokens_to_wav(tokens)
188
  if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
189
- assert outputs_diffusion.shape[1] == 1
190
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
191
  outputs_diffusion = outputs_diffusion.detach().cpu()
192
  return task_id, (output, outputs_diffusion) #Return the task id.
@@ -234,7 +234,7 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
234
  # Initialize Predictor *INSIDE* the function
235
  predictor = Predictor(model)
236
 
237
- task_id = predictor.predict(
238
  text=text,
239
  melody=melody,
240
  duration=duration,
@@ -245,8 +245,6 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
245
  cfg_coef=cfg_coef,
246
  )
247
 
248
- wav, diffusion_wav = predictor.get_result(task_id)
249
-
250
  # Save and return audio files
251
  wav_paths = []
252
  video_paths = []
@@ -272,7 +270,9 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
272
  video_paths.append(make_waveform(file.name)) # Make and clean up video
273
  file_cleaner.add(file.name)
274
  # Shutdown predictor to prevent hanging processes!
275
- predictor.shutdown()
 
 
276
 
277
  if use_mbd:
278
  return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
@@ -293,6 +293,7 @@ def toggle_diffusion(choice):
293
  return [gr.update(visible=False)] * 2
294
  # --- Gradio UI ---
295
 
 
296
  def ui_full(launch_kwargs):
297
  with gr.Blocks() as interface:
298
  gr.Markdown(
 
186
  tokens = torch.cat([left, right])
187
  outputs_diffusion = self.mbd.tokens_to_wav(tokens)
188
  if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
189
+ assert outputs_diffusion.shape[1] == 1 # output is mono
190
  outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
191
  outputs_diffusion = outputs_diffusion.detach().cpu()
192
  return task_id, (output, outputs_diffusion) #Return the task id.
 
234
  # Initialize Predictor *INSIDE* the function
235
  predictor = Predictor(model)
236
 
237
+ task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
238
  text=text,
239
  melody=melody,
240
  duration=duration,
 
245
  cfg_coef=cfg_coef,
246
  )
247
 
 
 
248
  # Save and return audio files
249
  wav_paths = []
250
  video_paths = []
 
270
  video_paths.append(make_waveform(file.name)) # Make and clean up video
271
  file_cleaner.add(file.name)
272
  # Shutdown predictor to prevent hanging processes!
273
+
274
+ if not predictor.is_daemon: # Important!
275
+ predictor.shutdown()
276
 
277
  if use_mbd:
278
  return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
 
293
  return [gr.update(visible=False)] * 2
294
  # --- Gradio UI ---
295
 
296
+
297
  def ui_full(launch_kwargs):
298
  with gr.Blocks() as interface:
299
  gr.Markdown(