import math import numpy as np from transformers import WhisperProcessor class WhisperPrePostProcessor(WhisperProcessor): def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size): inputs_len = inputs.shape[0] step = chunk_len - stride_left - stride_right all_chunk_start_idx = np.arange(0, inputs_len, step) num_samples = len(all_chunk_start_idx) num_batches = math.ceil(num_samples / batch_size) batch_idx = np.array_split(np.arange(num_samples), num_batches) for i, idx in enumerate(batch_idx): chunk_start_idx = all_chunk_start_idx[idx] chunk_end_idx = chunk_start_idx + chunk_len chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)] processed = self.feature_extractor( chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" ) _stride_left = np.where(chunk_start_idx == 0, 0, stride_left) is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len) _stride_right = np.where(is_last, 0, stride_right) chunk_lens = [chunk.shape[0] for chunk in chunks] strides = [ (int(chunk_l), int(_stride_l), int(_stride_r)) for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right) ] yield {"stride": strides, **processed} def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None): stride = None if isinstance(inputs, dict): stride = inputs.pop("stride", None) # Accepting `"array"` which is the key defined in `datasets` for # better integration if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): raise ValueError( "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a " '"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key ' "containing the sampling rate associated with the audio array." ) _inputs = inputs.pop("raw", None) if _inputs is None: # Remove path which will not be used from `datasets`. inputs.pop("path", None) _inputs = inputs.pop("array", None) in_sampling_rate = inputs.pop("sampling_rate") inputs = _inputs if in_sampling_rate != self.feature_extractor.sampling_rate: try: import librosa except ImportError as err: raise ImportError( "To support resampling audio files, please install 'librosa' and 'soundfile'." ) from err inputs = librosa.resample( inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate ) ratio = self.feature_extractor.sampling_rate / in_sampling_rate else: ratio = 1 if not isinstance(inputs, np.ndarray): raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`.") if len(inputs.shape) != 1: raise ValueError( f"We expect a single channel audio input for the Flax Whisper API, got {len(inputs.shape)} channels." ) if stride is not None: if stride[0] + stride[1] > inputs.shape[0]: raise ValueError("Stride is too large for input.") # Stride needs to get the chunk length here, it's going to get # swallowed by the `feature_extractor` later, and then batching # can add extra data in the inputs, so we need to keep track # of the original length in the stride so we can cut properly. stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) if chunk_length_s: if stride_length_s is None: stride_length_s = chunk_length_s / 6 if isinstance(stride_length_s, (int, float)): stride_length_s = [stride_length_s, stride_length_s] chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate) stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate) stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate) if chunk_len < stride_left + stride_right: raise ValueError("Chunk length must be superior to stride length.") for item in self.chunk_iter_with_batch( inputs, chunk_len, stride_left, stride_right, batch_size, ): yield item else: processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np" ) if stride is not None: processed["stride"] = stride yield processed def postprocess(self, model_outputs, return_timestamps=None, return_language=None): # unpack the outputs from list(dict(list)) to list(dict) model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())] time_precision = self.feature_extractor.chunk_length / 1500 # max source positions = 1500 # Send the chunking back to seconds, it's easier to handle in whisper sampling_rate = self.feature_extractor.sampling_rate for output in model_outputs: if "stride" in output: chunk_len, stride_left, stride_right = output["stride"] # Go back in seconds chunk_len /= sampling_rate stride_left /= sampling_rate stride_right /= sampling_rate output["stride"] = chunk_len, stride_left, stride_right text, optional = self.tokenizer._decode_asr( model_outputs, return_timestamps=return_timestamps, return_language=return_language, time_precision=time_precision, ) return {"text": text, **optional}