sanchit-gandhi HF staff commited on
Commit
678bd59
1 Parent(s): afab5c5

tf2xup5z: saving weights and logs of step 10k

Browse files
.gitattributes CHANGED
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
30
  *.zip filter=lfs diff=lfs merge=lfs -text
31
  *.zst filter=lfs diff=lfs merge=lfs -text
32
  *tfevents* filter=lfs diff=lfs merge=lfs -text
33
+ *.wandb filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "adapter_kernel_size": 3,
4
+ "adapter_stride": 2,
5
+ "add_adapter": false,
6
+ "apply_spec_augment": true,
7
+ "architectures": [
8
+ "Wav2Vec2ForCTC"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 1,
12
+ "classifier_proj_size": 256,
13
+ "codevector_dim": 768,
14
+ "contrastive_logits_temperature": 0.1,
15
+ "conv_bias": true,
16
+ "conv_dim": [
17
+ 512,
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512
24
+ ],
25
+ "conv_kernel": [
26
+ 10,
27
+ 3,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 2,
32
+ 2
33
+ ],
34
+ "conv_stride": [
35
+ 5,
36
+ 2,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2
42
+ ],
43
+ "ctc_loss_reduction": "sum",
44
+ "ctc_zero_infinity": false,
45
+ "diversity_loss_weight": 0.1,
46
+ "do_stable_layer_norm": true,
47
+ "eos_token_id": 2,
48
+ "feat_extract_activation": "gelu",
49
+ "feat_extract_dropout": 0.0,
50
+ "feat_extract_norm": "layer",
51
+ "feat_proj_dropout": 0.0,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "fuse_matmuls": false,
55
+ "gradient_checkpointing": true,
56
+ "hidden_act": "gelu",
57
+ "hidden_dropout": 0.1,
58
+ "hidden_dropout_prob": 0.1,
59
+ "hidden_size": 1024,
60
+ "initializer_range": 0.02,
61
+ "intermediate_size": 4096,
62
+ "layer_norm_eps": 1e-05,
63
+ "layerdrop": 0.0,
64
+ "mask_feature_length": 10,
65
+ "mask_feature_min_masks": 0,
66
+ "mask_feature_prob": 0.0,
67
+ "mask_time_length": 10,
68
+ "mask_time_min_masks": 2,
69
+ "mask_time_prob": 0.1,
70
+ "model_type": "wav2vec2",
71
+ "num_adapter_layers": 3,
72
+ "num_attention_heads": 16,
73
+ "num_codevector_groups": 2,
74
+ "num_codevectors_per_group": 320,
75
+ "num_conv_pos_embedding_groups": 16,
76
+ "num_conv_pos_embeddings": 128,
77
+ "num_feat_extract_layers": 7,
78
+ "num_hidden_layers": 24,
79
+ "num_negatives": 100,
80
+ "output_hidden_size": 1024,
81
+ "pad_token_id": 0,
82
+ "proj_codevector_dim": 768,
83
+ "tdnn_dilation": [
84
+ 1,
85
+ 2,
86
+ 3,
87
+ 1,
88
+ 1
89
+ ],
90
+ "tdnn_dim": [
91
+ 512,
92
+ 512,
93
+ 512,
94
+ 512,
95
+ 1500
96
+ ],
97
+ "tdnn_kernel": [
98
+ 5,
99
+ 3,
100
+ 3,
101
+ 1,
102
+ 1
103
+ ],
104
+ "transformers_version": "4.22.0.dev0",
105
+ "use_scan": true,
106
+ "use_weighted_layer_sum": false,
107
+ "vocab_size": 34,
108
+ "xvector_output_dim": 512
109
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee0d5ea5cd8989c9122114ebb26b843e5738a8ac85a029e1d30355daa62b24c
3
+ size 1261896350
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "processor_class": "Wav2Vec2Processor",
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 16000
10
+ }
run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ )
56
+ from transformers.file_utils import get_full_repo_name
57
+ from transformers.utils import check_min_version
58
+ from transformers.utils.versions import require_version
59
+
60
+
61
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
62
+ check_min_version("4.17.0.dev0")
63
+
64
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ @flax.struct.dataclass
70
+ class ModelArguments:
71
+ """
72
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
73
+ """
74
+
75
+ model_name_or_path: str = field(
76
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
77
+ )
78
+ config_name: Optional[str] = field(
79
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
80
+ )
81
+ tokenizer_name: Optional[str] = field(
82
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
83
+ )
84
+ feature_extractor_name: Optional[str] = field(
85
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
86
+ )
87
+ cache_dir: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
90
+ )
91
+ use_fast_tokenizer: bool = field(
92
+ default=True,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
+ )
95
+ model_revision: str = field(
96
+ default="main",
97
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
98
+ )
99
+ use_auth_token: bool = field(
100
+ default=False,
101
+ metadata={
102
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
103
+ "with private models)."
104
+ },
105
+ )
106
+ freeze_feature_encoder: bool = field(
107
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
108
+ )
109
+ activation_dropout: float = field(
110
+ default=0.1,
111
+ metadata={
112
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
113
+ },
114
+ )
115
+ hidden_dropout: float = field(
116
+ default=0.1,
117
+ metadata={
118
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
119
+ },
120
+ )
121
+ feat_proj_dropout: float = field(
122
+ default=0.0,
123
+ metadata={
124
+ "help": "The feat proj dropout probability for feature encoder representations."
125
+ },
126
+ )
127
+ mask_time_prob: float = field(
128
+ default=0.1,
129
+ metadata={
130
+ "help": "The spec aug dropout probability for feature encoder representations."
131
+ },
132
+ )
133
+
134
+
135
+ @flax.struct.dataclass
136
+ class DataTrainingArguments:
137
+ """
138
+ Arguments pertaining to what data we are going to input our model for training and eval.
139
+ """
140
+
141
+ dataset_name: str = field(
142
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
143
+ )
144
+ dataset_config_name: Optional[str] = field(
145
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
146
+ )
147
+ text_column: Optional[str] = field(
148
+ default=None,
149
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
150
+ )
151
+ dataset_cache_dir: Optional[str] = field(
152
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
153
+ )
154
+ overwrite_cache: bool = field(
155
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
156
+ )
157
+ preprocessing_num_workers: Optional[int] = field(
158
+ default=None,
159
+ metadata={"help": "The number of processes to use for the preprocessing."},
160
+ )
161
+ max_train_samples: Optional[int] = field(
162
+ default=None,
163
+ metadata={
164
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
165
+ "value if set."
166
+ },
167
+ )
168
+ max_eval_samples: Optional[int] = field(
169
+ default=None,
170
+ metadata={
171
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
172
+ "value if set."
173
+ },
174
+ )
175
+ max_test_samples: Optional[int] = field(
176
+ default=None,
177
+ metadata={
178
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
179
+ "value if set."
180
+ },
181
+ )
182
+ audio_column_name: str = field(
183
+ default="audio",
184
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
185
+ )
186
+ text_column_name: str = field(
187
+ default="text",
188
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
189
+ )
190
+ max_duration_in_seconds: float = field(
191
+ default=20.0,
192
+ metadata={
193
+ "help": "Filter audio files in the training set that are longer than `max_duration_in_seconds` seconds"
194
+ },
195
+ )
196
+ min_duration_in_seconds: float = field(
197
+ default=0.0, metadata={"help": "Filter audio files in the training set that are shorter than `min_duration_in_seconds` seconds"}
198
+ )
199
+ max_label_length: Optional[int] = field(
200
+ default=512,
201
+ metadata={
202
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
203
+ "than this will be filtered."
204
+ },
205
+ )
206
+ min_label_length: Optional[int] = field(
207
+ default=0,
208
+ metadata={
209
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
210
+ "than this will be filtered."
211
+ },
212
+ )
213
+ max_eval_duration_in_seconds: float = field(
214
+ default=None,
215
+ metadata={
216
+ "help": "Filter audio files in the eval/test set that are longer than `max_duration_in_seconds` seconds"
217
+ },
218
+ )
219
+ pad_input_to_multiple_of: Optional[int] = field(
220
+ default=32000,
221
+ metadata={
222
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
223
+ "This is important to avoid triggering recompilations on TPU."
224
+ },
225
+ )
226
+ pad_target_to_multiple_of: Optional[int] = field(
227
+ default=None,
228
+ metadata={
229
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
230
+ "This is important to avoid triggering recompilations on TPU."
231
+ },
232
+ )
233
+ preprocessing_only: bool = field(
234
+ default=False,
235
+ metadata={
236
+ "help": "Whether to only do data preprocessing and skip training. "
237
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
238
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
239
+ "so that the cached datasets can consequently be loaded in distributed training"
240
+ },
241
+ )
242
+ train_split_name: str = field(
243
+ default="train",
244
+ metadata={
245
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
246
+ },
247
+ )
248
+ eval_split_name: str = field(
249
+ default="validation",
250
+ metadata={
251
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
252
+ },
253
+ )
254
+ do_lower_case: bool = field(
255
+ default=True,
256
+ metadata={"help": "Whether the target text should be lower cased."},
257
+ )
258
+ wandb_project: str = field(
259
+ default="flax-speech-recognition-ctc",
260
+ metadata={"help": "The name of the wandb project."},
261
+ )
262
+ wandb_name: str = field(
263
+ default=None,
264
+ metadata={"help": "The name of the wandb run."},
265
+ )
266
+ wandb_job_type: str = field(
267
+ default="CTC",
268
+ metadata={"help": "The name of the wandb job type."},
269
+ )
270
+ test_split_name: str = field(
271
+ default="test",
272
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
273
+ )
274
+ remove_punctuation: bool = field(
275
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
276
+ )
277
+
278
+
279
+ # @flax.struct.dataclass
280
+ @dataclass
281
+ class FlaxTrainingArguments(TrainingArguments):
282
+ precision: str = field(
283
+ default="full",
284
+ metadata={
285
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
286
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
287
+ },
288
+ )
289
+ matmul_precision: str = field(
290
+ default="default",
291
+ metadata={
292
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
293
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
294
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
295
+ "it only changes the behaviors of calls with no such argument provided. "
296
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
297
+ },
298
+ )
299
+ multisteps: bool = field(
300
+ default=False,
301
+ metadata={
302
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
303
+ "a custom gradient accumulation implementation will be employed."
304
+ },
305
+ )
306
+
307
+
308
+ def to_fp32(t):
309
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
310
+
311
+
312
+ def to_bf16(t):
313
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
314
+
315
+
316
+ class MixedPrecisionTrainState(struct.PyTreeNode):
317
+ """Train state for use with a single Optax optimizer.
318
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
319
+
320
+ Synopsis::
321
+
322
+ state = TrainState.create(
323
+ apply_fn=model.apply,
324
+ params=variables['params'],
325
+ tx=tx)
326
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
327
+ for batch in data:
328
+ grads = grad_fn(state.params, batch)
329
+ state = state.apply_gradients(grads=grads)
330
+
331
+ Args:
332
+ step: Counter starts at 0 and is incremented by every call to
333
+ `.apply_gradients()`.
334
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
335
+ convenience to have a shorter params list for the `train_step()` function
336
+ in your training loop.
337
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
338
+ tx: An Optax gradient transformation.
339
+ opt_state: The state for `tx`.
340
+ dropout_rng: PRNG key for stochastic operations.
341
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
342
+ """
343
+
344
+ step: int
345
+ apply_fn: Callable = struct.field(pytree_node=False)
346
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
347
+ params: core.FrozenDict[str, Any]
348
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
349
+ opt_state: optax.OptState
350
+ dropout_rng: jnp.ndarray
351
+ max_grad_norm: Optional[float] = 1.0
352
+
353
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
354
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
355
+
356
+ Note that internally this function calls `.tx.update()` followed by a call
357
+ to `optax.apply_updates()` to update `params` and `opt_state`.
358
+
359
+ Args:
360
+ grads: Gradients that have the same pytree structure as `.params`.
361
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
362
+
363
+ Returns:
364
+ An updated instance of `self` with `step` incremented by one, `params`
365
+ and `opt_state` updated by applying `grads`, and additional attributes
366
+ replaced as specified by `kwargs`.
367
+ """
368
+
369
+ # clip gradients by global l2 norm
370
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
371
+ g_norm = linear_algebra.global_norm(grads)
372
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
373
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
374
+
375
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
376
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
377
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
378
+
379
+ new_params = optax.apply_updates(self.params, updates)
380
+ return self.replace(
381
+ step=self.step + 1,
382
+ params=new_params,
383
+ opt_state=to_dtype(new_opt_state),
384
+ **kwargs,
385
+ )
386
+
387
+ @classmethod
388
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
389
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
390
+ # downcast optimizer state to bf16 if mixed-precision training
391
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
392
+ return cls(
393
+ step=0,
394
+ apply_fn=apply_fn,
395
+ params=params,
396
+ tx=tx,
397
+ opt_state=opt_state,
398
+ **kwargs,
399
+ )
400
+
401
+ def replicate(self):
402
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
403
+
404
+
405
+ @flax.struct.dataclass
406
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
407
+ """
408
+ Data collator that will dynamically pad the inputs received.
409
+ Args:
410
+ processor ([`Wav2Vec2Processor`])
411
+ The processor used for proccessing the data.
412
+ decoder_start_token_id (:obj: `int`)
413
+ The begin-of-sentence of the decoder.
414
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
415
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
416
+ among:
417
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
418
+ sequence if provided).
419
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
420
+ maximum acceptable input length for the model if that argument is not provided.
421
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
422
+ different lengths).
423
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
424
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
425
+ See above for details.
426
+ max_input_length (:obj:`float`, `optional`):
427
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
428
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
429
+ If set will pad the input sequence to a multiple of the provided value.
430
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
431
+ 7.5 (Volta).
432
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
433
+ If set will pad the target sequence to a multiple of the provided value.
434
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
435
+ 7.5 (Volta).
436
+ """
437
+
438
+ processor: Any
439
+ input_padding: Union[bool, str] = "longest"
440
+ label_padding: Union[bool, str] = "max_length"
441
+ pad_input_to_multiple_of: Optional[int] = None
442
+ pad_to_multiple_of_label: Optional[int] = None
443
+ max_input_length: Optional[float] = None
444
+ max_label_length: Optional[float] = None
445
+
446
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
447
+ # split inputs and labels since they have to be of different lengths and need
448
+ # different padding methods
449
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
450
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
451
+
452
+ # reformat list to dict and set to pytorch format
453
+ batch = self.processor.feature_extractor.pad(
454
+ input_features,
455
+ max_length=self.max_input_length,
456
+ padding=self.input_padding,
457
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
458
+ return_tensors="np",
459
+ )
460
+
461
+ labels_batch = self.processor.tokenizer.pad(
462
+ label_features,
463
+ max_length=self.max_label_length,
464
+ padding=self.label_padding,
465
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
466
+ return_tensors="np",
467
+ )
468
+
469
+ labels = labels_batch["input_ids"]
470
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
471
+ labels = labels.filled(fill_value=-100)
472
+
473
+ batch["labels"] = labels
474
+
475
+ return batch
476
+
477
+
478
+ def get_grouped_indices(
479
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
480
+ ) -> np.array:
481
+ """
482
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
483
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
484
+ lengths. To do this, the indices are:
485
+
486
+ - randomly permuted (if a JAX rng is specified)
487
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
488
+ - sorted by length in each mega-batch
489
+
490
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
491
+ maximum length placed first, so that an OOM happens sooner rather than later.
492
+ """
493
+ lengths = dataset["input_length"]
494
+
495
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
496
+ if mega_batch_mult is None:
497
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
498
+ # Just in case, for tiny datasets
499
+ if mega_batch_mult == 0:
500
+ mega_batch_mult = 1
501
+
502
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
503
+ num_samples = len(lengths)
504
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
505
+
506
+ megabatch_size = mega_batch_mult * batch_size
507
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
508
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
509
+
510
+ # The rest is to get the biggest batch first.
511
+ # Since each megabatch is sorted by descending length, the longest element is the first
512
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
513
+ max_idx = np.argmax(megabatch_maximums).item()
514
+ # Switch to put the longest batch in first position
515
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
516
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
517
+
518
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
519
+
520
+ return megabatches
521
+
522
+
523
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
524
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
525
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
526
+ num_samples = len(samples_idx)
527
+ if drop_last:
528
+ samples_to_remove = num_samples % batch_size
529
+ if samples_to_remove != 0:
530
+ samples_idx = samples_idx[:-samples_to_remove]
531
+ sections_split = num_samples // batch_size
532
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
533
+ else:
534
+ sections_split = math.ceil(num_samples / batch_size)
535
+ samples_idx = np.array_split(samples_idx, sections_split)
536
+ return samples_idx
537
+
538
+
539
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
540
+ summary_writer.scalar("train_time", train_time, step)
541
+
542
+ train_metrics = get_metrics(train_metrics)
543
+ for key, vals in train_metrics.items():
544
+ tag = f"train_{key}"
545
+ for i, val in enumerate(vals):
546
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
547
+
548
+
549
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
550
+ for metric_name, value in eval_metrics.items():
551
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
552
+
553
+ if pred_str is not None:
554
+ # write output actual predictions for debugging
555
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
556
+
557
+
558
+ def write_wandb_log(metrics, step, prefix=None):
559
+ if jax.process_index() == 0:
560
+ log_metrics = {}
561
+ for k, v in metrics.items():
562
+ if "layer" in k:
563
+ log_metrics[f"{k}/"] = v
564
+ elif prefix is not None:
565
+ log_metrics[f"{prefix}/{k}"] = v
566
+ else:
567
+ log_metrics[k] = v
568
+ wandb.log(log_metrics, step)
569
+
570
+
571
+ def write_wandb_pred(pred_str, label_str, step, final_step=False, prefix="eval"):
572
+ if jax.process_index() == 0:
573
+ # convert str data to a wandb compatible format
574
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
575
+ if not final_step:
576
+ # we'll log the first 50 predictions for each intermediate epoch
577
+ wandb.log(
578
+ {
579
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
580
+ columns=["label_str", "pred_str"], data=str_data[:50]
581
+ )
582
+ },
583
+ step,
584
+ )
585
+ else:
586
+ # we'll log all predictions for the last epoch
587
+ wandb.log(
588
+ {
589
+ f"{prefix}/step_{int(step / 1000)}k_all": wandb.Table(
590
+ columns=["label_str", "pred_str"], data=str_data
591
+ )
592
+ },
593
+ step,
594
+ )
595
+
596
+
597
+ def create_learning_rate_fn(
598
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
599
+ ) -> Callable[[int], jnp.array]:
600
+ """Returns a linear warmup, linear_decay learning rate function."""
601
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
602
+ decay_fn = optax.linear_schedule(
603
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
604
+ )
605
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
606
+ return schedule_fn
607
+
608
+
609
+ def ctc_loss(
610
+ logits,
611
+ logits_attention_mask,
612
+ labels,
613
+ blank_id,
614
+ loss_reduction="mean",
615
+ output_emission_dict=False,
616
+ log_epsilon=-100000.0,
617
+ ):
618
+ """Computes CTC loss.
619
+ This function performs forward computation over an FSA with `N * 2` states
620
+ where `N` is the max number of labels. The states are split into two groups:
621
+ Phi states and emission states. a phi-state accepts repetition of
622
+ phi (blank)-symbols and transits to emission state when the correct label is
623
+ observed. An emission state accepts repetition of the label and transits to
624
+ the next phi states at any time (so called epsilon-transition).
625
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
626
+ and `N` denotes the time steps in `labels`.
627
+ Args:
628
+ logits: (B, T, K)-array containing log-probabilities of each class.
629
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
630
+ labels: (B, N)-array containing reference integer labels.
631
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
632
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
633
+ repetition of zeroes, followed by repetition of ones.
634
+ blank_id: Id for blank token.
635
+ loss_reduction: one of "mean", "sum", "default"
636
+ - "none": no reduction is applied.
637
+ - "mean": output loss will be divided by target lengths and then the
638
+ mean over the batch is taken.
639
+ - "sum": output loss are summed over batch
640
+ output_emission_dict: whether to output additional information about the emission probs
641
+ Returns:
642
+ A pair of `(per_seq_loss, aux)`.
643
+ per_seq_loss:
644
+ (B,)-array containing loss values for each sequence in the batch.
645
+ aux: Dictionary containing interim variables used for computing losses.
646
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
647
+ phi-state corresponding to the n-th label.
648
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
649
+ emission-state corresponding to the n-th label.
650
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
651
+ corresponding to each time frame.
652
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
653
+ corresponding to each time frame.
654
+ """
655
+ # label paddings are indicated by -100
656
+ labelpaddings = labels < 0
657
+ # logit paddings are the inverse of attention_mask
658
+ logitpaddings = ~logits_attention_mask
659
+
660
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
661
+ batchsize, unused_maxinputlen, num_classes = logits.shape
662
+ batchsize_, maxlabellen = labels.shape
663
+
664
+ logprobs = jax.nn.log_softmax(logits)
665
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
666
+
667
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
668
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
669
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
670
+
671
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
672
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
673
+
674
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
675
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
676
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
677
+
678
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
679
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
680
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
681
+
682
+ def loop_body(prev, x):
683
+ prev_phi, prev_emit = prev
684
+ # emit-to-phi epsilon transition, except if the next label is repetition
685
+ prev_phi_orig = prev_phi
686
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
687
+
688
+ logprob_emit, logprob_phi, pad = x
689
+
690
+ # phi-to-emit transition
691
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
692
+ # self-loop transition
693
+ next_phi = prev_phi + logprob_phi
694
+ # emit-to-phi blank transition only when the next label is repetition
695
+ next_phi = next_phi.at[:, 1:].set(
696
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
697
+ )
698
+
699
+ pad = pad.reshape((batchsize, 1))
700
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
701
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
702
+
703
+ return (next_phi, next_emit), (next_phi, next_emit)
704
+
705
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
706
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
707
+
708
+ # last row needs to be updated with the last epsilon transition
709
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
710
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
711
+
712
+ # extract per_seq_loss
713
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
714
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
715
+
716
+ if loss_reduction == "mean":
717
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
718
+ loss = (per_seq_loss / target_lengths).mean()
719
+ elif loss_reduction == "sum":
720
+ loss = per_seq_loss.sum()
721
+ else:
722
+ loss = per_seq_loss
723
+
724
+ if not output_emission_dict:
725
+ return loss
726
+
727
+ return loss, {
728
+ "logalpha_phi": logalpha_phi,
729
+ "logalpha_emit": logalpha_emit,
730
+ "logprobs_phi": logprobs_phi,
731
+ "logprobs_emit": logprobs_emit,
732
+ }
733
+
734
+
735
+ def main():
736
+ # 1. Parse input arguments
737
+ # See all possible arguments in src/transformers/training_args.py
738
+ # or by passing the --help flag to this script.
739
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
740
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
741
+
742
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
743
+ # If we pass only one argument to the script and it's the path to a json file,
744
+ # let's parse it to get our arguments.
745
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
746
+ else:
747
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
748
+
749
+ # 2. Setup logging
750
+ # Make one log on every process with the configuration for debugging.
751
+ logging.basicConfig(
752
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
753
+ datefmt="%m/%d/%Y %H:%M:%S",
754
+ handlers=[logging.StreamHandler(sys.stdout)],
755
+ )
756
+ # Set the verbosity to info of the Transformers logger.
757
+ # We only want one process per machine to log things on the screen.
758
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
759
+ if jax.process_index() == 0:
760
+ datasets.utils.logging.set_verbosity_warning()
761
+ transformers.utils.logging.set_verbosity_info()
762
+ else:
763
+ datasets.utils.logging.set_verbosity_error()
764
+ transformers.utils.logging.set_verbosity_error()
765
+
766
+ # Set up wandb run
767
+ if jax.process_index() == 0:
768
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
769
+
770
+ logger.info("Training/evaluation parameters %s", training_args)
771
+
772
+ # Set the default TPU matmul precision and display the number of devices
773
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
774
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
775
+
776
+ # 4. Load dataset
777
+ raw_datasets = DatasetDict()
778
+
779
+ if training_args.do_train:
780
+ raw_datasets["train"] = load_dataset(
781
+ data_args.dataset_name,
782
+ data_args.dataset_config_name,
783
+ split=data_args.train_split_name,
784
+ cache_dir=data_args.dataset_cache_dir,
785
+ use_auth_token=True if model_args.use_auth_token else None,
786
+ )
787
+
788
+ if training_args.do_eval:
789
+ raw_datasets["eval"] = load_dataset(
790
+ data_args.dataset_name,
791
+ data_args.dataset_config_name,
792
+ split=data_args.eval_split_name,
793
+ cache_dir=data_args.dataset_cache_dir,
794
+ use_auth_token=True if model_args.use_auth_token else None,
795
+ )
796
+
797
+ if training_args.do_predict:
798
+ test_split = data_args.test_split_name.split("+")
799
+ for split in test_split:
800
+ raw_datasets[split] = load_dataset(
801
+ data_args.dataset_name,
802
+ data_args.dataset_config_name,
803
+ split=split,
804
+ cache_dir=data_args.dataset_cache_dir,
805
+ use_auth_token=True if model_args.use_auth_token else None,
806
+ )
807
+
808
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
809
+ raise ValueError(
810
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
811
+ "training, evaluation or prediction has to be done."
812
+ )
813
+
814
+ # if not training, there is no need to run multiple epochs
815
+ if not training_args.do_train:
816
+ training_args.num_train_epochs = 1
817
+
818
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
819
+ raise ValueError(
820
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
821
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
822
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
823
+ )
824
+
825
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
826
+ raise ValueError(
827
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
828
+ "Make sure to set `--text_column_name` to the correct text column - one of "
829
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
830
+ )
831
+
832
+ # 5. Load pretrained model, tokenizer, and feature extractor
833
+ #
834
+ # Distributed training:
835
+ # The .from_pretrained methods guarantee that only one local process can concurrently
836
+ config = Wav2Vec2Config.from_pretrained(
837
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
838
+ cache_dir=model_args.cache_dir,
839
+ revision=model_args.model_revision,
840
+ use_auth_token=True if model_args.use_auth_token else None,
841
+ )
842
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
843
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
844
+ cache_dir=model_args.cache_dir,
845
+ revision=model_args.model_revision,
846
+ use_auth_token=True if model_args.use_auth_token else None,
847
+ )
848
+ tokenizer = AutoTokenizer.from_pretrained(
849
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
850
+ cache_dir=model_args.cache_dir,
851
+ revision=model_args.model_revision,
852
+ use_auth_token=True if model_args.use_auth_token else None,
853
+ )
854
+ # update config according to training args, model args, and tokenizer attributes
855
+ config.update(
856
+ {
857
+ "gradient_checkpointing": training_args.gradient_checkpointing,
858
+ "activation_dropout": model_args.activation_dropout,
859
+ "hidden_dropout": model_args.hidden_dropout,
860
+ "feat_proj_dropout": model_args.feat_proj_dropout,
861
+ "mask_time_prob": model_args.mask_time_prob,
862
+ "vocab_size": tokenizer.vocab_size,
863
+ }
864
+ )
865
+
866
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
867
+ raise ValueError(
868
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
869
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
870
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
871
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
872
+ )
873
+
874
+ if training_args.precision == "full_mixed":
875
+ dtype = jnp.bfloat16
876
+ training_args.mixed_precision = True
877
+ elif training_args.precision == "half_mixed":
878
+ dtype = jnp.bfloat16
879
+ training_args.mixed_precision = False
880
+ else:
881
+ dtype = jnp.float32
882
+ training_args.mixed_precision = False
883
+
884
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
885
+ model_args.model_name_or_path,
886
+ config=config,
887
+ dtype=dtype,
888
+ cache_dir=model_args.cache_dir,
889
+ revision=model_args.model_revision,
890
+ use_auth_token=True if model_args.use_auth_token else None,
891
+ )
892
+
893
+ # 6. Resample speech dataset ALWAYS
894
+ raw_datasets = raw_datasets.cast_column(
895
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
896
+ )
897
+
898
+ # 7. Preprocessing the datasets.
899
+ # We need to read the audio files as arrays and tokenize the targets.
900
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
901
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
902
+ max_eval_input_length = int(data_args.max_eval_duration_in_seconds * feature_extractor.sampling_rate) if data_args.max_eval_duration_in_seconds else None
903
+ max_target_length = data_args.max_label_length
904
+ min_target_length = data_args.min_label_length
905
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
906
+ audio_column_name = data_args.audio_column_name
907
+ num_workers = data_args.preprocessing_num_workers
908
+ text_column_name = data_args.text_column_name
909
+ model_input_name = feature_extractor.model_input_names[0]
910
+ do_lower_case = data_args.do_lower_case
911
+ dataset_name = data_args.dataset_name
912
+ tedlium_contractions = [" 's", " 't", " 're", " 've", " 'm", " 'll", " 'd", " 'clock", " 'all"]
913
+ gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
914
+ gigaspeech_disfluencies = ["<other>", "<sil>"]
915
+ swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "[vocalized-noise]", "<a_aside>", "<b_aside>", "<e_aside>",
916
+ "[laughter-", "_1", "[laugh]", "[sigh]", "[cough]", "[mn]", "[breath]", "[lipsmack]",
917
+ "[sneeze]", "[skip]", "[pause]", "(%hesitation)", "(%HESITATION)"]
918
+ swb_punctuations = ["{", "}", "[", "]-", "]", "((", "))", "(", ")"]
919
+ earnings_disfluencies = ["<noise>", "<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>"]
920
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
921
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
922
+ ignore_segments += swb_disfluencies
923
+
924
+ if training_args.do_train and data_args.max_train_samples is not None:
925
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
926
+
927
+ if training_args.do_eval and data_args.max_eval_samples is not None:
928
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
929
+
930
+ if training_args.do_predict and data_args.max_test_samples is not None:
931
+ for split in test_split:
932
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
933
+
934
+ # filter data where the targets are ignored in scoring
935
+ def is_target_labels(input_str):
936
+ return input_str.lower() not in ignore_segments
937
+
938
+ raw_datasets = raw_datasets.filter(
939
+ is_target_labels,
940
+ num_proc=num_workers,
941
+ input_columns=[text_column_name],
942
+ desc="filtering data where the targets are ignored in scoring",
943
+ )
944
+
945
+ def prepare_dataset(batch):
946
+ # Pre-process audio
947
+ try:
948
+ sample = batch[audio_column_name]
949
+ except ValueError:
950
+ # E22: some samples are empty (no audio). Reading the empty audio array will trigger
951
+ # a soundfile ValueError. For now, we'll manually set these arrays to a zero array.
952
+ # They will be filtered in the subsequent filtering stage and so are
953
+ # explicitly ignored during training.
954
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
955
+
956
+ # normalise audio (mean, std) to (0, 1)
957
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
958
+ # process audio length
959
+ batch[model_input_name] = inputs.input_values[0]
960
+ batch["input_length"] = len(batch["input_values"])
961
+
962
+ # 'Error correction' of targets
963
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
964
+
965
+ # LibriSpeech ASR
966
+ if dataset_name == "librispeech_asr":
967
+ pass # no error correction necessary
968
+
969
+ # VoxPopuli
970
+ if dataset_name == "google/xtreme_s":
971
+ pass # no error correction necessary
972
+
973
+ # Common Voice 9
974
+ if dataset_name == "mozilla-foundation/common_voice_9_0":
975
+ if input_str.startswith('"') and input_str.endswith('"'):
976
+ # we can remove trailing quotation marks as they do not affect the transcription
977
+ input_str = input_str[1:-1]
978
+ # replace double quotation marks with single
979
+ input_str = input_str.replace('""', '"')
980
+
981
+ # TED-LIUM (Release 3)
982
+ if dataset_name == "LIUM/tedlium":
983
+ # delete the <unk> token from the text
984
+ input_str = input_str.replace("<unk>", "")
985
+ # replace spaced apostrophes with un-spaced (it 's -> it's)
986
+ for contraction in tedlium_contractions:
987
+ input_str = input_str.replace(contraction, contraction[1:])
988
+
989
+ # GigaSpeech
990
+ if dataset_name == "speechcolab/gigaspeech":
991
+ for disfluency in gigaspeech_disfluencies:
992
+ input_str = input_str.replace(disfluency, "")
993
+ # convert spelled out punctuation to symbolic form
994
+ for punctuation, replacement in gigaspeech_punctuation.items():
995
+ input_str = input_str.replace(punctuation, replacement)
996
+
997
+ # SWB: hide the path to the private HF dataset
998
+ if "switchboard" in dataset_name:
999
+ # In one conversation people speak some German phrases that are tagged as
1000
+ # <german (( ja wohl )) > -- we remove these
1001
+ input_str = re.sub("<[^>]*>", "", input_str)
1002
+
1003
+ # Remove junk tokens
1004
+ for disfluency in swb_disfluencies:
1005
+ input_str = input_str.replace(disfluency, "")
1006
+
1007
+ # Replace partially pronounced words (square brackets + hyphen): westmin[ster]- to westmin- or -[go]ing to -ing
1008
+ # Replace anomalous words (square brackets + backslack): [lemguini/linguini] to linguini
1009
+ # Replace the combo of the two: [lem[guini]-/linguini] to lem-
1010
+ # Example: we [ah/are] -[go]ing to westmin[ster]- for [lem[guini]-/linguini]
1011
+ # Target: we ah -ing to westmin- for lem-
1012
+ # Treat anomalous words first then destroy the content of all square brackets (partially pronounced words)
1013
+
1014
+ # First treat partially pronounced anomalous words by removing correct word: [lem[guini]-/linguini] to [lem[guini]-
1015
+ input_str = re.sub(r"\-\/.*?\]", "-", input_str)
1016
+
1017
+ # Now replace anomalous words with their correct transcriptions: [lemguini/linguini] to linguini
1018
+ split_str = input_str.split("/")
1019
+ if len(split_str) > 1:
1020
+ input_str = " ".join(
1021
+ [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1022
+
1023
+ # Remove the trailing brackets on the start/end of words
1024
+ processed_str = []
1025
+ for word in input_str.split():
1026
+ if word[0] == "[":
1027
+ processed_str.append(word[1:])
1028
+ elif word[-1] == "]":
1029
+ processed_str.append(word[:-1])
1030
+ else:
1031
+ processed_str.append(word)
1032
+
1033
+ # Stick the processed words back together
1034
+ input_str = " ".join(processed_str)
1035
+
1036
+ # Now we can remove all words in square brackets: -[go]ing to -ing
1037
+ input_str = re.sub(r"\-\[(.*?)\]", "-", input_str)
1038
+
1039
+ # westmin[ster]- to westmin-
1040
+ input_str = re.sub(r"\[(.*?)\]\-", "-", input_str)
1041
+
1042
+ # tech[n]ology to tech-ology
1043
+ input_str = re.sub(r"\[(.*?)\]", "-", input_str)
1044
+
1045
+ # partially pronounced words are now done!
1046
+ # remove erroneous punctuations (curly braces, trailing square brackets, etc.)
1047
+ for punctuation in swb_punctuations:
1048
+ input_str = input_str.replace(punctuation, "")
1049
+
1050
+ # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
1051
+ if "earnings22" in dataset_name:
1052
+ for disfluency in earnings_disfluencies:
1053
+ input_str = input_str.replace(disfluency, "")
1054
+
1055
+ # SPGISpeech
1056
+ if dataset_name == "kensho/spgispeech":
1057
+ pass # no error correction necessary
1058
+
1059
+ # JIWER compliance (for WER/CER calc.)
1060
+ # remove multiple spaces
1061
+ input_str = re.sub(r"\s\s+", " ", input_str)
1062
+ # strip trailing spaces
1063
+ input_str = input_str.strip()
1064
+
1065
+ # Finally, we tokenize the processed text
1066
+ batch["labels"] = tokenizer(input_str).input_ids
1067
+ batch["labels_length"] = len(batch["labels"])
1068
+ return batch
1069
+
1070
+ vectorized_datasets = raw_datasets.map(
1071
+ prepare_dataset,
1072
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1073
+ num_proc=num_workers,
1074
+ desc="preprocess dataset",
1075
+ )
1076
+
1077
+ # filter training data with inputs longer than max_input_length
1078
+ def is_audio_in_length_range(length):
1079
+ return min_input_length < length < max_input_length
1080
+
1081
+ if training_args.do_train:
1082
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
1083
+ is_audio_in_length_range,
1084
+ num_proc=num_workers,
1085
+ input_columns=["input_length"],
1086
+ )
1087
+
1088
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1089
+ def is_labels_in_length_range(length):
1090
+ return min_target_length < length < max_target_length
1091
+
1092
+ if training_args.do_train:
1093
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
1094
+ is_labels_in_length_range,
1095
+ num_proc=num_workers,
1096
+ input_columns=["labels_length"],
1097
+ )
1098
+
1099
+ # filter data with targets shorter than 2 tokens (empty sentences)
1100
+ def is_labels_greater_than_min(length):
1101
+ return length > 2
1102
+
1103
+ vectorized_datasets = vectorized_datasets.filter(
1104
+ is_labels_greater_than_min,
1105
+ num_proc=num_workers,
1106
+ input_columns=["labels_length"],
1107
+ )
1108
+
1109
+ if max_eval_input_length is not None:
1110
+ # filter training data with inputs longer than max_input_length
1111
+ def is_eval_audio_in_length_range(length):
1112
+ return min_input_length < length < max_eval_input_length
1113
+
1114
+ if training_args.do_eval:
1115
+ vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
1116
+ is_eval_audio_in_length_range,
1117
+ num_proc=num_workers,
1118
+ input_columns=["input_length"],
1119
+ )
1120
+
1121
+ if training_args.do_predict:
1122
+ for split in test_split:
1123
+ vectorized_datasets[split] = vectorized_datasets[split].filter(
1124
+ is_eval_audio_in_length_range,
1125
+ num_proc=num_workers,
1126
+ input_columns=["input_length"],
1127
+ )
1128
+
1129
+ # for large datasets it is advised to run the preprocessing on a
1130
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1131
+ # be a timeout when running the script in distributed mode.
1132
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1133
+ # cached dataset
1134
+ if data_args.preprocessing_only:
1135
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1136
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1137
+ return
1138
+
1139
+ # 8. Load Metrics
1140
+ wer_metric = load_metric("wer")
1141
+ cer_metric = load_metric("cer")
1142
+
1143
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1144
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1145
+
1146
+ pred_str = tokenizer.batch_decode(pred_ids)
1147
+ # we do not want to group tokens when computing the metrics
1148
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1149
+
1150
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1151
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1152
+
1153
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1154
+
1155
+ # 9. save feature extractor, tokenizer and config
1156
+ feature_extractor.save_pretrained(training_args.output_dir)
1157
+ tokenizer.save_pretrained(training_args.output_dir)
1158
+ config.save_pretrained(training_args.output_dir)
1159
+
1160
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1161
+
1162
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1163
+ processor=processor,
1164
+ input_padding="longest",
1165
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1166
+ max_label_length=data_args.max_label_length,
1167
+ )
1168
+
1169
+ # Enable tensorboard only on the master node
1170
+ has_tensorboard = is_tensorboard_available()
1171
+ if has_tensorboard and jax.process_index() == 0:
1172
+ try:
1173
+ from flax.metrics.tensorboard import SummaryWriter
1174
+
1175
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1176
+ except ImportError as ie:
1177
+ has_tensorboard = False
1178
+ logger.warning(
1179
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1180
+ )
1181
+ else:
1182
+ logger.warning(
1183
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1184
+ "Please run `pip install tensorboard` to enable."
1185
+ )
1186
+
1187
+ # 10. Handle the repository creation
1188
+ if training_args.push_to_hub:
1189
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1190
+ git_lfs_extensions = f.read()
1191
+ if "*.wandb" not in git_lfs_extensions:
1192
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1193
+ if training_args.hub_model_id is None:
1194
+ repo_name = get_full_repo_name(
1195
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1196
+ )
1197
+ else:
1198
+ repo_name = training_args.hub_model_id
1199
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1200
+
1201
+ # 11. Initialize our training
1202
+ rng = jax.random.PRNGKey(training_args.seed)
1203
+ rng, dropout_rng = jax.random.split(rng)
1204
+
1205
+ # Store some constants
1206
+ max_steps = int(training_args.max_steps)
1207
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1208
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1209
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1210
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1211
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1212
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1213
+
1214
+ if training_args.do_train:
1215
+ num_train_samples = len(vectorized_datasets["train"])
1216
+ steps_per_epoch = num_train_samples // batch_size_per_update
1217
+ if max_steps > 0:
1218
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1219
+ total_train_steps = max_steps
1220
+ else:
1221
+ num_epochs = int(training_args.num_train_epochs)
1222
+ total_train_steps = steps_per_epoch * num_epochs
1223
+
1224
+ # Create learning rate schedule
1225
+ # Create learning rate schedule
1226
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1227
+ total_train_steps,
1228
+ training_args.warmup_steps,
1229
+ training_args.learning_rate,
1230
+ )
1231
+
1232
+ # We use Optax's "masking" functionality to not apply weight decay
1233
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1234
+ # mask boolean with the same structure as the parameters.
1235
+ # The mask is True for parameters that should be decayed.
1236
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1237
+ # For FlaxT5, one should correct the layer norm parameter naming
1238
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1239
+ def decay_mask_fn(params):
1240
+ flat_params = traverse_util.flatten_dict(params)
1241
+ layer_norm_params = [
1242
+ (name, "scale")
1243
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1244
+ ]
1245
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1246
+ return traverse_util.unflatten_dict(flat_mask)
1247
+
1248
+ if training_args.adafactor:
1249
+ # Create Adafactor optimizer
1250
+ optim = optax.adafactor(
1251
+ learning_rate=linear_decay_lr_schedule_fn,
1252
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1253
+ weight_decay_rate=training_args.weight_decay,
1254
+ weight_decay_mask=decay_mask_fn,
1255
+ )
1256
+ else:
1257
+ # Create AdamW optimizer
1258
+ optim = optax.adamw(
1259
+ learning_rate=linear_decay_lr_schedule_fn,
1260
+ b1=training_args.adam_beta1,
1261
+ b2=training_args.adam_beta2,
1262
+ eps=training_args.adam_epsilon,
1263
+ weight_decay=training_args.weight_decay,
1264
+ mask=decay_mask_fn,
1265
+ )
1266
+
1267
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1268
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1269
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1270
+ else:
1271
+ num_epochs = 0
1272
+ total_train_steps = 0
1273
+ num_train_samples = 0
1274
+ optim = None
1275
+
1276
+ # Setup train state
1277
+ state = MixedPrecisionTrainState.create(
1278
+ apply_fn=model.__call__,
1279
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1280
+ params=model.params,
1281
+ tx=optim,
1282
+ to_dtype=to_dtype,
1283
+ dropout_rng=dropout_rng,
1284
+ max_grad_norm=training_args.max_grad_norm,
1285
+ )
1286
+
1287
+ # Replicate the train state on each device
1288
+ state = state.replicate()
1289
+ blank_id = model.config.pad_token_id
1290
+
1291
+ # Define gradient update step fn
1292
+ def train_step(state, batch):
1293
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1294
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1295
+
1296
+ def compute_loss(params, minibatch):
1297
+ labels = minibatch.pop("labels")
1298
+ logits = state.apply_fn(
1299
+ **minibatch,
1300
+ params=params,
1301
+ dropout_rng=dropout_rng,
1302
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1303
+ train=True,
1304
+ )[0]
1305
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1306
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1307
+
1308
+ return loss
1309
+
1310
+ grad_fn = jax.value_and_grad(compute_loss)
1311
+
1312
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1313
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1314
+
1315
+ # Custom gradient accumulation
1316
+ else:
1317
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1318
+ batch = jax.tree_map(
1319
+ lambda x: x.reshape(
1320
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1321
+ ),
1322
+ batch,
1323
+ )
1324
+
1325
+ def accum_minibatch_step(accum_grad, minibatch):
1326
+ # compute loss, num labels and grad over minibatch and accumulate
1327
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1328
+ return jax.tree_map(jnp.add, accum_grad, grad), loss
1329
+
1330
+ # create an initial state for accumulating losses, num labels and gradients
1331
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1332
+ # loop accum minibatch step over the number of gradient accumulation steps
1333
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1334
+
1335
+ # update state
1336
+ new_state = state.apply_gradients(
1337
+ grads=grad,
1338
+ dropout_rng=new_dropout_rng,
1339
+ to_dtype=to_dtype,
1340
+ )
1341
+
1342
+ # compute gradient norms over all layers and globally for detailed monitoring
1343
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1344
+ logs = {
1345
+ "layer_grad_norm": layer_grad_norm,
1346
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1347
+ }
1348
+
1349
+ # compute parameter norms over all layers and globally for detailed monitoring
1350
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1351
+ logs["layer_param_norm"] = layer_param_norm
1352
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1353
+
1354
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1355
+ metrics.update(logs)
1356
+
1357
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1358
+ # metrics = to_fp32(metrics)
1359
+
1360
+ return new_state, metrics
1361
+
1362
+ # Define eval fn
1363
+ def eval_step(params, batch):
1364
+ labels = batch.pop("labels")
1365
+ logits = model(**batch, params=params, train=False)[0]
1366
+
1367
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1368
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1369
+
1370
+ pred_ids = jnp.argmax(logits, axis=-1)
1371
+
1372
+ # summarize metrics
1373
+ metrics = {"loss": loss}
1374
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1375
+ # metrics = to_fp32(metrics)
1376
+ return metrics, pred_ids
1377
+
1378
+ # Create parallel version of the train and eval step
1379
+ if training_args.do_train:
1380
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1381
+
1382
+ if training_args.do_eval or training_args.do_predict:
1383
+ p_eval_step = jax.pmap(eval_step, "batch")
1384
+
1385
+ def run_evaluation(step, final_step=False):
1386
+ if training_args.do_eval:
1387
+ # ======================== Evaluating ==============================
1388
+ eval_metrics = []
1389
+ eval_preds = []
1390
+ eval_labels = []
1391
+
1392
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1393
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1394
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1395
+
1396
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1397
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1398
+ batch = data_collator(samples)
1399
+ labels = batch["labels"]
1400
+
1401
+ try:
1402
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1403
+ except TypeError:
1404
+ continue
1405
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1406
+ eval_metrics.append(metrics)
1407
+
1408
+ eval_labels.extend(labels)
1409
+
1410
+ # normalize eval metrics
1411
+ eval_metrics = get_metrics(eval_metrics)
1412
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1413
+ eval_metrics = to_fp32(eval_metrics)
1414
+
1415
+ # always run compute metrics
1416
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1417
+ eval_metrics.update(error_rate_metric)
1418
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1419
+
1420
+ # Print metrics and update progress bar
1421
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1422
+ epochs.write(desc)
1423
+ epochs.desc = desc
1424
+
1425
+ # Save metrics
1426
+ write_wandb_log(eval_metrics, step, prefix="eval")
1427
+ write_wandb_pred(pred_str, label_str, step, final_step=final_step)
1428
+ # if has_tensorboard and jax.process_index() == 0:
1429
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1430
+
1431
+ def save_checkpoint(step):
1432
+ # save and push checkpoint to the hub
1433
+ if jax.process_index() == 0:
1434
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1435
+ model.save_pretrained(training_args.output_dir, params=params)
1436
+ tokenizer.save_pretrained(training_args.output_dir)
1437
+ if training_args.push_to_hub:
1438
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1439
+
1440
+ logger.info("***** Running training *****")
1441
+ logger.info(f" Num examples = {num_train_samples}")
1442
+ logger.info(f" Num Epochs = {num_epochs}")
1443
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1444
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1445
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1446
+ logger.info(f" Total optimization steps = {total_train_steps}")
1447
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1448
+ logger.info(f" Use scan: {config.use_scan}")
1449
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1450
+
1451
+ train_time = cur_step = 0
1452
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1453
+ for epoch in epochs:
1454
+ if training_args.do_train:
1455
+ # ======================== Training ================================
1456
+ train_start = time.time()
1457
+
1458
+ # Create sampling rng
1459
+ rng, input_rng = jax.random.split(rng)
1460
+
1461
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1462
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1463
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1464
+
1465
+ # Gather the indices for creating the batch and do a training step
1466
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1467
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1468
+ batch = data_collator(samples)
1469
+ batch = shard(batch.data)
1470
+ try:
1471
+ state, train_metric = p_train_step(state, batch)
1472
+ except TypeError as e:
1473
+ logger.warning("Encountered following error: \n", e)
1474
+
1475
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1476
+
1477
+ if cur_step % training_args.logging_steps == 0:
1478
+ # Save metrics
1479
+ train_metric = unreplicate(train_metric)
1480
+ train_time += time.time() - train_start
1481
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1482
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1483
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1484
+ # if has_tensorboard and jax.process_index() == 0:
1485
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1486
+
1487
+ epochs.write(
1488
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1489
+ )
1490
+
1491
+ if cur_step % total_train_steps == 0:
1492
+ break
1493
+
1494
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1495
+ run_evaluation(cur_step, final_step=False)
1496
+
1497
+ if cur_step % training_args.save_steps == 0:
1498
+ save_checkpoint(cur_step)
1499
+
1500
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1501
+ # run evaluation at the end of the epoch if eval steps are not specified
1502
+ run_evaluation(cur_step, final_step=False)
1503
+ save_checkpoint(cur_step)
1504
+
1505
+ if training_args.do_train:
1506
+ save_checkpoint(cur_step)
1507
+
1508
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1509
+
1510
+ if training_args.do_eval:
1511
+ run_evaluation(cur_step, final_step=True)
1512
+
1513
+ # TODO: collapse 'do_predict' into the run_evaluation function
1514
+ if training_args.do_predict:
1515
+ for split in test_split:
1516
+ # ======================== Evaluating ==============================
1517
+ eval_metrics = []
1518
+ eval_preds = []
1519
+ eval_labels = []
1520
+
1521
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1522
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1523
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1524
+
1525
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1526
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1527
+ batch = data_collator(samples)
1528
+ labels = batch["labels"]
1529
+
1530
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1531
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1532
+ eval_metrics.append(metrics)
1533
+
1534
+ eval_labels.extend(labels)
1535
+
1536
+ # normalize eval metrics
1537
+ eval_metrics = get_metrics(eval_metrics)
1538
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1539
+ eval_metrics = to_fp32(eval_metrics)
1540
+
1541
+ # always run compute metrics
1542
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1543
+ eval_metrics.update(error_rate_metric)
1544
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1545
+
1546
+ # Print metrics and update progress bar
1547
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1548
+ epochs.write(desc)
1549
+ epochs.desc = desc
1550
+
1551
+ # Save metrics
1552
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1553
+ write_wandb_pred(pred_str, label_str, cur_step, final_step=True, prefix=split)
1554
+ # if has_tensorboard and jax.process_index() == 0:
1555
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1556
+
1557
+
1558
+ if __name__ == "__main__":
1559
+ main()
run_switchboard.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python run_flax_speech_recognition_ctc.py \
3
+ --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4
+ --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-switchboard-black-box-tokenizer" \
5
+ --dataset_name="ldc/switchboard" \
6
+ --dataset_config_name="all" \
7
+ --train_split_name="train.fisher+train.switchboard" \
8
+ --eval_split_name="validation" \
9
+ --test_split_name="test.switchboard+test.callhome" \
10
+ --text_column_name="test" \
11
+ --output_dir="./flax-wav2vec2-ctc-switchboard-fisher-black-box" \
12
+ --wandb_project="switchboard" \
13
+ --wandb_name="flax-wav2vec2-ctc-switchboard-fisher-black-box" \
14
+ --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15
+ --max_steps="50000" \
16
+ --save_steps="10000" \
17
+ --eval_steps="10000" \
18
+ --learning_rate="3e-4" \
19
+ --logging_steps="25" \
20
+ --warmup_steps="5000" \
21
+ --preprocessing_num_workers="1" \
22
+ --do_lower_case="False" \
23
+ --do_train \
24
+ --do_eval \
25
+ --do_predict \
26
+ --overwrite_output_dir \
27
+ --gradient_checkpointing \
28
+ --freeze_feature_encoder \
29
+ --push_to_hub \
30
+ --use_auth_token
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<pad>",
5
+ "unk_token": "<unk>"
6
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "do_lower_case": false,
4
+ "eos_token": "</s>",
5
+ "name_or_path": "sanchit-gandhi/wav2vec2-ctc-switchboard-black-box-tokenizer",
6
+ "pad_token": "<pad>",
7
+ "replace_word_delimiter_char": " ",
8
+ "special_tokens_map_file": null,
9
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
10
+ "unk_token": "<unk>",
11
+ "word_delimiter_token": "|"
12
+ }
vocab.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "'": 32,
3
+ "-": 6,
4
+ "1": 28,
5
+ "</s>": 2,
6
+ "<pad>": 0,
7
+ "<s>": 1,
8
+ "<unk>": 3,
9
+ "a": 20,
10
+ "b": 16,
11
+ "c": 27,
12
+ "d": 19,
13
+ "e": 7,
14
+ "f": 8,
15
+ "g": 4,
16
+ "h": 15,
17
+ "i": 21,
18
+ "j": 5,
19
+ "k": 17,
20
+ "l": 12,
21
+ "m": 23,
22
+ "n": 30,
23
+ "o": 24,
24
+ "p": 10,
25
+ "q": 33,
26
+ "r": 25,
27
+ "s": 22,
28
+ "t": 31,
29
+ "u": 14,
30
+ "v": 11,
31
+ "w": 29,
32
+ "x": 26,
33
+ "y": 18,
34
+ "z": 9,
35
+ "|": 13
36
+ }