mariagrandury commited on
Commit
0581116
1 Parent(s): 8fefff1

Use torchaudio instead of librosa

Browse files
Files changed (1) hide show
  1. run_wav2vec2_pretrain_flax.py +9 -5
run_wav2vec2_pretrain_flax.py CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
13
  import flax
14
  import jax
15
  import jax.numpy as jnp
16
- import librosa
17
  import optax
18
  from flax import jax_utils, traverse_util
19
  from flax.training import train_state
@@ -320,14 +320,18 @@ def main():
320
  model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
321
  )
322
 
323
- def prepare_dataset(batch):
324
- # check that all files have the correct sampling rate
325
- batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
 
 
 
 
326
  return batch
327
 
328
  # load audio files into numpy arrays
329
  vectorized_datasets = datasets.map(
330
- prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
331
  )
332
 
333
  # filter audio files that are too long
13
  import flax
14
  import jax
15
  import jax.numpy as jnp
16
+ import torchaudio
17
  import optax
18
  from flax import jax_utils, traverse_util
19
  from flax.training import train_state
320
  model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
321
  )
322
 
323
+ resampler = torchaudio.transforms.Resample(48_000, 16_000)
324
+
325
+ # Preprocessing the datasets.
326
+ # We need to read the aduio files as arrays and tokenize the targets.
327
+ def speech_file_to_array_fn(batch):
328
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
329
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
330
  return batch
331
 
332
  # load audio files into numpy arrays
333
  vectorized_datasets = datasets.map(
334
+ speech_file_to_array_fn, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names,
335
  )
336
 
337
  # filter audio files that are too long