crypto-code commited on
Commit
d9cb0bd
1 Parent(s): a000794

Update llama/m2ugen.py

Browse files
Files changed (1) hide show
  1. llama/m2ugen.py +1 -1
llama/m2ugen.py CHANGED
@@ -604,7 +604,7 @@ class M2UGen(nn.Module):
604
  def generate_music(self, embeddings, audio_length_in_s, music_caption):
605
  gen_prefix = ''.join([f'[AUD{i}]' for i in range(len(self.audio_tokens))])
606
  gen_prefx_ids = self.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda:1")
607
- gen_prefix_embs = self.llama.tok_embeddings(gen_prefx_ids)
608
  if self.music_decoder == "audioldm2":
609
  gen_emb = self.output_projector(embeddings.float().to("cuda:1"), gen_prefix_embs).squeeze(dim=0) / 10
610
  prompt_embeds, generated_prompt_embeds = gen_emb[:, :128 * 1024], gen_emb[:, 128 * 1024:]
 
604
  def generate_music(self, embeddings, audio_length_in_s, music_caption):
605
  gen_prefix = ''.join([f'[AUD{i}]' for i in range(len(self.audio_tokens))])
606
  gen_prefx_ids = self.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda:1")
607
+ # gen_prefix_embs = self.llama.tok_embeddings(gen_prefx_ids)
608
  if self.music_decoder == "audioldm2":
609
  gen_emb = self.output_projector(embeddings.float().to("cuda:1"), gen_prefix_embs).squeeze(dim=0) / 10
610
  prompt_embeds, generated_prompt_embeds = gen_emb[:, :128 * 1024], gen_emb[:, 128 * 1024:]