Update audio_text_multimodal.py
Browse files- audio_text_multimodal.py +9 -9
audio_text_multimodal.py
CHANGED
@@ -21,7 +21,7 @@ class MultiModalConfig(PretrainedConfig):
|
|
21 |
super().__init__(**kwargs)
|
22 |
|
23 |
|
24 |
-
class
|
25 |
...
|
26 |
|
27 |
|
@@ -187,27 +187,27 @@ class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
|
|
187 |
)
|
188 |
|
189 |
|
190 |
-
class
|
191 |
"""
|
192 |
-
|
193 |
(e.g. sentiment analysis, text classification, etc.)
|
194 |
|
195 |
Args:
|
196 |
-
config (
|
197 |
|
198 |
Attributes:
|
199 |
-
config (
|
200 |
-
audio_config (
|
201 |
text_config (BertConfig): bert config
|
202 |
-
audio_model (
|
203 |
text_model (BertModel): bert model
|
204 |
classifier (torch.nn.Linear): classifier
|
205 |
"""
|
206 |
def __init__(self, config):
|
207 |
super().__init__(config)
|
208 |
-
self.audio_config =
|
209 |
self.text_config = BertConfig.from_dict(self.config.BertModel)
|
210 |
-
self.audio_model =
|
211 |
self.text_model = BertModel(self.text_config)
|
212 |
self.classifier = torch.nn.Linear(
|
213 |
self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
|
|
|
21 |
super().__init__(**kwargs)
|
22 |
|
23 |
|
24 |
+
class WavLMBertConfig(MultiModalConfig):
|
25 |
...
|
26 |
|
27 |
|
|
|
187 |
)
|
188 |
|
189 |
|
190 |
+
class WavLMBertForSequenceClassification(AudioTextModelForSequenceBaseClassification):
|
191 |
"""
|
192 |
+
WavLMBertForSequenceClassification is a model for sequence classification task
|
193 |
(e.g. sentiment analysis, text classification, etc.)
|
194 |
|
195 |
Args:
|
196 |
+
config (WavLMBertConfig): config
|
197 |
|
198 |
Attributes:
|
199 |
+
config (WavLMBertConfig): config
|
200 |
+
audio_config (WavLMConfig): wav2vec2 config
|
201 |
text_config (BertConfig): bert config
|
202 |
+
audio_model (WavLMModel): wav2vec2 model
|
203 |
text_model (BertModel): bert model
|
204 |
classifier (torch.nn.Linear): classifier
|
205 |
"""
|
206 |
def __init__(self, config):
|
207 |
super().__init__(config)
|
208 |
+
self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
|
209 |
self.text_config = BertConfig.from_dict(self.config.BertModel)
|
210 |
+
self.audio_model = WavLMModel(self.audio_config)
|
211 |
self.text_model = BertModel(self.text_config)
|
212 |
self.classifier = torch.nn.Linear(
|
213 |
self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
|