Ar4ikov commited on
Commit
7a19cc0
1 Parent(s): 2f67ef5

Update audio_text_multimodal.py

Browse files
Files changed (1) hide show
  1. audio_text_multimodal.py +9 -9
audio_text_multimodal.py CHANGED
@@ -21,7 +21,7 @@ class MultiModalConfig(PretrainedConfig):
21
  super().__init__(**kwargs)
22
 
23
 
24
- class Wav2Vec2BertConfig(MultiModalConfig):
25
  ...
26
 
27
 
@@ -187,27 +187,27 @@ class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
187
  )
188
 
189
 
190
- class Wav2Vec2BertForSequenceClassification(AudioTextModelForSequenceBaseClassification):
191
  """
192
- Wav2Vec2BertForSequenceClassification is a model for sequence classification task
193
  (e.g. sentiment analysis, text classification, etc.)
194
 
195
  Args:
196
- config (Wav2Vec2BertConfig): config
197
 
198
  Attributes:
199
- config (Wav2Vec2BertConfig): config
200
- audio_config (Wav2Vec2Config): wav2vec2 config
201
  text_config (BertConfig): bert config
202
- audio_model (Wav2Vec2Model): wav2vec2 model
203
  text_model (BertModel): bert model
204
  classifier (torch.nn.Linear): classifier
205
  """
206
  def __init__(self, config):
207
  super().__init__(config)
208
- self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model)
209
  self.text_config = BertConfig.from_dict(self.config.BertModel)
210
- self.audio_model = Wav2Vec2Model(self.audio_config)
211
  self.text_model = BertModel(self.text_config)
212
  self.classifier = torch.nn.Linear(
213
  self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
 
21
  super().__init__(**kwargs)
22
 
23
 
24
+ class WavLMBertConfig(MultiModalConfig):
25
  ...
26
 
27
 
 
187
  )
188
 
189
 
190
+ class WavLMBertForSequenceClassification(AudioTextModelForSequenceBaseClassification):
191
  """
192
+ WavLMBertForSequenceClassification is a model for sequence classification task
193
  (e.g. sentiment analysis, text classification, etc.)
194
 
195
  Args:
196
+ config (WavLMBertConfig): config
197
 
198
  Attributes:
199
+ config (WavLMBertConfig): config
200
+ audio_config (WavLMConfig): wav2vec2 config
201
  text_config (BertConfig): bert config
202
+ audio_model (WavLMModel): wav2vec2 model
203
  text_model (BertModel): bert model
204
  classifier (torch.nn.Linear): classifier
205
  """
206
  def __init__(self, config):
207
  super().__init__(config)
208
+ self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
209
  self.text_config = BertConfig.from_dict(self.config.BertModel)
210
+ self.audio_model = WavLMModel(self.audio_config)
211
  self.text_model = BertModel(self.text_config)
212
  self.classifier = torch.nn.Linear(
213
  self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels