Commit
•
df00910
1
Parent(s):
1f63fcf
Create processing_whisper.py
Browse files- processing_whisper.py +143 -0
processing_whisper.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from transformers import WhisperProcessor
|
5 |
+
|
6 |
+
|
7 |
+
class WhisperPrePostProcessor(WhisperProcessor):
|
8 |
+
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
|
9 |
+
inputs_len = inputs.shape[0]
|
10 |
+
step = chunk_len - stride_left - stride_right
|
11 |
+
|
12 |
+
all_chunk_start_idx = np.arange(0, inputs_len, step)
|
13 |
+
num_samples = len(all_chunk_start_idx)
|
14 |
+
|
15 |
+
num_batches = math.ceil(num_samples / batch_size)
|
16 |
+
batch_idx = np.array_split(np.arange(num_samples), num_batches)
|
17 |
+
|
18 |
+
for i, idx in enumerate(batch_idx):
|
19 |
+
chunk_start_idx = all_chunk_start_idx[idx]
|
20 |
+
|
21 |
+
chunk_end_idx = chunk_start_idx + chunk_len
|
22 |
+
|
23 |
+
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
|
24 |
+
processed = self.feature_extractor(
|
25 |
+
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
26 |
+
)
|
27 |
+
|
28 |
+
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
|
29 |
+
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
|
30 |
+
_stride_right = np.where(is_last, 0, stride_right)
|
31 |
+
|
32 |
+
chunk_lens = [chunk.shape[0] for chunk in chunks]
|
33 |
+
strides = [
|
34 |
+
(int(chunk_l), int(_stride_l), int(_stride_r))
|
35 |
+
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
|
36 |
+
]
|
37 |
+
|
38 |
+
yield {"stride": strides, **processed}
|
39 |
+
|
40 |
+
def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None):
|
41 |
+
stride = None
|
42 |
+
if isinstance(inputs, dict):
|
43 |
+
stride = inputs.pop("stride", None)
|
44 |
+
# Accepting `"array"` which is the key defined in `datasets` for
|
45 |
+
# better integration
|
46 |
+
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
47 |
+
raise ValueError(
|
48 |
+
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a "
|
49 |
+
'"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key '
|
50 |
+
"containing the sampling rate associated with the audio array."
|
51 |
+
)
|
52 |
+
|
53 |
+
_inputs = inputs.pop("raw", None)
|
54 |
+
if _inputs is None:
|
55 |
+
# Remove path which will not be used from `datasets`.
|
56 |
+
inputs.pop("path", None)
|
57 |
+
_inputs = inputs.pop("array", None)
|
58 |
+
in_sampling_rate = inputs.pop("sampling_rate")
|
59 |
+
inputs = _inputs
|
60 |
+
|
61 |
+
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
62 |
+
try:
|
63 |
+
import librosa
|
64 |
+
except ImportError as err:
|
65 |
+
raise ImportError(
|
66 |
+
"To support resampling audio files, please install 'librosa' and 'soundfile'."
|
67 |
+
) from err
|
68 |
+
|
69 |
+
inputs = librosa.resample(
|
70 |
+
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
71 |
+
)
|
72 |
+
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
|
73 |
+
else:
|
74 |
+
ratio = 1
|
75 |
+
|
76 |
+
if not isinstance(inputs, np.ndarray):
|
77 |
+
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
78 |
+
if len(inputs.shape) != 1:
|
79 |
+
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
80 |
+
|
81 |
+
if stride is not None:
|
82 |
+
if stride[0] + stride[1] > inputs.shape[0]:
|
83 |
+
raise ValueError("Stride is too large for input")
|
84 |
+
|
85 |
+
# Stride needs to get the chunk length here, it's going to get
|
86 |
+
# swallowed by the `feature_extractor` later, and then batching
|
87 |
+
# can add extra data in the inputs, so we need to keep track
|
88 |
+
# of the original length in the stride so we can cut properly.
|
89 |
+
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
|
90 |
+
|
91 |
+
if chunk_length_s:
|
92 |
+
if stride_length_s is None:
|
93 |
+
stride_length_s = chunk_length_s / 6
|
94 |
+
|
95 |
+
if isinstance(stride_length_s, (int, float)):
|
96 |
+
stride_length_s = [stride_length_s, stride_length_s]
|
97 |
+
|
98 |
+
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
|
99 |
+
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
|
100 |
+
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
|
101 |
+
|
102 |
+
if chunk_len < stride_left + stride_right:
|
103 |
+
raise ValueError("Chunk length must be superior to stride length")
|
104 |
+
|
105 |
+
for item in self.chunk_iter_with_batch(
|
106 |
+
inputs,
|
107 |
+
chunk_len,
|
108 |
+
stride_left,
|
109 |
+
stride_right,
|
110 |
+
batch_size,
|
111 |
+
):
|
112 |
+
yield item
|
113 |
+
else:
|
114 |
+
processed = self.feature_extractor(
|
115 |
+
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
116 |
+
)
|
117 |
+
if stride is not None:
|
118 |
+
processed["stride"] = stride
|
119 |
+
yield processed
|
120 |
+
|
121 |
+
def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
|
122 |
+
# unpack the outputs from list(dict(list)) to list(dict)
|
123 |
+
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
|
124 |
+
|
125 |
+
time_precision = self.feature_extractor.chunk_length / 1500 # max source positions = 1500
|
126 |
+
# Send the chunking back to seconds, it's easier to handle in whisper
|
127 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
128 |
+
for output in model_outputs:
|
129 |
+
if "stride" in output:
|
130 |
+
chunk_len, stride_left, stride_right = output["stride"]
|
131 |
+
# Go back in seconds
|
132 |
+
chunk_len /= sampling_rate
|
133 |
+
stride_left /= sampling_rate
|
134 |
+
stride_right /= sampling_rate
|
135 |
+
output["stride"] = chunk_len, stride_left, stride_right
|
136 |
+
|
137 |
+
text, optional = self.tokenizer._decode_asr(
|
138 |
+
model_outputs,
|
139 |
+
return_timestamps=return_timestamps,
|
140 |
+
return_language=return_language,
|
141 |
+
time_precision=time_precision,
|
142 |
+
)
|
143 |
+
return {"text": text, **optional}
|