AlexHung29629 commited on
Commit
66a79e5
1 Parent(s): a7d35c3

Update mllama_audio_model.py

Browse files
Files changed (1) hide show
  1. mllama_audio_model.py +1 -1
mllama_audio_model.py CHANGED
@@ -15,7 +15,7 @@ class MllamaAudioModel(MllamaPreTrainedModel):
15
  super().__init__(config)
16
  assert config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
17
  #assert config.output_hidden_size == text_embedding.weight.shape[1], f'Output hidden size({config.output_hidden_size}) of audio model and text embedding({text_embedding.weight.shape[1]}) must match!'
18
- asseert config.output_hidden_size == text_config.hidden_size
19
  self.text_embedding = nn.Embedding(text_config.vocab_size + 8, text_config.hidden_size, text_config.pad_token_id)
20
  self.audio_embedding = Wav2Vec2BertModel(config)
21
  self.start_of_audio = nn.Parameter(data=torch.mean(text_embedding.weight, dim=0).unsqueeze(0), requires_grad=True)
 
15
  super().__init__(config)
16
  assert config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
17
  #assert config.output_hidden_size == text_embedding.weight.shape[1], f'Output hidden size({config.output_hidden_size}) of audio model and text embedding({text_embedding.weight.shape[1]}) must match!'
18
+ assert config.output_hidden_size == text_config.hidden_size
19
  self.text_embedding = nn.Embedding(text_config.vocab_size + 8, text_config.hidden_size, text_config.pad_token_id)
20
  self.audio_embedding = Wav2Vec2BertModel(config)
21
  self.start_of_audio = nn.Parameter(data=torch.mean(text_embedding.weight, dim=0).unsqueeze(0), requires_grad=True)