farzadab commited on
Commit
921aed7
1 Parent(s): 15fb37c

Create ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +172 -0
ultravox_processing.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import transformers
6
+
7
+
8
+ class UltravoxProcessor(transformers.ProcessorMixin):
9
+ """
10
+ Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
11
+
12
+ Args:
13
+ audio_processor: The audio processor for the audio encoder.
14
+ tokenizer: The tokenizer for the language model.
15
+ """
16
+
17
+ attributes = ["audio_processor", "tokenizer"]
18
+ audio_processor_class = (
19
+ "Wav2Vec2Processor",
20
+ "SeamlessM4TFeatureExtractor",
21
+ "WhisperProcessor",
22
+ )
23
+ tokenizer_class = (
24
+ "PreTrainedTokenizer",
25
+ "PreTrainedTokenizerFast",
26
+ )
27
+
28
+ tokenizer: transformers.PreTrainedTokenizerBase
29
+ audio_processor: transformers.ProcessorMixin
30
+
31
+ def __init__(
32
+ self,
33
+ audio_processor=None,
34
+ tokenizer=None,
35
+ audio_padding: str = "longest",
36
+ encoder_ds_factor: int = 320,
37
+ stack_factor: int = 8,
38
+ audio_placeholder: str = "<|audio|>",
39
+ ):
40
+ """
41
+ Args:
42
+ audio_processor: The audio processor for the audio encoder.
43
+ tokenizer: The tokenizer for the language model.
44
+ audio_padding: The padding strategy for the audio encoder.
45
+ encoder_ds_factor: The downsample factor of the audio encoder.
46
+ stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
47
+ audio_placeholder: The placeholder for the audio in the text.
48
+ """
49
+ self.audio_padding = audio_padding
50
+ self.encoder_ds_factor = encoder_ds_factor
51
+ self.stack_factor = stack_factor
52
+ self.audio_placeholder = audio_placeholder
53
+ self.audio_token_replacement = tokenizer.eos_token
54
+ assert (
55
+ self.audio_token_replacement is not None
56
+ ), "The tokenizer has no EOS token. Cannot recover."
57
+ super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
58
+
59
+ def __call__(
60
+ self,
61
+ text: Optional[str] = None,
62
+ audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
63
+ sampling_rate: Optional[int] = None,
64
+ return_tensors: Optional[
65
+ Union[str, transformers.TensorType]
66
+ ] = transformers.TensorType.PYTORCH,
67
+ **kwargs,
68
+ ) -> transformers.BatchFeature:
69
+ """
70
+ Main method to prepare for the model one text sequence and audio. This method forwards the `text`
71
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
72
+ the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
73
+ audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
74
+ of the above two methods for more information.
75
+
76
+ Args:
77
+ text (`str`, `List[str]`):
78
+ The sequence to be encoded. Sequence can be a string or (pretokenized string).
79
+ audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
80
+ The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
81
+ NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
82
+ sample length of the audio.
83
+ sampling_rate (`int`, *optional*, defaults to 16000):
84
+ Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
85
+ you are doing.
86
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
87
+ If set, will return tensors of a particular framework. Acceptable values are:
88
+
89
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
90
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
91
+ - `'np'`: Return NumPy `np.ndarray` objects.
92
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
93
+
94
+ Returns:
95
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
96
+
97
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
98
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
99
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
100
+ `None`).
101
+ - **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
102
+ - **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
103
+ Returned when `audio` is not `None`.
104
+ - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
105
+ """
106
+ # TODO: Add support for multiple audio and text inputs.
107
+ data = {}
108
+ audio_embed_frames = 0
109
+ if audio is not None and len(audio) > 0:
110
+ if self.audio_padding == "max_length":
111
+ # 30 seconds is the expected length for Whisper
112
+ assert sampling_rate is not None, "Sampling rate must be provided."
113
+ audio_len = 30 * sampling_rate
114
+ else:
115
+ audio_len = audio.shape[-1]
116
+ # It's guaranteed that the number of frames is less than or equal to this amount.
117
+ # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
118
+ # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
119
+ nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
120
+ audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
121
+ data["audio_token_len"] = [audio_embed_frames]
122
+
123
+ x = self.audio_processor(
124
+ audio,
125
+ sampling_rate=sampling_rate,
126
+ padding="longest",
127
+ max_length=audio_len,
128
+ **kwargs,
129
+ )
130
+ if "input_features" in x:
131
+ data["audio_values"] = x.input_features
132
+ else:
133
+ data["audio_values"] = x.input_values
134
+
135
+ if text is not None:
136
+ assert isinstance(
137
+ text, str
138
+ ), "Text must be a string. Batch mode not supported yet."
139
+ if self.audio_placeholder in text:
140
+ if "audio_token_len" not in data:
141
+ raise ValueError(
142
+ f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
143
+ )
144
+
145
+ start_idx = len(
146
+ self.tokenizer.encode(
147
+ text[: text.index(self.audio_placeholder)],
148
+ add_special_tokens=False,
149
+ )
150
+ )
151
+ data["audio_token_start_idx"] = [start_idx]
152
+ text = text.replace(
153
+ self.audio_placeholder,
154
+ self.audio_token_replacement * audio_embed_frames,
155
+ )
156
+
157
+ # Special tokens like BOS should already have been added by the caller.
158
+ data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
159
+
160
+ return transformers.BatchFeature(data=data, tensor_type=return_tensors)
161
+
162
+ def batch_decode(self, *args, **kwargs):
163
+ return self.tokenizer.batch_decode(*args, **kwargs)
164
+
165
+ def decode(self, *args, **kwargs):
166
+ return self.tokenizer.decode(*args, **kwargs)
167
+
168
+ @property
169
+ def model_input_names(self):
170
+ tokenizer_input_names = self.tokenizer.model_input_names
171
+ audio_processor_input_names = self.audio_processor.model_input_names
172
+ return list(set(tokenizer_input_names + audio_processor_input_names))