adefossez commited on
Commit
86d0f16
1 Parent(s): c1546d4
audiocraft/models/musicgen.py CHANGED
@@ -115,7 +115,7 @@ class MusicGen:
115
  should we extend the audio each time. Larger values will mean less context is
116
  preserved, and shorter value will require extra computations.
117
  """
118
- assert extend_stride <= self.max_duration - 5, "Keep at least 5 seconds of overlap!"
119
  self.extend_stride = extend_stride
120
  self.duration = duration
121
  self.generation_params = {
 
115
  should we extend the audio each time. Larger values will mean less context is
116
  preserved, and shorter value will require extra computations.
117
  """
118
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
119
  self.extend_stride = extend_stride
120
  self.duration = duration
121
  self.generation_params = {
tests/models/test_musicgen.py CHANGED
@@ -13,7 +13,7 @@ from audiocraft.models import MusicGen
13
  class TestSEANetModel:
14
  def get_musicgen(self):
15
  mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
- mg.set_generation_params(duration=2.0)
17
  return mg
18
 
19
  def test_base(self):
@@ -51,7 +51,7 @@ class TestSEANetModel:
51
 
52
  def test_generate_long(self):
53
  mg = self.get_musicgen()
54
- mg.set_generation_params(duration=4.)
55
  wav = mg.generate(
56
  ['youpi', 'lapin dort'])
57
  assert list(wav.shape) == [2, 1, 32000 * 4]
 
13
  class TestSEANetModel:
14
  def get_musicgen(self):
15
  mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
+ mg.set_generation_params(duration=2.0, stride_extend=2.)
17
  return mg
18
 
19
  def test_base(self):
 
51
 
52
  def test_generate_long(self):
53
  mg = self.get_musicgen()
54
+ mg.set_generation_params(duration=4., stride_extend=2.)
55
  wav = mg.generate(
56
  ['youpi', 'lapin dort'])
57
  assert list(wav.shape) == [2, 1, 32000 * 4]