adefossez commited on
Commit
6457900
1 Parent(s): 4cf6900
Files changed (3) hide show
  1. README.md +6 -5
  2. app.py +25 -10
  3. audiocraft/models/musicgen.py +6 -2
README.md CHANGED
@@ -38,11 +38,12 @@ pip install -e . # or if you cloned the repo locally
38
 
39
  ## Usage
40
  We offer a number of way to interact with MusicGen:
41
- 1. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally, or use the provided [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
42
- 2. You can use the gradio demo locally by running `python app.py`.
43
- 3. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
44
- 4. Finally, you can run the [Gradio demo with a Colab GPU](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing),
45
- as adapted from [@camenduru Colab](https://github.com/camenduru/MusicGen-colab).
 
46
 
47
  ## API
48
 
 
38
 
39
  ## Usage
40
  We offer a number of way to interact with MusicGen:
41
+ 1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
42
+ 2. You can run the extended demo on a Colab: [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
43
+ 3. You can use the gradio demo locally by running `python app.py`.
44
+ 4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
45
+ 5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) which is regularly
46
+ updated with contributions from @camenduru and the community.
47
 
48
  ## API
49
 
app.py CHANGED
@@ -4,6 +4,9 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
 
7
  import argparse
8
  from concurrent.futures import ProcessPoolExecutor
9
  import os
@@ -22,8 +25,9 @@ from audiocraft.models import MusicGen
22
 
23
  MODEL = None # Last used model
24
  IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
25
- MAX_BATCH_SIZE = 12
26
  BATCHED_DURATION = 15
 
27
  # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
28
  _old_call = sp.call
29
 
@@ -37,10 +41,14 @@ def _call_nostderr(*args, **kwargs):
37
 
38
  sp.call = _call_nostderr
39
  # Preallocating the pool of processes.
40
- pool = ProcessPoolExecutor(3)
41
  pool.__enter__()
42
 
43
 
 
 
 
 
44
  def make_waveform(*args, **kwargs):
45
  # Further remove some warnings.
46
  be = time.time()
@@ -59,9 +67,6 @@ def load_model(version='melody'):
59
 
60
 
61
  def _do_predictions(texts, melodies, duration, **gen_kwargs):
62
- if duration > MODEL.lm.cfg.dataset.segment_duration:
63
- raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
64
-
65
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
66
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
67
  be = time.time()
@@ -84,10 +89,10 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs):
84
  descriptions=texts,
85
  melody_wavs=processed_melodies,
86
  melody_sample_rate=target_sr,
87
- progress=False
88
  )
89
  else:
90
- outputs = MODEL.generate(texts, progress=False)
91
 
92
  outputs = outputs.detach().cpu().float()
93
  out_files = []
@@ -110,9 +115,16 @@ def predict_batched(texts, melodies):
110
  return [res]
111
 
112
 
113
- def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
 
 
114
  topk = int(topk)
115
  load_model(model)
 
 
 
 
 
116
 
