stts / whisper_jax /pipeline.py
Afrinetwork7's picture
Rename whisper_jax/FlaxWhisperPipline.py to whisper_jax/pipeline.py
b1ffeca verified
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import jax
import jax.numpy as jnp
import numpy as np
import requests
from flax import jax_utils
from flax.core.frozen_dict import freeze
from flax.training.common_utils import shard
from jax.sharding import PartitionSpec as P
from transformers import WhisperProcessor, is_tokenizers_available, WhisperFeatureExtractor, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, WhisperTokenizer
from transformers.pipelines.audio_utils import ffmpeg_read
from transformers.utils import logging
from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
from .partitioner import PjitPartitioner
from .train_state import InferenceState
logger = logging.get_logger(__name__)
# 2D parameter and activation partitioning for DP
logical_axis_rules_dp = (
("batch", "data"),
("mlp", None),
("heads", None),
("vocab", None),
("embed", None),
("embed", None),
("joined_kv", None),
("kv", None),
("length", None),
("num_mel", None),
("channels", None),
)
class FlaxWhisperPipline:
def __init__(
self,
checkpoint="openai/whisper-large-v2",
dtype=jnp.float32,
batch_size=None,
max_length=None,
):
"""
Args
checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"):
The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub
with Flax weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs.
If specified all the computation will be performed with the given `dtype`. **Note that this only
specifies the dtype of the computation and does not influence the dtype of model parameters.**
batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method.
max_length (`int`, *optional*):
The maximum numbers of tokens to generate. Defaults to `model.config.max_length`.
"""
self.checkpoint = checkpoint
self.dtype = dtype
self.processor = WhisperProcessor.from_pretrained(self.checkpoint)
self.feature_extractor = self.processor.feature_extractor
# potentially load fast tokenizer if available
tokenizer_cls = WhisperTokenizerFast if is_tokenizers_available() else WhisperTokenizer
self.tokenizer = tokenizer_cls.from_pretrained(checkpoint)
self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained(
self.checkpoint,
_do_init=False,
dtype=self.dtype,
)
self.max_length = max_length if max_length is not None else self.model.generation_config.max_length
self.min_batch_size = jax.local_device_count()
self.batch_size = (
batch_size if batch_size is not None else self.min_batch_size
) # we need a minimum of 1 batch per-device
def generate(params, input_features, forced_decoder_ids, return_timestamps):
output_ids = self.model.pipeline_generate(
input_features,
params=params,
forced_decoder_ids=forced_decoder_ids,
return_timestamps=return_timestamps,
max_length=self.max_length,
)
return output_ids
# use pmap for DP by default - this is compatible on a Colab TPU v2
self.params = jax_utils.replicate(self.params)
self.p_generate = jax.pmap(
generate, "input_features", in_axes=(0, 0, None), out_axes=0, static_broadcasted_argnums=(3,)
)
self.is_sharded = False
def shard_params(self, num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp):
def init_fn():
input_shape = (1, self.model.config.num_mel_bins, 2 * self.model.config.max_source_positions)
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.model.config.eos_token_id)
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
rng = jax.random.PRNGKey(0)
init_params = self.model.module.init(
rng,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
return_dict=False,
)
return init_params
# Axis names metadata
param_axes = jax.eval_shape(init_fn)["params_axes"]
# Create InferenceState, since the partitioner expects it
state = InferenceState(
step=jnp.array(0),
params=freeze(self.model.params_shape_tree),
params_axes=freeze(param_axes),
flax_mutables=None,
flax_mutables_axes=param_axes,
)
partitioner = PjitPartitioner(num_partitions=num_mp_partitions, logical_axis_rules=logical_axis_rules)
mesh_axes = partitioner.get_mesh_axes(state)
params_spec = mesh_axes.params
p_shard_params = partitioner.partition(self.model.to_bf16, (params_spec,), params_spec)
# This will auto-magically run in mesh context
self.params = p_shard_params(freeze(jax_utils.unreplicate(self.params)))
self.is_sharded = True
def generate(params, input_features, forced_decoder_ids, return_timestamps):
output_ids = self.model.pipeline_generate(
input_features,
params=params,
forced_decoder_ids=forced_decoder_ids,
return_timestamps=return_timestamps,
max_length=self.max_length,
)
return output_ids
# Use pjit for generate only once we've sharded the params
self.p_generate = partitioner.partition(
generate,
in_axis_resources=(params_spec, P("data"), None),
out_axis_resources=P("data"),
static_argnums=(3,),
)
def generate(self, input_features, language=None, task=None, return_timestamps=False):
forced_decoder_ids = self.get_forced_decoder_ids(
language=language, task=task, return_timestamps=return_timestamps
)
if not self.is_sharded:
# if we're using pmap we need to manually replicate the input data across devices and gather the output tokens
output_ids = self.p_generate(
freeze(self.params), shard(input_features), forced_decoder_ids, return_timestamps
).sequences
output_ids = jax.device_get(output_ids.reshape(-1, self.max_length))
else:
# pjit handles replication / gathering for us auto-magically
output_ids = self.p_generate(
freeze(self.params), input_features, forced_decoder_ids, return_timestamps
).sequences
return output_ids
def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
if generation_config is None:
generation_config = self.model.generation_config
if hasattr(generation_config, "is_multilingual"):
is_multilingual = generation_config.is_multilingual
else:
is_multilingual = None
forced_decoder_ids = []
if is_multilingual:
if language is not None:
language = language.lower()
if language in generation_config.lang_to_id.keys():
language_token = language
elif language in TO_LANGUAGE_CODE.values():
language_token = f"<|{language}|>"
elif language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
else:
if len(language) == 2:
# ISO 639-1 language code
acceptable_languages = list(TO_LANGUAGE_CODE.values())
elif "<" in language or "|" in language or ">" in language:
# generation config language code
acceptable_languages = list(generation_config.lang_to_id.keys())
else:
# language passed as a string
acceptable_languages = list(TO_LANGUAGE_CODE.keys())
raise ValueError(
f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
if task is not None:
forced_decoder_ids.append((2, generation_config.task_to_id[task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if not return_timestamps:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
return forced_decoder_ids
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
all_chunk_start_idx = np.arange(0, inputs_len, step)
num_samples = len(all_chunk_start_idx)
num_batches = math.ceil(num_samples / batch_size)
batch_idx = np.array_split(np.arange(num_samples), num_batches)
for idx in batch_idx:
chunk_start_idx = all_chunk_start_idx[idx]
chunk_end_idx = chunk_start_idx + chunk_len
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
processed = self.feature_extractor(
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
)
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
_stride_right = np.where(is_last, 0, stride_right)
chunk_lens = [chunk.shape[0] for chunk in chunks]
strides = [
(chunk_l, _stride_l, _stride_r)
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
]
yield {"stride": strides, **processed}
def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None):
if isinstance(inputs, np.ndarray):
logger.warning(
"Numpy array passed as input - no sampling rate checks will be performed."
"It is strongly recommended to pass the input as a dictionary with an 'array' key "
"containing the numpy array representing the audio, and a 'sampling_rate' key "
"containing the sampling rate associated with the audio array."
"Failing to do so can result in silent errors that might be hard to debug."
)
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(inputs).content
else:
with open(inputs, "rb") as f:
inputs = f.read()
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
stride = None
if isinstance(inputs, dict):
stride = inputs.get("stride", None)
# Accepting `"array"` which is the key defined in `datasets` for
# better integration
if not ("sampling_rate" in inputs and "array" in inputs):
raise ValueError(
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
"containing the numpy array representing the audio, and a 'sampling_rate' key "
"containing the sampling rate associated with the audio array."
)
in_sampling_rate = inputs.get("sampling_rate")
inputs = inputs.get("array", None)
if in_sampling_rate != self.feature_extractor.sampling_rate:
try:
import librosa
except ImportError as err:
raise ImportError(
"To support resampling audio files, please install 'librosa' and 'soundfile'."
) from err
inputs = librosa.resample(
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
)
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
else:
ratio = 1
if not isinstance(inputs, np.ndarray):
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if stride is not None:
if stride[0] + stride[1] > inputs.shape[0]:
raise ValueError("Stride is too large for input")
# Stride needs to get the chunk length here, it's going to get
# swallowed by the `feature_extractor` later, and then batching
# can add extra data in the inputs, so we need to keep track
# of the original length in the stride so we can cut properly.
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
if chunk_length_s:
if stride_length_s is None:
stride_length_s = chunk_length_s / 6
if isinstance(stride_length_s, (int, float)):
stride_length_s = [stride_length_s, stride_length_s]
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")
for item in self.chunk_iter_with_batch(
inputs,
chunk_len,
stride_left,
stride_right,
batch_size,
):
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
)
if stride is not None:
processed["stride"] = stride
yield processed
def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
# unpack the outputs from list(dict(list)) to list(dict)
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
# Send the chunking back to seconds, it's easier to handle in whisper
sampling_rate = self.feature_extractor.sampling_rate
for output in model_outputs:
if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Go back in seconds
chunk_len /= sampling_rate
stride_left /= sampling_rate
stride_right /= sampling_rate
output["stride"] = chunk_len, stride_left, stride_right
text, optional = self.tokenizer._decode_asr(
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
return {"text": text, **optional}
def forward(self, model_inputs, batch_size=None, language=None, task=None, return_timestamps=False):
# We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
input_features = model_inputs.pop("input_features")
input_batch_size = input_features.shape[0]
if input_batch_size != batch_size:
padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype)
input_features = np.concatenate([input_features, padding])
pred_ids = self.generate(input_features, language=language, task=task, return_timestamps=return_timestamps)[
:input_batch_size
]
# tokenizer's decode method expects an extra dim - we insert it here for convenience
out = {"tokens": pred_ids[:, None, :]}
stride = model_inputs.pop("stride", None)
if stride is not None:
out["stride"] = stride
return out
def __call__(
self,
inputs,
chunk_length_s=30.0,
stride_length_s=None,
batch_size=None,
language=None,
task=None,
return_timestamps=None,
generate_kwargs=None,
):
"""
Transcribe an audio input sequence to a text transcription, optionally with timestamps.
Args:
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either:
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
- `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the
same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling
rate check will be done.
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array":
np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to
ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in
decoding (but used at inference to provide more context to the model). In general, this additional
stride argument is not required.
chunk_length_s (`float`, *optional*, defaults to 30.0):
The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk
length is set 30.0s, equal to Whisper's context window.
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
the model to *see* more context and infer letters better than without this context but the pipeline
discards the stride bits at the end to make the final reconstitution as perfect as possible.
<Tip>
For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking
blog post](https://huggingface.co/blog/asr-chunking).
</Tip>
batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
a batch size in the `__call__` method will supersede any batch size passed to the `__init__`.
task (`str`, *optional*):
Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
language (`str`, *optional*):
Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`.
Defaults to `None`, meaning the language is automatically inferred from the audio input.
return_timestamps (*optional*, `bool`):
Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline
will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"`
containing the transcription segments chunked by their utterance-level timestamps.
Return:
`Dict`: A dictionary with the following keys:
- **text** (`str` ) -- The recognised text.
- **chunks** (*optional(, `List[Dict]`)
When using `return_timestamps`, the `chunks` will become a list containing all the various text
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
`"".join(chunk["text"] for chunk in output["chunks"])`.
"""
batch_size = batch_size if batch_size is not None else self.batch_size
if batch_size % self.min_batch_size != 0:
raise ValueError(
f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}."
)
dataloader = self.preprocess_batch(
inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size
)
model_outputs = []
# iterate over our chunked audio samples
for batch in dataloader:
model_outputs.append(
self.forward(
batch, batch_size=batch_size, language=language, task=task, return_timestamps=return_timestamps
)
)
post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps)
return post_processed