adefossez commited on
Commit
1633cd5
·
1 Parent(s): f8cbd41
Files changed (1) hide show
  1. audiocraft/models/musicgen.py +8 -6
audiocraft/models/musicgen.py CHANGED
@@ -270,8 +270,7 @@ class MusicGen:
270
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
271
  """
272
  total_gen_len = int(self.duration * self.frame_rate)
273
-
274
- current_gen_offset = 0
275
 
276
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
277
  print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
@@ -299,7 +298,7 @@ class MusicGen:
299
  if prompt_tokens is not None:
300
  all_tokens.append(prompt_tokens)
301
 
302
- time_offset = 0
303
  while time_offset < self.duration:
304
  chunk_duration = min(self.duration - time_offset, self.max_duration)
305
  max_gen_len = int(chunk_duration * self.frame_rate)
@@ -308,12 +307,15 @@ class MusicGen:
308
  if wav_length == 0:
309
  continue
310
  # We will extend the wav periodically if it not long enough.
311
- # we have to do it here before it is too late.
 
312
  initial_position = int(time_offset * self.sample_rate)
313
- wav_target_length = int(chunk_duration * self.sample_rate)
314
  positions = torch.arange(initial_position,
315
  initial_position + wav_target_length, device=self.device)
316
- attr.wav['self_wav'] = ref_wav[:, positions % wav_length]
 
 
317
  with self.autocast:
318
  gen_tokens = self.lm.generate(
319
  prompt_tokens, attributes,
 
270
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
271
  """
272
  total_gen_len = int(self.duration * self.frame_rate)
273
+ current_gen_offset: int = 0
 
274
 
275
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
276
  print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
 
298
  if prompt_tokens is not None:
299
  all_tokens.append(prompt_tokens)
300
 
301
+ time_offset = 0.
302
  while time_offset < self.duration:
303
  chunk_duration = min(self.duration - time_offset, self.max_duration)
304
  max_gen_len = int(chunk_duration * self.frame_rate)
 
307
  if wav_length == 0:
308
  continue
309
  # We will extend the wav periodically if it not long enough.
310
+ # we have to do it here rather than in conditioners.py as otherwise
311
+ # we wouldn't have the full wav.
312
  initial_position = int(time_offset * self.sample_rate)
313
+ wav_target_length = int(self.max_duration * self.sample_rate)
314
  positions = torch.arange(initial_position,
315
  initial_position + wav_target_length, device=self.device)
316
+ attr.wav['self_wav'] = WavCondition(
317
+ ref_wav[0][:, positions % wav_length],
318
+ torch.full_like(ref_wav[1], wav_target_length))
319
  with self.autocast:
320
  gen_tokens = self.lm.generate(
321
  prompt_tokens, attributes,