sanchit-gandhi HF staff commited on
Commit
df00910
1 Parent(s): 1f63fcf

Create processing_whisper.py

Browse files
Files changed (1) hide show
  1. processing_whisper.py +143 -0
processing_whisper.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ from transformers import WhisperProcessor
5
+
6
+
7
+ class WhisperPrePostProcessor(WhisperProcessor):
8
+ def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
9
+ inputs_len = inputs.shape[0]
10
+ step = chunk_len - stride_left - stride_right
11
+
12
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
13
+ num_samples = len(all_chunk_start_idx)
14
+
15
+ num_batches = math.ceil(num_samples / batch_size)
16
+ batch_idx = np.array_split(np.arange(num_samples), num_batches)
17
+
18
+ for i, idx in enumerate(batch_idx):
19
+ chunk_start_idx = all_chunk_start_idx[idx]
20
+
21
+ chunk_end_idx = chunk_start_idx + chunk_len
22
+
23
+ chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
24
+ processed = self.feature_extractor(
25
+ chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
26
+ )
27
+
28
+ _stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
29
+ is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
30
+ _stride_right = np.where(is_last, 0, stride_right)
31
+
32
+ chunk_lens = [chunk.shape[0] for chunk in chunks]
33
+ strides = [
34
+ (int(chunk_l), int(_stride_l), int(_stride_r))
35
+ for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
36
+ ]
37
+
38
+ yield {"stride": strides, **processed}
39
+
40
+ def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None):
41
+ stride = None
42
+ if isinstance(inputs, dict):
43
+ stride = inputs.pop("stride", None)
44
+ # Accepting `"array"` which is the key defined in `datasets` for
45
+ # better integration
46
+ if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
47
+ raise ValueError(
48
+ "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a "
49
+ '"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key '
50
+ "containing the sampling rate associated with the audio array."
51
+ )
52
+
53
+ _inputs = inputs.pop("raw", None)
54
+ if _inputs is None:
55
+ # Remove path which will not be used from `datasets`.
56
+ inputs.pop("path", None)
57
+ _inputs = inputs.pop("array", None)
58
+ in_sampling_rate = inputs.pop("sampling_rate")
59
+ inputs = _inputs
60
+
61
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
62
+ try:
63
+ import librosa
64
+ except ImportError as err:
65
+ raise ImportError(
66
+ "To support resampling audio files, please install 'librosa' and 'soundfile'."
67
+ ) from err
68
+
69
+ inputs = librosa.resample(
70
+ inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
71
+ )
72
+ ratio = self.feature_extractor.sampling_rate / in_sampling_rate
73
+ else:
74
+ ratio = 1
75
+
76
+ if not isinstance(inputs, np.ndarray):
77
+ raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
78
+ if len(inputs.shape) != 1:
79
+ raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
80
+
81
+ if stride is not None:
82
+ if stride[0] + stride[1] > inputs.shape[0]:
83
+ raise ValueError("Stride is too large for input")
84
+
85
+ # Stride needs to get the chunk length here, it's going to get
86
+ # swallowed by the `feature_extractor` later, and then batching
87
+ # can add extra data in the inputs, so we need to keep track
88
+ # of the original length in the stride so we can cut properly.
89
+ stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
90
+
91
+ if chunk_length_s:
92
+ if stride_length_s is None:
93
+ stride_length_s = chunk_length_s / 6
94
+
95
+ if isinstance(stride_length_s, (int, float)):
96
+ stride_length_s = [stride_length_s, stride_length_s]
97
+
98
+ chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
99
+ stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
100
+ stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
101
+
102
+ if chunk_len < stride_left + stride_right:
103
+ raise ValueError("Chunk length must be superior to stride length")
104
+
105
+ for item in self.chunk_iter_with_batch(
106
+ inputs,
107
+ chunk_len,
108
+ stride_left,
109
+ stride_right,
110
+ batch_size,
111
+ ):
112
+ yield item
113
+ else:
114
+ processed = self.feature_extractor(
115
+ inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
116
+ )
117
+ if stride is not None:
118
+ processed["stride"] = stride
119
+ yield processed
120
+
121
+ def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
122
+ # unpack the outputs from list(dict(list)) to list(dict)
123
+ model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
124
+
125
+ time_precision = self.feature_extractor.chunk_length / 1500 # max source positions = 1500
126
+ # Send the chunking back to seconds, it's easier to handle in whisper
127
+ sampling_rate = self.feature_extractor.sampling_rate
128
+ for output in model_outputs:
129
+ if "stride" in output:
130
+ chunk_len, stride_left, stride_right = output["stride"]
131
+ # Go back in seconds
132
+ chunk_len /= sampling_rate
133
+ stride_left /= sampling_rate
134
+ stride_right /= sampling_rate
135
+ output["stride"] = chunk_len, stride_left, stride_right
136
+
137
+ text, optional = self.tokenizer._decode_asr(
138
+ model_outputs,
139
+ return_timestamps=return_timestamps,
140
+ return_language=return_language,
141
+ time_precision=time_precision,
142
+ )
143
+ return {"text": text, **optional}