alex-ht commited on
Commit
91d53d9
·
1 Parent(s): 0524ea2
Files changed (1) hide show
  1. ultravox_processing.py +38 -35
ultravox_processing.py CHANGED
@@ -138,22 +138,25 @@ class UltravoxProcessor(transformers.ProcessorMixin):
138
  if self.audio_padding == "max_length":
139
  # 30 seconds is the expected length for Whisper
140
  assert sampling_rate is not None, "Sampling rate must be provided."
141
- audio_len = 30 * sampling_rate
142
  else:
143
- audio_len = max([len(a) for a in audio])
144
- # It's guaranteed that the number of frames is less than or equal to this amount.
145
- # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
146
- # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
147
- nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
148
- audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
149
- data["audio_token_len"] = [audio_embed_frames]
 
 
 
150
 
151
  # Main audio processing. The processor is model-specific.
152
  x = self.audio_processor(
153
  audio,
154
  sampling_rate=sampling_rate,
155
  padding="longest",
156
- max_length=audio_len,
157
  return_attention_mask=True,
158
  **kwargs,
159
  )
@@ -161,39 +164,39 @@ class UltravoxProcessor(transformers.ProcessorMixin):
161
  data["audio_values"] = x.input_features
162
  else:
163
  data["audio_values"] = x.input_values
164
- if self.audio_padding == "max_length":
165
- data["audio_len"] = x.attention_mask.sum(-1) - 1
166
- else:
167
- data["audio_len"] = [data["audio_values"].shape[-1]]
168
 
169
  if text is not None:
170
  assert isinstance(
171
- text, str
172
- ), "Text must be a string. Batch mode not supported yet."
173
- if self.audio_placeholder in text:
174
- if "audio_token_len" not in data:
175
- raise ValueError(
176
- f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
 
 
 
 
 
 
 
 
 
177
  )
178
-
179
- start_idx = len(
180
- self.tokenizer.encode(
181
- text[: text.index(self.audio_placeholder)],
182
- add_special_tokens=False,
 
 
 
183
  )
184
- )
185
- data["audio_token_start_idx"] = [start_idx]
186
-
187
- # Replace the audio placeholder with the audio token.
188
- # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
189
- # where the number of </s> is the number of audio frames.
190
- text = text.replace(
191
- self.audio_placeholder,
192
- self.audio_token_replacement * audio_embed_frames,
193
- )
194
 
195
  # Special tokens like BOS should already have been added by the caller.
196
- data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
197
 
198
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
199
 
 
138
  if self.audio_padding == "max_length":
139
  # 30 seconds is the expected length for Whisper
140
  assert sampling_rate is not None, "Sampling rate must be provided."
141
+ max_audio_len = 30 * sampling_rate
142
  else:
143
+ max_audio_len = max([len(a) for a in audio])
144
+
145
+ data["audio_token_len"] = []
146
+ for a in audio:
147
+ # It's guaranteed that the number of frames is less than or equal to this amount.
148
+ # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
149
+ # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
150
+ nb_encoder_frames = int(round(min(len(a), max_audio_len)/ self.encoder_ds_factor + 1e-4))
151
+ audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
152
+ data["audio_token_len"].append(audio_embed_frames)
153
 
154
  # Main audio processing. The processor is model-specific.
155
  x = self.audio_processor(
156
  audio,
157
  sampling_rate=sampling_rate,
158
  padding="longest",
159
+ max_length=max_audio_len,
160
  return_attention_mask=True,
161
  **kwargs,
162
  )
 
164
  data["audio_values"] = x.input_features
165
  else:
166
  data["audio_values"] = x.input_values
167
+ data["audio_len"] = x.attention_mask.sum(-1) - 1
 
 
 
168
 
169
  if text is not None:
170
  assert isinstance(
171
+ text, list
172
+ ), "Text must be a list."
173
+ processed_text = []
174
+ for t in text:
175
+ if self.audio_placeholder in t:
176
+ if "audio_token_len" not in data:
177
+ raise ValueError(
178
+ f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
179
+ )
180
+
181
+ start_idx = len(
182
+ self.tokenizer.encode(
183
+ t[: text.t(self.audio_placeholder)],
184
+ add_special_tokens=False,
185
+ )
186
  )
187
+ data["audio_token_start_idx"] = [start_idx]
188
+
189
+ # Replace the audio placeholder with the audio token.
190
+ # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
191
+ # where the number of </s> is the number of audio frames.
192
+ t = t.replace(
193
+ self.audio_placeholder,
194
+ self.audio_token_replacement * audio_embed_frames,
195
  )
196
+ processed_text.append(t)
 
 
 
 
 
 
 
 
 
197
 
198
  # Special tokens like BOS should already have been added by the caller.
199
+ data.update(self.tokenizer(processed_text, add_special_tokens=False, **kwargs))
200
 
201
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
202