saikrishna32 commited on
Commit
4aa6431
1 Parent(s): 40fe490

added requirements

Browse files
Files changed (5) hide show
  1. adapter.py +73 -0
  2. fold_1.pt +3 -0
  3. requirements.txt +124 -0
  4. utils.py +321 -0
  5. wavlm_plus.py +253 -0
adapter.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # https://github.com/jxhe/unify-parameter-efficient-tuning
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class Adapter(nn.Module):
12
+ def __init__(
13
+ self,
14
+ config=None,
15
+ d_model=768,
16
+ bottleneck=None,
17
+ dropout=0.0,
18
+ init_option="lora",
19
+ adapter_scalar="1.0",
20
+ adapter_layernorm_option="none"
21
+ ):
22
+ super().__init__()
23
+ self.n_embd = config.d_model if d_model is None else d_model
24
+ self.down_size = config.attn_bn if bottleneck is None else bottleneck
25
+
26
+ #_before
27
+ self.adapter_layernorm_option = adapter_layernorm_option
28
+
29
+ self.adapter_layer_norm_before = None
30
+ if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
31
+ self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
32
+
33
+ if adapter_scalar == "learnable_scalar":
34
+ self.scale = nn.Parameter(torch.ones(1))
35
+ else:
36
+ self.scale = float(adapter_scalar)
37
+
38
+ self.down_proj = nn.Linear(self.n_embd, self.down_size)
39
+ self.non_linear_func = nn.ReLU()
40
+ self.up_proj = nn.Linear(self.down_size, self.n_embd)
41
+
42
+ self.dropout = dropout
43
+ if init_option == "bert":
44
+ raise NotImplementedError
45
+ elif init_option == "lora":
46
+ with torch.no_grad():
47
+ nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
48
+ nn.init.zeros_(self.up_proj.weight)
49
+ nn.init.zeros_(self.down_proj.bias)
50
+ nn.init.zeros_(self.up_proj.bias)
51
+
52
+ def forward(self, x, add_residual=True, residual=None):
53
+ residual = x if residual is None else residual
54
+ if self.adapter_layernorm_option == 'in':
55
+ x = self.adapter_layer_norm_before(x)
56
+
57
+ down = self.down_proj(x)
58
+
59
+ down = self.non_linear_func(down)
60
+ down = nn.functional.dropout(down, p=self.dropout, training=self.training)
61
+ up = self.up_proj(down)
62
+
63
+ up = up * self.scale
64
+
65
+ if self.adapter_layernorm_option == 'out':
66
+ up = self.adapter_layer_norm_before(up)
67
+
68
+ if add_residual:
69
+ output = up + residual
70
+ else:
71
+ output = up
72
+
73
+ return output
fold_1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a05b751ec5c090af650e7aa96278e5d2d77226321be92abb4976103102a2d99
3
+ size 379212283
requirements.txt ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aif360==0.5.0
2
+ absl-py==1.4.0
3
+ aiohttp==3.8.1
4
+ aiosignal==1.3.1
5
+ antlr4-python3-runtime==4.9.3
6
+ appdirs==1.4.4
7
+ async-timeout==4.0.2
8
+ attrs==22.2.0
9
+ audiomentations==0.28.0
10
+ audioread==3.0.0
11
+ blinker==1.5
12
+ Bottleneck==1.3.5
13
+ brotlipy==0.7.0
14
+ cachetools==5.3.0
15
+ certifi==2023.5.7
16
+ cffi==1.15.1
17
+ charset-normalizer==2.0.4
18
+ click==8.1.3
19
+ cloudpickle==2.2.1
20
+ cmake==3.26.4
21
+ contourpy==1.0.7
22
+ cryptography==38.0.4
23
+ cvxopt==1.3.0
24
+ cvxpy==1.3.0
25
+ cycler==0.11.0
26
+ cylp==0.91.5
27
+ decorator==5.1.1
28
+ ecos==2.0.12
29
+ exceptiongroup==1.1.0
30
+ ffmpeg-python==0.2.0
31
+ filelock==3.9.0
32
+ flit_core==3.6.0
33
+ fonttools==4.38.0
34
+ frozenlist==1.3.3
35
+ fsspec==2023.6.0
36
+ future==0.18.3
37
+ google-auth==2.16.0
38
+ google-auth-oauthlib==0.4.6
39
+ grpcio==1.42.0
40
+ holisticai==0.3.0
41
+ huggingface-hub==0.15.1
42
+ HyperPyYAML==1.2.1
43
+ idna==3.4
44
+ importlib-metadata==6.0.0
45
+ importlib-resources==5.10.2
46
+ iniconfig==2.0.0
47
+ joblib==1.2.0
48
+ kiwisolver==1.4.4
49
+ librosa==0.9.2
50
+ lit==16.0.6
51
+ llvmlite==0.39.1
52
+ loralib==0.1.1
53
+ Markdown==3.4.1
54
+ matplotlib==3.7.0
55
+ memory-profiler==0.61.0
56
+ more-itertools==9.1.0
57
+ multidict==6.0.2
58
+ numba==0.56.4
59
+ numexpr==2.8.4
60
+ numpy==1.23.5
61
+ oauthlib==3.2.2
62
+ omegaconf==2.3.0
63
+ osqp==0.6.2.post8
64
+ packaging==22.0
65
+ pandas==1.5.2
66
+ Pillow==9.3.0
67
+ pip==22.3.1
68
+ platformdirs==3.1.1
69
+ pluggy==1.0.0
70
+ pooch==1.6.0
71
+ protobuf==3.15.8
72
+ psutil==5.9.4
73
+ pyasn1==0.4.8
74
+ pyasn1-modules==0.2.7
75
+ pycparser==2.21
76
+ PyJWT==2.6.0
77
+ pyOpenSSL==22.0.0
78
+ pyparsing==3.1.0
79
+ PySocks==1.7.1
80
+ pytest==7.2.1
81
+ python-dateutil==2.8.2
82
+ pytz==2022.7
83
+ pyu2f==0.1.5
84
+ PyYAML==6.0
85
+ qdldl==0.1.5.post3
86
+ regex==2022.7.9
87
+ requests==2.28.1
88
+ requests-oauthlib==1.3.1
89
+ resampy==0.4.2
90
+ rsa==4.9
91
+ ruamel.yaml==0.17.28
92
+ ruamel.yaml.clib==0.2.7
93
+ s3prl==0.4.10
94
+ safetensors==0.3.1
95
+ scikit-learn==1.2.2
96
+ scipy==1.10.0
97
+ scs==3.2.2
98
+ seaborn==0.12.2
99
+ sentencepiece==0.1.99
100
+ setuptools==59.5.0
101
+ shap==0.41.0
102
+ six==1.16.0
103
+ slicer==0.0.7
104
+ soundfile==0.12.0
105
+ speechbrain==0.5.14
106
+ tempeh==0.1.12
107
+ threadpoolctl==3.1.0
108
+ tiktoken==0.3.1
109
+ tomli==2.0.1
110
+ torch==1.12.1
111
+ torchaudio==0.12.1
112
+ torchvision==0.13.1
113
+ tqdm==4.64.1
114
+ transformers==4.30.2
115
+ triton==2.0.0
116
+ typing_extensions==4.4.0
117
+ urllib3==1.26.14
118
+ Werkzeug==2.1.2
119
+ wheel==0.37.1
120
+ whisper==1.1.10
121
+ yarl==1.7.2
122
+ zipp==3.13.0
123
+ bokeh==2.4.3
124
+ streamlit_bokeh_events
utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import transformers
6
+ import argparse, logging
7
+
8
+
9
+ transformers.logging.set_verbosity(40)
10
+
11
+ logging.basicConfig(
12
+ format='%(asctime)s %(levelname)-3s ==> %(message)s',
13
+ level=logging.INFO,
14
+ datefmt='%Y-%m-%d %H:%M:%S'
15
+ )
16
+
17
+ def set_seed(seed):
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ np.random.seed(seed)
23
+ random.seed(seed)
24
+
25
+ def get_results(input_dict):
26
+ return_dict = dict()
27
+ return_dict["uar"] = input_dict["uar"]
28
+ return_dict["acc"] = input_dict["acc"]
29
+ return_dict["loss"] = input_dict["loss"]
30
+ return return_dict
31
+
32
+ def log_epoch_result(
33
+ result_hist_dict: dict,
34
+ epoch: int,
35
+ train_result: dict,
36
+ dev_result: dict,
37
+ test_result: dict,
38
+ log_dir: str,
39
+ fold_idx: int
40
+ ):
41
+ # read result
42
+ result_hist_dict[epoch] = dict()
43
+ result_hist_dict[epoch]["train"] = get_results(train_result)
44
+ result_hist_dict[epoch]["dev"] = get_results(dev_result)
45
+ result_hist_dict[epoch]["test"] = get_results(test_result)
46
+
47
+ # dump the dictionary
48
+ jsonString = json.dumps(result_hist_dict, indent=4)
49
+ jsonFile = open(str(log_dir.joinpath(f'fold_{fold_idx}.json')), "w")
50
+ jsonFile.write(jsonString)
51
+ jsonFile.close()
52
+
53
+
54
+ def log_best_result(
55
+ result_hist_dict: dict,
56
+ epoch: int,
57
+ best_dev_uar: float,
58
+ best_dev_acc: float,
59
+ best_test_uar: float,
60
+ best_test_acc: float,
61
+ log_dir: str,
62
+ fold_idx: int
63
+ ):
64
+ # log best result
65
+ result_hist_dict["best"] = dict()
66
+ result_hist_dict["best"]["dev"], result_hist_dict["best"]["test"] = dict(), dict()
67
+ result_hist_dict["best"]["dev"]["uar"] = best_dev_uar
68
+ result_hist_dict["best"]["dev"]["acc"] = best_dev_acc
69
+ result_hist_dict["best"]["test"]["uar"] = best_test_uar
70
+ result_hist_dict["best"]["test"]["acc"] = best_test_acc
71
+
72
+ # save results for this fold
73
+ jsonString = json.dumps(result_hist_dict, indent=4)
74
+ jsonFile = open(str(log_dir.joinpath(f'fold_{fold_idx}.json')), "w")
75
+ jsonFile.write(jsonString)
76
+ jsonFile.close()
77
+
78
+ def parse_finetune_args():
79
+ # parser
80
+ parser = argparse.ArgumentParser(description='emo2vec finetune experiments')
81
+ parser.add_argument(
82
+ '--data_dir',
83
+ default='/media/data/projects/speech-privacy/trust-ser/audio',
84
+ type=str,
85
+ help='raw audio path'
86
+ )
87
+
88
+ parser.add_argument(
89
+ '--model_dir',
90
+ default='/media/data/projects/speech-privacy/trust-ser/model',
91
+ type=str,
92
+ help='model save path'
93
+ )
94
+
95
+ parser.add_argument(
96
+ '--split_dir',
97
+ default='/media/data/projects/speech-privacy/trust-ser/train_split',
98
+ type=str,
99
+ help='train split path'
100
+ )
101
+
102
+ parser.add_argument(
103
+ '--log_dir',
104
+ default='log/finetune',
105
+ type=str,
106
+ help='model save path'
107
+ )
108
+
109
+ parser.add_argument(
110
+ '--uar_dir',
111
+ default='log/uar',
112
+ type=str,
113
+ help='model uar history'
114
+ )
115
+
116
+ parser.add_argument(
117
+ '--attack_dir',
118
+ default='/media/data/projects/speech-privacy/trust-ser/attack',
119
+ type=str,
120
+ help='attack data'
121
+ )
122
+
123
+ parser.add_argument(
124
+ '--privacy_attack_dir',
125
+ default='/media/data/projects/speech-privacy/trust-ser/privacy',
126
+ type=str,
127
+ help='privacy attack method data'
128
+ )
129
+
130
+ parser.add_argument(
131
+ '--privacy_attack',
132
+ default='gender',
133
+ type=str,
134
+ help='Privacy attack method'
135
+ )
136
+
137
+ parser.add_argument(
138
+ '--fairness_dir',
139
+ default='/media/data/projects/speech-privacy/trust-ser/fairness',
140
+ type=str,
141
+ help='model save path'
142
+ )
143
+
144
+ parser.add_argument(
145
+ '--sustainability_dir',
146
+ default='/media/data/projects/speech-privacy/trust-ser/sustainability',
147
+ type=str,
148
+ help='model save path'
149
+ )
150
+
151
+ parser.add_argument(
152
+ '--attack_method',
153
+ default='pgd',
154
+ type=str,
155
+ help='attack method'
156
+ )
157
+
158
+ parser.add_argument(
159
+ '--pretrain_model',
160
+ default='wav2vec2_0',
161
+ type=str,
162
+ help="pretrained model type"
163
+ )
164
+
165
+ parser.add_argument(
166
+ '--finetune',
167
+ default='frozen',
168
+ type=str,
169
+ help="partial finetune or not"
170
+ )
171
+
172
+ parser.add_argument(
173
+ '--learning_rate',
174
+ default=0.0002,
175
+ type=float,
176
+ help="learning rate",
177
+ )
178
+
179
+ parser.add_argument(
180
+ '--num_epochs',
181
+ default=50,
182
+ type=int,
183
+ help="total training rounds",
184
+ )
185
+
186
+ parser.add_argument(
187
+ '--optimizer',
188
+ default='adam',
189
+ type=str,
190
+ help="optimizer",
191
+ )
192
+
193
+ parser.add_argument(
194
+ '--dataset',
195
+ default="iemocap",
196
+ type=str,
197
+ help="Dataset name",
198
+ )
199
+
200
+ parser.add_argument(
201
+ '--audio_duration',
202
+ default=6,
203
+ type=int,
204
+ help="audio length for training"
205
+ )
206
+
207
+ parser.add_argument(
208
+ '--downstream_model',
209
+ default='rnn',
210
+ type=str,
211
+ help="model type"
212
+ )
213
+
214
+ parser.add_argument(
215
+ '--num_layers',
216
+ default=1,
217
+ type=int,
218
+ help="num of layers",
219
+ )
220
+
221
+ parser.add_argument(
222
+ '--snr',
223
+ default=45,
224
+ type=int,
225
+ help="SNR of the audio",
226
+ )
227
+
228
+ parser.add_argument(
229
+ '--conv_layers',
230
+ default=3,
231
+ type=int,
232
+ help="num of conv layers",
233
+ )
234
+
235
+ parser.add_argument(
236
+ '--hidden_size',
237
+ default=256,
238
+ type=int,
239
+ help="hidden size",
240
+ )
241
+
242
+ parser.add_argument(
243
+ '--pooling',
244
+ default='att',
245
+ type=str,
246
+ help="pooling method: att, average",
247
+ )
248
+
249
+ parser.add_argument(
250
+ '--norm',
251
+ default='nonorm',
252
+ type=str,
253
+ help="normalization or not",
254
+ )
255
+
256
+ parser.add_argument(
257
+ '--finetune_method',
258
+ default='finetune',
259
+ type=str,
260
+ help='finetune method: adapter, embedding prompt, input prompt'
261
+ )
262
+
263
+ parser.add_argument(
264
+ '--adapter_hidden_dim',
265
+ default=128,
266
+ type=int,
267
+ help='adapter dimension'
268
+ )
269
+
270
+ parser.add_argument(
271
+ '--finetune_emb',
272
+ default="all",
273
+ type=str,
274
+ help='adapter dimension'
275
+ )
276
+
277
+ parser.add_argument(
278
+ '--embedding_prompt_dim',
279
+ default=5,
280
+ type=int,
281
+ help='adapter dimension'
282
+ )
283
+
284
+ parser.add_argument(
285
+ '--lora_rank',
286
+ default=16,
287
+ type=int,
288
+ help='lora rank'
289
+ )
290
+
291
+ parser.add_argument(
292
+ '--LPF',
293
+ default=False,
294
+ type=bool,
295
+ help='need Low pass filter on Audio'
296
+ )
297
+
298
+ parser.add_argument(
299
+ '--HPF',
300
+ default=False,
301
+ type=bool,
302
+ help='need High pass filter on Audio'
303
+ )
304
+
305
+ args = parser.parse_args()
306
+ if args.finetune_method == "adapter" or args.finetune_method == "adapter_l":
307
+ setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.adapter_hidden_dim}'
308
+ elif args.finetune_method == "embedding_prompt":
309
+ setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.embedding_prompt_dim}'
310
+ elif args.finetune_method == "lora":
311
+ setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.lora_rank}'
312
+ elif args.finetune_method == "finetune":
313
+ setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}'
314
+ elif args.finetune_method == "combined":
315
+ setting = f'lr{str(args.learning_rate).replace(".", "")}_ep{args.num_epochs}_{args.finetune_method}_{args.adapter_hidden_dim}_{args.embedding_prompt_dim}_{args.lora_rank}'
316
+ args.setting = setting
317
+ if args.finetune_emb != "all":
318
+ args.setting = args.setting + "_avgtok"
319
+
320
+ return args
321
+
wavlm_plus.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of the code was referenced from SUPERB: https://github.com/s3prl/s3prl
2
+ # and https://github.com/wngh1187/IPET/blob/main/Speechcommands_V2/W2V2/models/W2V2.py
3
+ import os
4
+ import pdb
5
+ import copy
6
+ import torch
7
+ import argparse
8
+ import numpy as np
9
+ import loralib as lora
10
+ import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
11
+ import transformers.models.wavlm.modeling_wavlm as wavlm
12
+
13
+ from functools import lru_cache
14
+ from torchaudio.compliance import kaldi
15
+
16
+ from torch import nn
17
+ from adapter import Adapter
18
+ from collections import OrderedDict
19
+ from typing import Optional, Callable
20
+ from torch.nn import functional as F
21
+ from torch.nn.functional import normalize
22
+ from transformers import WavLMModel
23
+
24
+ class WavLMEncoderLayer(nn.Module):
25
+ def __init__(self, config, has_relative_position_bias: bool = True):
26
+ super().__init__()
27
+ self.attention = wavlm.WavLMAttention(
28
+ embed_dim=config.hidden_size,
29
+ num_heads=config.num_attention_heads,
30
+ dropout=config.attention_dropout,
31
+ num_buckets=config.num_buckets,
32
+ max_distance=config.max_bucket_distance,
33
+ has_relative_position_bias=has_relative_position_bias,
34
+ )
35
+ self.dropout = nn.Dropout(config.hidden_dropout)
36
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
37
+ self.feed_forward = wavlm.WavLMFeedForward(config)
38
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
39
+ self.config = config
40
+
41
+ if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined":
42
+ self.embed_prompt = nn.Parameter(torch.randn([1, self.config.embedding_prompt_dim, 768]))
43
+ nn.init.xavier_uniform_(self.embed_prompt)
44
+ if self.config.finetune_method == "lora" or self.config.finetune_method == "combined":
45
+ self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank)
46
+ self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank)
47
+
48
+ if self.config.finetune_method == "adapter" or self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined":
49
+ self.adapter = Adapter(
50
+ config,
51
+ dropout=0.1,
52
+ bottleneck=config.adapter_hidden_dim,
53
+ adapter_scalar=0.1
54
+ )
55
+
56
+ def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
57
+ if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined":
58
+ hidden_states = torch.cat((self.embed_prompt.repeat(hidden_states.size(0), 1, 1), hidden_states), dim=1)
59
+
60
+ attn_residual = hidden_states
61
+ hidden_states, attn_weights, position_bias = self.attention(
62
+ hidden_states,
63
+ attention_mask=attention_mask,
64
+ position_bias=position_bias,
65
+ output_attentions=output_attentions,
66
+ index=index,
67
+ )
68
+ hidden_states = self.dropout(hidden_states)
69
+ hidden_states = attn_residual + hidden_states
70
+
71
+ # Adapter
72
+ if self.config.finetune_method == "adapter":
73
+ adapt_h = self.adapter(hidden_states)
74
+
75
+ hidden_states = self.layer_norm(hidden_states)
76
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
77
+ if self.config.finetune_method == "adapter":
78
+ hidden_states = hidden_states + adapt_h
79
+ if self.config.finetune_method == "adapter_l" or self.config.finetune_method == "combined":
80
+ hidden_states = hidden_states + self.adapter(hidden_states)
81
+ hidden_states = self.final_layer_norm(hidden_states)
82
+ if self.config.finetune_method == "embedding_prompt" or self.config.finetune_method == "combined":
83
+ hidden_states = hidden_states[:, self.config.embedding_prompt_dim:, :]
84
+ outputs = (hidden_states, position_bias)
85
+
86
+ if output_attentions:
87
+ outputs += (attn_weights,)
88
+
89
+ return outputs
90
+
91
+ class WavLMWrapper(nn.Module):
92
+ def __init__(
93
+ self,
94
+ args,
95
+ hidden_dim=256,
96
+ output_class_num=7
97
+ ):
98
+ super(WavLMWrapper, self).__init__()
99
+ # 1. We Load the model first with weights
100
+ self.args = args
101
+ self.backbone_model = WavLMModel.from_pretrained(
102
+ "microsoft/wavlm-base-plus",
103
+ output_hidden_states=True
104
+ )
105
+ state_dict = self.backbone_model.state_dict()
106
+ # 2. Read the model config
107
+ self.model_config = self.backbone_model.config
108
+ self.model_config.finetune_method = args.finetune_method
109
+ self.model_config.adapter_hidden_dim = args.adapter_hidden_dim
110
+ self.model_config.embedding_prompt_dim = args.embedding_prompt_dim
111
+ self.model_config.lora_rank = args.lora_rank
112
+
113
+ # 3. Config encoder layers with adapter or embedding prompt
114
+ # pdb.set_trace()
115
+ self.backbone_model.encoder.layers = nn.ModuleList(
116
+ [WavLMEncoderLayer(self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)]
117
+ )
118
+ # 4. Load the weights back
119
+ msg = self.backbone_model.load_state_dict(state_dict, strict=False)
120
+ # 5. Freeze the weights
121
+ if self.args.finetune_method == "adapter" or self.args.finetune_method == "adapter_l" or self.args.finetune_method == "embedding_prompt" or self.args.finetune_method == "finetune" or self.args.finetune_method == "lora" or self.args.finetune_method == "combined":
122
+ for name, p in self.backbone_model.named_parameters():
123
+ if name in msg.missing_keys: p.requires_grad = True
124
+ else: p.requires_grad = False
125
+ self.finetune_method = self.args.finetune_method
126
+
127
+ # 6. Downstream models
128
+ self.model_seq = nn.Sequential(
129
+ nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0),
130
+ nn.ReLU(),
131
+ nn.Dropout(p=0.1),
132
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0),
133
+ nn.ReLU(),
134
+ nn.Dropout(p=0.1),
135
+ nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0)
136
+ )
137
+ self.weights = nn.Parameter(torch.zeros(self.model_config.num_hidden_layers))
138
+
139
+ # self.out_layer = nn.Sequential(
140
+ # nn.Linear(hidden_dim, hidden_dim),
141
+ # nn.ReLU(),
142
+ # nn.Linear(hidden_dim, output_class_num),
143
+ # )
144
+ self.out_layer = nn.Sequential(
145
+ nn.Linear(hidden_dim, hidden_dim),
146
+ nn.ReLU(),
147
+ nn.Linear(hidden_dim, 2),
148
+ nn.Sigmoid()
149
+ )
150
+
151
+ def forward(self, x, length=None):
152
+ # 1. feature extraction and projections
153
+ with torch.no_grad():
154
+ x = self.backbone_model.feature_extractor(x)
155
+ x = x.transpose(1, 2) # New version of huggingface
156
+ x, _ = self.backbone_model.feature_projection(x) # New version of huggingface
157
+
158
+ # 2. get length and mask
159
+ if length is not None:
160
+ length = self.get_feat_extract_output_lengths(length.detach().cpu())
161
+ length = length.cuda()
162
+
163
+ # 3. transformer encoding features
164
+ x = self.backbone_model.encoder(
165
+ x, output_hidden_states=True
166
+ ).hidden_states
167
+
168
+ # 4. stacked feature
169
+ stacked_feature = torch.stack(x, dim=0)[1:]
170
+
171
+ # 5. Weighted sum
172
+ _, *origin_shape = stacked_feature.shape
173
+ # Return transformer enc outputs [num_enc_layers, B, T, D]
174
+ stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1)
175
+ norm_weights = F.softmax(self.weights, dim=-1)
176
+
177
+ # Perform weighted average
178
+ weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
179
+ features = weighted_feature.view(*origin_shape)
180
+
181
+ # 6. Pass the weighted average to point-wise 1D Conv
182
+ # B x T x D
183
+ features = features.transpose(1, 2)
184
+ features = self.model_seq(features)
185
+ features = features.transpose(1, 2)
186
+
187
+ # 7. Pooling
188
+ if length is not None:
189
+ masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1)
190
+ masks = masks.float()
191
+ features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1)
192
+ else:
193
+ features = torch.mean(features, dim=1)
194
+
195
+ # 8. Output predictions
196
+ # B x D
197
+ predicted = self.out_layer(features)
198
+ return predicted
199
+
200
+ # From huggingface
201
+ def get_feat_extract_output_lengths(self, input_length):
202
+ """
203
+ Computes the output length of the convolutional layers
204
+ """
205
+ def _conv_out_length(input_length, kernel_size, stride):
206
+ # 1D convolutional layer output length formula taken
207
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
208
+ return (input_length - kernel_size) // stride + 1
209
+ for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
210
+ input_length = _conv_out_length(input_length, kernel_size, stride)
211
+ return input_length
212
+
213
+ def prepare_mask(length, shape, dtype):
214
+ # Modified from huggingface
215
+ mask = torch.zeros(
216
+ shape, dtype=dtype
217
+ )
218
+ # these two operations makes sure that all values
219
+ # before the output lengths indices are attended to
220
+ mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
221
+ mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
222
+ return mask
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ parser = argparse.ArgumentParser(description='emo2vec finetune experiments')
228
+ parser.add_argument(
229
+ '--finetune_method',
230
+ default='none',
231
+ type=str,
232
+ help='finetune method: adapter, embedding prompt, input prompt'
233
+ )
234
+
235
+ parser.add_argument(
236
+ '--adapter_hidden_dim',
237
+ default=128,
238
+ type=int,
239
+ help='adapter dimension'
240
+ )
241
+
242
+ parser.add_argument(
243
+ '--embedding_prompt_dim',
244
+ default=5,
245
+ type=int,
246
+ help='adapter dimension'
247
+ )
248
+
249
+ args = parser.parse_args()
250
+ model = WavLMWrapper(args)
251
+ data = torch.zeros([1, 16000])
252
+ output = model(data)
253
+ print(output.shape)