mariagrandury
commited on
Commit
•
0581116
1
Parent(s):
8fefff1
Use torchaudio instead of librosa
Browse files
run_wav2vec2_pretrain_flax.py
CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
|
|
13 |
import flax
|
14 |
import jax
|
15 |
import jax.numpy as jnp
|
16 |
-
import
|
17 |
import optax
|
18 |
from flax import jax_utils, traverse_util
|
19 |
from flax.training import train_state
|
@@ -320,14 +320,18 @@ def main():
|
|
320 |
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
|
321 |
)
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
|
|
326 |
return batch
|
327 |
|
328 |
# load audio files into numpy arrays
|
329 |
vectorized_datasets = datasets.map(
|
330 |
-
|
331 |
)
|
332 |
|
333 |
# filter audio files that are too long
|
|
|
13 |
import flax
|
14 |
import jax
|
15 |
import jax.numpy as jnp
|
16 |
+
import torchaudio
|
17 |
import optax
|
18 |
from flax import jax_utils, traverse_util
|
19 |
from flax.training import train_state
|
|
|
320 |
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
|
321 |
)
|
322 |
|
323 |
+
resampler = torchaudio.transforms.Resample(48_000, 16_000)
|
324 |
+
|
325 |
+
# Preprocessing the datasets.
|
326 |
+
# We need to read the aduio files as arrays and tokenize the targets.
|
327 |
+
def speech_file_to_array_fn(batch):
|
328 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
329 |
+
batch["speech"] = resampler(speech_array).squeeze().numpy()
|
330 |
return batch
|
331 |
|
332 |
# load audio files into numpy arrays
|
333 |
vectorized_datasets = datasets.map(
|
334 |
+
speech_file_to_array_fn, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names,
|
335 |
)
|
336 |
|
337 |
# filter audio files that are too long
|