Surn commited on
Commit
a16c003
1 Parent(s): 20a0fad

Update to fix Collab launch

Browse files
Files changed (1) hide show
  1. audiocraft/models/musicgen.py +32 -0
audiocraft/models/musicgen.py CHANGED
@@ -412,6 +412,38 @@ class MusicGen:
412
  gen_audio = self.compression_model.decode(gen_tokens, None)
413
  return gen_audio
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  def to(self, device: str):
416
  self.compression_model.to(device)
417
  self.lm.to(device)
 
412
  gen_audio = self.compression_model.decode(gen_tokens, None)
413
  return gen_audio
414
 
415
+ #def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
416
+ # prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
417
+ # """Generate discrete audio tokens given audio prompt and/or conditions.
418
+
419
+ # Args:
420
+ # attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
421
+ # prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
422
+ # progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
423
+ # Returns:
424
+ # torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
425
+ # """
426
+ # def _progress_callback(generated_tokens: int, tokens_to_generate: int):
427
+ # print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
428
+
429
+ # if prompt_tokens is not None:
430
+ # assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
431
+ # "Prompt is longer than audio to generate"
432
+
433
+ # callback = None
434
+ # if progress:
435
+ # callback = _progress_callback
436
+
437
+ # # generate by sampling from LM
438
+ # with self.autocast:
439
+ # gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
440
+
441
+ # # generate audio
442
+ # assert gen_tokens.dim() == 3
443
+ # with torch.no_grad():
444
+ # gen_audio = self.compression_model.decode(gen_tokens, None)
445
+ # return gen_audio
446
+
447
  def to(self, device: str):
448
  self.compression_model.to(device)
449
  self.lm.to(device)