|
import math |
|
|
|
import numpy as np |
|
from transformers import WhisperProcessor |
|
|
|
|
|
class WhisperPrePostProcessor(WhisperProcessor): |
|
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size): |
|
inputs_len = inputs.shape[0] |
|
step = chunk_len - stride_left - stride_right |
|
|
|
all_chunk_start_idx = np.arange(0, inputs_len, step) |
|
num_samples = len(all_chunk_start_idx) |
|
|
|
num_batches = math.ceil(num_samples / batch_size) |
|
batch_idx = np.array_split(np.arange(num_samples), num_batches) |
|
|
|
for i, idx in enumerate(batch_idx): |
|
chunk_start_idx = all_chunk_start_idx[idx] |
|
|
|
chunk_end_idx = chunk_start_idx + chunk_len |
|
|
|
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)] |
|
processed = self.feature_extractor( |
|
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" |
|
) |
|
|
|
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left) |
|
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len) |
|
_stride_right = np.where(is_last, 0, stride_right) |
|
|
|
chunk_lens = [chunk.shape[0] for chunk in chunks] |
|
strides = [ |
|
(int(chunk_l), int(_stride_l), int(_stride_r)) |
|
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right) |
|
] |
|
|
|
yield {"stride": strides, **processed} |
|
|
|
def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None): |
|
stride = None |
|
if isinstance(inputs, dict): |
|
stride = inputs.pop("stride", None) |
|
|
|
|
|
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): |
|
raise ValueError( |
|
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a " |
|
'"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key ' |
|
"containing the sampling rate associated with the audio array." |
|
) |
|
|
|
_inputs = inputs.pop("raw", None) |
|
if _inputs is None: |
|
|
|
inputs.pop("path", None) |
|
_inputs = inputs.pop("array", None) |
|
in_sampling_rate = inputs.pop("sampling_rate") |
|
inputs = _inputs |
|
|
|
if in_sampling_rate != self.feature_extractor.sampling_rate: |
|
try: |
|
import librosa |
|
except ImportError as err: |
|
raise ImportError( |
|
"To support resampling audio files, please install 'librosa' and 'soundfile'." |
|
) from err |
|
|
|
inputs = librosa.resample( |
|
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate |
|
) |
|
ratio = self.feature_extractor.sampling_rate / in_sampling_rate |
|
else: |
|
ratio = 1 |
|
|
|
if not isinstance(inputs, np.ndarray): |
|
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`.") |
|
if len(inputs.shape) != 1: |
|
raise ValueError( |
|
f"We expect a single channel audio input for the Flax Whisper API, got {len(inputs.shape)} channels." |
|
) |
|
|
|
if stride is not None: |
|
if stride[0] + stride[1] > inputs.shape[0]: |
|
raise ValueError("Stride is too large for input.") |
|
|
|
|
|
|
|
|
|
|
|
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) |
|
|
|
if chunk_length_s: |
|
if stride_length_s is None: |
|
stride_length_s = chunk_length_s / 6 |
|
|
|
if isinstance(stride_length_s, (int, float)): |
|
stride_length_s = [stride_length_s, stride_length_s] |
|
|
|
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate) |
|
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate) |
|
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate) |
|
|
|
if chunk_len < stride_left + stride_right: |
|
raise ValueError("Chunk length must be superior to stride length.") |
|
|
|
for item in self.chunk_iter_with_batch( |
|
inputs, |
|
chunk_len, |
|
stride_left, |
|
stride_right, |
|
batch_size, |
|
): |
|
yield item |
|
else: |
|
processed = self.feature_extractor( |
|
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" |
|
) |
|
if stride is not None: |
|
processed["stride"] = stride |
|
yield processed |
|
|
|
def postprocess(self, model_outputs, return_timestamps=None, return_language=None): |
|
|
|
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())] |
|
|
|
time_precision = self.feature_extractor.chunk_length / 1500 |
|
|
|
sampling_rate = self.feature_extractor.sampling_rate |
|
for output in model_outputs: |
|
if "stride" in output: |
|
chunk_len, stride_left, stride_right = output["stride"] |
|
|
|
chunk_len /= sampling_rate |
|
stride_left /= sampling_rate |
|
stride_right /= sampling_rate |
|
output["stride"] = chunk_len, stride_left, stride_right |
|
|
|
text, optional = self.tokenizer._decode_asr( |
|
model_outputs, |
|
return_timestamps=return_timestamps, |
|
return_language=return_language, |
|
time_precision=time_precision, |
|
) |
|
return {"text": text, **optional} |
|
|