Ar4ikov commited on
Commit
0dc050d
1 Parent(s): 7df433a

Upload config

Browse files
Files changed (2) hide show
  1. audio_text_multimodal.py +215 -0
  2. config.json +283 -0
audio_text_multimodal.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Union, Type
3
+
4
+ import torch
5
+ from transformers.modeling_outputs import SequenceClassifierOutput
6
+ from transformers import (
7
+ PreTrainedModel,
8
+ PretrainedConfig,
9
+ WavLMConfig,
10
+ BertConfig,
11
+ WavLMModel,
12
+ BertModel,
13
+ Wav2Vec2Config,
14
+ Wav2Vec2Model
15
+ )
16
+
17
+
18
+ class MultiModalConfig(PretrainedConfig):
19
+ """Base class for multimodal configs"""
20
+ def __init__(self, **kwargs):
21
+ super().__init__(**kwargs)
22
+
23
+
24
+ class Wav2Vec2BertConfig(MultiModalConfig):
25
+ ...
26
+
27
+
28
+ class BaseClassificationModel(PreTrainedModel):
29
+ config: Type[Union[PretrainedConfig, None]] = None
30
+
31
+ def compute_loss(self, logits, labels):
32
+ """Compute loss
33
+
34
+ Args:
35
+ logits (torch.FloatTensor): logits
36
+ labels (torch.LongTensor): labels
37
+
38
+ Returns:
39
+ torch.FloatTensor: loss
40
+
41
+ Raises:
42
+ ValueError: Invalid number of labels
43
+ """
44
+ if self.config.problem_type is None:
45
+ if self.num_labels == 1:
46
+ self.config.problem_type = "regression"
47
+ elif self.num_labels > 1:
48
+ self.config.problem_type = "single_label_classification"
49
+ else:
50
+ raise ValueError("Invalid number of labels: {}".format(self.num_labels))
51
+
52
+ if self.config.problem_type == "single_label_classification":
53
+ loss_fct = torch.nn.CrossEntropyLoss()
54
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
55
+
56
+ elif self.config.problem_type == "multi_label_classification":
57
+ loss_fct = torch.nn.BCEWithLogitsLoss()
58
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
59
+
60
+ elif self.config.problem_type == "regression":
61
+ loss_fct = torch.nn.MSELoss()
62
+ loss = loss_fct(logits.view(-1), labels.view(-1))
63
+ else:
64
+ raise ValueError("Problem_type {} not supported".format(self.config.problem_type))
65
+
66
+ return loss
67
+
68
+ @staticmethod
69
+ def merged_strategy(
70
+ hidden_states,
71
+ mode="mean"
72
+ ):
73
+ """Merged strategy for pooling
74
+
75
+ Args:
76
+ hidden_states (torch.FloatTensor): hidden states
77
+ mode (str, optional): pooling mode. Defaults to "mean".
78
+
79
+ Returns:
80
+ torch.FloatTensor: pooled hidden states
81
+ """
82
+ if mode == "mean":
83
+ outputs = torch.mean(hidden_states, dim=1)
84
+ elif mode == "sum":
85
+ outputs = torch.sum(hidden_states, dim=1)
86
+ elif mode == "max":
87
+ outputs = torch.max(hidden_states, dim=1)[0]
88
+ else:
89
+ raise Exception(
90
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
91
+
92
+ return outputs
93
+
94
+
95
+ class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
96
+ config_class = MultiModalConfig
97
+
98
+ def __init__(self, config):
99
+ """
100
+ Args:
101
+ config (MultiModalConfig): config
102
+
103
+ Attributes:
104
+ config (MultiModalConfig): config
105
+ num_labels (int): number of labels
106
+ audio_config (Union[PretrainedConfig, None]): audio config
107
+ text_config (Union[PretrainedConfig, None]): text config
108
+ audio_model (Union[PreTrainedModel, None]): audio model
109
+ text_model (Union[PreTrainedModel, None]): text model
110
+ classifier (Union[torch.nn.Linear, None]): classifier
111
+ """
112
+ super().__init__(config)
113
+ self.config = config
114
+ self.num_labels = self.config.num_labels
115
+ self.audio_config: Union[PretrainedConfig, None] = None
116
+ self.text_config: Union[PretrainedConfig, None] = None
117
+ self.audio_model: Union[PreTrainedModel, None] = None
118
+ self.text_model: Union[PreTrainedModel, None] = None
119
+ self.classifier: Union[torch.nn.Linear, None] = None
120
+
121
+ def forward(
122
+ self,
123
+ input_ids=None,
124
+ input_values=None,
125
+ text_attention_mask=None,
126
+ audio_attention_mask=None,
127
+ token_type_ids=None,
128
+ position_ids=None,
129
+ head_mask=None,
130
+ inputs_embeds=None,
131
+ labels=None,
132
+ output_attentions=None,
133
+ output_hidden_states=None,
134
+ return_dict=True,
135
+ ):
136
+ """Forward method for multimodal model for sequence classification task (e.g. text + audio)
137
+
138
+ Args:
139
+ input_ids (torch.LongTensor, optional): input ids. Defaults to None.
140
+ input_values (torch.FloatTensor, optional): input values. Defaults to None.
141
+ text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None.
142
+ audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None.
143
+ token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None.
144
+ position_ids (torch.LongTensor, optional): position ids. Defaults to None.
145
+ head_mask (torch.FloatTensor, optional): head mask. Defaults to None.
146
+ inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None.
147
+ labels (torch.LongTensor, optional): labels. Defaults to None.
148
+ output_attentions (bool, optional): output attentions. Defaults to None.
149
+ output_hidden_states (bool, optional): output hidden states. Defaults to None.
150
+ return_dict (bool, optional): return dict. Defaults to True.
151
+
152
+ Returns:
153
+ torch.FloatTensor: logits
154
+ """
155
+ audio_output = self.audio_model(
156
+ input_values=input_values,
157
+ attention_mask=audio_attention_mask,
158
+ output_attentions=output_attentions,
159
+ output_hidden_states=output_hidden_states,
160
+ return_dict=return_dict
161
+ )
162
+ text_output = self.text_model(
163
+ input_ids=input_ids,
164
+ attention_mask=text_attention_mask,
165
+ token_type_ids=token_type_ids,
166
+ position_ids=position_ids,
167
+ head_mask=head_mask,
168
+ inputs_embeds=inputs_embeds,
169
+ output_attentions=output_attentions,
170
+ output_hidden_states=output_hidden_states,
171
+ return_dict=return_dict,
172
+ )
173
+ audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode="mean")
174
+
175
+ pooled_output = torch.cat(
176
+ (audio_mean, text_output.pooler_output), dim=1
177
+ )
178
+ logits = self.classifier(pooled_output)
179
+ loss = None
180
+
181
+ if labels is not None:
182
+ loss = self.compute_loss(logits, labels)
183
+
184
+ return SequenceClassifierOutput(
185
+ loss=loss,
186
+ logits=logits
187
+ )
188
+
189
+
190
+ class Wav2Vec2BertForSequenceClassification(AudioTextModelForSequenceBaseClassification):
191
+ """
192
+ Wav2Vec2BertForSequenceClassification is a model for sequence classification task
193
+ (e.g. sentiment analysis, text classification, etc.)
194
+
195
+ Args:
196
+ config (Wav2Vec2BertConfig): config
197
+
198
+ Attributes:
199
+ config (Wav2Vec2BertConfig): config
200
+ audio_config (Wav2Vec2Config): wav2vec2 config
201
+ text_config (BertConfig): bert config
202
+ audio_model (Wav2Vec2Model): wav2vec2 model
203
+ text_model (BertModel): bert model
204
+ classifier (torch.nn.Linear): classifier
205
+ """
206
+ def __init__(self, config):
207
+ super().__init__(config)
208
+ self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model)
209
+ self.text_config = BertConfig.from_dict(self.config.BertModel)
210
+ self.audio_model = Wav2Vec2Model(self.audio_config)
211
+ self.text_model = BertModel(self.text_config)
212
+ self.classifier = torch.nn.Linear(
213
+ self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
214
+ )
215
+ self.init_weights()
config.json ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "BertModel": {
3
+ "_name_or_path": "cointegrated/rubert-tiny2",
4
+ "add_cross_attention": false,
5
+ "architectures": [
6
+ "BertForPreTraining"
7
+ ],
8
+ "attention_probs_dropout_prob": 0.1,
9
+ "bad_words_ids": null,
10
+ "begin_suppress_tokens": null,
11
+ "bos_token_id": null,
12
+ "chunk_size_feed_forward": 0,
13
+ "classifier_dropout": null,
14
+ "cross_attention_hidden_size": null,
15
+ "decoder_start_token_id": null,
16
+ "diversity_penalty": 0.0,
17
+ "do_sample": false,
18
+ "early_stopping": false,
19
+ "emb_size": 312,
20
+ "encoder_no_repeat_ngram_size": 0,
21
+ "eos_token_id": null,
22
+ "exponential_decay_length_penalty": null,
23
+ "finetuning_task": null,
24
+ "forced_bos_token_id": null,
25
+ "forced_eos_token_id": null,
26
+ "gradient_checkpointing": false,
27
+ "hidden_act": "gelu",
28
+ "hidden_dropout_prob": 0.1,
29
+ "hidden_size": 312,
30
+ "id2label": {
31
+ "0": "LABEL_0",
32
+ "1": "LABEL_1"
33
+ },
34
+ "initializer_range": 0.02,
35
+ "intermediate_size": 600,
36
+ "is_decoder": false,
37
+ "is_encoder_decoder": false,
38
+ "label2id": {
39
+ "LABEL_0": 0,
40
+ "LABEL_1": 1
41
+ },
42
+ "layer_norm_eps": 1e-12,
43
+ "length_penalty": 1.0,
44
+ "max_length": 20,
45
+ "max_position_embeddings": 2048,
46
+ "min_length": 0,
47
+ "model_type": "bert",
48
+ "no_repeat_ngram_size": 0,
49
+ "num_attention_heads": 12,
50
+ "num_beam_groups": 1,
51
+ "num_beams": 1,
52
+ "num_hidden_layers": 3,
53
+ "num_return_sequences": 1,
54
+ "output_attentions": false,
55
+ "output_hidden_states": false,
56
+ "output_scores": false,
57
+ "pad_token_id": 0,
58
+ "position_embedding_type": "absolute",
59
+ "prefix": null,
60
+ "problem_type": null,
61
+ "pruned_heads": {},
62
+ "remove_invalid_values": false,
63
+ "repetition_penalty": 1.0,
64
+ "return_dict": true,
65
+ "return_dict_in_generate": false,
66
+ "sep_token_id": null,
67
+ "suppress_tokens": null,
68
+ "task_specific_params": null,
69
+ "temperature": 1.0,
70
+ "tf_legacy_loss": false,
71
+ "tie_encoder_decoder": false,
72
+ "tie_word_embeddings": true,
73
+ "tokenizer_class": null,
74
+ "top_k": 50,
75
+ "top_p": 1.0,
76
+ "torch_dtype": "float32",
77
+ "torchscript": false,
78
+ "transformers_version": "4.27.4",
79
+ "type_vocab_size": 2,
80
+ "typical_p": 1.0,
81
+ "use_bfloat16": false,
82
+ "use_cache": true,
83
+ "vocab_size": 83828
84
+ },
85
+ "Wav2Vec2Model": {
86
+ "_name_or_path": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
87
+ "activation_dropout": 0.05,
88
+ "adapter_kernel_size": 3,
89
+ "adapter_stride": 2,
90
+ "add_adapter": false,
91
+ "add_cross_attention": false,
92
+ "apply_spec_augment": true,
93
+ "architectures": [
94
+ "Wav2Vec2Model"
95
+ ],
96
+ "attention_dropout": 0.1,
97
+ "bad_words_ids": null,
98
+ "begin_suppress_tokens": null,
99
+ "bos_token_id": 1,
100
+ "chunk_size_feed_forward": 0,
101
+ "classifier_proj_size": 256,
102
+ "codevector_dim": 768,
103
+ "contrastive_logits_temperature": 0.1,
104
+ "conv_bias": true,
105
+ "conv_dim": [
106
+ 512,
107
+ 512,
108
+ 512,
109
+ 512,
110
+ 512,
111
+ 512,
112
+ 512
113
+ ],
114
+ "conv_kernel": [
115
+ 10,
116
+ 3,
117
+ 3,
118
+ 3,
119
+ 3,
120
+ 2,
121
+ 2
122
+ ],
123
+ "conv_stride": [
124
+ 5,
125
+ 2,
126
+ 2,
127
+ 2,
128
+ 2,
129
+ 2,
130
+ 2
131
+ ],
132
+ "cross_attention_hidden_size": null,
133
+ "ctc_loss_reduction": "mean",
134
+ "ctc_zero_infinity": true,
135
+ "decoder_start_token_id": null,
136
+ "diversity_loss_weight": 0.1,
137
+ "diversity_penalty": 0.0,
138
+ "do_sample": false,
139
+ "do_stable_layer_norm": true,
140
+ "early_stopping": false,
141
+ "encoder_no_repeat_ngram_size": 0,
142
+ "eos_token_id": 2,
143
+ "exponential_decay_length_penalty": null,
144
+ "feat_extract_activation": "gelu",
145
+ "feat_extract_dropout": 0.0,
146
+ "feat_extract_norm": "layer",
147
+ "feat_proj_dropout": 0.05,
148
+ "feat_quantizer_dropout": 0.0,
149
+ "final_dropout": 0.0,
150
+ "finetuning_task": null,
151
+ "forced_bos_token_id": null,
152
+ "forced_eos_token_id": null,
153
+ "hidden_act": "gelu",
154
+ "hidden_dropout": 0.05,
155
+ "hidden_size": 1024,
156
+ "id2label": {
157
+ "0": "LABEL_0",
158
+ "1": "LABEL_1"
159
+ },
160
+ "initializer_range": 0.02,
161
+ "intermediate_size": 4096,
162
+ "is_decoder": false,
163
+ "is_encoder_decoder": false,
164
+ "label2id": {
165
+ "LABEL_0": 0,
166
+ "LABEL_1": 1
167
+ },
168
+ "layer_norm_eps": 1e-05,
169
+ "layerdrop": 0.05,
170
+ "length_penalty": 1.0,
171
+ "mask_channel_length": 10,
172
+ "mask_channel_min_space": 1,
173
+ "mask_channel_other": 0.0,
174
+ "mask_channel_prob": 0.0,
175
+ "mask_channel_selection": "static",
176
+ "mask_feature_length": 10,
177
+ "mask_feature_min_masks": 0,
178
+ "mask_feature_prob": 0.0,
179
+ "mask_time_length": 10,
180
+ "mask_time_min_masks": 2,
181
+ "mask_time_min_space": 1,
182
+ "mask_time_other": 0.0,
183
+ "mask_time_prob": 0.05,
184
+ "mask_time_selection": "static",
185
+ "max_length": 20,
186
+ "min_length": 0,
187
+ "model_type": "wav2vec2",
188
+ "no_repeat_ngram_size": 0,
189
+ "num_adapter_layers": 3,
190
+ "num_attention_heads": 16,
191
+ "num_beam_groups": 1,
192
+ "num_beams": 1,
193
+ "num_codevector_groups": 2,
194
+ "num_codevectors_per_group": 320,
195
+ "num_conv_pos_embedding_groups": 16,
196
+ "num_conv_pos_embeddings": 128,
197
+ "num_feat_extract_layers": 7,
198
+ "num_hidden_layers": 24,
199
+ "num_negatives": 100,
200
+ "num_return_sequences": 1,
201
+ "output_attentions": false,
202
+ "output_hidden_size": 1024,
203
+ "output_hidden_states": false,
204
+ "output_scores": false,
205
+ "pad_token_id": 0,
206
+ "prefix": null,
207
+ "problem_type": null,
208
+ "proj_codevector_dim": 768,
209
+ "pruned_heads": {},
210
+ "remove_invalid_values": false,
211
+ "repetition_penalty": 1.0,
212
+ "return_dict": true,
213
+ "return_dict_in_generate": false,
214
+ "sep_token_id": null,
215
+ "suppress_tokens": null,
216
+ "task_specific_params": null,
217
+ "tdnn_dilation": [
218
+ 1,
219
+ 2,
220
+ 3,
221
+ 1,
222
+ 1
223
+ ],
224
+ "tdnn_dim": [
225
+ 512,
226
+ 512,
227
+ 512,
228
+ 512,
229
+ 1500
230
+ ],
231
+ "tdnn_kernel": [
232
+ 5,
233
+ 3,
234
+ 3,
235
+ 1,
236
+ 1
237
+ ],
238
+ "temperature": 1.0,
239
+ "tf_legacy_loss": false,
240
+ "tie_encoder_decoder": false,
241
+ "tie_word_embeddings": true,
242
+ "tokenizer_class": null,
243
+ "top_k": 50,
244
+ "top_p": 1.0,
245
+ "torch_dtype": null,
246
+ "torchscript": false,
247
+ "transformers_version": "4.27.4",
248
+ "typical_p": 1.0,
249
+ "use_bfloat16": false,
250
+ "use_weighted_layer_sum": false,
251
+ "vocab_size": 39,
252
+ "xvector_output_dim": 512
253
+ },
254
+ "_name_or_path": "Ar4ikov/wav2vec2-bert-tiny2-2-multimodal-emotion-russian-resd",
255
+ "architectures": [
256
+ "Wav2Vec2BertForSequenceClassification"
257
+ ],
258
+ "auto_map": {
259
+ "AutoConfig": "audio_text_multimodal.Wav2Vec2BertConfig",
260
+ "AutoModel": "audio_text_multimodal.Wav2Vec2BertForSequenceClassification"
261
+ },
262
+ "id2label": {
263
+ "0": "anger",
264
+ "1": "disgust",
265
+ "2": "enthusiasm",
266
+ "3": "fear",
267
+ "4": "happiness",
268
+ "5": "neutral",
269
+ "6": "sadness"
270
+ },
271
+ "label2id": {
272
+ "anger": 0,
273
+ "disgust": 1,
274
+ "enthusiasm": 2,
275
+ "fear": 3,
276
+ "happiness": 4,
277
+ "neutral": 5,
278
+ "sadness": 6
279
+ },
280
+ "problem_type": "single_label_classification",
281
+ "torch_dtype": "float32",
282
+ "transformers_version": "4.27.2"
283
+ }