AlexHung29629 commited on
Commit
6b95e3e
·
verified ·
1 Parent(s): 7454ffa

Update mllama_audio_model.py

Browse files
Files changed (1) hide show
  1. mllama_audio_model.py +1 -1
mllama_audio_model.py CHANGED
@@ -14,7 +14,7 @@ class Llama3Embedding(Wav2Vec2BertPreTrainedModel):
14
  assert config.output_hidden_size == text_config.hidden_size
15
  self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
16
  self.audio_embedding = Wav2Vec2BertModel(config)
17
- assert self.text_embeddings.data.weight.size(-1) == text_config.hidden_size, f"{self.text_embeddings.weight}, {text_config.hidden_size=}, {text_config.vocab_size=}"
18
  self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
19
  self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
20
  self.text_config = text_config
 
14
  assert config.output_hidden_size == text_config.hidden_size
15
  self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
16
  self.audio_embedding = Wav2Vec2BertModel(config)
17
+ assert self.text_embeddings.weight.size(-1) == text_config.hidden_size, f"{self.text_embeddings.weight}, {text_config.hidden_size=}, {text_config.vocab_size=}"
18
  self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
19
  self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
20
  self.text_config = text_config