Create ultravox_processing.py
Browse files- ultravox_processing.py +210 -0
ultravox_processing.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import transformers
|
6 |
+
|
7 |
+
from .ultravox_config import UltravoxConfig
|
8 |
+
|
9 |
+
|
10 |
+
class UltravoxProcessor(transformers.ProcessorMixin):
|
11 |
+
"""
|
12 |
+
Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
|
13 |
+
Args:
|
14 |
+
audio_processor: The audio processor for the audio encoder.
|
15 |
+
tokenizer: The tokenizer for the language model.
|
16 |
+
"""
|
17 |
+
|
18 |
+
attributes = ["audio_processor", "tokenizer"]
|
19 |
+
audio_processor_class = (
|
20 |
+
"Wav2Vec2Processor",
|
21 |
+
"SeamlessM4TFeatureExtractor",
|
22 |
+
"WhisperProcessor",
|
23 |
+
)
|
24 |
+
tokenizer_class = (
|
25 |
+
"PreTrainedTokenizer",
|
26 |
+
"PreTrainedTokenizerFast",
|
27 |
+
)
|
28 |
+
|
29 |
+
tokenizer: transformers.PreTrainedTokenizerBase
|
30 |
+
audio_processor: transformers.ProcessorMixin
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
audio_processor=None,
|
35 |
+
tokenizer=None,
|
36 |
+
audio_padding: str = "longest",
|
37 |
+
encoder_ds_factor: int = 320,
|
38 |
+
stack_factor: int = 8,
|
39 |
+
audio_placeholder: str = "<|audio|>",
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
audio_processor: The audio processor for the audio encoder.
|
44 |
+
tokenizer: The tokenizer for the language model.
|
45 |
+
audio_padding: The padding strategy for the audio encoder.
|
46 |
+
encoder_ds_factor: The downsample factor of the audio encoder.
|
47 |
+
stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
|
48 |
+
audio_placeholder: The placeholder for the audio in the text.
|
49 |
+
"""
|
50 |
+
self.audio_padding = audio_padding
|
51 |
+
self.encoder_ds_factor = encoder_ds_factor
|
52 |
+
self.stack_factor = stack_factor
|
53 |
+
self.audio_placeholder = audio_placeholder
|
54 |
+
self.audio_token_replacement = tokenizer.eos_token
|
55 |
+
assert (
|
56 |
+
self.audio_token_replacement is not None
|
57 |
+
), "The tokenizer has no EOS token. Cannot recover."
|
58 |
+
if tokenizer.pad_token_id is None:
|
59 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
60 |
+
|
61 |
+
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
65 |
+
config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
|
66 |
+
pretrained_model_name_or_path, **kwargs
|
67 |
+
)
|
68 |
+
audio_processor = transformers.AutoProcessor.from_pretrained(
|
69 |
+
config.audio_model_id
|
70 |
+
or config.audio_config._name_or_path
|
71 |
+
or "facebook/wav2vec2-base-960h"
|
72 |
+
)
|
73 |
+
|
74 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
75 |
+
pretrained_model_name_or_path, **kwargs
|
76 |
+
)
|
77 |
+
tokenizer.padding_side = "left"
|
78 |
+
tokenizer.pad_token = tokenizer.eos_token
|
79 |
+
|
80 |
+
return cls(
|
81 |
+
audio_processor=audio_processor,
|
82 |
+
tokenizer=tokenizer,
|
83 |
+
stack_factor=config.stack_factor,
|
84 |
+
)
|
85 |
+
|
86 |
+
def __call__(
|
87 |
+
self,
|
88 |
+
text: Optional[str] = None,
|
89 |
+
audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
90 |
+
sampling_rate: Optional[int] = None,
|
91 |
+
return_tensors: Optional[
|
92 |
+
Union[str, transformers.TensorType]
|
93 |
+
] = transformers.TensorType.PYTORCH,
|
94 |
+
**kwargs,
|
95 |
+
) -> transformers.BatchFeature:
|
96 |
+
"""
|
97 |
+
Main method to prepare for the model one text sequence and audio. This method forwards the `text`
|
98 |
+
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
|
99 |
+
the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
|
100 |
+
audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
|
101 |
+
of the above two methods for more information.
|
102 |
+
Args:
|
103 |
+
text (`str`, `List[str]`):
|
104 |
+
The sequence to be encoded. Sequence can be a string or (pretokenized string).
|
105 |
+
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
106 |
+
The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
|
107 |
+
NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
|
108 |
+
sample length of the audio.
|
109 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
110 |
+
Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
|
111 |
+
you are doing.
|
112 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
113 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
114 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
115 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
116 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
117 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
118 |
+
Returns:
|
119 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
120 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
121 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
122 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
123 |
+
`None`).
|
124 |
+
- **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
|
125 |
+
- **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
|
126 |
+
Returned when `audio` is not `None`.
|
127 |
+
- **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
|
128 |
+
"""
|
129 |
+
# TODO: Add support for multiple audio and text inputs.
|
130 |
+
data = {}
|
131 |
+
audio_embed_frames = 0
|
132 |
+
if audio is not None and len(audio) > 0:
|
133 |
+
if self.audio_padding == "max_length":
|
134 |
+
# 30 seconds is the expected length for Whisper
|
135 |
+
assert sampling_rate is not None, "Sampling rate must be provided."
|
136 |
+
audio_len = 30 * sampling_rate
|
137 |
+
else:
|
138 |
+
audio_len = audio.shape[-1]
|
139 |
+
# It's guaranteed that the number of frames is less than or equal to this amount.
|
140 |
+
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
|
141 |
+
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
|
142 |
+
nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
|
143 |
+
audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
|
144 |
+
data["audio_token_len"] = [audio_embed_frames]
|
145 |
+
|
146 |
+
# Main audio processing. The processor is model-specific.
|
147 |
+
x = self.audio_processor(
|
148 |
+
audio,
|
149 |
+
sampling_rate=sampling_rate,
|
150 |
+
padding="longest",
|
151 |
+
max_length=audio_len,
|
152 |
+
return_attention_mask=True,
|
153 |
+
**kwargs,
|
154 |
+
)
|
155 |
+
if "input_features" in x:
|
156 |
+
data["audio_values"] = x.input_features
|
157 |
+
else:
|
158 |
+
data["audio_values"] = x.input_values
|
159 |
+
if self.audio_padding == "max_length":
|
160 |
+
data["audio_len"] = x.attention_mask.sum(-1) - 1
|
161 |
+
else:
|
162 |
+
data["audio_len"] = [data["audio_values"].shape[-1]]
|
163 |
+
|
164 |
+
if text is not None:
|
165 |
+
assert isinstance(
|
166 |
+
text, str
|
167 |
+
), "Text must be a string. Batch mode not supported yet."
|
168 |
+
if self.audio_placeholder in text:
|
169 |
+
if "audio_token_len" not in data:
|
170 |
+
raise ValueError(
|
171 |
+
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
|
172 |
+
)
|
173 |
+
|
174 |
+
start_idx = len(
|
175 |
+
self.tokenizer.encode(
|
176 |
+
text[: text.index(self.audio_placeholder)],
|
177 |
+
add_special_tokens=False,
|
178 |
+
)
|
179 |
+
)
|
180 |
+
data["audio_token_start_idx"] = [start_idx]
|
181 |
+
|
182 |
+
# Replace the audio placeholder with the audio token.
|
183 |
+
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
|
184 |
+
# where the number of </s> is the number of audio frames.
|
185 |
+
text = text.replace(
|
186 |
+
self.audio_placeholder,
|
187 |
+
self.audio_token_replacement * audio_embed_frames,
|
188 |
+
)
|
189 |
+
|
190 |
+
# Special tokens like BOS should already have been added by the caller.
|
191 |
+
data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
|
192 |
+
|
193 |
+
return transformers.BatchFeature(data=data, tensor_type=return_tensors)
|
194 |
+
|
195 |
+
def batch_decode(self, *args, **kwargs):
|
196 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
197 |
+
|
198 |
+
def decode(self, *args, **kwargs):
|
199 |
+
return self.tokenizer.decode(*args, **kwargs)
|
200 |
+
|
201 |
+
@property
|
202 |
+
def model_input_names(self):
|
203 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
204 |
+
audio_processor_input_names = self.audio_processor.model_input_names
|
205 |
+
return list(set(tokenizer_input_names + audio_processor_input_names))
|
206 |
+
|
207 |
+
|
208 |
+
UltravoxProcessor.register_for_auto_class()
|
209 |
+
|
210 |
+
transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
|