|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
from typing import TYPE_CHECKING, Dict, Optional, Union |
|
|
|
import numpy as np |
|
import requests |
|
|
|
from ..modelcard import ModelCard |
|
from ..tokenization_utils import PreTrainedTokenizer |
|
from ..utils import is_torch_available, is_torchaudio_available, logging |
|
from .audio_utils import ffmpeg_read |
|
from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model |
|
|
|
|
|
if TYPE_CHECKING: |
|
from pyctcdecode import BeamSearchDecoderCTC |
|
|
|
from ..feature_extraction_sequence_utils import SequenceFeatureExtractor |
|
from ..modeling_utils import PreTrainedModel |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES |
|
|
|
|
|
def rescale_stride(stride, ratio): |
|
""" |
|
Rescales the stride values from audio space to tokens/logits space. |
|
|
|
(160_000, 16_000, 16_000) -> (2000, 200, 200) for instance. |
|
""" |
|
|
|
|
|
|
|
new_strides = [] |
|
for input_n, left, right in stride: |
|
token_n = int(round(input_n * ratio)) |
|
left = int(round(left / input_n * token_n)) |
|
right = int(round(right / input_n * token_n)) |
|
new_stride = (token_n, left, right) |
|
new_strides.append(new_stride) |
|
|
|
return new_strides |
|
|
|
|
|
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): |
|
inputs_len = inputs.shape[0] |
|
step = chunk_len - stride_left - stride_right |
|
for chunk_start_idx in range(0, inputs_len, step): |
|
chunk_end_idx = chunk_start_idx + chunk_len |
|
chunk = inputs[chunk_start_idx:chunk_end_idx] |
|
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") |
|
if dtype is not None: |
|
processed = processed.to(dtype=dtype) |
|
_stride_left = 0 if chunk_start_idx == 0 else stride_left |
|
|
|
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len |
|
_stride_right = 0 if is_last else stride_right |
|
|
|
chunk_len = chunk.shape[0] |
|
stride = (chunk_len, _stride_left, _stride_right) |
|
if "input_features" in processed: |
|
processed_len = processed["input_features"].shape[-1] |
|
elif "input_values" in processed: |
|
processed_len = processed["input_values"].shape[-1] |
|
if processed_len != chunk.shape[-1] and rescale: |
|
ratio = processed_len / chunk_len |
|
stride = rescale_stride([stride], ratio)[0] |
|
if chunk.shape[0] > _stride_left: |
|
yield {"is_last": is_last, "stride": stride, **processed} |
|
if is_last: |
|
break |
|
|
|
|
|
def _fast_find_longest_common_sequence(sequence_left, sequence_right): |
|
seq_len_left = len(sequence_left) |
|
seq_len_right = len(sequence_right) |
|
counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)] |
|
longest = 0 |
|
for i in range(seq_len_left): |
|
for j in range(seq_len_right): |
|
if sequence_left[i] == sequence_right[j]: |
|
previous_counter = counter[i][j] + 1 |
|
counter[i + 1][j + 1] = previous_counter |
|
if previous_counter > longest: |
|
longest = previous_counter |
|
|
|
counter = np.array(counter) |
|
|
|
index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1 |
|
index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1 |
|
return index_left, index_right, longest |
|
|
|
|
|
def _find_longest_common_sequence(sequences, tokenizer): |
|
|
|
|
|
|
|
|
|
|
|
|
|
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids] |
|
for new_seq in sequences[1:]: |
|
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids] |
|
|
|
index = 0 |
|
max_ = 0.0 |
|
for i in range(1, len(new_sequence) + 1): |
|
|
|
eps = i / 10000.0 |
|
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i])) |
|
matching = matches / i + eps |
|
if matches > 1 and matching > max_: |
|
index = i |
|
max_ = matching |
|
sequence.extend(new_sequence[index:]) |
|
return np.array(sequence) |
|
|
|
|
|
class AutomaticSpeechRecognitionPipeline(ChunkPipeline): |
|
""" |
|
Pipeline that aims at extracting spoken text contained within some audio. |
|
|
|
The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for |
|
to support multiple audio formats |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import pipeline |
|
|
|
>>> transcriber = pipeline(model="openai/whisper-base") |
|
>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac") |
|
{'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'} |
|
``` |
|
|
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) |
|
|
|
Arguments: |
|
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): |
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from |
|
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow. |
|
tokenizer ([`PreTrainedTokenizer`]): |
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from |
|
[`PreTrainedTokenizer`]. |
|
feature_extractor ([`SequenceFeatureExtractor`]): |
|
The feature extractor that will be used by the pipeline to encode waveform for the model. |
|
chunk_length_s (`float`, *optional*, defaults to 0): |
|
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). |
|
|
|
<Tip> |
|
|
|
For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking |
|
blog post](https://huggingface.co/blog/asr-chunking). |
|
|
|
</Tip> |
|
|
|
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`): |
|
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables |
|
the model to *see* more context and infer letters better than without this context but the pipeline |
|
discards the stride bits at the end to make the final reconstitution as perfect as possible. |
|
|
|
<Tip> |
|
|
|
For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking |
|
blog post](https://huggingface.co/blog/asr-chunking). |
|
|
|
</Tip> |
|
|
|
framework (`str`, *optional*): |
|
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be |
|
installed. If no framework is specified, will default to the one currently installed. If no framework is |
|
specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if |
|
no model is provided. |
|
device (Union[`int`, `torch.device`], *optional*): |
|
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the |
|
model on the associated CUDA device id. |
|
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*): |
|
[PyCTCDecode's |
|
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180) |
|
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: "PreTrainedModel", |
|
feature_extractor: Union["SequenceFeatureExtractor", str] = None, |
|
tokenizer: Optional[PreTrainedTokenizer] = None, |
|
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, |
|
modelcard: Optional[ModelCard] = None, |
|
framework: Optional[str] = None, |
|
task: str = "", |
|
args_parser: ArgumentHandler = None, |
|
device: Union[int, "torch.device"] = None, |
|
torch_dtype: Optional[Union[str, "torch.dtype"]] = None, |
|
binary_output: bool = False, |
|
**kwargs, |
|
): |
|
if framework is None: |
|
framework, model = infer_framework_load_model(model, config=model.config) |
|
|
|
self.task = task |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.feature_extractor = feature_extractor |
|
self.modelcard = modelcard |
|
self.framework = framework |
|
|
|
|
|
hf_device_map = getattr(self.model, "hf_device_map", None) |
|
|
|
if hf_device_map is not None and device is not None: |
|
raise ValueError( |
|
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please " |
|
"discard the `device` argument when creating your pipeline object." |
|
) |
|
|
|
if self.framework == "tf": |
|
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") |
|
|
|
|
|
if device is not None and not (isinstance(device, int) and device < 0): |
|
self.model.to(device) |
|
|
|
if device is None: |
|
if hf_device_map is not None: |
|
|
|
device = next(iter(hf_device_map.values())) |
|
else: |
|
device = -1 |
|
|
|
if is_torch_available() and self.framework == "pt": |
|
if isinstance(device, torch.device): |
|
self.device = device |
|
elif isinstance(device, str): |
|
self.device = torch.device(device) |
|
elif device < 0: |
|
self.device = torch.device("cpu") |
|
else: |
|
self.device = torch.device(f"cuda:{device}") |
|
else: |
|
self.device = device if device is not None else -1 |
|
self.torch_dtype = torch_dtype |
|
self.binary_output = binary_output |
|
|
|
|
|
task_specific_params = self.model.config.task_specific_params |
|
if task_specific_params is not None and task in task_specific_params: |
|
self.model.config.update(task_specific_params.get(task)) |
|
if self.model.can_generate(): |
|
self.model.generation_config.update(**task_specific_params.get(task)) |
|
|
|
self.call_count = 0 |
|
self._batch_size = kwargs.pop("batch_size", None) |
|
self._num_workers = kwargs.pop("num_workers", None) |
|
|
|
|
|
if self.model.config.model_type == "whisper": |
|
self.type = "seq2seq_whisper" |
|
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): |
|
self.type = "seq2seq" |
|
elif ( |
|
feature_extractor._processor_class |
|
and feature_extractor._processor_class.endswith("WithLM") |
|
and decoder is not None |
|
): |
|
self.decoder = decoder |
|
self.type = "ctc_with_lm" |
|
else: |
|
self.type = "ctc" |
|
|
|
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) |
|
|
|
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy() |
|
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES) |
|
self.check_model_type(mapping) |
|
|
|
def __call__( |
|
self, |
|
inputs: Union[np.ndarray, bytes, str], |
|
**kwargs, |
|
): |
|
""" |
|
Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`] |
|
documentation for more information. |
|
|
|
Args: |
|
inputs (`np.ndarray` or `bytes` or `str` or `dict`): |
|
The inputs is either : |
|
- `str` that is either the filename of a local audio file, or a public URL address to download the |
|
audio file. The file will be read at the correct sampling rate to get the waveform using |
|
*ffmpeg*. This requires *ffmpeg* to be installed on the system. |
|
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the |
|
same way. |
|
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) |
|
Raw audio at the correct sampling rate (no further check will be done) |
|
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this |
|
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw": |
|
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to |
|
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at |
|
inference to provide more context to the model). Only use `stride` with CTC models. |
|
return_timestamps (*optional*, `str` or `bool`): |
|
Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for |
|
other sequence-to-sequence models. |
|
|
|
For CTC models, timestamps can take one of two formats: |
|
- `"char"`: the pipeline will return timestamps along the text for every character in the text. For |
|
instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7, |
|
0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before |
|
`0.6` seconds. |
|
- `"word"`: the pipeline will return timestamps along the text for every word in the text. For |
|
instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": |
|
(1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and |
|
before `0.9` seconds. |
|
|
|
For the Whisper model, timestamps can take one of two formats: |
|
- `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted |
|
through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps |
|
by inspecting the cross-attention weights. |
|
- `True`: the pipeline will return timestamps along the text for *segments* of words in the text. |
|
For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the |
|
model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. |
|
Note that a segment of text refers to a sequence of one or more words, rather than individual |
|
words as with word-level timestamps. |
|
generate_kwargs (`dict`, *optional*): |
|
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a |
|
complete overview of generate, check the [following |
|
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). |
|
max_new_tokens (`int`, *optional*): |
|
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. |
|
|
|
Return: |
|
`Dict`: A dictionary with the following keys: |
|
- **text** (`str`): The recognized text. |
|
- **chunks** (*optional(, `List[Dict]`) |
|
When using `return_timestamps`, the `chunks` will become a list containing all the various text |
|
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": |
|
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing |
|
`"".join(chunk["text"] for chunk in output["chunks"])`. |
|
""" |
|
return super().__call__(inputs, **kwargs) |
|
|
|
def _sanitize_parameters( |
|
self, |
|
chunk_length_s=None, |
|
stride_length_s=None, |
|
ignore_warning=None, |
|
decoder_kwargs=None, |
|
return_timestamps=None, |
|
return_language=None, |
|
generate_kwargs=None, |
|
max_new_tokens=None, |
|
): |
|
|
|
preprocess_params = {} |
|
if chunk_length_s is not None: |
|
if self.type == "seq2seq" and not ignore_warning: |
|
logger.warning( |
|
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" |
|
" be entirely accurate and will have caveats. More information:" |
|
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," |
|
" ignore_warning=True)" |
|
) |
|
preprocess_params["chunk_length_s"] = chunk_length_s |
|
if stride_length_s is not None: |
|
preprocess_params["stride_length_s"] = stride_length_s |
|
|
|
forward_params = defaultdict(dict) |
|
if max_new_tokens is not None: |
|
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens |
|
if generate_kwargs is not None: |
|
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs: |
|
raise ValueError( |
|
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use" |
|
" only 1 version" |
|
) |
|
forward_params["generate_kwargs"].update(generate_kwargs) |
|
|
|
postprocess_params = {} |
|
if decoder_kwargs is not None: |
|
postprocess_params["decoder_kwargs"] = decoder_kwargs |
|
if return_timestamps is not None: |
|
|
|
if self.type == "seq2seq" and return_timestamps: |
|
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!") |
|
if self.type == "ctc_with_lm" and return_timestamps != "word": |
|
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`") |
|
if self.type == "ctc" and return_timestamps not in ["char", "word"]: |
|
raise ValueError( |
|
"CTC can either predict character level timestamps, or word level timestamps." |
|
"Set `return_timestamps='char'` or `return_timestamps='word'` as required." |
|
) |
|
if self.type == "seq2seq_whisper" and return_timestamps == "char": |
|
raise ValueError( |
|
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. " |
|
"Use `return_timestamps='word'` or `return_timestamps=True` respectively." |
|
) |
|
forward_params["return_timestamps"] = return_timestamps |
|
postprocess_params["return_timestamps"] = return_timestamps |
|
if return_language is not None: |
|
if self.type != "seq2seq_whisper": |
|
raise ValueError("Only Whisper can return language for now.") |
|
postprocess_params["return_language"] = return_language |
|
|
|
return preprocess_params, forward_params, postprocess_params |
|
|
|
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): |
|
if isinstance(inputs, str): |
|
if inputs.startswith("http://") or inputs.startswith("https://"): |
|
|
|
|
|
inputs = requests.get(inputs).content |
|
else: |
|
with open(inputs, "rb") as f: |
|
inputs = f.read() |
|
|
|
if isinstance(inputs, bytes): |
|
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) |
|
|
|
stride = None |
|
extra = {} |
|
if isinstance(inputs, dict): |
|
stride = inputs.pop("stride", None) |
|
|
|
|
|
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): |
|
raise ValueError( |
|
"When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a " |
|
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, ' |
|
"containing the sampling_rate associated with that array" |
|
) |
|
|
|
_inputs = inputs.pop("raw", None) |
|
if _inputs is None: |
|
|
|
inputs.pop("path", None) |
|
_inputs = inputs.pop("array", None) |
|
in_sampling_rate = inputs.pop("sampling_rate") |
|
extra = inputs |
|
inputs = _inputs |
|
if in_sampling_rate != self.feature_extractor.sampling_rate: |
|
if is_torchaudio_available(): |
|
from torchaudio import functional as F |
|
else: |
|
raise ImportError( |
|
"torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. " |
|
"The torchaudio package can be installed through: `pip install torchaudio`." |
|
) |
|
|
|
inputs = F.resample( |
|
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate |
|
).numpy() |
|
ratio = self.feature_extractor.sampling_rate / in_sampling_rate |
|
else: |
|
ratio = 1 |
|
if stride is not None: |
|
if stride[0] + stride[1] > inputs.shape[0]: |
|
raise ValueError("Stride is too large for input") |
|
|
|
|
|
|
|
|
|
|
|
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio))) |
|
if not isinstance(inputs, np.ndarray): |
|
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`") |
|
if len(inputs.shape) != 1: |
|
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") |
|
|
|
if chunk_length_s: |
|
if stride_length_s is None: |
|
stride_length_s = chunk_length_s / 6 |
|
|
|
if isinstance(stride_length_s, (int, float)): |
|
stride_length_s = [stride_length_s, stride_length_s] |
|
|
|
|
|
|
|
|
|
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1) |
|
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to) |
|
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) |
|
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) |
|
|
|
if chunk_len < stride_left + stride_right: |
|
raise ValueError("Chunk length must be superior to stride length") |
|
|
|
rescale = self.type != "seq2seq_whisper" |
|
|
|
for item in chunk_iter( |
|
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype |
|
): |
|
yield item |
|
else: |
|
processed = self.feature_extractor( |
|
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" |
|
) |
|
if self.torch_dtype is not None: |
|
processed = processed.to(dtype=self.torch_dtype) |
|
if stride is not None: |
|
if self.type == "seq2seq": |
|
raise ValueError("Stride is only usable with CTC models, try removing it !") |
|
|
|
processed["stride"] = stride |
|
yield {"is_last": True, **processed, **extra} |
|
|
|
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): |
|
if generate_kwargs is None: |
|
generate_kwargs = {} |
|
|
|
attention_mask = model_inputs.pop("attention_mask", None) |
|
stride = model_inputs.pop("stride", None) |
|
is_last = model_inputs.pop("is_last") |
|
|
|
if self.type in {"seq2seq", "seq2seq_whisper"}: |
|
encoder = self.model.get_encoder() |
|
|
|
|
|
if "input_features" in model_inputs: |
|
inputs = model_inputs.pop("input_features") |
|
elif "input_values" in model_inputs: |
|
inputs = model_inputs.pop("input_values") |
|
else: |
|
raise ValueError( |
|
"Seq2Seq speech recognition model requires either a " |
|
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" |
|
) |
|
|
|
|
|
if return_timestamps and self.type == "seq2seq_whisper": |
|
generate_kwargs["return_timestamps"] = return_timestamps |
|
if return_timestamps == "word": |
|
generate_kwargs["return_token_timestamps"] = True |
|
|
|
if stride is not None: |
|
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length |
|
|
|
tokens = self.model.generate( |
|
encoder_outputs=encoder(inputs, attention_mask=attention_mask), |
|
attention_mask=attention_mask, |
|
**generate_kwargs, |
|
) |
|
if return_timestamps == "word" and self.type == "seq2seq_whisper": |
|
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} |
|
else: |
|
out = {"tokens": tokens} |
|
if self.type == "seq2seq_whisper": |
|
if stride is not None: |
|
out["stride"] = stride |
|
|
|
else: |
|
input_values = model_inputs.pop("input_values") |
|
outputs = self.model(input_values=input_values, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
if self.type == "ctc_with_lm": |
|
out = {"logits": logits} |
|
else: |
|
out = {"tokens": logits.argmax(dim=-1)} |
|
if stride is not None: |
|
|
|
|
|
|
|
ratio = 1 / self.model.config.inputs_to_logits_ratio |
|
if isinstance(stride, tuple): |
|
out["stride"] = rescale_stride([stride], ratio)[0] |
|
else: |
|
out["stride"] = rescale_stride(stride, ratio) |
|
|
|
extra = model_inputs |
|
return {"is_last": is_last, **out, **extra} |
|
|
|
def postprocess( |
|
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None |
|
): |
|
|
|
optional = {} |
|
|
|
final_items = [] |
|
key = "logits" if self.type == "ctc_with_lm" else "tokens" |
|
stride = None |
|
for outputs in model_outputs: |
|
items = outputs[key].numpy() |
|
stride = outputs.get("stride", None) |
|
if stride is not None and self.type in {"ctc", "ctc_with_lm"}: |
|
total_n, left, right = stride |
|
|
|
|
|
|
|
|
|
right_n = total_n - right |
|
items = items[:, left:right_n] |
|
final_items.append(items) |
|
|
|
if stride and self.type == "seq2seq": |
|
items = _find_longest_common_sequence(final_items, self.tokenizer) |
|
elif self.type == "seq2seq_whisper": |
|
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions |
|
|
|
sampling_rate = self.feature_extractor.sampling_rate |
|
for output in model_outputs: |
|
if "stride" in output: |
|
chunk_len, stride_left, stride_right = output["stride"] |
|
|
|
chunk_len /= sampling_rate |
|
stride_left /= sampling_rate |
|
stride_right /= sampling_rate |
|
output["stride"] = chunk_len, stride_left, stride_right |
|
|
|
text, optional = self.tokenizer._decode_asr( |
|
model_outputs, |
|
return_timestamps=return_timestamps, |
|
return_language=return_language, |
|
time_precision=time_precision, |
|
) |
|
else: |
|
items = np.concatenate(final_items, axis=1) |
|
items = items.squeeze(0) |
|
|
|
if self.type == "ctc_with_lm": |
|
if decoder_kwargs is None: |
|
decoder_kwargs = {} |
|
beams = self.decoder.decode_beams(items, **decoder_kwargs) |
|
text = beams[0][0] |
|
if return_timestamps: |
|
|
|
|
|
chunk_offset = beams[0][2] |
|
offsets = [] |
|
for word, (start_offset, end_offset) in chunk_offset: |
|
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) |
|
elif self.type != "seq2seq_whisper": |
|
skip_special_tokens = self.type != "ctc" |
|
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) |
|
if return_timestamps: |
|
offsets = self.tokenizer.decode( |
|
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True |
|
)["char_offsets"] |
|
if return_timestamps == "word": |
|
offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char) |
|
|
|
if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}: |
|
chunks = [] |
|
for item in offsets: |
|
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio |
|
start /= self.feature_extractor.sampling_rate |
|
|
|
stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio |
|
stop /= self.feature_extractor.sampling_rate |
|
|
|
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) |
|
optional["chunks"] = chunks |
|
|
|
extra = defaultdict(list) |
|
for output in model_outputs: |
|
output.pop("tokens", None) |
|
output.pop("logits", None) |
|
output.pop("is_last", None) |
|
output.pop("stride", None) |
|
output.pop("token_timestamps", None) |
|
for k, v in output.items(): |
|
extra[k].append(v) |
|
return {"text": text, **optional, **extra} |
|
|
|
|
|
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): |
|
""" |
|
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since |
|
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only |
|
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is |
|
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to |
|
properly compute the final `offset`. |
|
""" |
|
|
|
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 |
|
items = [] |
|
|
|
time_precision = feature_extractor.chunk_length / max_source_positions |
|
time = 0 |
|
for seq_idx, item in enumerate(sequences): |
|
sequence, stride = item |
|
if isinstance(sequence, list): |
|
sequence = np.array(sequence) |
|
chunk_len, stride_left, stride_right = stride |
|
sequence = sequence.squeeze(0) |
|
|
|
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0 |
|
sequence = sequence[begin_idx:] |
|
|
|
timestamp_tokens = sequence >= timestamp_begin |
|
if seq_idx != 0 and sum(timestamp_tokens) > 0: |
|
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 |
|
last_timestamp = np.where(timestamp_tokens)[0][-1] |
|
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive |
|
time -= stride_left + stride_right |
|
offset = int((time / feature_extractor.sampling_rate) / time_precision) |
|
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) |
|
|
|
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0] |
|
if relevant_timestamp.shape[0] > 0: |
|
relevant_timestamp = ( |
|
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0] |
|
) |
|
|
|
best_match = 0 |
|
sliced_sequence = [] |
|
for idx, previous_sequence in enumerate(reversed(items)): |
|
previous_tokens = previous_sequence[1:-1] |
|
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0: |
|
break |
|
if len(previous_tokens) > 0: |
|
|
|
index_left, index_right, match_length = _fast_find_longest_common_sequence( |
|
sequence[1:relevant_timestamp], previous_tokens |
|
) |
|
|
|
if match_length > 1 and match_length > best_match: |
|
best_match = match_length |
|
best_idx = idx |
|
end_of_curr_sequence_idx = ( |
|
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1 |
|
) |
|
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left |
|
|
|
if index_left == 0 and match_length == len(previous_tokens): |
|
sliced_sequence = np.insert( |
|
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0] |
|
) |
|
sliced_sequence[-1] = previous_sequence[-1] |
|
|
|
elif index_left >= 0: |
|
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx] |
|
|
|
previous_slice = ( |
|
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]] |
|
) |
|
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice) |
|
sliced_sequence[-1] += offset |
|
|
|
if len(sliced_sequence) > 0: |
|
items[len(items) - best_idx - 1] = sliced_sequence |
|
items = items[: len(items) - best_idx] |
|
sequence = sequence[end_of_curr_sequence_idx:] |
|
|
|
|
|
timestamp_tokens = sequence >= timestamp_begin |
|
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 |
|
if sum(timestamp_tokens) > 0: |
|
last_timestamp = np.where(timestamp_tokens)[0][-1] |
|
consecutive = ( |
|
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive |
|
) |
|
|
|
if len(consecutive) > 0: |
|
last_slice = 0 |
|
for current_slice in consecutive: |
|
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0] |
|
sliced_tokens = sequence[last_slice:current_slice] |
|
duration = sliced_tokens[-1] - sliced_tokens[0] |
|
sliced_tokens[0] = actual_offset |
|
sliced_tokens[-1] = actual_offset + duration |
|
items.append(sliced_tokens) |
|
last_slice = current_slice |
|
|
|
time += chunk_len |
|
result = [] |
|
for i in range(len(items)): |
|
result += items[i].tolist() |
|
return result |
|
|