mariagrandury commited on
Commit
062f3fc
1 Parent(s): 8bb6457

Update training script

Browse files
Files changed (2) hide show
  1. run_wav2vec2_pretrain_flax.py +3 -0
  2. train.sh +6 -7
run_wav2vec2_pretrain_flax.py CHANGED
@@ -174,6 +174,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
174
 
175
  batch_size = batch["input_values"].shape[0]
176
 
 
177
  if batch["attention_mask"] is not None:
178
  output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
179
  attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
@@ -196,6 +197,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
196
  batch["sampled_negative_indices"] = _sample_negative_indices(
197
  (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
198
  self.model.config.num_negatives,
 
199
  )
200
 
201
  return batch
@@ -342,6 +344,7 @@ def main():
342
  def normalize(batch):
343
  return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
344
 
 
345
  # normalize and transform to `BatchFeatures`
346
  vectorized_datasets = vectorized_datasets.map(
347
  normalize,
174
 
175
  batch_size = batch["input_values"].shape[0]
176
 
177
+ attention_mask = None
178
  if batch["attention_mask"] is not None:
179
  output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
180
  attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
197
  batch["sampled_negative_indices"] = _sample_negative_indices(
198
  (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
199
  self.model.config.num_negatives,
200
+ attention_mask=attention_mask,
201
  )
202
 
203
  return batch
344
  def normalize(batch):
345
  return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
346
 
347
+ batch_size = 64
348
  # normalize and transform to `BatchFeatures`
349
  vectorized_datasets = vectorized_datasets.map(
350
  normalize,
train.sh CHANGED
@@ -1,22 +1,21 @@
1
  #!/usr/bin/env bash
2
- ./preprocess_dataset.py \
3
- --output_dir="./output" \
4
  --num_train_epochs="5" \
5
- --per_device_train_batch_size="32" \
6
- --per_device_eval_batch_size="32" \
7
  --learning_rate="5e-4" \
8
  --weight_decay="0.01" \
9
- --warmup_steps="2000" \
10
  --model_name_or_path="./" \
11
  --dataset_name="common_voice" \
12
  --dataset_config_name="es" \
13
- --preprocessing_num_workers="64" \
14
  --max_duration_in_seconds="10.0" \
15
  --adam_beta1="0.9" \
16
  --adam_beta2="0.98" \
17
  --pad_to_multiple_of="16384" \
18
  --validation_split_percentage="5" \
19
  --speech_file_column="path" \
20
- --dtype="bfloat16" \
21
  --cache_dir="./data_cache" \
22
  --push_to_hub
1
  #!/usr/bin/env bash
2
+ ./run_wav2vec2_pretrain_flax.py \
3
+ --output_dir="./wav2vec2-spanish" \
4
  --num_train_epochs="5" \
5
+ --per_device_train_batch_size="16" \
6
+ --per_device_eval_batch_size="16" \
7
  --learning_rate="5e-4" \
8
  --weight_decay="0.01" \
9
+ --warmup_steps="1000" \
10
  --model_name_or_path="./" \
11
  --dataset_name="common_voice" \
12
  --dataset_config_name="es" \
13
+ --preprocessing_num_workers="32" \
14
  --max_duration_in_seconds="10.0" \
15
  --adam_beta1="0.9" \
16
  --adam_beta2="0.98" \
17
  --pad_to_multiple_of="16384" \
18
  --validation_split_percentage="5" \
19
  --speech_file_column="path" \
 
20
  --cache_dir="./data_cache" \
21
  --push_to_hub