alex-ht
commited on
Commit
·
91d53d9
1
Parent(s):
0524ea2
batch
Browse files- ultravox_processing.py +38 -35
ultravox_processing.py
CHANGED
@@ -138,22 +138,25 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
138 |
if self.audio_padding == "max_length":
|
139 |
# 30 seconds is the expected length for Whisper
|
140 |
assert sampling_rate is not None, "Sampling rate must be provided."
|
141 |
-
|
142 |
else:
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
150 |
|
151 |
# Main audio processing. The processor is model-specific.
|
152 |
x = self.audio_processor(
|
153 |
audio,
|
154 |
sampling_rate=sampling_rate,
|
155 |
padding="longest",
|
156 |
-
max_length=
|
157 |
return_attention_mask=True,
|
158 |
**kwargs,
|
159 |
)
|
@@ -161,39 +164,39 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
161 |
data["audio_values"] = x.input_features
|
162 |
else:
|
163 |
data["audio_values"] = x.input_values
|
164 |
-
|
165 |
-
data["audio_len"] = x.attention_mask.sum(-1) - 1
|
166 |
-
else:
|
167 |
-
data["audio_len"] = [data["audio_values"].shape[-1]]
|
168 |
|
169 |
if text is not None:
|
170 |
assert isinstance(
|
171 |
-
text,
|
172 |
-
), "Text must be a
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
)
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
183 |
)
|
184 |
-
|
185 |
-
data["audio_token_start_idx"] = [start_idx]
|
186 |
-
|
187 |
-
# Replace the audio placeholder with the audio token.
|
188 |
-
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
|
189 |
-
# where the number of </s> is the number of audio frames.
|
190 |
-
text = text.replace(
|
191 |
-
self.audio_placeholder,
|
192 |
-
self.audio_token_replacement * audio_embed_frames,
|
193 |
-
)
|
194 |
|
195 |
# Special tokens like BOS should already have been added by the caller.
|
196 |
-
data.update(self.tokenizer(
|
197 |
|
198 |
return transformers.BatchFeature(data=data, tensor_type=return_tensors)
|
199 |
|
|
|
138 |
if self.audio_padding == "max_length":
|
139 |
# 30 seconds is the expected length for Whisper
|
140 |
assert sampling_rate is not None, "Sampling rate must be provided."
|
141 |
+
max_audio_len = 30 * sampling_rate
|
142 |
else:
|
143 |
+
max_audio_len = max([len(a) for a in audio])
|
144 |
+
|
145 |
+
data["audio_token_len"] = []
|
146 |
+
for a in audio:
|
147 |
+
# It's guaranteed that the number of frames is less than or equal to this amount.
|
148 |
+
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
|
149 |
+
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
|
150 |
+
nb_encoder_frames = int(round(min(len(a), max_audio_len)/ self.encoder_ds_factor + 1e-4))
|
151 |
+
audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
|
152 |
+
data["audio_token_len"].append(audio_embed_frames)
|
153 |
|
154 |
# Main audio processing. The processor is model-specific.
|
155 |
x = self.audio_processor(
|
156 |
audio,
|
157 |
sampling_rate=sampling_rate,
|
158 |
padding="longest",
|
159 |
+
max_length=max_audio_len,
|
160 |
return_attention_mask=True,
|
161 |
**kwargs,
|
162 |
)
|
|
|
164 |
data["audio_values"] = x.input_features
|
165 |
else:
|
166 |
data["audio_values"] = x.input_values
|
167 |
+
data["audio_len"] = x.attention_mask.sum(-1) - 1
|
|
|
|
|
|
|
168 |
|
169 |
if text is not None:
|
170 |
assert isinstance(
|
171 |
+
text, list
|
172 |
+
), "Text must be a list."
|
173 |
+
processed_text = []
|
174 |
+
for t in text:
|
175 |
+
if self.audio_placeholder in t:
|
176 |
+
if "audio_token_len" not in data:
|
177 |
+
raise ValueError(
|
178 |
+
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
|
179 |
+
)
|
180 |
+
|
181 |
+
start_idx = len(
|
182 |
+
self.tokenizer.encode(
|
183 |
+
t[: text.t(self.audio_placeholder)],
|
184 |
+
add_special_tokens=False,
|
185 |
+
)
|
186 |
)
|
187 |
+
data["audio_token_start_idx"] = [start_idx]
|
188 |
+
|
189 |
+
# Replace the audio placeholder with the audio token.
|
190 |
+
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
|
191 |
+
# where the number of </s> is the number of audio frames.
|
192 |
+
t = t.replace(
|
193 |
+
self.audio_placeholder,
|
194 |
+
self.audio_token_replacement * audio_embed_frames,
|
195 |
)
|
196 |
+
processed_text.append(t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
# Special tokens like BOS should already have been added by the caller.
|
199 |
+
data.update(self.tokenizer(processed_text, add_special_tokens=False, **kwargs))
|
200 |
|
201 |
return transformers.BatchFeature(data=data, tensor_type=return_tensors)
|
202 |
|