Ar4ikov commited on
Commit
2a5b05e
0 Parent(s):

Duplicate from Ar4ikov/wavlm-bert-base-fusion-k-2-s-resd-1

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ duplicated_from: Ar4ikov/wavlm-bert-base-fusion-k-2-s-resd-1
3
+ ---
audio_text_multimodal.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Type
2
+ import torch
3
+ from transformers.modeling_outputs import SequenceClassifierOutput
4
+ from transformers import (
5
+ PreTrainedModel,
6
+ PretrainedConfig,
7
+ WavLMConfig,
8
+ BertConfig,
9
+ WavLMModel,
10
+ BertModel,
11
+ Wav2Vec2Config,
12
+ Wav2Vec2Model
13
+ )
14
+
15
+
16
+ class MultiModalConfig(PretrainedConfig):
17
+ """Base class for multimodal configs"""
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+
21
+
22
+ class WavLMBertConfig(MultiModalConfig):
23
+ ...
24
+
25
+
26
+ class BaseClassificationModel(PreTrainedModel):
27
+ config: Type[Union[PretrainedConfig, None]] = None
28
+
29
+ def compute_loss(self, logits, labels):
30
+ """Compute loss
31
+
32
+ Args:
33
+ logits (torch.FloatTensor): logits
34
+ labels (torch.LongTensor): labels
35
+
36
+ Returns:
37
+ torch.FloatTensor: loss
38
+
39
+ Raises:
40
+ ValueError: Invalid number of labels
41
+ """
42
+ if self.config.problem_type is None:
43
+ if self.num_labels == 1:
44
+ self.config.problem_type = "regression"
45
+ elif self.num_labels > 1:
46
+ self.config.problem_type = "single_label_classification"
47
+ else:
48
+ raise ValueError("Invalid number of labels: {}".format(self.num_labels))
49
+
50
+ if self.config.problem_type == "single_label_classification":
51
+ loss_fct = torch.nn.CrossEntropyLoss()
52
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
53
+
54
+ elif self.config.problem_type == "multi_label_classification":
55
+ loss_fct = torch.nn.BCEWithLogitsLoss(weight=torch.tensor([1.4411, 2.1129, 0.9927, 1.6995, 0.9038, 0.4126, 1.4150]).to("cuda"))
56
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
57
+
58
+ elif self.config.problem_type == "regression":
59
+ loss_fct = torch.nn.MSELoss()
60
+ loss = loss_fct(logits.view(-1), labels.view(-1))
61
+ else:
62
+ raise ValueError("Problem_type {} not supported".format(self.config.problem_type))
63
+
64
+ return loss
65
+
66
+ @staticmethod
67
+ def merged_strategy(
68
+ hidden_states,
69
+ mode="mean"
70
+ ):
71
+ """Merged strategy for pooling
72
+
73
+ Args:
74
+ hidden_states (torch.FloatTensor): hidden states
75
+ mode (str, optional): pooling mode. Defaults to "mean".
76
+
77
+ Returns:
78
+ torch.FloatTensor: pooled hidden states
79
+ """
80
+ if mode == "mean":
81
+ outputs = torch.mean(hidden_states, dim=1)
82
+ elif mode == "sum":
83
+ outputs = torch.sum(hidden_states, dim=1)
84
+ elif mode == "max":
85
+ outputs = torch.max(hidden_states, dim=1)[0]
86
+ else:
87
+ raise Exception(
88
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
89
+
90
+ return outputs
91
+
92
+
93
+ class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
94
+ config_class = MultiModalConfig
95
+
96
+ def __init__(self, config):
97
+ """
98
+ Args:
99
+ config (MultiModalConfig): config
100
+
101
+ Attributes:
102
+ config (MultiModalConfig): config
103
+ num_labels (int): number of labels
104
+ audio_config (Union[PretrainedConfig, None]): audio config
105
+ text_config (Union[PretrainedConfig, None]): text config
106
+ audio_model (Union[PreTrainedModel, None]): audio model
107
+ text_model (Union[PreTrainedModel, None]): text model
108
+ classifier (Union[torch.nn.Linear, None]): classifier
109
+ """
110
+ super().__init__(config)
111
+ self.config = config
112
+ self.num_labels = self.config.num_labels
113
+ self.audio_config: Union[PretrainedConfig, None] = None
114
+ self.text_config: Union[PretrainedConfig, None] = None
115
+ self.audio_model: Union[PreTrainedModel, None] = None
116
+ self.text_model: Union[PreTrainedModel, None] = None
117
+ self.classifier: Union[torch.nn.Linear, None] = None
118
+
119
+
120
+ class FusionModuleQ(torch.nn.Module):
121
+ def __init__(self, audio_dim, text_dim, num_heads, dropout=0.1):
122
+ super().__init__()
123
+
124
+ self.dimension = min(audio_dim, text_dim)
125
+
126
+ # attention modules
127
+ self.a_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads)
128
+ self.t_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads)
129
+
130
+ # layer norm
131
+ self.audio_norm = torch.nn.LayerNorm(self.dimension)
132
+ self.text_norm = torch.nn.LayerNorm(self.dimension)
133
+
134
+ def forward(self, audio_output, text_output):
135
+ # Multihead cross attention (dims ARE switched)
136
+ audio_attn, _ = self.a_self_attention(audio_output, text_output, text_output)
137
+ text_attn, _ = self.t_self_attention(text_output, audio_output, audio_output)
138
+
139
+ # Add & Norm with dropout
140
+ audio_add = self.audio_norm(audio_output + audio_attn)
141
+ text_add = self.text_norm(text_output + text_attn)
142
+
143
+ return audio_add, text_add
144
+
145
+
146
+ class AudioTextFusionModelForSequenceClassificaion(AudioTextModelForSequenceBaseClassification):
147
+ def __init__(self, config):
148
+ """
149
+ Args:
150
+ config (MultiModalConfig): config
151
+
152
+ Attributes:
153
+ fusion_module_1 (FusionModuleQ): Fusion Module Q 1
154
+ fusion_module_2 (FusionModuleQ): Fusion Module Q 2
155
+ audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds
156
+ text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds
157
+ audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block)
158
+ text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block)
159
+ """
160
+ super().__init__(config)
161
+
162
+ self.fusion_module_1: Union[FusionModuleQ, None] = None
163
+ self.fusion_module_2: Union[FusionModuleQ, None] = None
164
+ self.audio_projector: Union[torch.nn.Linear, None] = None
165
+ self.text_projector: Union[torch.nn.Linear, None] = None
166
+ self.audio_avg_pool: Union[torch.nn.AvgPool1d, None] = None
167
+ self.text_avg_pool: Union[torch.nn.AvgPool1d, None] = None
168
+
169
+
170
+ class WavLMBertForSequenceClassification(AudioTextFusionModelForSequenceClassificaion):
171
+ """
172
+ WavLMBertForSequenceClassification is a model for sequence classification task
173
+ (e.g. sentiment analysis, text classification, etc.) for fine-tuning
174
+
175
+ Args:
176
+ config (WavLMBertConfig): config
177
+
178
+ Attributes:
179
+ config (WavLMBertConfig): config
180
+ audio_config (WavLMConfig): wavlm config
181
+ text_config (BertConfig): bert config
182
+ audio_model (WavLMModel): wavlm model
183
+ text_model (BertModel): bert model
184
+ fusion_module_1 (FusionModuleQ): Fusion Module Q 1
185
+ fusion_module_2 (FusionModuleQ): Fusion Module Q 2
186
+ audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds
187
+ text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds
188
+ audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block)
189
+ text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block)
190
+ classifier (torch.nn.Linear): classifier
191
+ """
192
+ def __init__(self, config):
193
+ super().__init__(config)
194
+ self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
195
+ self.text_config = BertConfig.from_dict(self.config.BertModel)
196
+ self.audio_model = WavLMModel(self.audio_config)
197
+ self.text_model = BertModel(self.text_config)
198
+
199
+ # fusion module with V3 strategy (one projection on entry, no projection in continuous)
200
+ self.fusion_module_1 = FusionModuleQ(self.audio_config.hidden_size, self.text_config.hidden_size,
201
+ self.config.num_heads, self.config.f_dropout)
202
+ self.fusion_module_2 = FusionModuleQ(self.audio_config.hidden_size, self.text_config.hidden_size,
203
+ self.config.num_heads, self.config.f_dropout)
204
+
205
+ self.audio_projector = torch.nn.Linear(self.audio_config.hidden_size, self.text_config.hidden_size)
206
+ self.text_projector = torch.nn.Linear(self.text_config.hidden_size, self.text_config.hidden_size)
207
+
208
+ # Avg Pool
209
+ self.audio_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size)
210
+ self.text_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size)
211
+
212
+ # output dimensions of wav2vec2 and bert are 768 and 1024 respectively
213
+ cls_dim = min(self.audio_config.hidden_size, self.text_config.hidden_size)
214
+ self.classifier = torch.nn.Linear(
215
+ (cls_dim * 2) // self.config.kernel_size, self.config.num_labels
216
+ )
217
+ self.init_weights()
218
+
219
+ def forward(
220
+ self,
221
+ input_ids=None,
222
+ input_values=None,
223
+ text_attention_mask=None,
224
+ audio_attention_mask=None,
225
+ token_type_ids=None,
226
+ position_ids=None,
227
+ head_mask=None,
228
+ inputs_embeds=None,
229
+ labels=None,
230
+ output_attentions=None,
231
+ output_hidden_states=None,
232
+ return_dict=True,
233
+ ):
234
+ """Forward method for multimodal model for sequence classification task (e.g. text + audio)
235
+
236
+ Args:
237
+ input_ids (torch.LongTensor, optional): input ids. Defaults to None.
238
+ input_values (torch.FloatTensor, optional): input values. Defaults to None.
239
+ text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None.
240
+ audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None.
241
+ token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None.
242
+ position_ids (torch.LongTensor, optional): position ids. Defaults to None.
243
+ head_mask (torch.FloatTensor, optional): head mask. Defaults to None.
244
+ inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None.
245
+ labels (torch.LongTensor, optional): labels. Defaults to None.
246
+ output_attentions (bool, optional): output attentions. Defaults to None.
247
+ output_hidden_states (bool, optional): output hidden states. Defaults to None.
248
+ return_dict (bool, optional): return dict. Defaults to True.
249
+
250
+ Returns:
251
+ torch.FloatTensor: logits
252
+ """
253
+ audio_output = self.audio_model(
254
+ input_values=input_values,
255
+ attention_mask=audio_attention_mask,
256
+ output_attentions=output_attentions,
257
+ output_hidden_states=output_hidden_states,
258
+ return_dict=return_dict
259
+ )
260
+ text_output = self.text_model(
261
+ input_ids=input_ids,
262
+ attention_mask=text_attention_mask,
263
+ token_type_ids=token_type_ids,
264
+ position_ids=position_ids,
265
+ head_mask=head_mask,
266
+ inputs_embeds=inputs_embeds,
267
+ output_attentions=output_attentions,
268
+ output_hidden_states=output_hidden_states,
269
+ return_dict=return_dict,
270
+ )
271
+
272
+ # Mean pooling
273
+ audio_avg = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode)
274
+
275
+ # Projection
276
+ audio_proj = self.audio_projector(audio_avg)
277
+ text_proj = self.text_projector(text_output.pooler_output)
278
+
279
+ audio_mha, text_mha = self.fusion_module_1(audio_proj, text_proj)
280
+ audio_mha, text_mha = self.fusion_module_2(audio_mha, text_mha)
281
+
282
+ audio_avg = self.audio_avg_pool(audio_mha)
283
+ text_avg = self.text_avg_pool(text_mha)
284
+
285
+ fusion_output = torch.concat((audio_avg, text_avg), dim=1)
286
+
287
+ logits = self.classifier(fusion_output)
288
+ loss = None
289
+
290
+ if labels is not None:
291
+ loss = self.compute_loss(logits, labels)
292
+
293
+ return SequenceClassifierOutput(
294
+ loss=loss,
295
+ logits=logits
296
+ )
config.json ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "BertModel": {
3
+ "_name_or_path": "DeepPavlov/rubert-base-cased",
4
+ "add_cross_attention": false,
5
+ "architectures": [
6
+ "BertModel"
7
+ ],
8
+ "attention_probs_dropout_prob": 0.1,
9
+ "bad_words_ids": null,
10
+ "begin_suppress_tokens": null,
11
+ "bos_token_id": null,
12
+ "chunk_size_feed_forward": 0,
13
+ "classifier_dropout": null,
14
+ "cross_attention_hidden_size": null,
15
+ "decoder_start_token_id": null,
16
+ "directionality": "bidi",
17
+ "diversity_penalty": 0.0,
18
+ "do_sample": false,
19
+ "early_stopping": false,
20
+ "encoder_no_repeat_ngram_size": 0,
21
+ "eos_token_id": null,
22
+ "exponential_decay_length_penalty": null,
23
+ "finetuning_task": null,
24
+ "forced_bos_token_id": null,
25
+ "forced_eos_token_id": null,
26
+ "hidden_act": "gelu",
27
+ "hidden_dropout_prob": 0.1,
28
+ "hidden_size": 768,
29
+ "id2label": {
30
+ "0": "LABEL_0",
31
+ "1": "LABEL_1"
32
+ },
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 3072,
35
+ "is_decoder": false,
36
+ "is_encoder_decoder": false,
37
+ "label2id": {
38
+ "LABEL_0": 0,
39
+ "LABEL_1": 1
40
+ },
41
+ "layer_norm_eps": 1e-12,
42
+ "length_penalty": 1.0,
43
+ "max_length": 20,
44
+ "max_position_embeddings": 512,
45
+ "min_length": 0,
46
+ "model_type": "bert",
47
+ "no_repeat_ngram_size": 0,
48
+ "num_attention_heads": 12,
49
+ "num_beam_groups": 1,
50
+ "num_beams": 1,
51
+ "num_hidden_layers": 12,
52
+ "num_return_sequences": 1,
53
+ "output_attentions": false,
54
+ "output_hidden_states": false,
55
+ "output_past": true,
56
+ "output_scores": false,
57
+ "pad_token_id": 0,
58
+ "pooler_fc_size": 768,
59
+ "pooler_num_attention_heads": 12,
60
+ "pooler_num_fc_layers": 3,
61
+ "pooler_size_per_head": 128,
62
+ "pooler_type": "first_token_transform",
63
+ "position_embedding_type": "absolute",
64
+ "prefix": null,
65
+ "problem_type": null,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.27.4",
84
+ "type_vocab_size": 2,
85
+ "typical_p": 1.0,
86
+ "use_bfloat16": false,
87
+ "use_cache": true,
88
+ "vocab_size": 119547
89
+ },
90
+ "WavLMModel": {
91
+ "_name_or_path": "jonatasgrosman/exp_w2v2t_ru_wavlm_s363",
92
+ "activation_dropout": 0.05,
93
+ "adapter_kernel_size": 3,
94
+ "adapter_stride": 2,
95
+ "add_adapter": false,
96
+ "add_cross_attention": false,
97
+ "apply_spec_augment": true,
98
+ "architectures": [
99
+ "WavLMModel"
100
+ ],
101
+ "attention_dropout": 0.05,
102
+ "bad_words_ids": null,
103
+ "begin_suppress_tokens": null,
104
+ "bos_token_id": 1,
105
+ "chunk_size_feed_forward": 0,
106
+ "classifier_proj_size": 256,
107
+ "codevector_dim": 768,
108
+ "contrastive_logits_temperature": 0.1,
109
+ "conv_bias": false,
110
+ "conv_dim": [
111
+ 512,
112
+ 512,
113
+ 512,
114
+ 512,
115
+ 512,
116
+ 512,
117
+ 512
118
+ ],
119
+ "conv_kernel": [
120
+ 10,
121
+ 3,
122
+ 3,
123
+ 3,
124
+ 3,
125
+ 2,
126
+ 2
127
+ ],
128
+ "conv_stride": [
129
+ 5,
130
+ 2,
131
+ 2,
132
+ 2,
133
+ 2,
134
+ 2,
135
+ 2
136
+ ],
137
+ "cross_attention_hidden_size": null,
138
+ "ctc_loss_reduction": "sum",
139
+ "ctc_zero_infinity": false,
140
+ "decoder_start_token_id": null,
141
+ "diversity_loss_weight": 0.1,
142
+ "diversity_penalty": 0.0,
143
+ "do_sample": false,
144
+ "do_stable_layer_norm": true,
145
+ "early_stopping": false,
146
+ "encoder_no_repeat_ngram_size": 0,
147
+ "eos_token_id": 2,
148
+ "exponential_decay_length_penalty": null,
149
+ "feat_extract_activation": "gelu",
150
+ "feat_extract_dropout": 0.0,
151
+ "feat_extract_norm": "layer",
152
+ "feat_proj_dropout": 0.05,
153
+ "feat_quantizer_dropout": 0.0,
154
+ "final_dropout": 0.05,
155
+ "finetuning_task": null,
156
+ "forced_bos_token_id": null,
157
+ "forced_eos_token_id": null,
158
+ "gradient_checkpointing": false,
159
+ "hidden_act": "gelu",
160
+ "hidden_dropout": 0.05,
161
+ "hidden_size": 1024,
162
+ "id2label": {
163
+ "0": "LABEL_0",
164
+ "1": "LABEL_1"
165
+ },
166
+ "initializer_range": 0.02,
167
+ "intermediate_size": 4096,
168
+ "is_decoder": false,
169
+ "is_encoder_decoder": false,
170
+ "label2id": {
171
+ "LABEL_0": 0,
172
+ "LABEL_1": 1
173
+ },
174
+ "layer_norm_eps": 1e-05,
175
+ "layerdrop": 0.05,
176
+ "length_penalty": 1.0,
177
+ "mask_channel_length": 10,
178
+ "mask_channel_min_space": 1,
179
+ "mask_channel_other": 0.0,
180
+ "mask_channel_prob": 0.0,
181
+ "mask_channel_selection": "static",
182
+ "mask_feature_length": 10,
183
+ "mask_feature_min_masks": 0,
184
+ "mask_feature_prob": 0.0,
185
+ "mask_time_length": 10,
186
+ "mask_time_min_masks": 2,
187
+ "mask_time_min_space": 1,
188
+ "mask_time_other": 0.0,
189
+ "mask_time_prob": 0.05,
190
+ "mask_time_selection": "static",
191
+ "max_bucket_distance": 800,
192
+ "max_length": 20,
193
+ "min_length": 0,
194
+ "model_type": "wavlm",
195
+ "no_repeat_ngram_size": 0,
196
+ "num_adapter_layers": 3,
197
+ "num_attention_heads": 16,
198
+ "num_beam_groups": 1,
199
+ "num_beams": 1,
200
+ "num_buckets": 320,
201
+ "num_codevector_groups": 2,
202
+ "num_codevectors_per_group": 320,
203
+ "num_conv_pos_embedding_groups": 16,
204
+ "num_conv_pos_embeddings": 128,
205
+ "num_ctc_classes": 80,
206
+ "num_feat_extract_layers": 7,
207
+ "num_hidden_layers": 24,
208
+ "num_negatives": 100,
209
+ "num_return_sequences": 1,
210
+ "output_attentions": false,
211
+ "output_hidden_size": 1024,
212
+ "output_hidden_states": false,
213
+ "output_scores": false,
214
+ "pad_token_id": 0,
215
+ "prefix": null,
216
+ "problem_type": null,
217
+ "proj_codevector_dim": 768,
218
+ "pruned_heads": {},
219
+ "remove_invalid_values": false,
220
+ "repetition_penalty": 1.0,
221
+ "replace_prob": 0.5,
222
+ "return_dict": true,
223
+ "return_dict_in_generate": false,
224
+ "sep_token_id": null,
225
+ "suppress_tokens": null,
226
+ "task_specific_params": null,
227
+ "tdnn_dilation": [
228
+ 1,
229
+ 2,
230
+ 3,
231
+ 1,
232
+ 1
233
+ ],
234
+ "tdnn_dim": [
235
+ 512,
236
+ 512,
237
+ 512,
238
+ 512,
239
+ 1500
240
+ ],
241
+ "tdnn_kernel": [
242
+ 5,
243
+ 3,
244
+ 3,
245
+ 1,
246
+ 1
247
+ ],
248
+ "temperature": 1.0,
249
+ "tf_legacy_loss": false,
250
+ "tie_encoder_decoder": false,
251
+ "tie_word_embeddings": true,
252
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
253
+ "top_k": 50,
254
+ "top_p": 1.0,
255
+ "torch_dtype": "float32",
256
+ "torchscript": false,
257
+ "transformers_version": "4.27.4",
258
+ "typical_p": 1.0,
259
+ "use_bfloat16": false,
260
+ "use_weighted_layer_sum": false,
261
+ "vocab_size": 40,
262
+ "xvector_output_dim": 512
263
+ },
264
+ "_name_or_path": "Ar4ikov/wavlm-bert-base-fusion-k-2-s-resd-1",
265
+ "architectures": [
266
+ "WavLMBertForSequenceClassification"
267
+ ],
268
+ "auto_map": {
269
+ "AutoConfig": "audio_text_multimodal.WavLMBertConfig",
270
+ "AutoModel": "audio_text_multimodal.WavLMBertForSequenceClassification"
271
+ },
272
+ "f_dropout": 0.1,
273
+ "id2label": {
274
+ "0": "anger",
275
+ "1": "disgust",
276
+ "2": "enthusiasm",
277
+ "3": "fear",
278
+ "4": "happiness",
279
+ "5": "neutral",
280
+ "6": "sadness"
281
+ },
282
+ "kernel_size": 1,
283
+ "label2id": {
284
+ "anger": 0,
285
+ "disgust": 1,
286
+ "enthusiasm": 2,
287
+ "fear": 3,
288
+ "happiness": 4,
289
+ "neutral": 5,
290
+ "sadness": 6
291
+ },
292
+ "num_heads": 8,
293
+ "pooling_mode": "mean",
294
+ "problem_type": "single_label_classification",
295
+ "torch_dtype": "float32",
296
+ "transformers_version": "4.27.4"
297
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "processor_class": "Wav2Vec2Processor",
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 16000
10
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f71c4d1f4ec6a54ed2b9f7ada9a87f0bb52cc5f235acef15af3c2237e34b025
3
+ size 2016857477
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": false,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "never_split": null,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "special_tokens_map_file": "/home/ar4ikov/.cache/huggingface/hub/models--DeepPavlov--rubert-base-cased/snapshots/4036cab694767a299f2b9e6492909664d9414229/special_tokens_map.json",
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
vocab.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "'": 5,
3
+ "-": 6,
4
+ "</s>": 2,
5
+ "<pad>": 0,
6
+ "<s>": 1,
7
+ "<unk>": 3,
8
+ "|": 4,
9
+ "а": 7,
10
+ "б": 8,
11
+ "в": 9,
12
+ "г": 10,
13
+ "д": 11,
14
+ "е": 12,
15
+ "ж": 13,
16
+ "з": 14,
17
+ "и": 15,
18
+ "й": 16,
19
+ "к": 17,
20
+ "л": 18,
21
+ "м": 19,
22
+ "н": 20,
23
+ "о": 21,
24
+ "п": 22,
25
+ "р": 23,
26
+ "с": 24,
27
+ "т": 25,
28
+ "у": 26,
29
+ "ф": 27,
30
+ "х": 28,
31
+ "ц": 29,
32
+ "ч": 30,
33
+ "ш": 31,
34
+ "щ": 32,
35
+ "ъ": 33,
36
+ "ы": 34,
37
+ "ь": 35,
38
+ "э": 36,
39
+ "ю": 37,
40
+ "я": 38,
41
+ "ё": 39
42
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff