Update mllama_audio_model.py
Browse files- mllama_audio_model.py +1 -1
mllama_audio_model.py
CHANGED
@@ -14,7 +14,7 @@ class Llama3Embedding(Wav2Vec2BertPreTrainedModel):
|
|
14 |
assert config.output_hidden_size == text_config.hidden_size
|
15 |
self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
|
16 |
self.audio_embedding = Wav2Vec2BertModel(config)
|
17 |
-
assert self.text_embeddings.weight.
|
18 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
19 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
20 |
self.text_config = text_config
|
|
|
14 |
assert config.output_hidden_size == text_config.hidden_size
|
15 |
self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
|
16 |
self.audio_embedding = Wav2Vec2BertModel(config)
|
17 |
+
assert self.text_embeddings.data.weight.size(-1) == text_config.hidden_size, f"{self.text_embeddings.weight}, {text_config.hidden_size=}, {text_config.vocab_size=}"
|
18 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
19 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
20 |
self.text_config = text_config
|