adefossez commited on
Commit
4cf6900
1 Parent(s): 86d0f16

kind of working

Browse files
Files changed (2) hide show
  1. app.py +12 -6
  2. audiocraft/models/musicgen.py +19 -6
app.py CHANGED
@@ -59,6 +59,9 @@ def load_model(version='melody'):
59
 
60
 
61
  def _do_predictions(texts, melodies, duration, **gen_kwargs):
 
 
 
62
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
63
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
64
  be = time.time()
@@ -76,7 +79,7 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs):
76
  melody = convert_audio(melody, sr, target_sr, target_ac)
77
  processed_melodies.append(melody)
78
 
79
- if processed_melodies.any():
80
  outputs = MODEL.generate_with_chroma(
81
  descriptions=texts,
82
  melody_wavs=processed_melodies,
@@ -110,12 +113,10 @@ def predict_batched(texts, melodies):
110
  def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
111
  topk = int(topk)
112
  load_model(model)
113
- if duration > MODEL.lm.cfg.dataset.segment_duration:
114
- raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
115
 
116
  outs = _do_predictions(
117
  [text], [melody], duration,
118
- topk=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
119
  return outs[0]
120
 
121
 
@@ -138,7 +139,7 @@ def ui_full(launch_kwargs):
138
  with gr.Row():
139
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
140
  with gr.Row():
141
- duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
142
  with gr.Row():
143
  topk = gr.Number(label="Top-k", value=250, interactive=True)
144
  topp = gr.Number(label="Top-p", value=0, interactive=True)
@@ -184,7 +185,12 @@ def ui_full(launch_kwargs):
184
  ### More details
185
 
186
  The model will generate a short music extract based on the description you provided.
187
- You can generate up to 30 seconds of audio.
 
 
 
 
 
188
 
189
  We present 4 model variations:
190
  1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
 
59
 
60
 
61
  def _do_predictions(texts, melodies, duration, **gen_kwargs):
62
+ if duration > MODEL.lm.cfg.dataset.segment_duration:
63
+ raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
64
+
65
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
66
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
67
  be = time.time()
 
79
  melody = convert_audio(melody, sr, target_sr, target_ac)
80
  processed_melodies.append(melody)
81
 
82
+ if any(m is not None for m in processed_melodies):
83
  outputs = MODEL.generate_with_chroma(
84
  descriptions=texts,
85
  melody_wavs=processed_melodies,
 
113
  def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
114
  topk = int(topk)
115
  load_model(model)
 
 
116
 
117
  outs = _do_predictions(
118
  [text], [melody], duration,
119
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
120
  return outs[0]
121
 
122
 
 
139
  with gr.Row():
140
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
141
  with gr.Row():
142
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
143
  with gr.Row():
144
  topk = gr.Number(label="Top-k", value=250, interactive=True)
145
  topp = gr.Number(label="Top-p", value=0, interactive=True)
 
185
  ### More details
186
 
187
  The model will generate a short music extract based on the description you provided.
188
+ The model can generate up to 30 seconds of audio in one pass. It is now possible
189
+ to extend the generation by feeding back the end of the previous chunk of audio.
190
+ This can take a long time, and the model might lose consistency. The model might also
191
+ decide at arbitrary positions that the song ends.
192
+
193
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
194
 
195
  We present 4 model variations:
196
  1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
audiocraft/models/musicgen.py CHANGED
@@ -45,6 +45,7 @@ class MusicGen:
45
  self.device = next(iter(lm.parameters())).device
46
  self.generation_params: dict = {}
47
  self.set_generation_params(duration=15) # 15 seconds by default
 
48
  if self.device.type == 'cpu':
49
  self.autocast = TorchAutocast(enabled=False)
50
  else:
@@ -127,6 +128,9 @@ class MusicGen:
127
  'two_step_cfg': two_step_cfg,
128
  }
129
 
 
 
 
130
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
131
  """Generate samples in an unconditional manner.
132
 
@@ -274,6 +278,10 @@ class MusicGen:
274
  current_gen_offset: int = 0
275
 
276
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
 
 
 
 
277
  print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
278
 
279
  if prompt_tokens is not None:
@@ -296,11 +304,17 @@ class MusicGen:
296
  # melody conditioning etc.
297
  ref_wavs = [attr.wav['self_wav'] for attr in attributes]
298
  all_tokens = []
299
- if prompt_tokens is not None:
 
 
300
  all_tokens.append(prompt_tokens)
 
 
 
 
301
 
302
- time_offset = 0.
303
- while time_offset < self.duration:
304
  chunk_duration = min(self.duration - time_offset, self.max_duration)
305
  max_gen_len = int(chunk_duration * self.frame_rate)
306
  for attr, ref_wav in zip(attributes, ref_wavs):
@@ -321,14 +335,13 @@ class MusicGen:
321
  gen_tokens = self.lm.generate(
322
  prompt_tokens, attributes,
323
  callback=callback, max_gen_len=max_gen_len, **self.generation_params)
324
- stride_tokens = int(self.frame_rate * self.extend_stride)
325
  if prompt_tokens is None:
326
  all_tokens.append(gen_tokens)
327
  else:
328
  all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
329
- prompt_tokens = gen_tokens[:, :, stride_tokens]
 
330
  current_gen_offset += stride_tokens
331
- time_offset += self.extend_stride
332
 
333
  gen_tokens = torch.cat(all_tokens, dim=-1)
334
 
 
45
  self.device = next(iter(lm.parameters())).device
46
  self.generation_params: dict = {}
47
  self.set_generation_params(duration=15) # 15 seconds by default
48
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
49
  if self.device.type == 'cpu':
50
  self.autocast = TorchAutocast(enabled=False)
51
  else:
 
128
  'two_step_cfg': two_step_cfg,
129
  }
130
 
131
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
132
+ self._progress_callback = progress_callback
133
+
134
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
135
  """Generate samples in an unconditional manner.
136
 
 
278
  current_gen_offset: int = 0
279
 
280
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
281
+ generated_tokens += current_gen_offset
282
+ if self._progress_callback is not None:
283
+ self._progress_callback(generated_tokens, total_gen_len)
284
+ else:
285
  print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
286
 
287
  if prompt_tokens is not None:
 
304
  # melody conditioning etc.
305
  ref_wavs = [attr.wav['self_wav'] for attr in attributes]
306
  all_tokens = []
307
+ if prompt_tokens is None:
308
+ prompt_length = 0
309
+ else:
310
  all_tokens.append(prompt_tokens)
311
+ prompt_length = prompt_tokens.shape[-1]
312
+
313
+
314
+ stride_tokens = int(self.frame_rate * self.extend_stride)
315
 
316
+ while current_gen_offset + prompt_length < total_gen_len:
317
+ time_offset = current_gen_offset / self.frame_rate
318
  chunk_duration = min(self.duration - time_offset, self.max_duration)
319
  max_gen_len = int(chunk_duration * self.frame_rate)
320
  for attr, ref_wav in zip(attributes, ref_wavs):
 
335
  gen_tokens = self.lm.generate(
336
  prompt_tokens, attributes,
337
  callback=callback, max_gen_len=max_gen_len, **self.generation_params)
 
338
  if prompt_tokens is None:
339
  all_tokens.append(gen_tokens)
340
  else:
341
  all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
342
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
343
+ prompt_length = prompt_tokens.shape[-1]
344
  current_gen_offset += stride_tokens
 
345
 
346
  gen_tokens = torch.cat(all_tokens, dim=-1)
347