adefossez commited on
Commit
c1546d4
1 Parent(s): 1633cd5
audiocraft/models/musicgen.py CHANGED
@@ -79,7 +79,7 @@ class MusicGen:
79
  # used only for unit tests
80
  compression_model = get_debug_compression_model(device)
81
  lm = get_debug_lm_model(device)
82
- return MusicGen(name, compression_model, lm)
83
 
84
  if name not in HF_MODEL_CHECKPOINTS_MAP:
85
  raise ValueError(
@@ -270,13 +270,14 @@ class MusicGen:
270
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
271
  """
272
  total_gen_len = int(self.duration * self.frame_rate)
 
273
  current_gen_offset: int = 0
274
 
275
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
276
  print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
277
 
278
  if prompt_tokens is not None:
279
- assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
280
  "Prompt is longer than audio to generate"
281
 
282
  callback = None
 
79
  # used only for unit tests
80
  compression_model = get_debug_compression_model(device)
81
  lm = get_debug_lm_model(device)
82
+ return MusicGen(name, compression_model, lm, max_duration=3.)
83
 
84
  if name not in HF_MODEL_CHECKPOINTS_MAP:
85
  raise ValueError(
 
270
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
271
  """
272
  total_gen_len = int(self.duration * self.frame_rate)
273
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
274
  current_gen_offset: int = 0
275
 
276
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
277
  print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
278
 
279
  if prompt_tokens is not None:
280
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
281
  "Prompt is longer than audio to generate"
282
 
283
  callback = None
tests/models/test_musicgen.py CHANGED
@@ -48,3 +48,10 @@ class TestSEANetModel:
48
  wav = mg.generate(
49
  ['youpi', 'lapin dort'])
50
  assert list(wav.shape) == [2, 1, 64000]
 
 
 
 
 
 
 
 
48
  wav = mg.generate(
49
  ['youpi', 'lapin dort'])
50
  assert list(wav.shape) == [2, 1, 64000]
51
+
52
+ def test_generate_long(self):
53
+ mg = self.get_musicgen()
54
+ mg.set_generation_params(duration=4.)
55
+ wav = mg.generate(
56
+ ['youpi', 'lapin dort'])
57
+ assert list(wav.shape) == [2, 1, 32000 * 4]