farzadab commited on
Commit
bd99b9d
1 Parent(s): 6917624

Upload whisper_model_modified.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. whisper_model_modified.py +141 -0
whisper_model_modified.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
2
+ # see this issue for the commentary: https://github.com/huggingface/transformers/issues/25744
3
+ #
4
+ # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ import transformers.modeling_outputs
21
+ from transformers.models.whisper import modeling_whisper as whisper
22
+
23
+
24
+ class WhisperEncoder(whisper.WhisperEncoder):
25
+ """
26
+ Encoder portion of OpenAI's Whisper model.
27
+
28
+ This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
29
+ 1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
30
+ 2. allow less than 30 second of audio padding to be passed in:
31
+ - relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
32
+ - embed_pos is now sliced to match the length of `inputs_embeds`
33
+
34
+ Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
35
+ """
36
+
37
+ base_model_prefix = "model.encoder"
38
+
39
+ def forward(
40
+ self,
41
+ input_features,
42
+ attention_mask=None,
43
+ head_mask=None,
44
+ output_attentions=None,
45
+ output_hidden_states=None,
46
+ return_dict=None,
47
+ ):
48
+ expected_seq_length = (
49
+ self.config.max_source_positions
50
+ * self.conv1.stride[0]
51
+ * self.conv2.stride[0]
52
+ )
53
+ if input_features.shape[-1] > expected_seq_length:
54
+ raise ValueError(
55
+ f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
56
+ )
57
+
58
+ output_attentions = (
59
+ output_attentions
60
+ if output_attentions is not None
61
+ else self.config.output_attentions
62
+ )
63
+ output_hidden_states = (
64
+ output_hidden_states
65
+ if output_hidden_states is not None
66
+ else self.config.output_hidden_states
67
+ )
68
+ return_dict = (
69
+ return_dict if return_dict is not None else self.config.use_return_dict
70
+ )
71
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
72
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
73
+
74
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
75
+ embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
76
+
77
+ hidden_states = inputs_embeds + embed_pos
78
+ hidden_states = nn.functional.dropout(
79
+ hidden_states, p=self.dropout, training=self.training
80
+ )
81
+
82
+ encoder_states = () if output_hidden_states else None
83
+ all_attentions = () if output_attentions else None
84
+
85
+ # check if head_mask has a correct number of layers specified if desired
86
+ if head_mask is not None:
87
+ assert head_mask.size()[0] == (
88
+ len(self.layers)
89
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
90
+
91
+ for idx, encoder_layer in enumerate(self.layers):
92
+ if output_hidden_states:
93
+ encoder_states = encoder_states + (hidden_states,)
94
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
95
+ to_drop = False
96
+ if self.training:
97
+ dropout_probability = torch.rand([])
98
+ if dropout_probability < self.layerdrop: # skip the layer
99
+ to_drop = True
100
+
101
+ if to_drop:
102
+ layer_outputs = (None, None)
103
+ else:
104
+ if self.gradient_checkpointing and self.training:
105
+ layer_outputs = self._gradient_checkpointing_func(
106
+ encoder_layer.__call__,
107
+ hidden_states,
108
+ None,
109
+ (head_mask[idx] if head_mask is not None else None),
110
+ output_attentions,
111
+ )
112
+ else:
113
+ layer_outputs = encoder_layer(
114
+ hidden_states,
115
+ None,
116
+ layer_head_mask=(
117
+ head_mask[idx] if head_mask is not None else None
118
+ ),
119
+ output_attentions=output_attentions,
120
+ )
121
+
122
+ hidden_states = layer_outputs[0]
123
+
124
+ if output_attentions:
125
+ all_attentions = all_attentions + (layer_outputs[1],)
126
+
127
+ hidden_states = self.layer_norm(hidden_states)
128
+ if output_hidden_states:
129
+ encoder_states = encoder_states + (hidden_states,)
130
+
131
+ if not return_dict:
132
+ return tuple(
133
+ v
134
+ for v in [hidden_states, encoder_states, all_attentions]
135
+ if v is not None
136
+ )
137
+ return transformers.modeling_outputs.BaseModelOutput(
138
+ last_hidden_state=hidden_states,
139
+ hidden_states=encoder_states,
140
+ attentions=all_attentions,
141
+ )