nguyenvulebinh commited on
Commit
4180506
1 Parent(s): 1e2f9b5

Upload model_handling.py

Browse files
Files changed (1) hide show
  1. model_handling.py +226 -0
model_handling.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model, AutoConfig
2
+ from torch import nn
3
+ import warnings
4
+ import torch
5
+ from transformers.modeling_outputs import CausalLMOutput
6
+ from collections import OrderedDict
7
+ from transformers import Wav2Vec2CTCTokenizer
8
+ from transformers import Wav2Vec2FeatureExtractor
9
+ from transformers import Wav2Vec2Processor
10
+
11
+ _HIDDEN_STATES_START_POSITION = 2
12
+
13
+
14
+ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+
18
+ self.wav2vec2 = Wav2Vec2Model(config)
19
+ self.dropout = nn.Dropout(config.final_dropout)
20
+
21
+ self.feature_transform = nn.Sequential(OrderedDict([
22
+ ('linear1', nn.Linear(config.hidden_size, config.hidden_size)),
23
+ ('bn1', nn.BatchNorm1d(config.hidden_size)),
24
+ ('activation1', nn.LeakyReLU()),
25
+ ('drop1', nn.Dropout(config.final_dropout)),
26
+ ('linear2', nn.Linear(config.hidden_size, config.hidden_size)),
27
+ ('bn2', nn.BatchNorm1d(config.hidden_size)),
28
+ ('activation2', nn.LeakyReLU()),
29
+ ('drop2', nn.Dropout(config.final_dropout)),
30
+ ('linear3', nn.Linear(config.hidden_size, config.hidden_size)),
31
+ ('bn3', nn.BatchNorm1d(config.hidden_size)),
32
+ ('activation3', nn.LeakyReLU()),
33
+ ('drop3', nn.Dropout(config.final_dropout))
34
+ ]))
35
+
36
+ if config.vocab_size is None:
37
+ raise ValueError(
38
+ f"You are trying to instantiate {self.__class__} with a configuration that "
39
+ "does not define the vocabulary size of the language model head. Please "
40
+ "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
41
+ "or define `vocab_size` of your model's configuration."
42
+ )
43
+ self.output_head = nn.Linear(config.hidden_size, config.vocab_size)
44
+
45
+ self.is_wav2vec_freeze = False
46
+
47
+ # Initialize weights and apply final processing
48
+ self.post_init()
49
+
50
+ def freeze_feature_extractor(self):
51
+ """
52
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
53
+ not be updated during training.
54
+ """
55
+ warnings.warn(
56
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5."
57
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
58
+ FutureWarning,
59
+ )
60
+ self.freeze_feature_encoder()
61
+
62
+ def freeze_feature_encoder(self):
63
+ """
64
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
65
+ not be updated during training.
66
+ """
67
+ self.wav2vec2.feature_extractor._freeze_parameters()
68
+
69
+ def freeze_wav2vec(self, is_freeze=True):
70
+ """
71
+ Calling this function will disable the gradient computation for the feature extractor so that its parameter
72
+ will not be updated during training.
73
+ """
74
+ if is_freeze:
75
+ self.is_wav2vec_freeze = True
76
+ for param in self.wav2vec2.parameters():
77
+ param.requires_grad = False
78
+ else:
79
+ self.is_wav2vec_freeze = False
80
+ for param in self.wav2vec2.parameters():
81
+ param.requires_grad = True
82
+ self.freeze_feature_encoder()
83
+
84
+ model_total_params = sum(p.numel() for p in self.parameters())
85
+ model_total_params_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
86
+ print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
87
+ model_total_params_trainable))
88
+
89
+ def forward(
90
+ self,
91
+ input_values,
92
+ attention_mask=None,
93
+ output_attentions=None,
94
+ output_hidden_states=None,
95
+ return_dict=None,
96
+ wav=None,
97
+ length=None,
98
+ lengths=None,
99
+ labels=None,
100
+ label_hiragana=None,
101
+ ):
102
+ r"""
103
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
104
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
105
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
106
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
107
+ config.vocab_size - 1]`.
108
+ """
109
+
110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
111
+
112
+ outputs = self.wav2vec2(
113
+ input_values,
114
+ attention_mask=attention_mask,
115
+ output_attentions=output_attentions,
116
+ output_hidden_states=output_hidden_states,
117
+ return_dict=return_dict,
118
+ )
119
+
120
+ hidden_states = outputs[0]
121
+ hidden_states = self.dropout(hidden_states)
122
+
123
+ B, T, F = hidden_states.size()
124
+ hidden_states = hidden_states.view(B * T, F)
125
+
126
+ hidden_states = self.feature_transform(hidden_states)
127
+
128
+ hidden_states = hidden_states.view(B, T, F)
129
+
130
+ logits = self.output_head(hidden_states)
131
+
132
+ loss = None
133
+ if labels is not None:
134
+
135
+ if labels.max() >= self.config.vocab_size:
136
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
137
+
138
+ # retrieve loss input_lengths from attention_mask
139
+ attention_mask = (
140
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
141
+ )
142
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
143
+
144
+ # assuming that padded tokens are filled with -100
145
+ # when not being attended to
146
+ labels_mask = labels >= 0
147
+ target_lengths = labels_mask.sum(-1)
148
+ flattened_targets = labels.masked_select(labels_mask)
149
+
150
+ # ctc_loss doesn't support fp16
151
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
152
+
153
+ with torch.backends.cudnn.flags(enabled=False):
154
+ loss = nn.functional.ctc_loss(
155
+ log_probs,
156
+ flattened_targets,
157
+ input_lengths,
158
+ target_lengths,
159
+ blank=self.config.pad_token_id,
160
+ reduction=self.config.ctc_loss_reduction,
161
+ # zero_infinity=self.config.ctc_zero_infinity,
162
+ zero_infinity=True,
163
+ )
164
+
165
+ if not return_dict:
166
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
167
+ return ((loss,) + output) if loss is not None else output
168
+
169
+ return CausalLMOutput(
170
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
171
+ )
172
+
173
+
174
+ def init_model(model_name_or_path, cache_dir=None):
175
+ tokenizer = init_tokenizer()
176
+
177
+ config = AutoConfig.from_pretrained(
178
+ model_name_or_path, cache_dir=cache_dir, use_auth_token=True
179
+ )
180
+ # adapt config
181
+ config.update(
182
+ {
183
+ "feat_proj_dropout": 0.3,
184
+ "attention_dropout": 0.3,
185
+ "hidden_dropout": 0.3,
186
+ "final_dropout": 0.3,
187
+ "mask_time_prob": 0.05,
188
+ "mask_time_length": 10,
189
+ "mask_feature_prob": 0,
190
+ "mask_feature_length": 10,
191
+ "gradient_checkpointing": True,
192
+ "layerdrop": 0.1,
193
+ "ctc_loss_reduction": "mean",
194
+ "pad_token_id": tokenizer.pad_token_id,
195
+ "vocab_size": len(tokenizer),
196
+ "activation_dropout": 0.3,
197
+ }
198
+ )
199
+
200
+ # create model
201
+ model = Wav2Vec2ForCTC.from_pretrained(
202
+ model_name_or_path,
203
+ cache_dir=cache_dir,
204
+ config=config, use_auth_token=True
205
+ )
206
+
207
+ model.freeze_wav2vec(True)
208
+
209
+ return model
210
+
211
+
212
+ def init_tokenizer():
213
+ return Wav2Vec2CTCTokenizer("./model-bin/hyper-ja/vocab.json", unk_token="<unk>", pad_token="<pad>",
214
+ word_delimiter_token="|")
215
+
216
+
217
+ def init_feature_extractor():
218
+ return Wav2Vec2FeatureExtractor.from_pretrained('./model-bin/hyper-ja/')
219
+
220
+
221
+ def init_processor(tokenizer, feature_extractor):
222
+ return Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ print(init_model('nguyenvulebinh/wav2vec2-base-ja', './cache'))