Surn commited on
Commit
eef4b32
1 Parent(s): 7ec97f3

Fix LM assert error

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. audiocraft/models/lm.py +2 -1
app.py CHANGED
@@ -266,14 +266,14 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
266
  descriptions=[text],
267
  melody_wavs=melody,
268
  melody_sample_rate=sr,
269
- progress=True
270
  )
271
  # All output_segments are populated, so we can break the loop or set duration to 0
272
  break
273
  else:
274
  #output = MODEL.generate(descriptions=[text], progress=False)
275
  if not output_segments:
276
- next_segment = MODEL.generate(descriptions=[text], progress=True)
277
  duration -= segment_duration
278
  else:
279
  last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
 
266
  descriptions=[text],
267
  melody_wavs=melody,
268
  melody_sample_rate=sr,
269
+ progress=False
270
  )
271
  # All output_segments are populated, so we can break the loop or set duration to 0
272
  break
273
  else:
274
  #output = MODEL.generate(descriptions=[text], progress=False)
275
  if not output_segments:
276
+ next_segment = MODEL.generate(descriptions=[text], progress=False)
277
  duration -= segment_duration
278
  else:
279
  last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
audiocraft/models/lm.py CHANGED
@@ -456,7 +456,8 @@ class LMModel(StreamingModule):
456
 
457
  B, K, T = prompt.shape
458
  start_offset = T
459
- assert start_offset < max_gen_len
 
460
 
461
  pattern = self.pattern_provider.get_pattern(max_gen_len)
462
  # this token is used as default value for codes that are not generated yet
 
456
 
457
  B, K, T = prompt.shape
458
  start_offset = T
459
+ print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
460
+ assert start_offset <= max_gen_len
461
 
462
  pattern = self.pattern_provider.get_pattern(max_gen_len)
463
  # this token is used as default value for codes that are not generated yet