117
  outs = _do_predictions(
118
  [text], [melody], duration,
@@ -136,6 +148,8 @@ def ui_full(launch_kwargs):
136
  melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
137
  with gr.Row():
138
  submit = gr.Button("Submit")
 
 
139
  with gr.Row():
140
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
141
  with gr.Row():
@@ -190,7 +204,8 @@ def ui_full(launch_kwargs):
190
  This can take a long time, and the model might lose consistency. The model might also
191
  decide at arbitrary positions that the song ends.
192
 
193
- **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
 
194
 
195
  We present 4 model variations:
196
  1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
@@ -207,7 +222,7 @@ def ui_full(launch_kwargs):
207
  """
208
  )
209
 
210
- interface.queue().launch(**launch_kwargs, max_threads=1)
211
 
212
 
213
  def ui_batched(launch_kwargs):
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
+ # also released under the MIT license.
9
+
10
  import argparse
11
  from concurrent.futures import ProcessPoolExecutor
12
  import os
 
25
 
26
  MODEL = None # Last used model
27
  IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
28
+ MAX_BATCH_SIZE = 8
29
  BATCHED_DURATION = 15
30
+ INTERRUPTING = False
31
  # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
32
  _old_call = sp.call
33
 
 
41
 
42
  sp.call = _call_nostderr
43
  # Preallocating the pool of processes.
44
+ pool = ProcessPoolExecutor(4)
45
  pool.__enter__()
46
 
47
 
48
+ def interrupt():
49
+ global INTERRUPTING
50
+ INTERRUPTING = True
51
+
52
  def make_waveform(*args, **kwargs):
53
  # Further remove some warnings.
54
  be = time.time()
 
67
 
68
 
69
  def _do_predictions(texts, melodies, duration, **gen_kwargs):
 
 
 
70
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
71
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
72
  be = time.time()
 
89
  descriptions=texts,
90
  melody_wavs=processed_melodies,
91
  melody_sample_rate=target_sr,
92
+ progress=True
93
  )
94
  else:
95
+ outputs = MODEL.generate(texts, progress=True)
96
 
97
  outputs = outputs.detach().cpu().float()
98
  out_files = []
 
115
  return [res]
116
 
117
 
118
+ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
119
+ global INTERRUPTING
120
+ INTERRUPTING = False
121
  topk = int(topk)
122
  load_model(model)
123
+ def _progress(generated, to_generate):
124
+ progress((generated, to_generate))
125
+ if INTERRUPTING:
126
+ raise gr.Error("Interrupted.")
127
+ MODEL.set_custom_progress_callback(_progress)
128
 
129
  outs = _do_predictions(
130
  [text], [melody], duration,
 
148
  melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
149
  with gr.Row():
150
  submit = gr.Button("Submit")
151
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
152
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
153
  with gr.Row():
154
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
155
  with gr.Row():
 
204
  This can take a long time, and the model might lose consistency. The model might also
205
  decide at arbitrary positions that the song ends.
206
 
207
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min). An overlap of 12 seconds
208
+ is kept with the previously generated chunk, and 18 "new" seconds are generated each time.
209
 
210
  We present 4 model variations:
211
  1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
 
222
  """
223
  )
224
 
225
+ interface.queue().launch(**launch_kwargs)
226
 
227
 
228
  def ui_batched(launch_kwargs):
audiocraft/models/musicgen.py CHANGED
@@ -99,7 +99,7 @@ class MusicGen:
99
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
100
  top_p: float = 0.0, temperature: float = 1.0,
101
  duration: float = 30.0, cfg_coef: float = 3.0,
102
- two_step_cfg: bool = False, extend_stride: float = 15):
103
  """Set the generation parameters for MusicGen.
104
 
105
  Args:
@@ -129,6 +129,7 @@ class MusicGen:
129
  }
130
 
131
  def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
 
132
  self._progress_callback = progress_callback
133
 
134
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
@@ -280,9 +281,11 @@ class MusicGen:
280
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
281
  generated_tokens += current_gen_offset
282
  if self._progress_callback is not None:
 
 
283
  self._progress_callback(generated_tokens, total_gen_len)
284
  else:
285
- print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
286
 
287
  if prompt_tokens is not None:
288
  assert max_prompt_len >= prompt_tokens.shape[-1], \
@@ -326,6 +329,7 @@ class MusicGen:
326
  # we wouldn't have the full wav.
327
  initial_position = int(time_offset * self.sample_rate)
328
  wav_target_length = int(self.max_duration * self.sample_rate)
 
329
  positions = torch.arange(initial_position,
330
  initial_position + wav_target_length, device=self.device)
331
  attr.wav['self_wav'] = WavCondition(
 
99
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
100
  top_p: float = 0.0, temperature: float = 1.0,
101
  duration: float = 30.0, cfg_coef: float = 3.0,
102
+ two_step_cfg: bool = False, extend_stride: float = 18):
103
  """Set the generation parameters for MusicGen.
104
 
105
  Args:
 
129
  }
130
 
131
  def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
132
+ """Override the default progress callback."""
133
  self._progress_callback = progress_callback
134
 
135
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
 
281
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
282
  generated_tokens += current_gen_offset
283
  if self._progress_callback is not None:
284
+ # Note that total_gen_len might be quite wrong depending on the
285
+ # codebook pattern used, but with delay it is almost accurate.
286
  self._progress_callback(generated_tokens, total_gen_len)
287
  else:
288
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
289
 
290
  if prompt_tokens is not None:
291
  assert max_prompt_len >= prompt_tokens.shape[-1], \
 
329
  # we wouldn't have the full wav.
330
  initial_position = int(time_offset * self.sample_rate)
331
  wav_target_length = int(self.max_duration * self.sample_rate)
332
+ print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
333
  positions = torch.arange(initial_position,
334
  initial_position + wav_target_length, device=self.device)
335
  attr.wav['self_wav'] = WavCondition(