Ar4ikov commited on
Commit
b85572b
1 Parent(s): a391a89

Upload config

Browse files
Files changed (2) hide show
  1. audio_text_multimodal.py +215 -0
  2. config.json +293 -0
audio_text_multimodal.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Union, Type
3
+
4
+ import torch
5
+ from transformers.modeling_outputs import SequenceClassifierOutput
6
+ from transformers import (
7
+ PreTrainedModel,
8
+ PretrainedConfig,
9
+ WavLMConfig,
10
+ BertConfig,
11
+ WavLMModel,
12
+ BertModel,
13
+ Wav2Vec2Config,
14
+ Wav2Vec2Model
15
+ )
16
+
17
+
18
+ class MultiModalConfig(PretrainedConfig):
19
+ """Base class for multimodal configs"""
20
+ def __init__(self, **kwargs):
21
+ super().__init__(**kwargs)
22
+
23
+
24
+ class WavLMBertConfig(MultiModalConfig):
25
+ ...
26
+
27
+
28
+ class BaseClassificationModel(PreTrainedModel):
29
+ config: Type[Union[PretrainedConfig, None]] = None
30
+
31
+ def compute_loss(self, logits, labels):
32
+ """Compute loss
33
+
34
+ Args:
35
+ logits (torch.FloatTensor): logits
36
+ labels (torch.LongTensor): labels
37
+
38
+ Returns:
39
+ torch.FloatTensor: loss
40
+
41
+ Raises:
42
+ ValueError: Invalid number of labels
43
+ """
44
+ if self.config.problem_type is None:
45
+ if self.num_labels == 1:
46
+ self.config.problem_type = "regression"
47
+ elif self.num_labels > 1:
48
+ self.config.problem_type = "single_label_classification"
49
+ else:
50
+ raise ValueError("Invalid number of labels: {}".format(self.num_labels))
51
+
52
+ if self.config.problem_type == "single_label_classification":
53
+ loss_fct = torch.nn.CrossEntropyLoss()
54
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
55
+
56
+ elif self.config.problem_type == "multi_label_classification":
57
+ loss_fct = torch.nn.BCEWithLogitsLoss()
58
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
59
+
60
+ elif self.config.problem_type == "regression":
61
+ loss_fct = torch.nn.MSELoss()
62
+ loss = loss_fct(logits.view(-1), labels.view(-1))
63
+ else:
64
+ raise ValueError("Problem_type {} not supported".format(self.config.problem_type))
65
+
66
+ return loss
67
+
68
+ @staticmethod
69
+ def merged_strategy(
70
+ hidden_states,
71
+ mode="mean"
72
+ ):
73
+ """Merged strategy for pooling
74
+
75
+ Args:
76
+ hidden_states (torch.FloatTensor): hidden states
77
+ mode (str, optional): pooling mode. Defaults to "mean".
78
+
79
+ Returns:
80
+ torch.FloatTensor: pooled hidden states
81
+ """
82
+ if mode == "mean":
83
+ outputs = torch.mean(hidden_states, dim=1)
84
+ elif mode == "sum":
85
+ outputs = torch.sum(hidden_states, dim=1)
86
+ elif mode == "max":
87
+ outputs = torch.max(hidden_states, dim=1)[0]
88
+ else:
89
+ raise Exception(
90
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
91
+
92
+ return outputs
93
+
94
+
95
+ class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
96
+ config_class = MultiModalConfig
97
+
98
+ def __init__(self, config):
99
+ """
100
+ Args:
101
+ config (MultiModalConfig): config
102
+
103
+ Attributes:
104
+ config (MultiModalConfig): config
105
+ num_labels (int): number of labels
106
+ audio_config (Union[PretrainedConfig, None]): audio config
107
+ text_config (Union[PretrainedConfig, None]): text config
108
+ audio_model (Union[PreTrainedModel, None]): audio model
109
+ text_model (Union[PreTrainedModel, None]): text model
110
+ classifier (Union[torch.nn.Linear, None]): classifier
111
+ """
112
+ super().__init__(config)
113
+ self.config = config
114
+ self.num_labels = self.config.num_labels
115
+ self.audio_config: Union[PretrainedConfig, None] = None
116
+ self.text_config: Union[PretrainedConfig, None] = None
117
+ self.audio_model: Union[PreTrainedModel, None] = None
118
+ self.text_model: Union[PreTrainedModel, None] = None
119
+ self.classifier: Union[torch.nn.Linear, None] = None
120
+
121
+ def forward(
122
+ self,
123
+ input_ids=None,
124
+ input_values=None,
125
+ text_attention_mask=None,
126
+ audio_attention_mask=None,
127
+ token_type_ids=None,
128
+ position_ids=None,
129
+ head_mask=None,
130
+ inputs_embeds=None,
131
+ labels=None,
132
+ output_attentions=None,
133
+ output_hidden_states=None,
134
+ return_dict=True,
135
+ ):
136
+ """Forward method for multimodal model for sequence classification task (e.g. text + audio)
137
+
138
+ Args:
139
+ input_ids (torch.LongTensor, optional): input ids. Defaults to None.
140
+ input_values (torch.FloatTensor, optional): input values. Defaults to None.
141
+ text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None.
142
+ audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None.
143
+ token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None.
144
+ position_ids (torch.LongTensor, optional): position ids. Defaults to None.
145
+ head_mask (torch.FloatTensor, optional): head mask. Defaults to None.
146
+ inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None.
147
+ labels (torch.LongTensor, optional): labels. Defaults to None.
148
+ output_attentions (bool, optional): output attentions. Defaults to None.
149
+ output_hidden_states (bool, optional): output hidden states. Defaults to None.
150
+ return_dict (bool, optional): return dict. Defaults to True.
151
+
152
+ Returns:
153
+ torch.FloatTensor: logits
154
+ """
155
+ audio_output = self.audio_model(
156
+ input_values=input_values,
157
+ attention_mask=audio_attention_mask,
158
+ output_attentions=output_attentions,
159
+ output_hidden_states=output_hidden_states,
160
+ return_dict=return_dict
161
+ )
162
+ text_output = self.text_model(
163
+ input_ids=input_ids,
164
+ attention_mask=text_attention_mask,
165
+ token_type_ids=token_type_ids,
166
+ position_ids=position_ids,
167
+ head_mask=head_mask,
168
+ inputs_embeds=inputs_embeds,
169
+ output_attentions=output_attentions,
170
+ output_hidden_states=output_hidden_states,
171
+ return_dict=return_dict,
172
+ )
173
+ audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode="mean")
174
+
175
+ pooled_output = torch.cat(
176
+ (audio_mean, text_output.pooler_output), dim=1
177
+ )
178
+ logits = self.classifier(pooled_output)
179
+ loss = None
180
+
181
+ if labels is not None:
182
+ loss = self.compute_loss(logits, labels)
183
+
184
+ return SequenceClassifierOutput(
185
+ loss=loss,
186
+ logits=logits
187
+ )
188
+
189
+
190
+ class WavLMBertForSequenceClassification(AudioTextModelForSequenceBaseClassification):
191
+ """
192
+ WavLMBertForSequenceClassification is a model for sequence classification task
193
+ (e.g. sentiment analysis, text classification, etc.)
194
+
195
+ Args:
196
+ config (WavLMBertConfig): config
197
+
198
+ Attributes:
199
+ config (WavLMBertConfig): config
200
+ audio_config (WavLMConfig): wav2vec2 config
201
+ text_config (BertConfig): bert config
202
+ audio_model (WavLMModel): wav2vec2 model
203
+ text_model (BertModel): bert model
204
+ classifier (torch.nn.Linear): classifier
205
+ """
206
+ def __init__(self, config):
207
+ super().__init__(config)
208
+ self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
209
+ self.text_config = BertConfig.from_dict(self.config.BertModel)
210
+ self.audio_model = WavLMModel(self.audio_config)
211
+ self.text_model = BertModel(self.text_config)
212
+ self.classifier = torch.nn.Linear(
213
+ self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
214
+ )
215
+ self.init_weights()
config.json ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "BertModel": {
3
+ "_name_or_path": "DeepPavlov/rubert-base-cased",
4
+ "add_cross_attention": false,
5
+ "architectures": [
6
+ "BertModel"
7
+ ],
8
+ "attention_probs_dropout_prob": 0.1,
9
+ "bad_words_ids": null,
10
+ "begin_suppress_tokens": null,
11
+ "bos_token_id": null,
12
+ "chunk_size_feed_forward": 0,
13
+ "classifier_dropout": null,
14
+ "cross_attention_hidden_size": null,
15
+ "decoder_start_token_id": null,
16
+ "directionality": "bidi",
17
+ "diversity_penalty": 0.0,
18
+ "do_sample": false,
19
+ "early_stopping": false,
20
+ "encoder_no_repeat_ngram_size": 0,
21
+ "eos_token_id": null,
22
+ "exponential_decay_length_penalty": null,
23
+ "finetuning_task": null,
24
+ "forced_bos_token_id": null,
25
+ "forced_eos_token_id": null,
26
+ "hidden_act": "gelu",
27
+ "hidden_dropout_prob": 0.1,
28
+ "hidden_size": 768,
29
+ "id2label": {
30
+ "0": "LABEL_0",
31
+ "1": "LABEL_1"
32
+ },
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 3072,
35
+ "is_decoder": false,
36
+ "is_encoder_decoder": false,
37
+ "label2id": {
38
+ "LABEL_0": 0,
39
+ "LABEL_1": 1
40
+ },
41
+ "layer_norm_eps": 1e-12,
42
+ "length_penalty": 1.0,
43
+ "max_length": 20,
44
+ "max_position_embeddings": 512,
45
+ "min_length": 0,
46
+ "model_type": "bert",
47
+ "no_repeat_ngram_size": 0,
48
+ "num_attention_heads": 12,
49
+ "num_beam_groups": 1,
50
+ "num_beams": 1,
51
+ "num_hidden_layers": 12,
52
+ "num_return_sequences": 1,
53
+ "output_attentions": false,
54
+ "output_hidden_states": false,
55
+ "output_past": true,
56
+ "output_scores": false,
57
+ "pad_token_id": 0,
58
+ "pooler_fc_size": 768,
59
+ "pooler_num_attention_heads": 12,
60
+ "pooler_num_fc_layers": 3,
61
+ "pooler_size_per_head": 128,
62
+ "pooler_type": "first_token_transform",
63
+ "position_embedding_type": "absolute",
64
+ "prefix": null,
65
+ "problem_type": null,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.27.4",
84
+ "type_vocab_size": 2,
85
+ "typical_p": 1.0,
86
+ "use_bfloat16": false,
87
+ "use_cache": true,
88
+ "vocab_size": 119547
89
+ },
90
+ "WavLMModel": {
91
+ "_name_or_path": "jonatasgrosman/exp_w2v2t_ru_wavlm_s363",
92
+ "activation_dropout": 0.05,
93
+ "adapter_kernel_size": 3,
94
+ "adapter_stride": 2,
95
+ "add_adapter": false,
96
+ "add_cross_attention": false,
97
+ "apply_spec_augment": true,
98
+ "architectures": [
99
+ "WavLMModel"
100
+ ],
101
+ "attention_dropout": 0.05,
102
+ "bad_words_ids": null,
103
+ "begin_suppress_tokens": null,
104
+ "bos_token_id": 1,
105
+ "chunk_size_feed_forward": 0,
106
+ "classifier_proj_size": 256,
107
+ "codevector_dim": 768,
108
+ "contrastive_logits_temperature": 0.1,
109
+ "conv_bias": false,
110
+ "conv_dim": [
111
+ 512,
112
+ 512,
113
+ 512,
114
+ 512,
115
+ 512,
116
+ 512,
117
+ 512
118
+ ],
119
+ "conv_kernel": [
120
+ 10,
121
+ 3,
122
+ 3,
123
+ 3,
124
+ 3,
125
+ 2,
126
+ 2
127
+ ],
128
+ "conv_stride": [
129
+ 5,
130
+ 2,
131
+ 2,
132
+ 2,
133
+ 2,
134
+ 2,
135
+ 2
136
+ ],
137
+ "cross_attention_hidden_size": null,
138
+ "ctc_loss_reduction": "sum",
139
+ "ctc_zero_infinity": false,
140
+ "decoder_start_token_id": null,
141
+ "diversity_loss_weight": 0.1,
142
+ "diversity_penalty": 0.0,
143
+ "do_sample": false,
144
+ "do_stable_layer_norm": true,
145
+ "early_stopping": false,
146
+ "encoder_no_repeat_ngram_size": 0,
147
+ "eos_token_id": 2,
148
+ "exponential_decay_length_penalty": null,
149
+ "feat_extract_activation": "gelu",
150
+ "feat_extract_dropout": 0.0,
151
+ "feat_extract_norm": "layer",
152
+ "feat_proj_dropout": 0.05,
153
+ "feat_quantizer_dropout": 0.0,
154
+ "final_dropout": 0.05,
155
+ "finetuning_task": null,
156
+ "forced_bos_token_id": null,
157
+ "forced_eos_token_id": null,
158
+ "gradient_checkpointing": false,
159
+ "hidden_act": "gelu",
160
+ "hidden_dropout": 0.05,
161
+ "hidden_size": 1024,
162
+ "id2label": {
163
+ "0": "LABEL_0",
164
+ "1": "LABEL_1"
165
+ },
166
+ "initializer_range": 0.02,
167
+ "intermediate_size": 4096,
168
+ "is_decoder": false,
169
+ "is_encoder_decoder": false,
170
+ "label2id": {
171
+ "LABEL_0": 0,
172
+ "LABEL_1": 1
173
+ },
174
+ "layer_norm_eps": 1e-05,
175
+ "layerdrop": 0.05,
176
+ "length_penalty": 1.0,
177
+ "mask_channel_length": 10,
178
+ "mask_channel_min_space": 1,
179
+ "mask_channel_other": 0.0,
180
+ "mask_channel_prob": 0.0,
181
+ "mask_channel_selection": "static",
182
+ "mask_feature_length": 10,
183
+ "mask_feature_min_masks": 0,
184
+ "mask_feature_prob": 0.0,
185
+ "mask_time_length": 10,
186
+ "mask_time_min_masks": 2,
187
+ "mask_time_min_space": 1,
188
+ "mask_time_other": 0.0,
189
+ "mask_time_prob": 0.05,
190
+ "mask_time_selection": "static",
191
+ "max_bucket_distance": 800,
192
+ "max_length": 20,
193
+ "min_length": 0,
194
+ "model_type": "wavlm",
195
+ "no_repeat_ngram_size": 0,
196
+ "num_adapter_layers": 3,
197
+ "num_attention_heads": 16,
198
+ "num_beam_groups": 1,
199
+ "num_beams": 1,
200
+ "num_buckets": 320,
201
+ "num_codevector_groups": 2,
202
+ "num_codevectors_per_group": 320,
203
+ "num_conv_pos_embedding_groups": 16,
204
+ "num_conv_pos_embeddings": 128,
205
+ "num_ctc_classes": 80,
206
+ "num_feat_extract_layers": 7,
207
+ "num_hidden_layers": 24,
208
+ "num_negatives": 100,
209
+ "num_return_sequences": 1,
210
+ "output_attentions": false,
211
+ "output_hidden_size": 1024,
212
+ "output_hidden_states": false,
213
+ "output_scores": false,
214
+ "pad_token_id": 0,
215
+ "prefix": null,
216
+ "problem_type": null,
217
+ "proj_codevector_dim": 768,
218
+ "pruned_heads": {},
219
+ "remove_invalid_values": false,
220
+ "repetition_penalty": 1.0,
221
+ "replace_prob": 0.5,
222
+ "return_dict": true,
223
+ "return_dict_in_generate": false,
224
+ "sep_token_id": null,
225
+ "suppress_tokens": null,
226
+ "task_specific_params": null,
227
+ "tdnn_dilation": [
228
+ 1,
229
+ 2,
230
+ 3,
231
+ 1,
232
+ 1
233
+ ],
234
+ "tdnn_dim": [
235
+ 512,
236
+ 512,
237
+ 512,
238
+ 512,
239
+ 1500
240
+ ],
241
+ "tdnn_kernel": [
242
+ 5,
243
+ 3,
244
+ 3,
245
+ 1,
246
+ 1
247
+ ],
248
+ "temperature": 1.0,
249
+ "tf_legacy_loss": false,
250
+ "tie_encoder_decoder": false,
251
+ "tie_word_embeddings": true,
252
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
253
+ "top_k": 50,
254
+ "top_p": 1.0,
255
+ "torch_dtype": "float32",
256
+ "torchscript": false,
257
+ "transformers_version": "4.27.4",
258
+ "typical_p": 1.0,
259
+ "use_bfloat16": false,
260
+ "use_weighted_layer_sum": false,
261
+ "vocab_size": 40,
262
+ "xvector_output_dim": 512
263
+ },
264
+ "_name_or_path": "Ar4ikov/wavlm-bert-base-multimodal-emotion-russian-resd",
265
+ "architectures": [
266
+ "WavLMBertForSequenceClassification"
267
+ ],
268
+ "auto_map": {
269
+ "AutoConfig": "audio_text_multimodal.WavLMBertConfig",
270
+ "AutoModel": "audio_text_multimodal.WavLMBertForSequenceClassification"
271
+ },
272
+ "id2label": {
273
+ "0": "anger",
274
+ "1": "disgust",
275
+ "2": "enthusiasm",
276
+ "3": "fear",
277
+ "4": "happiness",
278
+ "5": "neutral",
279
+ "6": "sadness"
280
+ },
281
+ "label2id": {
282
+ "anger": 0,
283
+ "disgust": 1,
284
+ "enthusiasm": 2,
285
+ "fear": 3,
286
+ "happiness": 4,
287
+ "neutral": 5,
288
+ "sadness": 6
289
+ },
290
+ "problem_type": "single_label_classification",
291
+ "torch_dtype": "float32",
292
+ "transformers_version": "4.27.4"
293
+ }