yangwang825 commited on
Commit
310fbde
·
verified ·
1 Parent(s): 69859f1

Upload feature extractor

Browse files
feature_extraction_whisper_spkreg.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature extractor class for Whisper
3
+ """
4
+
5
+ from typing import List, Optional, Union
6
+
7
+ import numpy as np
8
+
9
+ from transformers import is_torch_available
10
+ from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
11
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
12
+ from transformers.feature_extraction_utils import BatchFeature
13
+ from transformers.utils import TensorType, logging
14
+
15
+
16
+ if is_torch_available():
17
+ import torch
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class WhisperSpkRegFeatureExtractor(SequenceFeatureExtractor):
23
+ r"""
24
+ Constructs a Whisper feature extractor.
25
+
26
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
27
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
28
+
29
+ This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
30
+ Fourier Transform` which should match pytorch's `torch.stft` equivalent.
31
+
32
+ Args:
33
+ feature_size (`int`, *optional*, defaults to 80):
34
+ The feature dimension of the extracted features.
35
+ sampling_rate (`int`, *optional*, defaults to 16000):
36
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
37
+ hop_length (`int`, *optional*, defaults to 160):
38
+ Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
39
+ chunk_length (`int`, *optional*, defaults to 30):
40
+ The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio
41
+ sequences.
42
+ n_fft (`int`, *optional*, defaults to 400):
43
+ Size of the Fourier transform.
44
+ padding_value (`float`, *optional*, defaults to 0.0):
45
+ Padding value used to pad the audio. Should correspond to silences.
46
+ """
47
+
48
+ model_input_names = ["input_features"]
49
+
50
+ def __init__(
51
+ self,
52
+ feature_size=80,
53
+ sampling_rate=16000,
54
+ hop_length=160,
55
+ chunk_length=30,
56
+ n_fft=400,
57
+ padding_value=0.0,
58
+ return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
59
+ **kwargs,
60
+ ):
61
+ super().__init__(
62
+ feature_size=feature_size,
63
+ sampling_rate=sampling_rate,
64
+ padding_value=padding_value,
65
+ return_attention_mask=return_attention_mask,
66
+ **kwargs,
67
+ )
68
+ self.n_fft = n_fft
69
+ self.hop_length = hop_length
70
+ self.chunk_length = chunk_length
71
+ self.n_samples = chunk_length * sampling_rate
72
+ self.nb_max_frames = self.n_samples // hop_length
73
+ self.sampling_rate = sampling_rate
74
+ self.mel_filters = mel_filter_bank(
75
+ num_frequency_bins=1 + n_fft // 2,
76
+ num_mel_filters=feature_size,
77
+ min_frequency=0.0,
78
+ max_frequency=8000.0,
79
+ sampling_rate=sampling_rate,
80
+ norm="slaney",
81
+ mel_scale="slaney",
82
+ )
83
+
84
+ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
85
+ """
86
+ Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
87
+ implementation with 1e-5 tolerance.
88
+ """
89
+ if device != "cpu":
90
+ raise ValueError(
91
+ f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
92
+ "devices requires torch, which is not installed. Either set `device='cpu'`, or "
93
+ "install torch according to the official instructions: https://pytorch.org/get-started/locally/"
94
+ )
95
+ log_spec_batch = []
96
+ for waveform in waveform_batch:
97
+ log_spec = spectrogram(
98
+ waveform,
99
+ window_function(self.n_fft, "hann"),
100
+ frame_length=self.n_fft,
101
+ hop_length=self.hop_length,
102
+ power=2.0,
103
+ mel_filters=self.mel_filters,
104
+ log_mel="log10",
105
+ )
106
+ log_spec = log_spec[:, :-1]
107
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
108
+ log_spec = (log_spec + 4.0) / 4.0
109
+ log_spec_batch.append(log_spec)
110
+ log_spec_batch = np.array(log_spec_batch)
111
+ return log_spec_batch
112
+
113
+ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
114
+ """
115
+ Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
116
+ yielding results similar to cpu computing with 1e-5 tolerance.
117
+ """
118
+ waveform = torch.from_numpy(waveform).type(torch.float32)
119
+
120
+ window = torch.hann_window(self.n_fft)
121
+ if device != "cpu":
122
+ waveform = waveform.to(device)
123
+ window = window.to(device)
124
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
125
+ magnitudes = stft[..., :-1].abs() ** 2
126
+
127
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
128
+ if device != "cpu":
129
+ mel_filters = mel_filters.to(device)
130
+ mel_spec = mel_filters.T @ magnitudes
131
+
132
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
133
+ if waveform.dim() == 2:
134
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
135
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
136
+ else:
137
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
138
+ log_spec = (log_spec + 4.0) / 4.0
139
+ if device != "cpu":
140
+ log_spec = log_spec.detach().cpu()
141
+ return log_spec.numpy()
142
+
143
+ @staticmethod
144
+ # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
145
+ def zero_mean_unit_var_norm(
146
+ input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
147
+ ) -> List[np.ndarray]:
148
+ """
149
+ Every array in the list is normalized to have zero mean and unit variance
150
+ """
151
+ if attention_mask is not None:
152
+ attention_mask = np.array(attention_mask, np.int32)
153
+ normed_input_values = []
154
+
155
+ for vector, length in zip(input_values, attention_mask.sum(-1)):
156
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
157
+ if length < normed_slice.shape[0]:
158
+ normed_slice[length:] = padding_value
159
+
160
+ normed_input_values.append(normed_slice)
161
+ else:
162
+ normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
163
+
164
+ return normed_input_values
165
+
166
+ def __call__(
167
+ self,
168
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
169
+ truncation: bool = True,
170
+ pad_to_multiple_of: Optional[int] = None,
171
+ return_tensors: Optional[Union[str, TensorType]] = None,
172
+ return_attention_mask: Optional[bool] = None,
173
+ padding: Optional[str] = "max_length",
174
+ max_length: Optional[int] = None,
175
+ sampling_rate: Optional[int] = None,
176
+ do_normalize: Optional[bool] = None,
177
+ device: Optional[str] = "cpu",
178
+ return_token_timestamps: Optional[bool] = None,
179
+ **kwargs,
180
+ ) -> BatchFeature:
181
+ """
182
+ Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
183
+ the STFT computation if available, otherwise a slower NumPy based one.
184
+
185
+ Args:
186
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
187
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
188
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
189
+ stereo, i.e. single float per timestep.
190
+ truncation (`bool`, *optional*, default to `True`):
191
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
192
+ pad_to_multiple_of (`int`, *optional*, defaults to None):
193
+ If set will pad the sequence to a multiple of the provided value.
194
+
195
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
196
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
197
+ return_attention_mask (`bool`, *optional*):
198
+ Whether to return the attention mask. If left to the default, will return the attention mask according
199
+ to the specific feature_extractor's default.
200
+
201
+ [What are attention masks?](../glossary#attention-mask)
202
+
203
+ <Tip>
204
+
205
+ For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle
206
+ bugs.
207
+
208
+ </Tip>
209
+
210
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
211
+ If set, will return tensors instead of list of python integers. Acceptable values are:
212
+
213
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
214
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
215
+ - `'np'`: Return Numpy `np.ndarray` objects.
216
+ sampling_rate (`int`, *optional*):
217
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
218
+ `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
219
+ pipeline.
220
+ padding_value (`float`, *optional*, defaults to 0.0):
221
+ The value that is used to fill the padding values / vectors.
222
+ do_normalize (`bool`, *optional*, defaults to `False`):
223
+ Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
224
+ improve the performance of the model.
225
+ device (`str`, *optional*, defaults to `'cpu'`):
226
+ Specifies the device for computation of the log-mel spectrogram of audio signals in the
227
+ `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
228
+ return_token_timestamps (`bool`, *optional*, defaults to `None`):
229
+ Whether or not to return the number of frames of the input raw_speech.
230
+ These num_frames can be used by the model to compute word level timestamps.
231
+ """
232
+
233
+ if sampling_rate is not None:
234
+ if sampling_rate != self.sampling_rate:
235
+ raise ValueError(
236
+ f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
237
+ f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
238
+ f" was sampled with {self.sampling_rate} and not {sampling_rate}."
239
+ )
240
+ else:
241
+ logger.warning(
242
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
243
+ "Failing to do so can result in silent errors that might be hard to debug."
244
+ )
245
+
246
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
247
+ if is_batched_numpy and len(raw_speech.shape) > 2:
248
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
249
+ is_batched = is_batched_numpy or (
250
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
251
+ )
252
+
253
+ if is_batched:
254
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
255
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
256
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
257
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
258
+ raw_speech = raw_speech.astype(np.float32)
259
+
260
+ # always return batch
261
+ if not is_batched:
262
+ raw_speech = [np.asarray([raw_speech]).T]
263
+
264
+ batched_speech = BatchFeature({"input_features": raw_speech})
265
+
266
+ # convert into correct format for padding
267
+
268
+ padded_inputs = self.pad(
269
+ batched_speech,
270
+ padding=padding,
271
+ max_length=max_length if max_length else self.n_samples,
272
+ truncation=truncation,
273
+ pad_to_multiple_of=pad_to_multiple_of,
274
+ return_attention_mask=return_attention_mask or do_normalize,
275
+ )
276
+
277
+ # zero-mean and unit-variance normalization
278
+ if do_normalize:
279
+ padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
280
+ padded_inputs["input_features"],
281
+ attention_mask=padded_inputs["attention_mask"],
282
+ padding_value=self.padding_value,
283
+ )
284
+ padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
285
+
286
+ # make sure list is in array format
287
+ input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
288
+
289
+ extract_fbank_features = (
290
+ self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
291
+ )
292
+ input_features = extract_fbank_features(input_features[0], device)
293
+
294
+ if isinstance(input_features[0], List):
295
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
296
+
297
+ else:
298
+ padded_inputs["input_features"] = input_features
299
+
300
+ if return_attention_mask:
301
+ # rescale from sample (48000) to feature (3000)
302
+ padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
303
+
304
+ if return_token_timestamps is not None:
305
+ padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
306
+
307
+ if return_tensors is not None:
308
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
309
+
310
+ return padded_inputs
preprocessor_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoFeatureExtractor": "feature_extraction_whisper_spkreg.WhisperSpkRegFeatureExtractor"
4
+ },
5
+ "chunk_length": 30,
6
+ "feature_extractor_type": "WhisperSpkRegFeatureExtractor",
7
+ "feature_size": 80,
8
+ "hop_length": 160,
9
+ "n_fft": 400,
10
+ "n_samples": 480000,
11
+ "nb_max_frames": 3000,
12
+ "padding_side": "right",
13
+ "padding_value": 0.0,
14
+ "processor_class": "WhisperProcessor",
15
+ "return_attention_mask": false,
16
+ "sampling_rate": 16000
17
+ }