Ar4ikov commited on
Commit
1596979
1 Parent(s): 54ab69d

Update audio_text_multimodal.py

Browse files
Files changed (1) hide show
  1. audio_text_multimodal.py +17 -2
audio_text_multimodal.py CHANGED
@@ -14,6 +14,14 @@ from transformers import (
14
  Wav2Vec2Model
15
  )
16
 
 
 
 
 
 
 
 
 
17
 
18
  class MultiModalConfig(PretrainedConfig):
19
  """Base class for multimodal configs"""
@@ -170,7 +178,7 @@ class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
170
  output_hidden_states=output_hidden_states,
171
  return_dict=return_dict,
172
  )
173
- audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode="mean")
174
 
175
  pooled_output = torch.cat(
176
  (audio_mean, text_output.pooler_output), dim=1
@@ -205,6 +213,8 @@ class Wav2Vec2BertForSequenceClassification(AudioTextModelForSequenceBaseClassif
205
  """
206
  def __init__(self, config):
207
  super().__init__(config)
 
 
208
  self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model)
209
  self.text_config = BertConfig.from_dict(self.config.BertModel)
210
  self.audio_model = Wav2Vec2Model(self.audio_config)
@@ -212,4 +222,9 @@ class Wav2Vec2BertForSequenceClassification(AudioTextModelForSequenceBaseClassif
212
  self.classifier = torch.nn.Linear(
213
  self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
214
  )
215
- self.init_weights()
 
 
 
 
 
 
14
  Wav2Vec2Model
15
  )
16
 
17
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
18
+ Wav2Vec2Encoder,
19
+ Wav2Vec2EncoderStableLayerNorm,
20
+ Wav2Vec2FeatureEncoder
21
+ )
22
+
23
+ from transformers.models.bert.modeling_bert import BertEncoder
24
+
25
 
26
  class MultiModalConfig(PretrainedConfig):
27
  """Base class for multimodal configs"""
 
178
  output_hidden_states=output_hidden_states,
179
  return_dict=return_dict,
180
  )
181
+ audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode)
182
 
183
  pooled_output = torch.cat(
184
  (audio_mean, text_output.pooler_output), dim=1
 
213
  """
214
  def __init__(self, config):
215
  super().__init__(config)
216
+ self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True)
217
+
218
  self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model)
219
  self.text_config = BertConfig.from_dict(self.config.BertModel)
220
  self.audio_model = Wav2Vec2Model(self.audio_config)
 
222
  self.classifier = torch.nn.Linear(
223
  self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
224
  )
225
+ self.init_weights()
226
+
227
+ @staticmethod
228
+ def _set_gradient_checkpointing(module, value=False):
229
+ if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder, BertEncoder)):
230
+ module.gradient_checkpointing = value