Spaces:
Running
Running
File size: 6,490 Bytes
9b43ccd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import math
import numpy as np
from transformers import WhisperProcessor
class WhisperPrePostProcessor(WhisperProcessor):
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
all_chunk_start_idx = np.arange(0, inputs_len, step)
num_samples = len(all_chunk_start_idx)
num_batches = math.ceil(num_samples / batch_size)
batch_idx = np.array_split(np.arange(num_samples), num_batches)
for i, idx in enumerate(batch_idx):
chunk_start_idx = all_chunk_start_idx[idx]
chunk_end_idx = chunk_start_idx + chunk_len
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
processed = self.feature_extractor(
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
)
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
_stride_right = np.where(is_last, 0, stride_right)
chunk_lens = [chunk.shape[0] for chunk in chunks]
strides = [
(int(chunk_l), int(_stride_l), int(_stride_r))
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
]
yield {"stride": strides, **processed}
def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None):
stride = None
if isinstance(inputs, dict):
stride = inputs.pop("stride", None)
# Accepting `"array"` which is the key defined in `datasets` for
# better integration
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a "
'"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key '
"containing the sampling rate associated with the audio array."
)
_inputs = inputs.pop("raw", None)
if _inputs is None:
# Remove path which will not be used from `datasets`.
inputs.pop("path", None)
_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
try:
import librosa
except ImportError as err:
raise ImportError(
"To support resampling audio files, please install 'librosa' and 'soundfile'."
) from err
inputs = librosa.resample(
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
)
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
else:
ratio = 1
if not isinstance(inputs, np.ndarray):
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`.")
if len(inputs.shape) != 1:
raise ValueError(
f"We expect a single channel audio input for the Flax Whisper API, got {len(inputs.shape)} channels."
)
if stride is not None:
if stride[0] + stride[1] > inputs.shape[0]:
raise ValueError("Stride is too large for input.")
# Stride needs to get the chunk length here, it's going to get
# swallowed by the `feature_extractor` later, and then batching
# can add extra data in the inputs, so we need to keep track
# of the original length in the stride so we can cut properly.
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
if chunk_length_s:
if stride_length_s is None:
stride_length_s = chunk_length_s / 6
if isinstance(stride_length_s, (int, float)):
stride_length_s = [stride_length_s, stride_length_s]
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length.")
for item in self.chunk_iter_with_batch(
inputs,
chunk_len,
stride_left,
stride_right,
batch_size,
):
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
)
if stride is not None:
processed["stride"] = stride
yield processed
def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
# unpack the outputs from list(dict(list)) to list(dict)
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
time_precision = self.feature_extractor.chunk_length / 1500 # max source positions = 1500
# Send the chunking back to seconds, it's easier to handle in whisper
sampling_rate = self.feature_extractor.sampling_rate
for output in model_outputs:
if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Go back in seconds
chunk_len /= sampling_rate
stride_left /= sampling_rate
stride_right /= sampling_rate
output["stride"] = chunk_len, stride_left, stride_right
text, optional = self.tokenizer._decode_asr(
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
return {"text": text, **optional}
|