Spaces:
Runtime error
Runtime error
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from collections import defaultdict | |
from typing import TYPE_CHECKING, Dict, Optional, Union | |
import numpy as np | |
import requests | |
from ..modelcard import ModelCard | |
from ..tokenization_utils import PreTrainedTokenizer | |
from ..utils import is_torch_available, is_torchaudio_available, logging | |
from .audio_utils import ffmpeg_read | |
from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model | |
if TYPE_CHECKING: | |
from pyctcdecode import BeamSearchDecoderCTC | |
from ..feature_extraction_sequence_utils import SequenceFeatureExtractor | |
from ..modeling_utils import PreTrainedModel | |
logger = logging.get_logger(__name__) | |
if is_torch_available(): | |
import torch | |
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES | |
def rescale_stride(stride, ratio): | |
""" | |
Rescales the stride values from audio space to tokens/logits space. | |
(160_000, 16_000, 16_000) -> (2000, 200, 200) for instance. | |
""" | |
# Shape is [B, SEQ] for tokens | |
# [B, SEQ, V] for logits | |
new_strides = [] | |
for input_n, left, right in stride: | |
token_n = int(round(input_n * ratio)) | |
left = int(round(left / input_n * token_n)) | |
right = int(round(right / input_n * token_n)) | |
new_stride = (token_n, left, right) | |
new_strides.append(new_stride) | |
return new_strides | |
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): | |
inputs_len = inputs.shape[0] | |
step = chunk_len - stride_left - stride_right | |
for chunk_start_idx in range(0, inputs_len, step): | |
chunk_end_idx = chunk_start_idx + chunk_len | |
chunk = inputs[chunk_start_idx:chunk_end_idx] | |
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") | |
if dtype is not None: | |
processed = processed.to(dtype=dtype) | |
_stride_left = 0 if chunk_start_idx == 0 else stride_left | |
# all right strides must be full, otherwise it is the last item | |
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len | |
_stride_right = 0 if is_last else stride_right | |
chunk_len = chunk.shape[0] | |
stride = (chunk_len, _stride_left, _stride_right) | |
if "input_features" in processed: | |
processed_len = processed["input_features"].shape[-1] | |
elif "input_values" in processed: | |
processed_len = processed["input_values"].shape[-1] | |
if processed_len != chunk.shape[-1] and rescale: | |
ratio = processed_len / chunk_len | |
stride = rescale_stride([stride], ratio)[0] | |
if chunk.shape[0] > _stride_left: | |
yield {"is_last": is_last, "stride": stride, **processed} | |
if is_last: | |
break | |
def _fast_find_longest_common_sequence(sequence_left, sequence_right): | |
seq_len_left = len(sequence_left) | |
seq_len_right = len(sequence_right) | |
counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)] | |
longest = 0 | |
for i in range(seq_len_left): | |
for j in range(seq_len_right): | |
if sequence_left[i] == sequence_right[j]: | |
previous_counter = counter[i][j] + 1 | |
counter[i + 1][j + 1] = previous_counter | |
if previous_counter > longest: | |
longest = previous_counter | |
counter = np.array(counter) | |
# we return the idx of the first element of the longest common sequence in the left sequence | |
index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1 | |
index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1 | |
return index_left, index_right, longest | |
def _find_longest_common_sequence(sequences, tokenizer): | |
# TODO Use a faster algorithm this can probably be done in O(n) | |
# using suffix array. | |
# It might be tedious to do because of fault tolerance. | |
# We actually have a really good property which is that the total sequence | |
# MUST be those subsequences in order. | |
# Also the algorithm should be more tolerant to errors. | |
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids] | |
for new_seq in sequences[1:]: | |
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids] | |
index = 0 | |
max_ = 0.0 | |
for i in range(1, len(new_sequence) + 1): | |
# epsilon to favor long perfect matches | |
eps = i / 10000.0 | |
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i])) | |
matching = matches / i + eps | |
if matches > 1 and matching > max_: | |
index = i | |
max_ = matching | |
sequence.extend(new_sequence[index:]) | |
return np.array(sequence) | |
class AutomaticSpeechRecognitionPipeline(ChunkPipeline): | |
""" | |
Pipeline that aims at extracting spoken text contained within some audio. | |
The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for | |
to support multiple audio formats | |
Example: | |
```python | |
>>> from transformers import pipeline | |
>>> transcriber = pipeline(model="openai/whisper-base") | |
>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac") | |
{'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'} | |
``` | |
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) | |
Arguments: | |
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): | |
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from | |
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow. | |
tokenizer ([`PreTrainedTokenizer`]): | |
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from | |
[`PreTrainedTokenizer`]. | |
feature_extractor ([`SequenceFeatureExtractor`]): | |
The feature extractor that will be used by the pipeline to encode waveform for the model. | |
chunk_length_s (`float`, *optional*, defaults to 0): | |
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). | |
<Tip> | |
For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking | |
blog post](https://huggingface.co/blog/asr-chunking). | |
</Tip> | |
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): | |
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables | |
the model to *see* more context and infer letters better than without this context but the pipeline | |
discards the stride bits at the end to make the final reconstitution as perfect as possible. | |
<Tip> | |
For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking | |
blog post](https://huggingface.co/blog/asr-chunking). | |
</Tip> | |
framework (`str`, *optional*): | |
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be | |
installed. If no framework is specified, will default to the one currently installed. If no framework is | |
specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if | |
no model is provided. | |
device (Union[`int`, `torch.device`], *optional*): | |
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the | |
model on the associated CUDA device id. | |
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*): | |
[PyCTCDecode's | |
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180) | |
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information. | |
""" | |
def __init__( | |
self, | |
model: "PreTrainedModel", | |
feature_extractor: Union["SequenceFeatureExtractor", str] = None, | |
tokenizer: Optional[PreTrainedTokenizer] = None, | |
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, | |
modelcard: Optional[ModelCard] = None, | |
framework: Optional[str] = None, | |
task: str = "", | |
args_parser: ArgumentHandler = None, | |
device: Union[int, "torch.device"] = None, | |
torch_dtype: Optional[Union[str, "torch.dtype"]] = None, | |
binary_output: bool = False, | |
**kwargs, | |
): | |
if framework is None: | |
framework, model = infer_framework_load_model(model, config=model.config) | |
self.task = task | |
self.model = model | |
self.tokenizer = tokenizer | |
self.feature_extractor = feature_extractor | |
self.modelcard = modelcard | |
self.framework = framework | |
# `accelerate` device map | |
hf_device_map = getattr(self.model, "hf_device_map", None) | |
if hf_device_map is not None and device is not None: | |
raise ValueError( | |
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please " | |
"discard the `device` argument when creating your pipeline object." | |
) | |
if self.framework == "tf": | |
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") | |
# We shouldn't call `model.to()` for models loaded with accelerate | |
if device is not None and not (isinstance(device, int) and device < 0): | |
self.model.to(device) | |
if device is None: | |
if hf_device_map is not None: | |
# Take the first device used by `accelerate`. | |
device = next(iter(hf_device_map.values())) | |
else: | |
device = -1 | |
if is_torch_available() and self.framework == "pt": | |
if isinstance(device, torch.device): | |
self.device = device | |
elif isinstance(device, str): | |
self.device = torch.device(device) | |
elif device < 0: | |
self.device = torch.device("cpu") | |
else: | |
self.device = torch.device(f"cuda:{device}") | |
else: | |
self.device = device if device is not None else -1 | |
self.torch_dtype = torch_dtype | |
self.binary_output = binary_output | |
# Update config and generation_config with task specific parameters | |
task_specific_params = self.model.config.task_specific_params | |
if task_specific_params is not None and task in task_specific_params: | |
self.model.config.update(task_specific_params.get(task)) | |
if self.model.can_generate(): | |
self.model.generation_config.update(**task_specific_params.get(task)) | |
self.call_count = 0 | |
self._batch_size = kwargs.pop("batch_size", None) | |
self._num_workers = kwargs.pop("num_workers", None) | |
# set the model type so we can check we have the right pre- and post-processing parameters | |
if self.model.config.model_type == "whisper": | |
self.type = "seq2seq_whisper" | |
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): | |
self.type = "seq2seq" | |
elif ( | |
feature_extractor._processor_class | |
and feature_extractor._processor_class.endswith("WithLM") | |
and decoder is not None | |
): | |
self.decoder = decoder | |
self.type = "ctc_with_lm" | |
else: | |
self.type = "ctc" | |
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) | |
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy() | |
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES) | |
self.check_model_type(mapping) | |
def __call__( | |
self, | |
inputs: Union[np.ndarray, bytes, str], | |
**kwargs, | |
): | |
""" | |
Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`] | |
documentation for more information. | |
Args: | |
inputs (`np.ndarray` or `bytes` or `str` or `dict`): | |
The inputs is either : | |
- `str` that is either the filename of a local audio file, or a public URL address to download the | |
audio file. The file will be read at the correct sampling rate to get the waveform using | |
*ffmpeg*. This requires *ffmpeg* to be installed on the system. | |
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the | |
same way. | |
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) | |
Raw audio at the correct sampling rate (no further check will be done) | |
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | |
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw": | |
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to | |
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at | |
inference to provide more context to the model). Only use `stride` with CTC models. | |
return_timestamps (*optional*, `str` or `bool`): | |
Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for | |
other sequence-to-sequence models. | |
For CTC models, timestamps can take one of two formats: | |
- `"char"`: the pipeline will return timestamps along the text for every character in the text. For | |
instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7, | |
0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before | |
`0.6` seconds. | |
- `"word"`: the pipeline will return timestamps along the text for every word in the text. For | |
instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": | |
(1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and | |
before `0.9` seconds. | |
For the Whisper model, timestamps can take one of two formats: | |
- `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted | |
through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps | |
by inspecting the cross-attention weights. | |
- `True`: the pipeline will return timestamps along the text for *segments* of words in the text. | |
For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the | |
model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. | |
Note that a segment of text refers to a sequence of one or more words, rather than individual | |
words as with word-level timestamps. | |
generate_kwargs (`dict`, *optional*): | |
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a | |
complete overview of generate, check the [following | |
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). | |
max_new_tokens (`int`, *optional*): | |
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. | |
Return: | |
`Dict`: A dictionary with the following keys: | |
- **text** (`str`): The recognized text. | |
- **chunks** (*optional(, `List[Dict]`) | |
When using `return_timestamps`, the `chunks` will become a list containing all the various text | |
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": | |
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing | |
`"".join(chunk["text"] for chunk in output["chunks"])`. | |
""" | |
return super().__call__(inputs, **kwargs) | |
def _sanitize_parameters( | |
self, | |
chunk_length_s=None, | |
stride_length_s=None, | |
ignore_warning=None, | |
decoder_kwargs=None, | |
return_timestamps=None, | |
return_language=None, | |
generate_kwargs=None, | |
max_new_tokens=None, | |
): | |
# No parameters on this pipeline right now | |
preprocess_params = {} | |
if chunk_length_s is not None: | |
if self.type == "seq2seq" and not ignore_warning: | |
logger.warning( | |
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" | |
" be entirely accurate and will have caveats. More information:" | |
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," | |
" ignore_warning=True)" | |
) | |
preprocess_params["chunk_length_s"] = chunk_length_s | |
if stride_length_s is not None: | |
preprocess_params["stride_length_s"] = stride_length_s | |
forward_params = defaultdict(dict) | |
if max_new_tokens is not None: | |
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens | |
if generate_kwargs is not None: | |
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs: | |
raise ValueError( | |
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use" | |
" only 1 version" | |
) | |
forward_params["generate_kwargs"].update(generate_kwargs) | |
postprocess_params = {} | |
if decoder_kwargs is not None: | |
postprocess_params["decoder_kwargs"] = decoder_kwargs | |
if return_timestamps is not None: | |
# Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass | |
if self.type == "seq2seq" and return_timestamps: | |
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!") | |
if self.type == "ctc_with_lm" and return_timestamps != "word": | |
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`") | |
if self.type == "ctc" and return_timestamps not in ["char", "word"]: | |
raise ValueError( | |
"CTC can either predict character level timestamps, or word level timestamps." | |
"Set `return_timestamps='char'` or `return_timestamps='word'` as required." | |
) | |
if self.type == "seq2seq_whisper" and return_timestamps == "char": | |
raise ValueError( | |
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. " | |
"Use `return_timestamps='word'` or `return_timestamps=True` respectively." | |
) | |
forward_params["return_timestamps"] = return_timestamps | |
postprocess_params["return_timestamps"] = return_timestamps | |
if return_language is not None: | |
if self.type != "seq2seq_whisper": | |
raise ValueError("Only Whisper can return language for now.") | |
postprocess_params["return_language"] = return_language | |
return preprocess_params, forward_params, postprocess_params | |
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): | |
if isinstance(inputs, str): | |
if inputs.startswith("http://") or inputs.startswith("https://"): | |
# We need to actually check for a real protocol, otherwise it's impossible to use a local file | |
# like http_huggingface_co.png | |
inputs = requests.get(inputs).content | |
else: | |
with open(inputs, "rb") as f: | |
inputs = f.read() | |
if isinstance(inputs, bytes): | |
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) | |
stride = None | |
extra = {} | |
if isinstance(inputs, dict): | |
stride = inputs.pop("stride", None) | |
# Accepting `"array"` which is the key defined in `datasets` for | |
# better integration | |
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): | |
raise ValueError( | |
"When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a " | |
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, ' | |
"containing the sampling_rate associated with that array" | |
) | |
_inputs = inputs.pop("raw", None) | |
if _inputs is None: | |
# Remove path which will not be used from `datasets`. | |
inputs.pop("path", None) | |
_inputs = inputs.pop("array", None) | |
in_sampling_rate = inputs.pop("sampling_rate") | |
extra = inputs | |
inputs = _inputs | |
if in_sampling_rate != self.feature_extractor.sampling_rate: | |
if is_torchaudio_available(): | |
from torchaudio import functional as F | |
else: | |
raise ImportError( | |
"torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. " | |
"The torchaudio package can be installed through: `pip install torchaudio`." | |
) | |
inputs = F.resample( | |
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate | |
).numpy() | |
ratio = self.feature_extractor.sampling_rate / in_sampling_rate | |
else: | |
ratio = 1 | |
if stride is not None: | |
if stride[0] + stride[1] > inputs.shape[0]: | |
raise ValueError("Stride is too large for input") | |
# Stride needs to get the chunk length here, it's going to get | |
# swallowed by the `feature_extractor` later, and then batching | |
# can add extra data in the inputs, so we need to keep track | |
# of the original length in the stride so we can cut properly. | |
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) | |
if not isinstance(inputs, np.ndarray): | |
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") | |
if len(inputs.shape) != 1: | |
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") | |
if chunk_length_s: | |
if stride_length_s is None: | |
stride_length_s = chunk_length_s / 6 | |
if isinstance(stride_length_s, (int, float)): | |
stride_length_s = [stride_length_s, stride_length_s] | |
# XXX: Carefuly, this variable will not exist in `seq2seq` setting. | |
# Currently chunking is not possible at this level for `seq2seq` so | |
# it's ok. | |
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1) | |
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to) | |
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) | |
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) | |
if chunk_len < stride_left + stride_right: | |
raise ValueError("Chunk length must be superior to stride length") | |
rescale = self.type != "seq2seq_whisper" | |
# make sure that | |
for item in chunk_iter( | |
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype | |
): | |
yield item | |
else: | |
processed = self.feature_extractor( | |
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" | |
) | |
if self.torch_dtype is not None: | |
processed = processed.to(dtype=self.torch_dtype) | |
if stride is not None: | |
if self.type == "seq2seq": | |
raise ValueError("Stride is only usable with CTC models, try removing it !") | |
processed["stride"] = stride | |
yield {"is_last": True, **processed, **extra} | |
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): | |
if generate_kwargs is None: | |
generate_kwargs = {} | |
attention_mask = model_inputs.pop("attention_mask", None) | |
stride = model_inputs.pop("stride", None) | |
is_last = model_inputs.pop("is_last") | |
if self.type in {"seq2seq", "seq2seq_whisper"}: | |
encoder = self.model.get_encoder() | |
# Consume values so we can let extra information flow freely through | |
# the pipeline (important for `partial` in microphone) | |
if "input_features" in model_inputs: | |
inputs = model_inputs.pop("input_features") | |
elif "input_values" in model_inputs: | |
inputs = model_inputs.pop("input_values") | |
else: | |
raise ValueError( | |
"Seq2Seq speech recognition model requires either a " | |
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" | |
) | |
# custom processing for Whisper timestamps and word-level timestamps | |
if return_timestamps and self.type == "seq2seq_whisper": | |
generate_kwargs["return_timestamps"] = return_timestamps | |
if return_timestamps == "word": | |
generate_kwargs["return_token_timestamps"] = True | |
if stride is not None: | |
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length | |
tokens = self.model.generate( | |
encoder_outputs=encoder(inputs, attention_mask=attention_mask), | |
attention_mask=attention_mask, | |
**generate_kwargs, | |
) | |
if return_timestamps == "word" and self.type == "seq2seq_whisper": | |
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} | |
else: | |
out = {"tokens": tokens} | |
if self.type == "seq2seq_whisper": | |
if stride is not None: | |
out["stride"] = stride | |
else: | |
input_values = model_inputs.pop("input_values") | |
outputs = self.model(input_values=input_values, attention_mask=attention_mask) | |
logits = outputs.logits | |
if self.type == "ctc_with_lm": | |
out = {"logits": logits} | |
else: | |
out = {"tokens": logits.argmax(dim=-1)} | |
if stride is not None: | |
# Send stride to `postprocess`. | |
# it needs to be handled there where | |
# the pieces are to be concatenated. | |
ratio = 1 / self.model.config.inputs_to_logits_ratio | |
if isinstance(stride, tuple): | |
out["stride"] = rescale_stride([stride], ratio)[0] | |
else: | |
out["stride"] = rescale_stride(stride, ratio) | |
# Leftover | |
extra = model_inputs | |
return {"is_last": is_last, **out, **extra} | |
def postprocess( | |
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None | |
): | |
# Optional return types | |
optional = {} | |
final_items = [] | |
key = "logits" if self.type == "ctc_with_lm" else "tokens" | |
stride = None | |
for outputs in model_outputs: | |
items = outputs[key].numpy() | |
stride = outputs.get("stride", None) | |
if stride is not None and self.type in {"ctc", "ctc_with_lm"}: | |
total_n, left, right = stride | |
# Total_n might be < logits.shape[1] | |
# because of padding, that's why | |
# we need to reconstruct this information | |
# This won't work with left padding (which doesn't exist right now) | |
right_n = total_n - right | |
items = items[:, left:right_n] | |
final_items.append(items) | |
if stride and self.type == "seq2seq": | |
items = _find_longest_common_sequence(final_items, self.tokenizer) | |
elif self.type == "seq2seq_whisper": | |
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions | |
# Send the chunking back to seconds, it's easier to handle in whisper | |
sampling_rate = self.feature_extractor.sampling_rate | |
for output in model_outputs: | |
if "stride" in output: | |
chunk_len, stride_left, stride_right = output["stride"] | |
# Go back in seconds | |
chunk_len /= sampling_rate | |
stride_left /= sampling_rate | |
stride_right /= sampling_rate | |
output["stride"] = chunk_len, stride_left, stride_right | |
text, optional = self.tokenizer._decode_asr( | |
model_outputs, | |
return_timestamps=return_timestamps, | |
return_language=return_language, | |
time_precision=time_precision, | |
) | |
else: | |
items = np.concatenate(final_items, axis=1) | |
items = items.squeeze(0) | |
if self.type == "ctc_with_lm": | |
if decoder_kwargs is None: | |
decoder_kwargs = {} | |
beams = self.decoder.decode_beams(items, **decoder_kwargs) | |
text = beams[0][0] | |
if return_timestamps: | |
# Simply cast from pyctcdecode format to wav2vec2 format to leverage | |
# pre-existing code later | |
chunk_offset = beams[0][2] | |
offsets = [] | |
for word, (start_offset, end_offset) in chunk_offset: | |
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) | |
elif self.type != "seq2seq_whisper": | |
skip_special_tokens = self.type != "ctc" | |
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) | |
if return_timestamps: | |
offsets = self.tokenizer.decode( | |
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True | |
)["char_offsets"] | |
if return_timestamps == "word": | |
offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char) | |
if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}: | |
chunks = [] | |
for item in offsets: | |
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio | |
start /= self.feature_extractor.sampling_rate | |
stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio | |
stop /= self.feature_extractor.sampling_rate | |
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) | |
optional["chunks"] = chunks | |
extra = defaultdict(list) | |
for output in model_outputs: | |
output.pop("tokens", None) | |
output.pop("logits", None) | |
output.pop("is_last", None) | |
output.pop("stride", None) | |
output.pop("token_timestamps", None) | |
for k, v in output.items(): | |
extra[k].append(v) | |
return {"text": text, **optional, **extra} | |
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): | |
""" | |
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since | |
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only | |
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is | |
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to | |
properly compute the final `offset`. | |
""" | |
# index of the first timestamp token | |
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 | |
items = [] | |
# approximation of the token to time ratio : ~0.2seconds | |
time_precision = feature_extractor.chunk_length / max_source_positions | |
time = 0 | |
for seq_idx, item in enumerate(sequences): | |
sequence, stride = item | |
if isinstance(sequence, list): | |
sequence = np.array(sequence) | |
chunk_len, stride_left, stride_right = stride | |
sequence = sequence.squeeze(0) | |
# get rid of the `forced_decoder_idx` that are use to parametrize the generation | |
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0 | |
sequence = sequence[begin_idx:] | |
timestamp_tokens = sequence >= timestamp_begin | |
if seq_idx != 0 and sum(timestamp_tokens) > 0: | |
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 | |
last_timestamp = np.where(timestamp_tokens)[0][-1] | |
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive | |
time -= stride_left + stride_right | |
offset = int((time / feature_extractor.sampling_rate) / time_precision) | |
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) | |
# relevant timestamps are in the overlapping part | |
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0] | |
if relevant_timestamp.shape[0] > 0: | |
relevant_timestamp = ( | |
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0] | |
) | |
# if a big stride is used, we need to check some of the previous items for the best overlap | |
best_match = 0 | |
sliced_sequence = [] | |
for idx, previous_sequence in enumerate(reversed(items)): | |
previous_tokens = previous_sequence[1:-1] | |
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0: | |
break # the previous sequence is too far in the past | |
if len(previous_tokens) > 0: | |
# find the longest common sequence between the overlapping parts | |
index_left, index_right, match_length = _fast_find_longest_common_sequence( | |
sequence[1:relevant_timestamp], previous_tokens | |
) | |
# don't do anything if only 1 token was matched | |
if match_length > 1 and match_length > best_match: | |
best_match = match_length | |
best_idx = idx | |
end_of_curr_sequence_idx = ( | |
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1 | |
) | |
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left | |
# if all the tokens are matched, suffix | |
if index_left == 0 and match_length == len(previous_tokens): | |
sliced_sequence = np.insert( | |
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0] | |
) | |
sliced_sequence[-1] = previous_sequence[-1] | |
# if part of the previous sequence is not taken | |
elif index_left >= 0: | |
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx] | |
# let's insert the missing part of the previous sequence | |
previous_slice = ( | |
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]] | |
) | |
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice) | |
sliced_sequence[-1] += offset | |
if len(sliced_sequence) > 0: | |
items[len(items) - best_idx - 1] = sliced_sequence | |
items = items[: len(items) - best_idx] | |
sequence = sequence[end_of_curr_sequence_idx:] | |
# sequence might have changed | |
timestamp_tokens = sequence >= timestamp_begin | |
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 | |
if sum(timestamp_tokens) > 0: | |
last_timestamp = np.where(timestamp_tokens)[0][-1] | |
consecutive = ( | |
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive | |
) | |
if len(consecutive) > 0: | |
last_slice = 0 | |
for current_slice in consecutive: | |
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0] | |
sliced_tokens = sequence[last_slice:current_slice] | |
duration = sliced_tokens[-1] - sliced_tokens[0] | |
sliced_tokens[0] = actual_offset | |
sliced_tokens[-1] = actual_offset + duration | |
items.append(sliced_tokens) | |
last_slice = current_slice | |
time += chunk_len | |
result = [] | |
for i in range(len(items)): | |
result += items[i].tolist() | |
return result | |