nguyenvulebinh commited on
Commit
1f1a1ba
1 Parent(s): 6f7801f

add model file

Browse files
Files changed (1) hide show
  1. model_handling.py +165 -0
model_handling.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model
2
+ from torch import nn
3
+ import warnings
4
+ import torch
5
+ from transformers.modeling_outputs import CausalLMOutput
6
+ from collections import OrderedDict
7
+
8
+ _HIDDEN_STATES_START_POSITION = 2
9
+
10
+
11
+ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ self.wav2vec2 = Wav2Vec2Model(config)
16
+ self.dropout = nn.Dropout(config.final_dropout)
17
+
18
+ self.feature_transform = nn.Sequential(OrderedDict([
19
+ ('linear1', nn.Linear(config.hidden_size, config.hidden_size)),
20
+ ('bn1', nn.BatchNorm1d(config.hidden_size)),
21
+ ('activation1', nn.LeakyReLU()),
22
+ ('drop1', nn.Dropout(config.final_dropout)),
23
+ ('linear2', nn.Linear(config.hidden_size, config.hidden_size)),
24
+ ('bn2', nn.BatchNorm1d(config.hidden_size)),
25
+ ('activation2', nn.LeakyReLU()),
26
+ ('drop2', nn.Dropout(config.final_dropout)),
27
+ ('linear3', nn.Linear(config.hidden_size, config.hidden_size)),
28
+ ('bn3', nn.BatchNorm1d(config.hidden_size)),
29
+ ('activation3', nn.LeakyReLU()),
30
+ ('drop3', nn.Dropout(config.final_dropout))
31
+ ]))
32
+
33
+ if config.vocab_size is None:
34
+ raise ValueError(
35
+ f"You are trying to instantiate {self.__class__} with a configuration that "
36
+ "does not define the vocabulary size of the language model head. Please "
37
+ "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
38
+ "or define `vocab_size` of your model's configuration."
39
+ )
40
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
41
+
42
+ self.is_wav2vec_freeze = False
43
+
44
+ # Initialize weights and apply final processing
45
+ self.post_init()
46
+
47
+ def freeze_feature_extractor(self):
48
+ """
49
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
50
+ not be updated during training.
51
+ """
52
+ warnings.warn(
53
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5."
54
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
55
+ FutureWarning,
56
+ )
57
+ self.freeze_feature_encoder()
58
+
59
+ def freeze_feature_encoder(self):
60
+ """
61
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
62
+ not be updated during training.
63
+ """
64
+ self.wav2vec2.feature_extractor._freeze_parameters()
65
+
66
+ def freeze_wav2vec(self, is_freeze=True):
67
+ """
68
+ Calling this function will disable the gradient computation for the feature extractor so that its parameter
69
+ will not be updated during training.
70
+ """
71
+ if is_freeze:
72
+ self.is_wav2vec_freeze = True
73
+ for param in self.wav2vec2.parameters():
74
+ param.requires_grad = False
75
+ else:
76
+ self.is_wav2vec_freeze = False
77
+ for param in self.wav2vec2.parameters():
78
+ param.requires_grad = True
79
+ self.freeze_feature_encoder()
80
+
81
+ model_total_params = sum(p.numel() for p in self.parameters())
82
+ model_total_params_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
83
+ print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
84
+ model_total_params_trainable))
85
+
86
+ def forward(
87
+ self,
88
+ input_values,
89
+ attention_mask=None,
90
+ output_attentions=None,
91
+ output_hidden_states=None,
92
+ return_dict=None,
93
+ path=None,
94
+ length=None,
95
+ labels=None,
96
+ ):
97
+ r"""
98
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
99
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
100
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
101
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
102
+ config.vocab_size - 1]`.
103
+ """
104
+
105
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
106
+
107
+ outputs = self.wav2vec2(
108
+ input_values,
109
+ attention_mask=attention_mask,
110
+ output_attentions=output_attentions,
111
+ output_hidden_states=output_hidden_states,
112
+ return_dict=return_dict,
113
+ )
114
+
115
+ hidden_states = outputs[0]
116
+ hidden_states = self.dropout(hidden_states)
117
+
118
+ B, T, F = hidden_states.size()
119
+ hidden_states = hidden_states.view(B * T, F)
120
+
121
+ hidden_states = self.feature_transform(hidden_states)
122
+
123
+ hidden_states = hidden_states.view(B, T, F)
124
+
125
+ logits = self.lm_head(hidden_states)
126
+
127
+ loss = None
128
+ if labels is not None:
129
+
130
+ if labels.max() >= self.config.vocab_size:
131
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
132
+
133
+ # retrieve loss input_lengths from attention_mask
134
+ attention_mask = (
135
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
136
+ )
137
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
138
+
139
+ # assuming that padded tokens are filled with -100
140
+ # when not being attended to
141
+ labels_mask = labels >= 0
142
+ target_lengths = labels_mask.sum(-1)
143
+ flattened_targets = labels.masked_select(labels_mask)
144
+
145
+ # ctc_loss doesn't support fp16
146
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
147
+
148
+ with torch.backends.cudnn.flags(enabled=False):
149
+ loss = nn.functional.ctc_loss(
150
+ log_probs,
151
+ flattened_targets,
152
+ input_lengths,
153
+ target_lengths,
154
+ blank=self.config.pad_token_id,
155
+ reduction=self.config.ctc_loss_reduction,
156
+ zero_infinity=self.config.ctc_zero_infinity,
157
+ )
158
+
159
+ if not return_dict:
160
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
161
+ return ((loss,) + output) if loss is not None else output
162
+
163
+ return CausalLMOutput(
164
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
165
+ )