mariagrandury
commited on
Commit
•
062f3fc
1
Parent(s):
8bb6457
Update training script
Browse files- run_wav2vec2_pretrain_flax.py +3 -0
- train.sh +6 -7
run_wav2vec2_pretrain_flax.py
CHANGED
@@ -174,6 +174,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
174 |
|
175 |
batch_size = batch["input_values"].shape[0]
|
176 |
|
|
|
177 |
if batch["attention_mask"] is not None:
|
178 |
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
179 |
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
@@ -196,6 +197,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
196 |
batch["sampled_negative_indices"] = _sample_negative_indices(
|
197 |
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
198 |
self.model.config.num_negatives,
|
|
|
199 |
)
|
200 |
|
201 |
return batch
|
@@ -342,6 +344,7 @@ def main():
|
|
342 |
def normalize(batch):
|
343 |
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
|
344 |
|
|
|
345 |
# normalize and transform to `BatchFeatures`
|
346 |
vectorized_datasets = vectorized_datasets.map(
|
347 |
normalize,
|
|
|
174 |
|
175 |
batch_size = batch["input_values"].shape[0]
|
176 |
|
177 |
+
attention_mask = None
|
178 |
if batch["attention_mask"] is not None:
|
179 |
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
180 |
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
|
|
197 |
batch["sampled_negative_indices"] = _sample_negative_indices(
|
198 |
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
199 |
self.model.config.num_negatives,
|
200 |
+
attention_mask=attention_mask,
|
201 |
)
|
202 |
|
203 |
return batch
|
|
|
344 |
def normalize(batch):
|
345 |
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
|
346 |
|
347 |
+
batch_size = 64
|
348 |
# normalize and transform to `BatchFeatures`
|
349 |
vectorized_datasets = vectorized_datasets.map(
|
350 |
normalize,
|
train.sh
CHANGED
@@ -1,22 +1,21 @@
|
|
1 |
#!/usr/bin/env bash
|
2 |
-
./
|
3 |
-
--output_dir="./
|
4 |
--num_train_epochs="5" \
|
5 |
-
--per_device_train_batch_size="
|
6 |
-
--per_device_eval_batch_size="
|
7 |
--learning_rate="5e-4" \
|
8 |
--weight_decay="0.01" \
|
9 |
-
--warmup_steps="
|
10 |
--model_name_or_path="./" \
|
11 |
--dataset_name="common_voice" \
|
12 |
--dataset_config_name="es" \
|
13 |
-
--preprocessing_num_workers="
|
14 |
--max_duration_in_seconds="10.0" \
|
15 |
--adam_beta1="0.9" \
|
16 |
--adam_beta2="0.98" \
|
17 |
--pad_to_multiple_of="16384" \
|
18 |
--validation_split_percentage="5" \
|
19 |
--speech_file_column="path" \
|
20 |
-
--dtype="bfloat16" \
|
21 |
--cache_dir="./data_cache" \
|
22 |
--push_to_hub
|
|
|
1 |
#!/usr/bin/env bash
|
2 |
+
./run_wav2vec2_pretrain_flax.py \
|
3 |
+
--output_dir="./wav2vec2-spanish" \
|
4 |
--num_train_epochs="5" \
|
5 |
+
--per_device_train_batch_size="16" \
|
6 |
+
--per_device_eval_batch_size="16" \
|
7 |
--learning_rate="5e-4" \
|
8 |
--weight_decay="0.01" \
|
9 |
+
--warmup_steps="1000" \
|
10 |
--model_name_or_path="./" \
|
11 |
--dataset_name="common_voice" \
|
12 |
--dataset_config_name="es" \
|
13 |
+
--preprocessing_num_workers="32" \
|
14 |
--max_duration_in_seconds="10.0" \
|
15 |
--adam_beta1="0.9" \
|
16 |
--adam_beta2="0.98" \
|
17 |
--pad_to_multiple_of="16384" \
|
18 |
--validation_split_percentage="5" \
|
19 |
--speech_file_column="path" \
|
|
|
20 |
--cache_dir="./data_cache" \
|
21 |
--push_to_hub
|