mariagrandury commited on
Commit
8fefff1
1 Parent(s): 2e6a6e4

Add training script

Browse files
Files changed (3) hide show
  1. run_wav2vec2_pretrain_flax.py +597 -0
  2. train.sh +22 -0
  3. train_dummy.sh +22 -0
run_wav2vec2_pretrain_flax.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import logging
3
+ import sys
4
+ import time
5
+ from dataclasses import field
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import numpy as np
10
+ from datasets import DatasetDict, load_dataset
11
+ from tqdm import tqdm
12
+
13
+ import flax
14
+ import jax
15
+ import jax.numpy as jnp
16
+ import librosa
17
+ import optax
18
+ from flax import jax_utils, traverse_util
19
+ from flax.training import train_state
20
+ from flax.training.common_utils import get_metrics, onehot, shard
21
+ from transformers import (
22
+ FlaxWav2Vec2ForPreTraining,
23
+ HfArgumentParser,
24
+ TrainingArguments,
25
+ Wav2Vec2Config,
26
+ Wav2Vec2FeatureExtractor,
27
+ is_tensorboard_available,
28
+ )
29
+ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices, _sample_negative_indices
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @flax.struct.dataclass
36
+ class ModelArguments:
37
+ """
38
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
39
+ """
40
+
41
+ model_name_or_path: str = field(
42
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
43
+ )
44
+ cache_dir: Optional[str] = field(
45
+ default=None,
46
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
47
+ )
48
+ freeze_feature_extractor: Optional[bool] = field(
49
+ default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
50
+ )
51
+ gradient_checkpointing: Optional[bool] = field(
52
+ default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
53
+ )
54
+ verbose_logging: Optional[bool] = field(
55
+ default=False,
56
+ metadata={"help": "Whether to log verbose messages or not."},
57
+ )
58
+ max_gumbel_temperature: Optional[float] = field(
59
+ default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
60
+ )
61
+ min_gumbel_temperature: Optional[float] = field(
62
+ default=0.1, metadata={"help": "Minimum temperature for gumbel softmax."}
63
+ )
64
+ gumbel_temperature_decay: Optional[float] = field(
65
+ default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
66
+ )
67
+ dtype: Optional[str] = field(
68
+ default="float32",
69
+ metadata={
70
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
71
+ },
72
+ )
73
+
74
+
75
+ @flax.struct.dataclass
76
+ class DataTrainingArguments:
77
+ """
78
+ Arguments pertaining to what data we are going to input our model for training and eval.
79
+ Using `HfArgumentParser` we can turn this class
80
+ into argparse arguments to be able to specify them on
81
+ the command line.
82
+ """
83
+
84
+ dataset_name: str = field(
85
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
86
+ )
87
+ dataset_config_name: Optional[str] = field(
88
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
89
+ )
90
+ train_split_name: Optional[str] = field(
91
+ default="train",
92
+ metadata={
93
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
94
+ },
95
+ )
96
+ validation_split_name: Optional[str] = field(
97
+ default="validation",
98
+ metadata={
99
+ "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
100
+ },
101
+ )
102
+ speech_file_column: Optional[str] = field(
103
+ default="file",
104
+ metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
105
+ )
106
+ overwrite_cache: bool = field(
107
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
108
+ )
109
+ validation_split_percentage: Optional[int] = field(
110
+ default=5,
111
+ metadata={
112
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
113
+ },
114
+ )
115
+ preprocessing_num_workers: Optional[int] = field(
116
+ default=None,
117
+ metadata={"help": "The number of processes to use for the preprocessing."},
118
+ )
119
+ max_duration_in_seconds: Optional[float] = field(
120
+ default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
121
+ )
122
+ pad_to_multiple_of: Optional[int] = field(
123
+ default=1024,
124
+ metadata={
125
+ "help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
126
+ },
127
+ )
128
+
129
+
130
+ @flax.struct.dataclass
131
+ class FlaxDataCollatorForWav2Vec2Pretraining:
132
+ """
133
+ Data collator that will dynamically pad the inputs received and prepare masked indices
134
+ for self-supervised pretraining.
135
+ Args:
136
+ model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
137
+ The Wav2Vec2 model used for pretraining. The data collator needs to have access
138
+ to config and ``_get_feat_extract_output_lengths`` function for correct padding.
139
+ feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
140
+ The processor used for proccessing the data.
141
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
142
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
143
+ among:
144
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
145
+ sequence if provided).
146
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
147
+ maximum acceptable input length for the model if that argument is not provided.
148
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
149
+ different lengths).
150
+ max_length (:obj:`int`, `optional`):
151
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
152
+ pad_to_multiple_of (:obj:`int`, `optional`):
153
+ If set will pad the sequence to a multiple of the provided value.
154
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
155
+ 7.5 (Volta).
156
+ """
157
+
158
+ model: FlaxWav2Vec2ForPreTraining
159
+ feature_extractor: Wav2Vec2FeatureExtractor
160
+ padding: Union[bool, str] = "longest"
161
+ pad_to_multiple_of: Optional[int] = None
162
+ max_length: Optional[int] = None
163
+
164
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
165
+ # reformat list to dict and set to pytorch format
166
+ batch = self.feature_extractor.pad(
167
+ features,
168
+ max_length=self.max_length,
169
+ padding=self.padding,
170
+ pad_to_multiple_of=self.pad_to_multiple_of,
171
+ return_tensors="np",
172
+ )
173
+ mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
174
+
175
+ batch_size = batch["input_values"].shape[0]
176
+
177
+ if batch["attention_mask"] is not None:
178
+ output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
179
+ attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
180
+
181
+ # these two operations makes sure that all values
182
+ # before the output lengths indices are attended to
183
+ attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
184
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
185
+
186
+ # sample randomly masked indices
187
+ batch["mask_time_indices"] = _compute_mask_indices(
188
+ (batch_size, mask_indices_seq_length),
189
+ self.model.config.mask_time_prob,
190
+ self.model.config.mask_time_length,
191
+ attention_mask=attention_mask,
192
+ min_masks=2,
193
+ )
194
+
195
+ # sample indices to take for negative vectors
196
+ batch["sampled_negative_indices"] = _sample_negative_indices(
197
+ (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
198
+ self.model.config.num_negatives,
199
+ )
200
+
201
+ return batch
202
+
203
+
204
+ def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
205
+ logging.basicConfig(
206
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
207
+ datefmt="%m/%d/%Y %H:%M:%S",
208
+ handlers=[logging.StreamHandler(sys.stdout)],
209
+ )
210
+ logging_level = logging.WARNING
211
+ if model_args.verbose_logging:
212
+ logging_level = logging.DEBUG
213
+ logger.setLevel(logging_level)
214
+
215
+
216
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
217
+ summary_writer.scalar("train_time", train_time, step)
218
+
219
+ train_metrics = get_metrics(train_metrics)
220
+ for key, vals in train_metrics.items():
221
+ tag = f"train_{key}"
222
+ for i, val in enumerate(vals):
223
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
224
+
225
+
226
+ def write_eval_metric(summary_writer, eval_metrics, step):
227
+ for metric_name, value in eval_metrics.items():
228
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
229
+
230
+
231
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
232
+ num_samples = len(samples_idx)
233
+ samples_to_remove = num_samples % batch_size
234
+
235
+ if samples_to_remove != 0:
236
+ samples_idx = samples_idx[:-samples_to_remove]
237
+ sections_split = num_samples // batch_size
238
+ batch_idx = np.split(samples_idx, sections_split)
239
+ return batch_idx
240
+
241
+
242
+ def compute_contrastive_loss(
243
+ quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives
244
+ ):
245
+ batch_size, sequence_length, hidden_size = quantized_features.shape
246
+
247
+ # take negative vectors from sampled indices
248
+ quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)]
249
+ quantized_negatives = quantized_negatives.reshape(
250
+ batch_size, sequence_length, num_negatives, hidden_size
251
+ ).transpose(2, 0, 1, 3)
252
+
253
+ target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0)
254
+ loss_logits = optax.cosine_similarity(transformer_features, target_features)
255
+ loss_logits = loss_logits / logits_temp
256
+
257
+ neg_is_pos = (quantized_features == quantized_negatives).all(-1)
258
+ neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
259
+
260
+ # make sure incorrectly sampled vectors don't contribute to loss
261
+ loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)
262
+
263
+ predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
264
+ targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
265
+
266
+ target_mask = jnp.where(targets >= 0, 1.0, 0.0)
267
+ contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask
268
+
269
+ contrastive_loss = contrastive_loss.sum()
270
+
271
+ return contrastive_loss
272
+
273
+
274
+ def main():
275
+ # See all possible arguments in src/transformers/training_args.py
276
+ # or by passing the --help flag to this script.
277
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
278
+
279
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
280
+
281
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
282
+ configure_logger(model_args, training_args)
283
+
284
+ # Downloading and loading a dataset from the hub.
285
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
286
+
287
+ if "validation" not in datasets.keys():
288
+ # make sure only "validation" and "train" keys remain"
289
+ datasets = DatasetDict()
290
+ datasets["validation"] = load_dataset(
291
+ data_args.dataset_name,
292
+ data_args.dataset_config_name,
293
+ split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
294
+ cache_dir=model_args.cache_dir,
295
+ )
296
+ datasets["train"] = load_dataset(
297
+ data_args.dataset_name,
298
+ data_args.dataset_config_name,
299
+ split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
300
+ cache_dir=model_args.cache_dir,
301
+ )
302
+ else:
303
+ # make sure only "validation" and "train" keys remain"
304
+ datasets = DatasetDict()
305
+ datasets["validation"] = load_dataset(
306
+ data_args.dataset_name,
307
+ data_args.dataset_config_name,
308
+ split="validation",
309
+ cache_dir=model_args.cache_dir,
310
+ )
311
+ datasets["train"] = load_dataset(
312
+ data_args.dataset_name,
313
+ data_args.dataset_config_name,
314
+ split=f"{data_args.train_split_name}",
315
+ cache_dir=model_args.cache_dir,
316
+ )
317
+
318
+ # only normalized-inputs-training is supported
319
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
320
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
321
+ )
322
+
323
+ def prepare_dataset(batch):
324
+ # check that all files have the correct sampling rate
325
+ batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
326
+ return batch
327
+
328
+ # load audio files into numpy arrays
329
+ vectorized_datasets = datasets.map(
330
+ prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
331
+ )
332
+
333
+ # filter audio files that are too long
334
+ vectorized_datasets = vectorized_datasets.filter(
335
+ lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
336
+ )
337
+
338
+ def normalize(batch):
339
+ return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
340
+
341
+ # normalize and transform to `BatchFeatures`
342
+ vectorized_datasets = vectorized_datasets.map(
343
+ normalize,
344
+ batched=True,
345
+ num_proc=data_args.preprocessing_num_workers,
346
+ load_from_cache_file=not data_args.overwrite_cache,
347
+ remove_columns=vectorized_datasets["train"].column_names,
348
+ )
349
+
350
+ # pretraining is only supported for "newer" stable layer norm architecture
351
+ # apply_spec_augment has to be True, mask_feature_prob has to be 0.0
352
+ config = Wav2Vec2Config.from_pretrained(
353
+ model_args.model_name_or_path,
354
+ cache_dir=model_args.cache_dir,
355
+ gradient_checkpointing=model_args.gradient_checkpointing,
356
+ )
357
+
358
+ if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
359
+ raise ValueError(
360
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
361
+ )
362
+
363
+ model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
364
+
365
+ data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
366
+ model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
367
+ )
368
+
369
+ # Enable tensorboard only on the master node
370
+ has_tensorboard = is_tensorboard_available()
371
+ if has_tensorboard and jax.process_index() == 0:
372
+ try:
373
+ from flax.metrics.tensorboard import SummaryWriter
374
+
375
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
376
+ except ImportError as ie:
377
+ has_tensorboard = False
378
+ logger.warning(
379
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
380
+ )
381
+ else:
382
+ logger.warning(
383
+ "Unable to display metrics through TensorBoard because the package is not installed: "
384
+ "Please run pip install tensorboard to enable."
385
+ )
386
+
387
+ # Initialize our training
388
+ rng = jax.random.PRNGKey(training_args.seed)
389
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
390
+ gumbel_rngs = jax.random.split(rng, jax.local_device_count())
391
+
392
+ num_epochs = int(training_args.num_train_epochs)
393
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
394
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
395
+
396
+ num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs
397
+
398
+ # Create learning rate schedule
399
+ warmup_fn = optax.linear_schedule(
400
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
401
+ )
402
+ decay_fn = optax.linear_schedule(
403
+ init_value=training_args.learning_rate,
404
+ end_value=0,
405
+ transition_steps=num_train_steps - training_args.warmup_steps,
406
+ )
407
+ linear_decay_lr_schedule_fn = optax.join_schedules(
408
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
409
+ )
410
+
411
+ # We use Optax's "masking" functionality to not apply weight decay
412
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
413
+ # mask boolean with the same structure as the parameters.
414
+ # The mask is True for parameters that should be decayed.
415
+ def decay_mask_fn(params):
416
+ flat_params = traverse_util.flatten_dict(params)
417
+ flat_mask = {
418
+ path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
419
+ for path in flat_params
420
+ }
421
+ return traverse_util.unflatten_dict(flat_mask)
422
+
423
+ # create adam optimizer
424
+ adamw = optax.adamw(
425
+ learning_rate=linear_decay_lr_schedule_fn,
426
+ b1=training_args.adam_beta1,
427
+ b2=training_args.adam_beta2,
428
+ eps=training_args.adam_epsilon,
429
+ weight_decay=training_args.weight_decay,
430
+ mask=decay_mask_fn,
431
+ )
432
+
433
+ # Setup train state and define training hyper-parameters
434
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
435
+ num_negatives = model.config.num_negatives
436
+ contrastive_logits_temperature = model.config.contrastive_logits_temperature
437
+ num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
438
+ diversity_loss_weight = model.config.diversity_loss_weight
439
+
440
+ # Define gradient update step fn
441
+ def train_step(state, batch, dropout_rng, gumbel_rng):
442
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
443
+ gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng)
444
+
445
+ def loss_fn(params):
446
+ negative_indices = batch.pop("sampled_negative_indices")
447
+
448
+ gumbel_temperature = jnp.clip(
449
+ model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step,
450
+ a_min=model_args.min_gumbel_temperature,
451
+ )
452
+
453
+ outputs = state.apply_fn(
454
+ **batch,
455
+ gumbel_temperature=gumbel_temperature,
456
+ params=params,
457
+ dropout_rng=dropout_rng,
458
+ gumbel_rng=gumbel_rng,
459
+ train=True,
460
+ )
461
+
462
+ contrastive_loss = compute_contrastive_loss(
463
+ outputs.projected_quantized_states,
464
+ outputs.projected_states,
465
+ negative_indices,
466
+ batch["mask_time_indices"],
467
+ contrastive_logits_temperature,
468
+ num_negatives,
469
+ )
470
+
471
+ diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
472
+ loss = contrastive_loss + diversity_loss_weight * diversity_loss
473
+
474
+ return loss
475
+
476
+ grad_fn = jax.value_and_grad(loss_fn)
477
+ loss, grad = grad_fn(state.params)
478
+ grad = jax.lax.pmean(grad, "batch")
479
+ new_state = state.apply_gradients(grads=grad)
480
+
481
+ metrics = jax.lax.pmean(
482
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
483
+ )
484
+
485
+ return new_state, metrics, new_dropout_rng, new_gumbel_rng
486
+
487
+ # Create parallel version of the train step
488
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
489
+
490
+ # Define eval fn
491
+ def eval_step(params, batch):
492
+ negative_indices = batch.pop("sampled_negative_indices")
493
+
494
+ outputs = model(**batch, params=params, train=False)
495
+
496
+ contrastive_loss = compute_contrastive_loss(
497
+ outputs.projected_quantized_states,
498
+ outputs.projected_states,
499
+ negative_indices,
500
+ batch["mask_time_indices"],
501
+ contrastive_logits_temperature,
502
+ num_negatives,
503
+ )
504
+
505
+ diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
506
+ loss = contrastive_loss + diversity_loss_weight * diversity_loss
507
+
508
+ # summarize metrics
509
+ metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity}
510
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
511
+
512
+ return metrics
513
+
514
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
515
+
516
+ # Replicate the train state on each device
517
+ state = jax_utils.replicate(state)
518
+
519
+ train_time = 0
520
+ train_metrics = []
521
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
522
+ for epoch in epochs:
523
+ # ======================== Training ================================
524
+ train_start = time.time()
525
+
526
+ # Create sampling rng
527
+ rng, input_rng = jax.random.split(rng)
528
+
529
+ # Generate an epoch by shuffling sampling indices from the train dataset
530
+ num_train_samples = len(vectorized_datasets["train"])
531
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
532
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
533
+
534
+ # Gather the indexes for creating the batch and do a training step
535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
536
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
537
+ model_inputs = data_collator(samples)
538
+ model_inputs = shard(model_inputs.data)
539
+
540
+ # Model forward
541
+ state, train_metric, dropout_rngs, gumbel_rngs = p_train_step(
542
+ state, model_inputs, dropout_rngs, gumbel_rngs
543
+ )
544
+ train_metrics.append(train_metric)
545
+
546
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
547
+
548
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
549
+ # Save metrics
550
+ train_metric = jax_utils.unreplicate(train_metric)
551
+ train_time += time.time() - train_start
552
+ if has_tensorboard and jax.process_index() == 0:
553
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
554
+
555
+ epochs.write(
556
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
557
+ )
558
+
559
+ train_metrics = []
560
+
561
+ # ======================== Evaluating ==============================
562
+ num_eval_samples = len(vectorized_datasets["validation"])
563
+ eval_samples_idx = jnp.arange(num_eval_samples)
564
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
565
+
566
+ eval_metrics = []
567
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
568
+ samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx]
569
+ model_inputs = data_collator(samples)
570
+
571
+ # Model forward
572
+ model_inputs = shard(model_inputs.data)
573
+ metrics = p_eval_step(state.params, model_inputs)
574
+ eval_metrics.append(metrics)
575
+
576
+ # get eval metrics
577
+ eval_metrics = get_metrics(eval_metrics)
578
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
579
+
580
+ # Update progress bar
581
+ epochs.write(
582
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
583
+ )
584
+
585
+ # Save metrics
586
+ if has_tensorboard and jax.process_index() == 0:
587
+ cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
588
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
589
+
590
+ # save checkpoint after each epoch and push checkpoint to the hub
591
+ if jax.process_index() == 0:
592
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
593
+ model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
594
+
595
+
596
+ if __name__ == "__main__":
597
+ main()
train.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ ./preprocess_dataset.py \
3
+ --output_dir="./output" \
4
+ --num_train_epochs="5" \
5
+ --per_device_train_batch_size="32" \
6
+ --per_device_eval_batch_size="32" \
7
+ --learning_rate="5e-4" \
8
+ --weight_decay="0.01" \
9
+ --warmup_steps="2000" \
10
+ --model_name_or_path="./" \
11
+ --dataset_name="common_voice" \
12
+ --dataset_config_name="es" \
13
+ --preprocessing_num_workers="64" \
14
+ --max_duration_in_seconds="10.0" \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --pad_to_multiple_of="16384" \
18
+ --validation_split_percentage="5" \
19
+ --speech_file_column="path" \
20
+ --dtype="bfloat16" \
21
+ --cache_dir="./data_cache" \
22
+ --push_to_hub
train_dummy.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ ./run_wav2vec2_pretrain_flax.py \
3
+ --output_dir="./dummy" \
4
+ --num_train_epochs="1" \
5
+ --per_device_train_batch_size="4" \
6
+ --per_device_eval_batch_size="4" \
7
+ --learning_rate="5e-4" \
8
+ --weight_decay="0.01" \
9
+ --warmup_steps="200" \
10
+ --model_name_or_path="./" \
11
+ --dataset_name="common_voice" \
12
+ --dataset_config_name="cnh" \
13
+ --preprocessing_num_workers="96" \
14
+ --max_duration_in_seconds="20.0" \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --pad_to_multiple_of="16384" \
18
+ --validation_split_percentage="50" \
19
+ --speech_file_column="path" \
20
+ --dtype="bfloat16" \
21
+ --cache_dir="./data_cache_dummy/" \
22
+ --push_to_hub